Browse Source

Change to poll-based model.

Methods to connect, request credit and so on won't block until the other
end responds to them, but rather will return immediately. The client
should then call poll_recv, which will return one of several possible
events either for data being received on the connection, or connection,
disconnection.

This also lays the groundwork for handling more than one connection at
once.
Andrew Walbran 2 years ago
parent
commit
d8164d6e2b
1 changed files with 106 additions and 119 deletions
  1. 106 119
      src/device/socket/vsock.rs

+ 106 - 119
src/device/socket/vsock.rs

@@ -9,8 +9,8 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
 use crate::Result;
+use core::mem::size_of;
 use core::ptr::NonNull;
-use core::{convert::TryFrom, mem::size_of};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 
@@ -195,6 +195,10 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
 
     /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
     pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
         if self.connection_info.is_some() {
             return Err(SocketError::ConnectionExists.into());
@@ -216,31 +220,26 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, &[])?;
 
         self.connection_info = Some(new_connection_info);
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Response], &mut [], |header| {
-            header.check_data_is_empty().map_err(|e| e.into())
-        })?;
-        debug!("Connection established: {:?}", self.connection_info);
+        debug!("Connection requested: {:?}", self.connection_info);
         Ok(())
     }
 
-    /// Requests the credit and updates the peer credit in the current connection info.
+    /// Requests the peer to send us a credit update for the current connection.
     fn request_credit(&mut self) -> Result {
         let connection_info = self.connection_info()?;
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditRequest.into(),
             ..connection_info.new_header(self.guest_cid)
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::CreditUpdate], &mut [], |_| {
-            Ok(())
-        })
+        self.send_packet_to_tx_queue(&header, &[])
     }
 
     /// Sends the buffer to the destination.
     pub fn send(&mut self, buffer: &[u8]) -> Result {
-        self.check_peer_buffer_is_sufficient(buffer.len())?;
-
         let connection_info = self.connection_info()?;
+
+        self.check_peer_buffer_is_sufficient(&connection_info, buffer.len())?;
+
         let len = buffer.len() as u32;
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rw.into(),
@@ -252,23 +251,17 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, buffer)
     }
 
-    fn check_peer_buffer_is_sufficient(&mut self, buffer_len: usize) -> Result {
-        if usize::try_from(self.connection_info()?.peer_free())
-            .map_err(|_| SocketError::InvalidNumber)?
-            >= buffer_len
-        {
+    fn check_peer_buffer_is_sufficient(
+        &mut self,
+        connection_info: &ConnectionInfo,
+        buffer_len: usize,
+    ) -> Result {
+        if connection_info.peer_free() as usize >= buffer_len {
             Ok(())
         } else {
-            // Update cached peer credit and try again.
+            // Request an update of the cached peer credit, and tell the caller to try again later.
             self.request_credit()?;
-            if usize::try_from(self.connection_info()?.peer_free())
-                .map_err(|_| SocketError::InvalidNumber)?
-                >= buffer_len
-            {
-                Ok(())
-            } else {
-                Err(SocketError::InsufficientBufferSpaceInPeer.into())
-            }
+            Err(SocketError::InsufficientBufferSpaceInPeer.into())
         }
     }
 
@@ -276,7 +269,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     ///
     /// A buffer must be provided to put the data in if there is some to
     /// receive.
-    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
         let connection_info = self.connection_info()?;
 
         // Tell the peer that we have space to receive some data.
@@ -287,44 +280,31 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         };
         self.send_packet_to_tx_queue(&header, &[])?;
 
-        // Wait to receive some data.
-        let mut length: u32 = 0;
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Rw], buffer, |header| {
-            length = header.len();
-            Ok(())
-        })?;
-        self.connection_info.as_mut().unwrap().fwd_cnt += length;
-
-        Ok(VsockEvent {
-            source: connection_info.dst,
-            destination: VsockAddr {
-                cid: self.guest_cid,
-                port: connection_info.src_port,
-            },
-            event_type: VsockEventType::Received {
-                length: length as usize,
-            },
-        })
+        // Handle entries from the RX virtqueue until we find one that generates an event.
+        let event = self.poll_rx_queue(buffer)?;
+
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
+        }
+
+        Ok(event)
     }
 
-    /// Shuts down the connection.
+    /// Request to shut down the connection.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
     pub fn shutdown(&mut self) -> Result {
         let connection_info = self.connection_info()?;
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Shutdown.into(),
             ..connection_info.new_header(self.guest_cid)
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
-        self.poll_and_filter_packet_from_rx_queue(
-            &[VirtioVsockOp::Rst, VirtioVsockOp::Shutdown],
-            &mut [],
-            |_| Ok(()),
-        )?;
-        Ok(())
+        self.send_packet_to_tx_queue(&header, &[])
     }
 
     fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
