Browse Source

Change to poll-based model.

Methods to connect, request credit and so on won't block until the other
end responds to them, but rather will return immediately. The client
should then call poll_recv, which will return one of several possible
events either for data being received on the connection, or connection,
disconnection.

This also lays the groundwork for handling more than one connection at
once.
Andrew Walbran 2 years ago
parent
commit
d8164d6e2b
1 changed files with 106 additions and 119 deletions
  1. 106 119
      src/device/socket/vsock.rs

+ 106 - 119
src/device/socket/vsock.rs

@@ -9,8 +9,8 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::volread;
 use crate::volatile::volread;
 use crate::Result;
 use crate::Result;
+use core::mem::size_of;
 use core::ptr::NonNull;
 use core::ptr::NonNull;
-use core::{convert::TryFrom, mem::size_of};
 use log::{debug, info};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 use zerocopy::{AsBytes, FromBytes};
 
 
@@ -195,6 +195,10 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
     }
 
 
     /// Sends a request to connect to the given destination.
     /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
     pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
     pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
         if self.connection_info.is_some() {
         if self.connection_info.is_some() {
             return Err(SocketError::ConnectionExists.into());
             return Err(SocketError::ConnectionExists.into());
@@ -216,31 +220,26 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, &[])?;
         self.send_packet_to_tx_queue(&header, &[])?;
 
 
         self.connection_info = Some(new_connection_info);
         self.connection_info = Some(new_connection_info);
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Response], &mut [], |header| {
-            header.check_data_is_empty().map_err(|e| e.into())
-        })?;
-        debug!("Connection established: {:?}", self.connection_info);
+        debug!("Connection requested: {:?}", self.connection_info);
         Ok(())
         Ok(())
     }
     }
 
 
-    /// Requests the credit and updates the peer credit in the current connection info.
+    /// Requests the peer to send us a credit update for the current connection.
     fn request_credit(&mut self) -> Result {
     fn request_credit(&mut self) -> Result {
         let connection_info = self.connection_info()?;
         let connection_info = self.connection_info()?;
         let header = VirtioVsockHdr {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditRequest.into(),
             op: VirtioVsockOp::CreditRequest.into(),
             ..connection_info.new_header(self.guest_cid)
             ..connection_info.new_header(self.guest_cid)
         };
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::CreditUpdate], &mut [], |_| {
-            Ok(())
-        })
+        self.send_packet_to_tx_queue(&header, &[])
     }
     }
 
 
     /// Sends the buffer to the destination.
     /// Sends the buffer to the destination.
     pub fn send(&mut self, buffer: &[u8]) -> Result {
     pub fn send(&mut self, buffer: &[u8]) -> Result {
-        self.check_peer_buffer_is_sufficient(buffer.len())?;
-
         let connection_info = self.connection_info()?;
         let connection_info = self.connection_info()?;
+
+        self.check_peer_buffer_is_sufficient(&connection_info, buffer.len())?;
+
         let len = buffer.len() as u32;
         let len = buffer.len() as u32;
         let header = VirtioVsockHdr {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rw.into(),
             op: VirtioVsockOp::Rw.into(),
@@ -252,23 +251,17 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, buffer)
         self.send_packet_to_tx_queue(&header, buffer)
     }
     }
 
 
-    fn check_peer_buffer_is_sufficient(&mut self, buffer_len: usize) -> Result {
-        if usize::try_from(self.connection_info()?.peer_free())
-            .map_err(|_| SocketError::InvalidNumber)?
-            >= buffer_len
-        {
+    fn check_peer_buffer_is_sufficient(
+        &mut self,
+        connection_info: &ConnectionInfo,
+        buffer_len: usize,
+    ) -> Result {
+        if connection_info.peer_free() as usize >= buffer_len {
             Ok(())
             Ok(())
         } else {
         } else {
-            // Update cached peer credit and try again.
+            // Request an update of the cached peer credit, and tell the caller to try again later.
             self.request_credit()?;
             self.request_credit()?;
-            if usize::try_from(self.connection_info()?.peer_free())
-                .map_err(|_| SocketError::InvalidNumber)?
-                >= buffer_len
-            {
-                Ok(())
-            } else {
-                Err(SocketError::InsufficientBufferSpaceInPeer.into())
-            }
+            Err(SocketError::InsufficientBufferSpaceInPeer.into())
         }
         }
     }
     }
 
 
