浏览代码

Separate connection tracking from low level vsock driver.

Andrew Walbran 1 年之前
父节点
当前提交
1db89bd7de
共有 4 个文件被更改,包括 591 次插入421 次删除
  1. 4 3
      examples/aarch64/src/main.rs
  2. 4 0
      src/device/socket/mod.rs
  3. 434 0
      src/device/socket/singleconnectionmanager.rs
  4. 149 418
      src/device/socket/vsock.rs

+ 4 - 3
examples/aarch64/src/main.rs

@@ -30,7 +30,7 @@ use virtio_drivers::{
         blk::VirtIOBlk,
         console::VirtIOConsole,
         gpu::VirtIOGpu,
-        socket::{VirtIOSocket, VsockAddr, VsockEventType},
+        socket::{SingleConnectionManager, VirtIOSocket, VsockAddr, VsockEventType},
     },
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
@@ -204,8 +204,9 @@ fn virtio_console<T: Transport>(transport: T) {
 }
 
 fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
-    let mut socket =
-        VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver");
+    let mut socket = SingleConnectionManager::new(
+        VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver"),
+    );
     let host_cid = 2;
     let port = 1221;
     info!("Connecting to host on port {port}...");

+ 4 - 0
src/device/socket/mod.rs

@@ -3,9 +3,13 @@
 mod error;
 mod protocol;
 #[cfg(feature = "alloc")]
+mod singleconnectionmanager;
+#[cfg(feature = "alloc")]
 mod vsock;
 
 pub use error::SocketError;
 pub use protocol::VsockAddr;
 #[cfg(feature = "alloc")]
+pub use singleconnectionmanager::SingleConnectionManager;
+#[cfg(feature = "alloc")]
 pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};

+ 434 - 0
src/device/socket/singleconnectionmanager.rs

