Browse Source

Merge pull request #73 from rcore-os/vsock_poll

Change vsock driver to poll model
Andrew Walbran 2 years ago
parent
commit
8bd86eb19d
3 changed files with 229 additions and 135 deletions
  1. 13 4
      examples/aarch64/src/main.rs
  2. 1 1
      src/device/socket/mod.rs
  3. 215 130
      src/device/socket/vsock.rs

+ 13 - 4
examples/aarch64/src/main.rs

@@ -23,7 +23,12 @@ use hal::HalImpl;
 use log::{debug, error, info, trace, warn, LevelFilter};
 use psci::system_off;
 use virtio_drivers::{
-    device::{blk::VirtIOBlk, console::VirtIOConsole, gpu::VirtIOGpu, socket::VirtIOSocket},
+    device::{
+        blk::VirtIOBlk,
+        console::VirtIOConsole,
+        gpu::VirtIOGpu,
+        socket::{VirtIOSocket, VsockEventType},
+    },
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
         pci::{
@@ -189,18 +194,22 @@ fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
     let port = 1221;
     info!("Connecting to host on port {port}...");
     socket.connect(host_cid, port, port)?;
+    socket.wait_for_connect()?;
     info!("Connected to the host");
 
     const EXCHANGE_NUM: usize = 2;
     let messages = ["0-Ack. Hello from guest.", "1-Ack. Received again."];
     for k in 0..EXCHANGE_NUM {
         let mut buffer = [0u8; 24];
-        let len = socket.recv(&mut buffer)?;
+        let socket_event = socket.wait_for_recv(&mut buffer)?;
+        let VsockEventType::Received {length, ..} = socket_event.event_type else {
+            panic!("Received unexpected socket event {:?}", socket_event);
+        };
         info!(
             "Received message: {:?}({:?}), len: {:?}",
             buffer,
-            core::str::from_utf8(&buffer[..len]),
-            len
+            core::str::from_utf8(&buffer[..length]),
+            length
         );
 
         let message = messages[k % messages.len()];

+ 1 - 1
src/device/socket/mod.rs

@@ -5,4 +5,4 @@ mod protocol;
 mod vsock;
 
 pub use error::SocketError;
-pub use vsock::VirtIOSocket;
+pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};

+ 215 - 130
src/device/socket/vsock.rs

@@ -9,8 +9,9 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
 use crate::Result;
+use core::hint::spin_loop;
+use core::mem::size_of;
 use core::ptr::NonNull;
-use core::{convert::TryFrom, mem::size_of};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 
@@ -34,6 +35,11 @@ struct ConnectionInfo {
     tx_cnt: u32,
     /// The number of bytes of packet bodies which we have received from the peer and handled.
     fwd_cnt: u32,
+    /// Whether we have recently requested credit from the peer.
+    ///
+    /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
+    /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
+    has_pending_credit_request: bool,
 }
 
 impl ConnectionInfo {
@@ -53,6 +59,44 @@ impl ConnectionInfo {
     }
 }
 
+/// An event received from a VirtIO socket device.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockEvent {
+    /// The source of the event, i.e. the peer who sent it.
+    pub source: VsockAddr,
+    /// The destination of the event, i.e. the CID and port on our side.
+    pub destination: VsockAddr,
+    /// The type of event.
+    pub event_type: VsockEventType,
+}
+
+/// The reason why a vsock connection was closed.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum DisconnectReason {
+    /// The peer has either closed the connection in response to our shutdown request, or forcibly
+    /// closed it of its own accord.
+    Reset,
+    /// The peer asked to shut down the connection.
+    Shutdown,
+}
+
+/// Details of the type of an event received from a VirtIO socket.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum VsockEventType {
+    /// The connection was successfully established.
+    Connected,
+    /// The connection was closed.
+    Disconnected {
+        /// The reason for the disconnection.
+        reason: DisconnectReason,
+    },
+    /// Data was received on the connection.
+    Received {
+        /// The length of the data in bytes.
+        length: usize,
+    },
+}
+
 /// Driver for a VirtIO socket device.
 pub struct VirtIOSocket<H: Hal, T: Transport> {
     transport: T,
@@ -156,61 +200,65 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    /// Connects to the destination.
+    /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
     pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
         if self.connection_info.is_some() {
             return Err(SocketError::ConnectionExists.into());
         }
+        let new_connection_info = ConnectionInfo {
+            dst: VsockAddr {
+                cid: dst_cid,
+                port: dst_port,
+            },
+            src_port,
+            ..Default::default()
+        };
         let header = VirtioVsockHdr {
-            src_cid: self.guest_cid.into(),
-            dst_cid: dst_cid.into(),
-            src_port: src_port.into(),
-            dst_port: dst_port.into(),
             op: VirtioVsockOp::Request.into(),
-            ..Default::default()
+            ..new_connection_info.new_header(self.guest_cid)
         };
         // Sends a header only packet to the tx queue to connect the device to the listening
         // socket at the given destination.
         self.send_packet_to_tx_queue(&header, &[])?;
 
-        let dst = VsockAddr {
-            cid: dst_cid,
-            port: dst_port,
-        };
-        self.connection_info.replace(ConnectionInfo {
-            dst,
-            src_port,
-            ..Default::default()
-        });
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Response], &mut [], |header| {
-            header.check_data_is_empty().map_err(|e| e.into())
-        })?;
-        debug!("Connection established: {:?}", self.connection_info);
+        self.connection_info = Some(new_connection_info);
+        debug!("Connection requested: {:?}", self.connection_info);
         Ok(())
     }
 