@@ -276,7 +269,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     ///
     ///
     /// A buffer must be provided to put the data in if there is some to
     /// A buffer must be provided to put the data in if there is some to
     /// receive.
     /// receive.
-    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
         let connection_info = self.connection_info()?;
         let connection_info = self.connection_info()?;
 
 
         // Tell the peer that we have space to receive some data.
         // Tell the peer that we have space to receive some data.
@@ -287,44 +280,31 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         };
         };
         self.send_packet_to_tx_queue(&header, &[])?;
         self.send_packet_to_tx_queue(&header, &[])?;
 
 
-        // Wait to receive some data.
-        let mut length: u32 = 0;
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Rw], buffer, |header| {
-            length = header.len();
-            Ok(())
-        })?;
-        self.connection_info.as_mut().unwrap().fwd_cnt += length;
-
-        Ok(VsockEvent {
-            source: connection_info.dst,
-            destination: VsockAddr {
-                cid: self.guest_cid,
-                port: connection_info.src_port,
-            },
-            event_type: VsockEventType::Received {
-                length: length as usize,
-            },
-        })
+        // Handle entries from the RX virtqueue until we find one that generates an event.
+        let event = self.poll_rx_queue(buffer)?;
+
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
+        }
+
+        Ok(event)
     }
     }
 
 
-    /// Shuts down the connection.
+    /// Request to shut down the connection.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
     pub fn shutdown(&mut self) -> Result {
     pub fn shutdown(&mut self) -> Result {
         let connection_info = self.connection_info()?;
         let connection_info = self.connection_info()?;
         let header = VirtioVsockHdr {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Shutdown.into(),
             op: VirtioVsockOp::Shutdown.into(),
             ..connection_info.new_header(self.guest_cid)
             ..connection_info.new_header(self.guest_cid)
         };
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
-        self.poll_and_filter_packet_from_rx_queue(
-            &[VirtioVsockOp::Rst, VirtioVsockOp::Shutdown],
-            &mut [],
-            |_| Ok(()),
-        )?;
-        Ok(())
+        self.send_packet_to_tx_queue(&header, &[])
     }
     }
 
 
     fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
     fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
-        // TODO: Virtio v1.1 5.10.6.1.1 The rx virtqueue MUST be processed even when the tx virtqueue is full.
         let _len = self.tx.add_notify_wait_pop(
         let _len = self.tx.add_notify_wait_pop(
             &[header.as_bytes(), buffer],
             &[header.as_bytes(), buffer],
             &mut [],
             &mut [],
@@ -333,26 +313,23 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
         Ok(())
     }
     }
 
 