@@ -0,0 +1,434 @@
+use super::{
+    protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
+    VsockEventType,
+};
+use crate::{transport::Transport, Hal, Result};
+use core::hint::spin_loop;
+use log::debug;
+
+/// A higher level interface for vsock devices.
+///
+/// This keeps track of a single vsock connection.
+pub struct SingleConnectionManager<H: Hal, T: Transport> {
+    driver: VirtIOSocket<H, T>,
+    connection_info: Option<ConnectionInfo>,
+}
+
+impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
+    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
+    pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+        Self {
+            driver,
+            connection_info: None,
+        }
+    }
+
+    /// Returns the CID which has been assigned to this guest.
+    pub fn guest_cid(&self) -> u64 {
+        self.driver.guest_cid()
+    }
+
+    /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
+    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        if self.connection_info.is_some() {
+            return Err(SocketError::ConnectionExists.into());
+        }
+
+        let new_connection_info = ConnectionInfo::new(destination, src_port);
+
+        self.driver.connect(destination, src_port)?;
+        debug!("Connection requested: {:?}", new_connection_info);
+        self.connection_info = Some(new_connection_info);
+        Ok(())
+    }
+
+    /// Sends the buffer to the destination.
+    pub fn send(&mut self, buffer: &[u8]) -> Result {
+        self.driver.send(
+            buffer,
+            self.connection_info
+                .as_mut()
+                .ok_or(SocketError::NotConnected)?,
+        )
+    }
+
+    /// Polls the vsock device to receive data or other updates.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
+        let Some(connection_info) = &self.connection_info else {
+            return Err(SocketError::NotConnected.into());
+        };
+
+        // Tell the peer that we have space to receive some data.
+        self.driver
+            .credit_update(connection_info, buffer.len() as u32)?;
+
+        self.poll_rx_queue(buffer)
+    }
+
+    /// Blocks until we get some event from the vsock device.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+        loop {
+            if let Some(event) = self.poll_recv(buffer)? {
+                return Ok(event);
+            } else {
+                spin_loop();
+            }
+        }
+    }
+
+    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
+        loop {
+            let Some(event) = self.driver.poll_recv(body)? else {
+                return Ok(None)
+            };
+
+            let Some(connection_info) = &mut self.connection_info else {
+                continue;
+            };
+
+            // Skip packets which don't match our current connection.
+            if !event.matches_connection(connection_info, self.driver.guest_cid()) {
+                debug!(
+                    "Skipping {:?} as connection is {:?}",
+                    event, connection_info
+                );
+                continue;
+            }
+
+            // Update stored connection info.
+            connection_info.update_for_event(&event);
+
+            match event.event_type {
+                VsockEventType::Connected => {}
+                VsockEventType::Disconnected { .. } => {
+                    self.connection_info = None;
+                }
+                VsockEventType::Received { length } => {
+                    connection_info.done_forwarding(length);
+                }
+                VsockEventType::CreditRequest => {
+                    // TODO: Send a credit update.
+                }
+                VsockEventType::CreditUpdate => {}
+            }
+
+            return Ok(Some(event));
+        }
+    }
+
+    /// Requests to shut down the connection cleanly.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
+    pub fn shutdown(&mut self) -> Result {
+        self.driver.shutdown(
+            self.connection_info
+                .as_ref()
+                .ok_or(SocketError::NotConnected)?,
+        )
+    }
+
+    /// Forcibly closes the connection without waiting for the peer.
+    pub fn force_close(&mut self) -> Result {
+        self.driver.force_close(
+            self.connection_info
+                .as_ref()
+                .ok_or(SocketError::NotConnected)?,
+        )?;
+        self.connection_info = None;
+        Ok(())
+    }
+
+    /// Blocks until the peer either accepts our connection request (with a
+    /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
+    /// `VIRTIO_VSOCK_OP_RST`).
+    pub fn wait_for_connect(&mut self) -> Result {
+        loop {
+            match self.wait_for_recv(&mut [])?.event_type {
+                VsockEventType::Connected => return Ok(()),
+                VsockEventType::Disconnected { .. } => {
+                    return Err(SocketError::ConnectionFailed.into())
+                }
+                VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()),
+                VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {}
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{
+        device::socket::{
+            protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
+            vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
+        },
+        hal::fake::FakeHal,
+        transport::{
+            fake::{FakeTransport, QueueStatus, State},
+            DeviceStatus, DeviceType,
+        },
+        volatile::ReadOnly,
+    };
+    use alloc::{sync::Arc, vec};
+    use core::{mem::size_of, ptr::NonNull};
+    use std::{sync::Mutex, thread};
+    use zerocopy::{AsBytes, FromBytes};
+
+    #[test]
+    fn send_recv() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
+        let hello_from_guest = "Hello from guest";
+        let hello_from_host = "Hello from host";
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket = SingleConnectionManager::new(
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+        );
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Wait for connection request.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Accept connection and give the peer enough credit to send the message.
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a credit update.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect the guest to send some data.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            let request = state
+                .lock()
+                .unwrap()
+                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+            assert_eq!(
+                request.len(),
+                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
+            );
+            assert_eq!(
+                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rw.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: (hello_from_guest.len() as u32).into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+            assert_eq!(
+                &request[size_of::<VirtioVsockHdr>()..],
+                hello_from_guest.as_bytes()
+            );
+
+            // Send a response.
+            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
+            VirtioVsockHdr {
+                op: VirtioVsockOp::Rw.into(),
+                src_cid: host_cid.into(),
+                dst_cid: guest_cid.into(),
+                src_port: host_port.into(),
+                dst_port: guest_port.into(),
+                len: (hello_from_host.len() as u32).into(),
+                socket_type: SocketType::Stream.into(),
+                flags: 0.into(),
+                buf_alloc: 50.into(),
+                fwd_cnt: (hello_from_guest.len() as u32).into(),
+            }
+            .write_to_prefix(response.as_mut_slice());
+            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
+            state
+                .lock()
+                .unwrap()
+                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
+
+            // Expect a credit update.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 64.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect a shutdown.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Shutdown.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: (hello_from_host.len() as u32).into(),
+                }
+            );
+        });
+
+        socket.connect(host_address, guest_port).unwrap();
+        socket.wait_for_connect().unwrap();
+        socket.send(hello_from_guest.as_bytes()).unwrap();
+        let mut buffer = [0u8; 64];
+        let event = socket.wait_for_recv(&mut buffer).unwrap();
+        assert_eq!(
+            event,
+            VsockEvent {
+                source: VsockAddr {
+                    cid: host_cid,
+                    port: host_port,
+                },
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::Received {
+                    length: hello_from_host.len()
+                },
+                buffer_status: VsockBufferStatus {
+                    buffer_allocation: 50,
+                    forward_count: hello_from_guest.len() as u32,
+                },
+            }
+        );
+        assert_eq!(
+            &buffer[0..hello_from_host.len()],
+            hello_from_host.as_bytes()
+        );
+        socket.shutdown().unwrap();
+
+        handle.join().unwrap();
+    }
+}