-    /// Requests the credit and updates the peer credit in the current connection info.
+    /// Blocks until the peer either accepts our connection request (with a
+    /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
+    /// `VIRTIO_VSOCK_OP_RST`).
+    pub fn wait_for_connect(&mut self) -> Result {
+        match self.wait_for_recv(&mut [])?.event_type {
+            VsockEventType::Connected => Ok(()),
+            VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
+            VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()),
+        }
+    }
+
+    /// Requests the peer to send us a credit update for the current connection.
     fn request_credit(&mut self) -> Result {
         let connection_info = self.connection_info()?;
         let header = VirtioVsockHdr {
-            src_cid: self.guest_cid.into(),
-            dst_cid: connection_info.dst.cid.into(),
-            src_port: connection_info.src_port.into(),
-            dst_port: connection_info.dst.port.into(),
             op: VirtioVsockOp::CreditRequest.into(),
-            ..Default::default()
+            ..connection_info.new_header(self.guest_cid)
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::CreditUpdate], &mut [], |_| {
-            Ok(())
-        })
+        self.send_packet_to_tx_queue(&header, &[])
     }
 
     /// Sends the buffer to the destination.
     pub fn send(&mut self, buffer: &[u8]) -> Result {
-        self.check_peer_buffer_is_sufficient(buffer.len())?;
+        let mut connection_info = self.connection_info()?;
+
+        let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len());
+        self.connection_info = Some(connection_info.clone());
+        result?;
 
-        let connection_info = self.connection_info()?;
         let len = buffer.len() as u32;
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rw.into(),
@@ -222,32 +270,32 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, buffer)
     }
 