-    fn poll_and_filter_packet_from_rx_queue<F>(
-        &mut self,
-        accepted_ops: &[VirtioVsockOp],
-        body: &mut [u8],
-        f: F,
-    ) -> Result
-    where
-        F: FnOnce(&VirtioVsockHdr) -> Result,
-    {
-        let our_cid = self.guest_cid;
-        let mut result = Ok(());
+    /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
+    /// which generates a `VsockEvent`.
+    ///
+    /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
+    /// don't result in any events to return.
+    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
         loop {
         loop {
-            self.wait_one_in_rx_queue();
             let mut connection_info = self.connection_info.clone().unwrap_or_default();
             let mut connection_info = self.connection_info.clone().unwrap_or_default();
-            let header = self.pop_packet_from_rx_queue(body)?;
+            let Some(header) = self.pop_packet_from_rx_queue(body)? else{
+                return Ok(None);
+            };
+
             let op = header.op()?;
             let op = header.op()?;
 
 
             // Skip packets which don't match our current connection.
             // Skip packets which don't match our current connection.
             if header.source() != connection_info.dst
             if header.source() != connection_info.dst
-                || header.dst_cid.get() != our_cid
+                || header.dst_cid.get() != self.guest_cid
                 || header.dst_port.get() != connection_info.src_port
                 || header.dst_port.get() != connection_info.src_port
             {
             {
                 debug!(
                 debug!(
@@ -362,65 +339,75 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
                 continue;
                 continue;
             }
             }
 
 
+            connection_info.peer_buf_alloc = header.buf_alloc.into();
+            connection_info.peer_fwd_cnt = header.fwd_cnt.into();
+            if self.connection_info.is_some() {
+                self.connection_info = Some(connection_info.clone());
+                debug!("Connection info updated: {:?}", self.connection_info);
+            }
+
             match op {
             match op {
+                VirtioVsockOp::Request => {
+                    // TODO: Send a Rst, or support listening.
+                }
+                VirtioVsockOp::Response => {
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Connected,
+                    }));
+                }
                 VirtioVsockOp::CreditUpdate => {
                 VirtioVsockOp::CreditUpdate => {
                     header.check_data_is_empty()?;
                     header.check_data_is_empty()?;
 
 
-                    connection_info.peer_buf_alloc = header.buf_alloc.into();
-                    connection_info.peer_fwd_cnt = header.fwd_cnt.into();
-                    self.connection_info.replace(connection_info);
-                    debug!("Connection info updated: {:?}", self.connection_info);
-
-                    if accepted_ops.contains(&op) {
-                        break;
-                    } else {
-                        // Virtio v1.1 5.10.6.3
-                        // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
-                        // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
-                        // any time a change in buffer space occurs.
-                        continue;
-                    }
+                    // Virtio v1.1 5.10.6.3
+                    // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
+                    // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
+                    // any time a change in buffer space occurs.
+                    continue;
                 }
                 }
                 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
                 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
                     header.check_data_is_empty()?;
                     header.check_data_is_empty()?;
 
 
-                    self.connection_info.take();
+                    self.connection_info = None;
                     info!("Disconnected from the peer");
                     info!("Disconnected from the peer");
-                    if accepted_ops.contains(&op) {
-                    } else if op == VirtioVsockOp::Rst {
-                        result = Err(SocketError::ConnectionFailed.into());
+
+                    let reason = if op == VirtioVsockOp::Rst {
+                        DisconnectReason::Reset
                     } else {
                     } else {
-                        assert_eq!(VirtioVsockOp::Shutdown, op);
-                        result = Err(SocketError::PeerSocketShutdown.into());
-                    }
-                    break;
+                        DisconnectReason::Shutdown
+                    };
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Disconnected { reason },
+                    }));
                 }
                 }
-                // TODO: Update peer_buf_alloc and peer_fwd_cnt for other packets too.
-                x if accepted_ops.contains(&x) => {
-                    f(&header)?;
-                    break;
+                VirtioVsockOp::Rw => {
+                    self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Received {
+                            length: header.len() as usize,
+                        },
+                    }));
                 }
                 }
-                _ => {
-                    result = Err(SocketError::InvalidOperation.into());
-                    break;
+                VirtioVsockOp::CreditRequest => {
+                    // TODO: Send a credit update.
+                }
+                VirtioVsockOp::Invalid => {
+                    return Err(SocketError::InvalidOperation.into());
                 }
                 }
-            };
-        }
-
-        if self.rx.should_notify() {
-            self.transport.notify(RX_QUEUE_IDX);
-        }
-        result
-    }
-
-    /// Waits until there is at least one element to pop in rx queue.
-    fn wait_one_in_rx_queue(&mut self) {
-        const TIMEOUT_N: usize = 1000000;
-        for _ in 0..TIMEOUT_N {
-            if self.rx.can_pop() {
-                break;
-            } else {
-                core::hint::spin_loop();
             }
             }
         }
         }
     }
     }
@@ -428,11 +415,11 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
     /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
     /// the body into the given buffer.
     /// the body into the given buffer.
     ///
     ///
-    /// Returns an error if there is no pending packet, or the body is bigger than the buffer
-    /// supplied.
-    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<VirtioVsockHdr> {
+    /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
+    /// buffer supplied.
+    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
         let Some(token) = self.rx.peek_used() else {
         let Some(token) = self.rx.peek_used() else {
-            return Err(SocketError::NoResponseReceived.into());
+            return Ok(None);
         };
         };
 
 
         // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
         // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
@@ -456,7 +443,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         }?;
         }?;
 
 
         debug!("Received packet {:?}. Op {:?}", header, header.op());
         debug!("Received packet {:?}. Op {:?}", header, header.op());
-        Ok(header)
+        Ok(Some(header))
     }
     }
 
 
     fn connection_info(&self) -> Result<ConnectionInfo> {
     fn connection_info(&self) -> Result<ConnectionInfo> {