+ 149 - 418
src/device/socket/vsock.rs

@@ -9,25 +9,24 @@ use crate::transport::Transport;
 use crate::volatile::volread;
 use crate::Result;
 use alloc::boxed::Box;
-use core::hint::spin_loop;
 use core::mem::size_of;
 use core::ptr::{null_mut, NonNull};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 
-const RX_QUEUE_IDX: u16 = 0;
-const TX_QUEUE_IDX: u16 = 1;
+pub(crate) const RX_QUEUE_IDX: u16 = 0;
+pub(crate) const TX_QUEUE_IDX: u16 = 1;
 const EVENT_QUEUE_IDX: u16 = 2;
 
-const QUEUE_SIZE: usize = 8;
+pub(crate) const QUEUE_SIZE: usize = 8;
 
 /// The size in bytes of each buffer used in the RX virtqueue.
 const RX_BUFFER_SIZE: usize = 512;
 
 #[derive(Clone, Debug, Default, PartialEq, Eq)]
-struct ConnectionInfo {
-    dst: VsockAddr,
-    src_port: u32,
+pub struct ConnectionInfo {
+    pub dst: VsockAddr,
+    pub src_port: u32,
     /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
     /// bytes it has allocated for packet bodies.
     peer_buf_alloc: u32,
@@ -46,6 +45,33 @@ struct ConnectionInfo {
 }
 
 impl ConnectionInfo {
+    pub fn new(destination: VsockAddr, src_port: u32) -> Self {
+        Self {
+            dst: destination,
+            src_port,
+            ..Default::default()
+        }
+    }
+
+    /// Updates this connection info with the peer buffer allocation and forwarded count from the
+    /// given event.
+    pub fn update_for_event(&mut self, event: &VsockEvent) {
+        self.peer_buf_alloc = event.buffer_status.buffer_allocation;
+        self.peer_fwd_cnt = event.buffer_status.forward_count;
+
+        if let VsockEventType::CreditUpdate = event.event_type {
+            self.has_pending_credit_request = false;
+        }
+    }
+
+    /// Increases the forwarded count recorded for this connection by the given number of bytes.
+    ///
+    /// This should be called once received data has been passed to the client, so there is buffer
+    /// space available for more.
+    pub fn done_forwarding(&mut self, length: usize) {
+        self.fwd_cnt += length as u32;
+    }
+
     fn peer_free(&self) -> u32 {
         self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
     }
@@ -69,10 +95,27 @@ pub struct VsockEvent {
     pub source: VsockAddr,
     /// The destination of the event, i.e. the CID and port on our side.
     pub destination: VsockAddr,
+    /// The peer's buffer status for the connection.
+    pub buffer_status: VsockBufferStatus,
     /// The type of event.
     pub event_type: VsockEventType,
 }
 
+impl VsockEvent {
+    /// Returns whether the event matches the given connection.
+    pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
+        self.source == connection_info.dst
+            && self.destination.cid == guest_cid
+            && self.destination.port == connection_info.src_port
+    }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockBufferStatus {
+    pub buffer_allocation: u32,
+    pub forward_count: u32,
+}
+
 /// The reason why a vsock connection was closed.
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum DisconnectReason {
@@ -98,6 +141,10 @@ pub enum VsockEventType {
         /// The length of the data in bytes.
         length: usize,
     },
+    /// The peer requests us to send a credit update.
+    CreditRequest,
+    /// The peer just sent us a credit update with nothing else.
+    CreditUpdate,
 }
 
 /// Driver for a VirtIO socket device.
@@ -112,9 +159,6 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
     /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
     guest_cid: u64,
     rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
-
-    /// Currently the device is only allowed to be connected to one destination at a time.
-    connection_info: Option<ConnectionInfo>,
 }
 
 impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
@@ -180,7 +224,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             event,
             guest_cid,
             rx_queue_buffers,
-            connection_info: None,
         })
     }
 