-    fn check_peer_buffer_is_sufficient(&mut self, buffer_len: usize) -> Result {
-        if usize::try_from(self.connection_info()?.peer_free())
-            .map_err(|_| SocketError::InvalidNumber)?
-            >= buffer_len
-        {
+    fn check_peer_buffer_is_sufficient(
+        &mut self,
+        connection_info: &mut ConnectionInfo,
+        buffer_len: usize,
+    ) -> Result {
+        if connection_info.peer_free() as usize >= buffer_len {
             Ok(())
         } else {
-            // Update cached peer credit and try again.
-            self.request_credit()?;
-            if usize::try_from(self.connection_info()?.peer_free())
-                .map_err(|_| SocketError::InvalidNumber)?
-                >= buffer_len
-            {
-                Ok(())
-            } else {
-                Err(SocketError::InsufficientBufferSpaceInPeer.into())
+            // Request an update of the cached peer credit, if we haven't already done so, and tell
+            // the caller to try again later.
+            if !connection_info.has_pending_credit_request {
+                self.request_credit()?;
+                connection_info.has_pending_credit_request = true;
             }
+            Err(SocketError::InsufficientBufferSpaceInPeer.into())
         }
     }
 
-    /// Receives the buffer from the destination.
-    /// Returns the actual size of the message.
-    pub fn recv(&mut self, buffer: &mut [u8]) -> Result<usize> {
+    /// Polls the vsock device to receive data or other updates.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
         let connection_info = self.connection_info()?;
 
-        // Tell the peer that we have space to recieve some data.
+        // Tell the peer that we have space to receive some data.
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditUpdate.into(),
             buf_alloc: (buffer.len() as u32).into(),
@@ -255,34 +303,57 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         };
         self.send_packet_to_tx_queue(&header, &[])?;
 
-        // Wait to receive some data.
-        let mut len: u32 = 0;
-        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Rw], buffer, |header| {
-            len = header.len();
-            Ok(())
-        })?;
-        self.connection_info.as_mut().unwrap().fwd_cnt += len;
-        Ok(len as usize)
+        // Handle entries from the RX virtqueue until we find one that generates an event.
+        let event = self.poll_rx_queue(buffer)?;
+
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
+        }
+
+        Ok(event)
+    }
+
+    /// Blocks until we get some event from the vsock device.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+        loop {
+            if let Some(event) = self.poll_recv(buffer)? {
+                return Ok(event);
+            } else {
+                spin_loop();
+            }
+        }
     }
 
-    /// Shuts down the connection.
+    /// Request to shut down the connection cleanly.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
     pub fn shutdown(&mut self) -> Result {
         let connection_info = self.connection_info()?;
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Shutdown.into(),
             ..connection_info.new_header(self.guest_cid)
         };
+        self.send_packet_to_tx_queue(&header, &[])
+    }
+
+    /// Forcibly closes the connection without waiting for the peer.
+    pub fn force_close(&mut self) -> Result {
+        let connection_info = self.connection_info()?;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Rst.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
         self.send_packet_to_tx_queue(&header, &[])?;
-        self.poll_and_filter_packet_from_rx_queue(
-            &[VirtioVsockOp::Rst, VirtioVsockOp::Shutdown],
-            &mut [],
-            |_| Ok(()),
-        )?;
+        self.connection_info = None;
         Ok(())
     }
 
     fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