-        // TODO: Virtio v1.1 5.10.6.1.1 The rx virtqueue MUST be processed even when the tx virtqueue is full.
         let _len = self.tx.add_notify_wait_pop(
             &[header.as_bytes(), buffer],
             &mut [],
@@ -333,26 +313,23 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    fn poll_and_filter_packet_from_rx_queue<F>(
-        &mut self,
-        accepted_ops: &[VirtioVsockOp],
-        body: &mut [u8],
-        f: F,
-    ) -> Result
-    where
-        F: FnOnce(&VirtioVsockHdr) -> Result,
-    {
-        let our_cid = self.guest_cid;
-        let mut result = Ok(());
+    /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
+    /// which generates a `VsockEvent`.
+    ///
+    /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
+    /// don't result in any events to return.
+    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
         loop {
-            self.wait_one_in_rx_queue();
             let mut connection_info = self.connection_info.clone().unwrap_or_default();
-            let header = self.pop_packet_from_rx_queue(body)?;
+            let Some(header) = self.pop_packet_from_rx_queue(body)? else{
+                return Ok(None);
+            };
+
             let op = header.op()?;
 
             // Skip packets which don't match our current connection.
             if header.source() != connection_info.dst
-                || header.dst_cid.get() != our_cid
+                || header.dst_cid.get() != self.guest_cid
                 || header.dst_port.get() != connection_info.src_port
             {
                 debug!(
@@ -362,65 +339,75 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
                 continue;
             }
 
+            connection_info.peer_buf_alloc = header.buf_alloc.into();
+            connection_info.peer_fwd_cnt = header.fwd_cnt.into();
+            if self.connection_info.is_some() {
+                self.connection_info = Some(connection_info.clone());
+                debug!("Connection info updated: {:?}", self.connection_info);
+            }
+
             match op {
+                VirtioVsockOp::Request => {
+                    // TODO: Send a Rst, or support listening.
+                }
+                VirtioVsockOp::Response => {
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Connected,
+                    }));
+                }
                 VirtioVsockOp::CreditUpdate => {
                     header.check_data_is_empty()?;
 
-                    connection_info.peer_buf_alloc = header.buf_alloc.into();
-                    connection_info.peer_fwd_cnt = header.fwd_cnt.into();
-                    self.connection_info.replace(connection_info);
-                    debug!("Connection info updated: {:?}", self.connection_info);
-
-                    if accepted_ops.contains(&op) {
-                        break;
-                    } else {
-                        // Virtio v1.1 5.10.6.3
-                        // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
-                        // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
-                        // any time a change in buffer space occurs.
-                        continue;
-                    }
+                    // Virtio v1.1 5.10.6.3
+                    // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
+                    // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
+                    // any time a change in buffer space occurs.
+                    continue;
                 }
                 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
                     header.check_data_is_empty()?;
 
-                    self.connection_info.take();
+                    self.connection_info = None;
                     info!("Disconnected from the peer");
-                    if accepted_ops.contains(&op) {
-                    } else if op == VirtioVsockOp::Rst {
-                        result = Err(SocketError::ConnectionFailed.into());
+
+                    let reason = if op == VirtioVsockOp::Rst {
+                        DisconnectReason::Reset
                     } else {
-                        assert_eq!(VirtioVsockOp::Shutdown, op);
-                        result = Err(SocketError::PeerSocketShutdown.into());
-                    }
-                    break;
+                        DisconnectReason::Shutdown
+                    };
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Disconnected { reason },
+                    }));
                 }
-                // TODO: Update peer_buf_alloc and peer_fwd_cnt for other packets too.
-                x if accepted_ops.contains(&x) => {
-                    f(&header)?;
-                    break;
+                VirtioVsockOp::Rw => {
+                    self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Received {
+                            length: header.len() as usize,
+                        },
+                    }));
                 }
-                _ => {
-                    result = Err(SocketError::InvalidOperation.into());
-                    break;
+                VirtioVsockOp::CreditRequest => {
+                    // TODO: Send a credit update.
+                }
+                VirtioVsockOp::Invalid => {
+                    return Err(SocketError::InvalidOperation.into());
                 }
-            };
-        }
-
-        if self.rx.should_notify() {
-            self.transport.notify(RX_QUEUE_IDX);
-        }
-        result
-    }
-
-    /// Waits until there is at least one element to pop in rx queue.
-    fn wait_one_in_rx_queue(&mut self) {
-        const TIMEOUT_N: usize = 1000000;
-        for _ in 0..TIMEOUT_N {
-            if self.rx.can_pop() {
-                break;
-            } else {
-                core::hint::spin_loop();
             }
         }
     }
@@ -428,11 +415,11 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
     /// the body into the given buffer.
     ///
-    /// Returns an error if there is no pending packet, or the body is bigger than the buffer
-    /// supplied.
-    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<VirtioVsockHdr> {
+    /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
+    /// buffer supplied.
+    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
         let Some(token) = self.rx.peek_used() else {
-            return Err(SocketError::NoResponseReceived.into());
+            return Ok(None);
         };
 
         // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
@@ -456,7 +443,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         }?;
 
         debug!("Received packet {:?}. Op {:?}", header, header.op());
-        Ok(header)
+        Ok(Some(header))
     }
 
     fn connection_info(&self) -> Result<ConnectionInfo> {