@@ -195,41 +238,23 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
     /// before sending data.
     pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
-        if self.connection_info.is_some() {
-            return Err(SocketError::ConnectionExists.into());
-        }
-        let new_connection_info = ConnectionInfo {
-            dst: destination,
-            src_port,
-            ..Default::default()
-        };
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Request.into(),
-            ..new_connection_info.new_header(self.guest_cid)
+            src_cid: self.guest_cid.into(),
+            dst_cid: destination.cid.into(),
+            src_port: src_port.into(),
+            dst_port: destination.port.into(),
+            ..Default::default()
         };
         // Sends a header only packet to the tx queue to connect the device to the listening
         // socket at the given destination.
         self.send_packet_to_tx_queue(&header, &[])?;
 
-        self.connection_info = Some(new_connection_info);
-        debug!("Connection requested: {:?}", self.connection_info);
         Ok(())
     }
 
-    /// Blocks until the peer either accepts our connection request (with a
-    /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
-    /// `VIRTIO_VSOCK_OP_RST`).
-    pub fn wait_for_connect(&mut self) -> Result {
-        match self.wait_for_recv(&mut [])?.event_type {
-            VsockEventType::Connected => Ok(()),
-            VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
-            VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()),
-        }
-    }
-
-    /// Requests the peer to send us a credit update for the current connection.
-    fn request_credit(&mut self) -> Result {
-        let connection_info = self.connection_info()?;
+    /// Requests the peer to send us a credit update for the given connection.
+    fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditRequest.into(),
             ..connection_info.new_header(self.guest_cid)
@@ -238,12 +263,8 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
 
     /// Sends the buffer to the destination.
-    pub fn send(&mut self, buffer: &[u8]) -> Result {
-        let mut connection_info = self.connection_info()?;
-
-        let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len());
-        self.connection_info = Some(connection_info.clone());
-        result?;
+    pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
+        self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
 
         let len = buffer.len() as u32;
         let header = VirtioVsockHdr {
@@ -252,7 +273,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             buf_alloc: 0.into(),
             ..connection_info.new_header(self.guest_cid)
         };
-        self.connection_info.as_mut().unwrap().tx_cnt += len;
+        connection_info.tx_cnt += len;
         self.send_packet_to_tx_queue(&header, buffer)
     }
 
@@ -267,28 +288,28 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             // Request an update of the cached peer credit, if we haven't already done so, and tell
             // the caller to try again later.
             if !connection_info.has_pending_credit_request {
-                self.request_credit()?;
+                self.request_credit(connection_info)?;
                 connection_info.has_pending_credit_request = true;
             }
             Err(SocketError::InsufficientBufferSpaceInPeer.into())
         }
     }
 
-    /// Polls the vsock device to receive data or other updates.
-    ///
-    /// A buffer must be provided to put the data in if there is some to
-    /// receive.
-    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
-        let connection_info = self.connection_info()?;
-
-        // Tell the peer that we have space to receive some data.
+    /// Tells the peer how much buffer space we have to receive data.
+    pub fn credit_update(&mut self, connection_info: &ConnectionInfo, buffer_size: u32) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditUpdate.into(),
-            buf_alloc: (buffer.len() as u32).into(),
+            buf_alloc: buffer_size.into(),
             ..connection_info.new_header(self.guest_cid)
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
+        self.send_packet_to_tx_queue(&header, &[])
+    }
 