-        // TODO: Virtio v1.1 5.10.6.1.1 The rx virtqueue MUST be processed even when the tx virtqueue is full.
         let _len = self.tx.add_notify_wait_pop(
             &[header.as_bytes(), buffer],
             &mut [],
@@ -291,26 +362,23 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    fn poll_and_filter_packet_from_rx_queue<F>(
-        &mut self,
-        accepted_ops: &[VirtioVsockOp],
-        body: &mut [u8],
-        f: F,
-    ) -> Result
-    where
-        F: FnOnce(&VirtioVsockHdr) -> Result,
-    {
-        let our_cid = self.guest_cid;
-        let mut result = Ok(());
+    /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
+    /// which generates a `VsockEvent`.
+    ///
+    /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
+    /// don't result in any events to return.
+    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
         loop {
-            self.wait_one_in_rx_queue();
             let mut connection_info = self.connection_info.clone().unwrap_or_default();
-            let header = self.pop_packet_from_rx_queue(body)?;
+            let Some(header) = self.pop_packet_from_rx_queue(body)? else{
+                return Ok(None);
+            };
+
             let op = header.op()?;
 
             // Skip packets which don't match our current connection.
             if header.source() != connection_info.dst
-                || header.dst_cid.get() != our_cid
+                || header.dst_cid.get() != self.guest_cid
                 || header.dst_port.get() != connection_info.src_port
             {
                 debug!(
@@ -320,65 +388,82 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
                 continue;
             }
 
+            connection_info.peer_buf_alloc = header.buf_alloc.into();
+            connection_info.peer_fwd_cnt = header.fwd_cnt.into();
+            if self.connection_info.is_some() {
+                self.connection_info = Some(connection_info.clone());
+                debug!("Connection info updated: {:?}", self.connection_info);
+            }
+
             match op {
+                VirtioVsockOp::Request => {
+                    header.check_data_is_empty()?;
+                    // TODO: Send a Rst, or support listening.
+                }
+                VirtioVsockOp::Response => {
+                    header.check_data_is_empty()?;
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Connected,
+                    }));
+                }
                 VirtioVsockOp::CreditUpdate => {
                     header.check_data_is_empty()?;
-
-                    connection_info.peer_buf_alloc = header.buf_alloc.into();
-                    connection_info.peer_fwd_cnt = header.fwd_cnt.into();
-                    self.connection_info.replace(connection_info);
-                    debug!("Connection info updated: {:?}", self.connection_info);
-
-                    if accepted_ops.contains(&op) {
-                        break;
-                    } else {
-                        // Virtio v1.1 5.10.6.3
-                        // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
-                        // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
-                        // any time a change in buffer space occurs.
-                        continue;
+                    connection_info.has_pending_credit_request = false;
+                    if self.connection_info.is_some() {
+                        self.connection_info = Some(connection_info.clone());
                     }
+
+                    // Virtio v1.1 5.10.6.3
+                    // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
+                    // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
+                    // any time a change in buffer space occurs.
+                    continue;
                 }
                 VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
                     header.check_data_is_empty()?;
 
-                    self.connection_info.take();
+                    self.connection_info = None;
                     info!("Disconnected from the peer");
-                    if accepted_ops.contains(&op) {
-                    } else if op == VirtioVsockOp::Rst {
-                        result = Err(SocketError::ConnectionFailed.into());
+
+                    let reason = if op == VirtioVsockOp::Rst {
+                        DisconnectReason::Reset
                     } else {
-                        assert_eq!(VirtioVsockOp::Shutdown, op);
-                        result = Err(SocketError::PeerSocketShutdown.into());
-                    }
-                    break;
+                        DisconnectReason::Shutdown
+                    };
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Disconnected { reason },
+                    }));
+                }
+                VirtioVsockOp::Rw => {
+                    self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
+                    return Ok(Some(VsockEvent {
+                        source: connection_info.dst,
+                        destination: VsockAddr {
+                            cid: self.guest_cid,
+                            port: connection_info.src_port,
+                        },
+                        event_type: VsockEventType::Received {
+                            length: header.len() as usize,
+                        },
+                    }));
                 }
-                // TODO: Update peer_buf_alloc and peer_fwd_cnt for other packets too.
-                x if accepted_ops.contains(&x) => {
-                    f(&header)?;
-                    break;
+                VirtioVsockOp::CreditRequest => {
+                    header.check_data_is_empty()?;
+                    // TODO: Send a credit update.
                 }
-                _ => {
-                    result = Err(SocketError::InvalidOperation.into());
-                    break;
+                VirtioVsockOp::Invalid => {
+                    return Err(SocketError::InvalidOperation.into());
                 }
-            };
-        }
-
-        if self.rx.should_notify() {
-            self.transport.notify(RX_QUEUE_IDX);
-        }
-        result
-    }
-
-    /// Waits until there is at least one element to pop in rx queue.
-    fn wait_one_in_rx_queue(&mut self) {
-        const TIMEOUT_N: usize = 1000000;
-        for _ in 0..TIMEOUT_N {
-            if self.rx.can_pop() {
-                break;
-            } else {
-                core::hint::spin_loop();
             }
         }
     }
@@ -386,11 +471,11 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
     /// the body into the given buffer.
     ///
-    /// Returns an error if there is no pending packet, or the body is bigger than the buffer
-    /// supplied.
-    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<VirtioVsockHdr> {
+    /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
+    /// buffer supplied.
+    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
         let Some(token) = self.rx.peek_used() else {
-            return Err(SocketError::NoResponseReceived.into());
+            return Ok(None);
         };
 
         // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
@@ -414,7 +499,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         }?;
 
         debug!("Received packet {:?}. Op {:?}", header, header.op());
-        Ok(header)
+        Ok(Some(header))
     }
 
     fn connection_info(&self) -> Result<ConnectionInfo> {