+    /// Polls the vsock device to receive data or other updates.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
         // Handle entries from the RX virtqueue until we find one that generates an event.
         let event = self.poll_rx_queue(buffer)?;
 
@@ -299,27 +320,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(event)
     }
 
-    /// Blocks until we get some event from the vsock device.
-    ///
-    /// A buffer must be provided to put the data in if there is some to
-    /// receive.
-    pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
-        loop {
-            if let Some(event) = self.poll_recv(buffer)? {
-                return Ok(event);
-            } else {
-                spin_loop();
-            }
-        }
-    }
-
-    /// Request to shut down the connection cleanly.
+    /// Requests to shut down the connection cleanly.
     ///
     /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
     /// shutdown.
-    pub fn shutdown(&mut self) -> Result {
-        let connection_info = self.connection_info()?;
+    pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Shutdown.into(),
             ..connection_info.new_header(self.guest_cid)
@@ -328,14 +334,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
 
     /// Forcibly closes the connection without waiting for the peer.
-    pub fn force_close(&mut self) -> Result {
-        let connection_info = self.connection_info()?;
+    pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rst.into(),
             ..connection_info.new_header(self.guest_cid)
         };
         self.send_packet_to_tx_queue(&header, &[])?;
-        self.connection_info = None;
         Ok(())
     }
 
@@ -348,109 +352,83 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
-    /// which generates a `VsockEvent`.
+    /// Polls the RX virtqueue for the next event.
     ///
     /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
     /// don't result in any events to return.
     fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
-        loop {
-            let mut connection_info = self.connection_info.clone().unwrap_or_default();
-            let Some(header) = self.pop_packet_from_rx_queue(body)? else{
-                return Ok(None);
-            };
-
-            let op = header.op()?;
-
-            // Skip packets which don't match our current connection.
-            if header.source() != connection_info.dst
-                || header.dst_cid.get() != self.guest_cid
-                || header.dst_port.get() != connection_info.src_port
-            {
-                debug!(
-                    "Skipping {:?} as connection is {:?}",
-                    header, connection_info
-                );
-                continue;
-            }
+        let Some(header) = self.pop_packet_from_rx_queue(body)? else {
+            return Ok(None);
+        };
 
-            connection_info.peer_buf_alloc = header.buf_alloc.into();
-            connection_info.peer_fwd_cnt = header.fwd_cnt.into();
-            if self.connection_info.is_some() {
-                self.connection_info = Some(connection_info.clone());
-                debug!("Connection info updated: {:?}", self.connection_info);
-            }
+        let op = header.op()?;
 
-            match op {
-                VirtioVsockOp::Request => {
-                    header.check_data_is_empty()?;
-                    // TODO: Send a Rst, or support listening.
-                }
-                VirtioVsockOp::Response => {
-                    header.check_data_is_empty()?;
-                    return Ok(Some(VsockEvent {
-                        source: connection_info.dst,
-                        destination: VsockAddr {
-                            cid: self.guest_cid,
-                            port: connection_info.src_port,
-                        },
-                        event_type: VsockEventType::Connected,
-                    }));
-                }
-                VirtioVsockOp::CreditUpdate => {
-                    header.check_data_is_empty()?;
-                    connection_info.has_pending_credit_request = false;
-                    if self.connection_info.is_some() {
-                        self.connection_info = Some(connection_info.clone());
-                    }
-
-                    // Virtio v1.1 5.10.6.3
-                    // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
-                    // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
-                    // any time a change in buffer space occurs.
-                    continue;
-                }
-                VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
-                    header.check_data_is_empty()?;
-
-                    self.connection_info = None;
-                    info!("Disconnected from the peer");
-
-                    let reason = if op == VirtioVsockOp::Rst {
-                        DisconnectReason::Reset
-                    } else {
-                        DisconnectReason::Shutdown
-                    };
-                    return Ok(Some(VsockEvent {
-                        source: connection_info.dst,
-                        destination: VsockAddr {
-                            cid: self.guest_cid,
-                            port: connection_info.src_port,
-                        },
-                        event_type: VsockEventType::Disconnected { reason },
-                    }));
-                }
-                VirtioVsockOp::Rw => {
-                    self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
-                    return Ok(Some(VsockEvent {
-                        source: connection_info.dst,
-                        destination: VsockAddr {
-                            cid: self.guest_cid,
-                            port: connection_info.src_port,
-                        },
-                        event_type: VsockEventType::Received {
-                            length: header.len() as usize,
-                        },
-                    }));
-                }
-                VirtioVsockOp::CreditRequest => {
-                    header.check_data_is_empty()?;
-                    // TODO: Send a credit update.
-                }
-                VirtioVsockOp::Invalid => {
-                    return Err(SocketError::InvalidOperation.into());
-                }
+        let buffer_status = VsockBufferStatus {
+            buffer_allocation: header.buf_alloc.into(),
+            forward_count: header.fwd_cnt.into(),
+        };
+        let source = header.source();
+        let destination = header.destination();
+
+        match op {
+            VirtioVsockOp::Request => {
+                header.check_data_is_empty()?;
+                // TODO: Send a Rst, or support listening.
+                Ok(None)
+            }
+            VirtioVsockOp::Response => {
+                header.check_data_is_empty()?;
+                Ok(Some(VsockEvent {
+                    source,
+                    destination,
+                    buffer_status,
+                    event_type: VsockEventType::Connected,
+                }))
+            }
+            VirtioVsockOp::CreditUpdate => {
+                header.check_data_is_empty()?;
+                Ok(Some(VsockEvent {
+                    source,
+                    destination,
+                    buffer_status,
+                    event_type: VsockEventType::CreditUpdate,
+                }))
             }
+            VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
+                header.check_data_is_empty()?;
+
+                info!("Disconnected from the peer");
+
+                let reason = if op == VirtioVsockOp::Rst {
+                    DisconnectReason::Reset
+                } else {
+                    DisconnectReason::Shutdown
+                };
+                Ok(Some(VsockEvent {
+                    source,
+                    destination,
+                    buffer_status,
+                    event_type: VsockEventType::Disconnected { reason },
+                }))
+            }
+            VirtioVsockOp::Rw => Ok(Some(VsockEvent {
+                source,
+                destination,
+                buffer_status,
+                event_type: VsockEventType::Received {
+                    length: header.len() as usize,
+                },
+            })),
+            VirtioVsockOp::CreditRequest => {
+                header.check_data_is_empty()?;
+                Ok(Some(VsockEvent {
+                    source,
+                    destination,
+                    buffer_status,
+                    event_type: VsockEventType::CreditRequest,
+                }))
+            }
+            VirtioVsockOp::Invalid => Err(SocketError::InvalidOperation.into()),
         }
     }
 
@@ -487,12 +465,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         debug!("Received packet {:?}. Op {:?}", header, header.op());
         Ok(Some(header))
     }
-
-    fn connection_info(&self) -> Result<ConnectionInfo> {
-        self.connection_info
-            .clone()
-            .ok_or(SocketError::NotConnected.into())
-    }
 }
 
 fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
@@ -514,7 +486,6 @@ fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr
 mod tests {
     use super::*;
     use crate::{
-        device::socket::protocol::SocketType,
         hal::fake::FakeHal,
         transport::{
             fake::{FakeTransport, QueueStatus, State},
@@ -524,7 +495,7 @@ mod tests {
     };
     use alloc::{sync::Arc, vec};
     use core::ptr::NonNull;
-    use std::{sync::Mutex, thread};
+    use std::sync::Mutex;
 
     #[test]
     fn config() {
@@ -554,244 +525,4 @@ mod tests {
             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
         assert_eq!(socket.guest_cid(), 0x00_0000_0042);
     }
-
-    #[test]
-    fn send_recv() {
-        let host_cid = 2;
-        let guest_cid = 66;
-        let host_port = 1234;
-        let guest_port = 4321;
-        let host_address = VsockAddr {
-            cid: host_cid,
-            port: host_port,
-        };
-        let hello_from_guest = "Hello from guest";
-        let hello_from_host = "Hello from host";
-
-        let mut config_space = VirtioVsockConfig {
-            guest_cid_low: ReadOnly::new(66),
-            guest_cid_high: ReadOnly::new(0),
-        };
-        let state = Arc::new(Mutex::new(State {
-            status: DeviceStatus::empty(),
-            driver_features: 0,
-            guest_page_size: 0,
-            interrupt_pending: false,
-            queues: vec![
-                QueueStatus::default(),
-                QueueStatus::default(),
-                QueueStatus::default(),
-            ],
-        }));
-        let transport = FakeTransport {
-            device_type: DeviceType::Socket,
-            max_queue_size: 32,
-            device_features: 0,
-            config_space: NonNull::from(&mut config_space),
-            state: state.clone(),
-        };
-        let mut socket =
-            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
-
-        // Start a thread to simulate the device.
-        let handle = thread::spawn(move || {
-            // Wait for connection request.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Request.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-
-            // Accept connection and give the peer enough credit to send the message.
-            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
-                RX_QUEUE_IDX,
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Response.into(),
-                    src_cid: host_cid.into(),
-                    dst_cid: guest_cid.into(),
-                    src_port: host_port.into(),
-                    dst_port: guest_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 50.into(),
-                    fwd_cnt: 0.into(),
-                }
-                .as_bytes(),
-            );
-
-            // Expect a credit update.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::CreditUpdate.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-
-            // Expect the guest to send some data.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            let request = state
-                .lock()
-                .unwrap()
-                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
-            assert_eq!(
-                request.len(),
-                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
-            );
-            assert_eq!(
-                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Rw.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: (hello_from_guest.len() as u32).into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-            assert_eq!(
-                &request[size_of::<VirtioVsockHdr>()..],
-                hello_from_guest.as_bytes()
-            );
-
-            // Send a response.
-            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
-            VirtioVsockHdr {
-                op: VirtioVsockOp::Rw.into(),
-                src_cid: host_cid.into(),
-                dst_cid: guest_cid.into(),
-                src_port: host_port.into(),
-                dst_port: guest_port.into(),
-                len: (hello_from_host.len() as u32).into(),
-                socket_type: SocketType::Stream.into(),
-                flags: 0.into(),
-                buf_alloc: 50.into(),
-                fwd_cnt: (hello_from_guest.len() as u32).into(),
-            }
-            .write_to_prefix(response.as_mut_slice());
-            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
-            state
-                .lock()
-                .unwrap()
-                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
-
-            // Expect a credit update.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::CreditUpdate.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 64.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-
-            // Expect a shutdown.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Shutdown.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: (hello_from_host.len() as u32).into(),
-                }
-            );
-        });
-
-        socket.connect(host_address, guest_port).unwrap();
-        socket.wait_for_connect().unwrap();
-        socket.send(hello_from_guest.as_bytes()).unwrap();
-        let mut buffer = [0u8; 64];
-        let event = socket.wait_for_recv(&mut buffer).unwrap();
-        assert_eq!(
-            event,
-            VsockEvent {
-                source: VsockAddr {
-                    cid: host_cid,
-                    port: host_port,
-                },
-                destination: VsockAddr {
-                    cid: guest_cid,
-                    port: guest_port,
-                },
-                event_type: VsockEventType::Received {
-                    length: hello_from_host.len()
-                }
-            }
-        );
-        assert_eq!(
-            &buffer[0..hello_from_host.len()],
-            hello_from_host.as_bytes()
-        );
-        socket.shutdown().unwrap();
-
-        handle.join().unwrap();
-    }
 }