Przeglądaj źródła

Merge pull request #91 from rcore-os/vsock

Support multiple and incoming connections in vsock driver
Andrew Walbran 1 rok temu
rodzic
commit
b003840720

+ 18 - 14
examples/aarch64/src/main.rs

@@ -30,7 +30,7 @@ use virtio_drivers::{
         blk::VirtIOBlk,
         console::VirtIOConsole,
         gpu::VirtIOGpu,
-        socket::{VirtIOSocket, VsockAddr, VsockEventType},
+        socket::{VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType},
     },
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
@@ -204,29 +204,33 @@ fn virtio_console<T: Transport>(transport: T) {
 }
 
 fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
-    let mut socket =
-        VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver");
+    let mut socket = VsockConnectionManager::new(
+        VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver"),
+    );
     let host_cid = 2;
     let port = 1221;
-    info!("Connecting to host on port {port}...");
-    socket.connect(
-        VsockAddr {
-            cid: host_cid,
-            port,
-        },
+    let host_address = VsockAddr {
+        cid: host_cid,
         port,
-    )?;
-    socket.wait_for_connect()?;
+    };
+    info!("Connecting to host on port {port}...");
+    socket.connect(host_address, port)?;
+    let event = socket.wait_for_event()?;
+    assert_eq!(event.source, host_address);
+    assert_eq!(event.destination.port, port);
+    assert_eq!(event.event_type, VsockEventType::Connected);
     info!("Connected to the host");
 
     const EXCHANGE_NUM: usize = 2;
     let messages = ["0-Ack. Hello from guest.", "1-Ack. Received again."];
     for k in 0..EXCHANGE_NUM {
         let mut buffer = [0u8; 24];
-        let socket_event = socket.wait_for_recv(&mut buffer)?;
+        let socket_event = socket.wait_for_event()?;
         let VsockEventType::Received {length, ..} = socket_event.event_type else {
             panic!("Received unexpected socket event {:?}", socket_event);
         };
+        let read_length = socket.recv(host_address, port, &mut buffer)?;
+        assert_eq!(length, read_length);
         info!(
             "Received message: {:?}({:?}), len: {:?}",
             buffer,
@@ -235,10 +239,10 @@ fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
         );
 
         let message = messages[k % messages.len()];
-        socket.send(message.as_bytes())?;
+        socket.send(host_address, port, message.as_bytes())?;
         info!("Sent message: {:?}", message);
     }
-    socket.shutdown()?;
+    socket.shutdown(host_address, port)?;
     info!("Shutdown the connection");
     Ok(())
 }

+ 16 - 1
src/device/socket/mod.rs

@@ -1,11 +1,26 @@
-//! This module implements the virtio vsock device.
+//! Driver for VirtIO socket devices.
+//!
+//! To use the driver, you should first create a [`VirtIOSocket`] instance with your VirtIO
+//! transport, and then create a [`VsockConnectionManager`] wrapping it to keep track of
+//! connections. If you only want to have a single outgoing vsock connection at once, you can use
+//! [`SingleConnectionManager`] for a slightly simpler interface.
+//!
+//! See [`VsockConnectionManager`] for a usage example.
 
 mod error;
+#[cfg(feature = "alloc")]
+mod multiconnectionmanager;
 mod protocol;
 #[cfg(feature = "alloc")]
+mod singleconnectionmanager;
+#[cfg(feature = "alloc")]
 mod vsock;
 
 pub use error::SocketError;
+#[cfg(feature = "alloc")]
+pub use multiconnectionmanager::VsockConnectionManager;
 pub use protocol::VsockAddr;
 #[cfg(feature = "alloc")]
+pub use singleconnectionmanager::SingleConnectionManager;
+#[cfg(feature = "alloc")]
 pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};

+ 763 - 0
src/device/socket/multiconnectionmanager.rs

@@ -0,0 +1,763 @@
+use super::{
+    protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
+    VsockEvent, VsockEventType,
+};
+use crate::{transport::Transport, Hal, Result};
+use alloc::{boxed::Box, vec::Vec};
+use core::cmp::min;
+use core::convert::TryInto;
+use core::hint::spin_loop;
+use log::debug;
+use zerocopy::FromBytes;
+
+const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
+
+/// A higher level interface for VirtIO socket (vsock) devices.
+///
+/// This keeps track of multiple vsock connections.
+///
+/// # Example
+///
+/// ```
+/// # use virtio_drivers::{Error, Hal};
+/// # use virtio_drivers::transport::Transport;
+/// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
+///
+/// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
+/// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
+///
+/// // Start a thread to call `socket.poll()` and handle events.
+///
+/// let remote_address = VsockAddr { cid: 2, port: 42 };
+/// let local_port = 1234;
+/// socket.connect(remote_address, local_port)?;
+///
+/// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
+///
+/// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
+///
+/// socket.shutdown(remote_address, local_port)?;
+/// # Ok(())
+/// # }
+/// ```
+pub struct VsockConnectionManager<H: Hal, T: Transport> {
+    driver: VirtIOSocket<H, T>,
+    connections: Vec<Connection>,
+    listening_ports: Vec<u32>,
+}
+
+#[derive(Debug)]
+struct Connection {
+    info: ConnectionInfo,
+    buffer: RingBuffer,
+    /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
+    /// still data in the buffer.
+    peer_requested_shutdown: bool,
+}
+
+impl Connection {
+    fn new(peer: VsockAddr, local_port: u32) -> Self {
+        let mut info = ConnectionInfo::new(peer, local_port);
+        info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
+        Self {
+            info,
+            buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
+            peer_requested_shutdown: false,
+        }
+    }
+}
+
+impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
+    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
+    pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+        Self {
+            driver,
+            connections: Vec::new(),
+            listening_ports: Vec::new(),
+        }
+    }
+
+    /// Returns the CID which has been assigned to this guest.
+    pub fn guest_cid(&self) -> u64 {
+        self.driver.guest_cid()
+    }
+
+    /// Allows incoming connections on the given port number.
+    pub fn listen(&mut self, port: u32) {
+        if !self.listening_ports.contains(&port) {
+            self.listening_ports.push(port);
+        }
+    }
+
+    /// Stops allowing incoming connections on the given port number.
+    pub fn unlisten(&mut self, port: u32) {
+        self.listening_ports.retain(|p| *p != port);
+    }
+
+    /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
+    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        if self.connections.iter().any(|connection| {
+            connection.info.dst == destination && connection.info.src_port == src_port
+        }) {
+            return Err(SocketError::ConnectionExists.into());
+        }
+
+        let new_connection = Connection::new(destination, src_port);
+
+        self.driver.connect(&new_connection.info)?;
+        debug!("Connection requested: {:?}", new_connection.info);
+        self.connections.push(new_connection);
+        Ok(())
+    }
+
+    /// Sends the buffer to the destination.
+    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
+        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
+
+        self.driver.send(buffer, &mut connection.info)
+    }
+
+    /// Polls the vsock device to receive data or other updates.
+    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
+        let guest_cid = self.driver.guest_cid();
+        let connections = &mut self.connections;
+
+        let result = self.driver.poll(|event, body| {
+            let connection = get_connection_for_event(connections, &event, guest_cid);
+
+            // Skip events which don't match any connection we know about, unless they are a
+            // connection request.
+            let connection = if let Some((_, connection)) = connection {
+                connection
+            } else if let VsockEventType::ConnectionRequest = event.event_type {
+                // If the requested connection already exists or the CID isn't ours, ignore it.
+                if connection.is_some() || event.destination.cid != guest_cid {
+                    return Ok(None);
+                }
+                // Add the new connection to our list, at least for now. It will be removed again
+                // below if we weren't listening on the port.
+                connections.push(Connection::new(event.source, event.destination.port));
+                connections.last_mut().unwrap()
+            } else {
+                return Ok(None);
+            };
+
+            // Update stored connection info.
+            connection.info.update_for_event(&event);
+
+            if let VsockEventType::Received { length } = event.event_type {
+                // Copy to buffer
+                if !connection.buffer.add(body) {
+                    return Err(SocketError::OutputBufferTooShort(length).into());
+                }
+            }
+
+            Ok(Some(event))
+        })?;
+
+        let Some(event) = result else {
+            return Ok(None);
+        };
+
+        // The connection must exist because we found it above in the callback.
+        let (connection_index, connection) =
+            get_connection_for_event(connections, &event, guest_cid).unwrap();
+
+        match event.event_type {
+            VsockEventType::ConnectionRequest => {
+                if self.listening_ports.contains(&event.destination.port) {
+                    self.driver.accept(&connection.info)?;
+                } else {
+                    // Reject the connection request and remove it from our list.
+                    self.driver.force_close(&connection.info)?;
+                    self.connections.swap_remove(connection_index);
+
+                    // No need to pass the request on to the client, as we've already rejected it.
+                    return Ok(None);
+                }
+            }
+            VsockEventType::Connected => {}
+            VsockEventType::Disconnected { reason } => {
+                // Wait until client reads all data before removing connection.
+                if connection.buffer.is_empty() {
+                    if reason == DisconnectReason::Shutdown {
+                        self.driver.force_close(&connection.info)?;
+                    }
+                    self.connections.swap_remove(connection_index);
+                } else {
+                    connection.peer_requested_shutdown = true;
+                }
+            }
+            VsockEventType::Received { .. } => {
+                // Already copied the buffer in the callback above.
+            }
+            VsockEventType::CreditRequest => {
+                // If the peer requested credit, send an update.
+                self.driver.credit_update(&connection.info)?;
+                // No need to pass the request on to the client, we've already handled it.
+                return Ok(None);
+            }
+            VsockEventType::CreditUpdate => {}
+        }
+
+        Ok(Some(event))
+    }
+
+    /// Reads data received from the given connection.
+    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
+        let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
+
+        // Copy from ring buffer
+        let bytes_read = connection.buffer.drain(buffer);
+
+        connection.info.done_forwarding(bytes_read);
+
+        // If buffer is now empty and the peer requested shutdown, finish shutting down the
+        // connection.
+        if connection.peer_requested_shutdown && connection.buffer.is_empty() {
+            self.driver.force_close(&connection.info)?;
+            self.connections.swap_remove(connection_index);
+        }
+
+        Ok(bytes_read)
+    }
+
+    /// Blocks until we get some event from the vsock device.
+    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
+        loop {
+            if let Some(event) = self.poll()? {
+                return Ok(event);
+            } else {
+                spin_loop();
+            }
+        }
+    }
+
+    /// Requests to shut down the connection cleanly.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
+    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
+
+        self.driver.shutdown(&connection.info)
+    }
+
+    /// Forcibly closes the connection without waiting for the peer.
+    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
+
+        self.driver.force_close(&connection.info)?;
+
+        self.connections.swap_remove(index);
+        Ok(())
+    }
+}
+
+/// Returns the connection from the given list matching the given peer address and local port, and
+/// its index.
+///
+/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
+fn get_connection(
+    connections: &mut [Connection],
+    peer: VsockAddr,
+    local_port: u32,
+) -> core::result::Result<(usize, &mut Connection), SocketError> {
+    connections
+        .iter_mut()
+        .enumerate()
+        .find(|(_, connection)| {
+            connection.info.dst == peer && connection.info.src_port == local_port
+        })
+        .ok_or(SocketError::NotConnected)
+}
+
+/// Returns the connection from the given list matching the event, if any, and its index.
+fn get_connection_for_event<'a>(
+    connections: &'a mut [Connection],
+    event: &VsockEvent,
+    local_cid: u64,
+) -> Option<(usize, &'a mut Connection)> {
+    connections
+        .iter_mut()
+        .enumerate()
+        .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
+}
+
+#[derive(Debug)]
+struct RingBuffer {
+    buffer: Box<[u8]>,
+    /// The number of bytes currently in the buffer.
+    used: usize,
+    /// The index of the first used byte in the buffer.
+    start: usize,
+}
+
+impl RingBuffer {
+    pub fn new(capacity: usize) -> Self {
+        Self {
+            buffer: FromBytes::new_box_slice_zeroed(capacity),
+            used: 0,
+            start: 0,
+        }
+    }
+
+    /// Returns the number of bytes currently used in the buffer.
+    pub fn used(&self) -> usize {
+        self.used
+    }
+
+    /// Returns true iff there are currently no bytes in the buffer.
+    pub fn is_empty(&self) -> bool {
+        self.used == 0
+    }
+
+    /// Returns the number of bytes currently free in the buffer.
+    pub fn available(&self) -> usize {
+        self.buffer.len() - self.used
+    }
+
+    /// Adds the given bytes to the buffer if there is enough capacity for them all.
+    ///
+    /// Returns true if they were added, or false if they were not.
+    pub fn add(&mut self, bytes: &[u8]) -> bool {
+        if bytes.len() > self.available() {
+            return false;
+        }
+
+        // The index of the first available position in the buffer.
+        let first_available = (self.start + self.used) % self.buffer.len();
+        // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
+        // `buffer.len()`.
+        let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
+        self.buffer[first_available..first_available + copy_length_before_wraparound]
+            .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
+        if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
+            self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
+        }
+        self.used += bytes.len();
+
+        true
+    }
+
+    /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
+    /// buffer.
+    pub fn drain(&mut self, out: &mut [u8]) -> usize {
+        let bytes_read = min(self.used, out.len());
+
+        // The number of bytes to copy out between `start` and the end of the buffer.
+        let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
+        // The number of bytes to copy out from the beginning of the buffer after wrapping around.
+        let read_after_wraparound = bytes_read
+            .checked_sub(read_before_wraparound)
+            .unwrap_or_default();
+
+        out[0..read_before_wraparound]
+            .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
+        out[read_before_wraparound..bytes_read]
+            .copy_from_slice(&self.buffer[0..read_after_wraparound]);
+
+        self.used -= bytes_read;
+        self.start = (self.start + bytes_read) % self.buffer.len();
+
+        bytes_read
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{
+        device::socket::{
+            protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
+            vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
+        },
+        hal::fake::FakeHal,
+        transport::{
+            fake::{FakeTransport, QueueStatus, State},
+            DeviceStatus, DeviceType,
+        },
+        volatile::ReadOnly,
+    };
+    use alloc::{sync::Arc, vec};
+    use core::{mem::size_of, ptr::NonNull};
+    use std::{sync::Mutex, thread};
+    use zerocopy::{AsBytes, FromBytes};
+
+    #[test]
+    fn send_recv() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
+        let hello_from_guest = "Hello from guest";
+        let hello_from_host = "Hello from host";
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket = VsockConnectionManager::new(
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+        );
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Wait for connection request.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Accept connection and give the peer enough credit to send the message.
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect the guest to send some data.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            let request = state
+                .lock()
+                .unwrap()
+                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+            assert_eq!(
+                request.len(),
+                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
+            );
+            assert_eq!(
+                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rw.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: (hello_from_guest.len() as u32).into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+            assert_eq!(
+                &request[size_of::<VirtioVsockHdr>()..],
+                hello_from_guest.as_bytes()
+            );
+
+            println!("Host sending");
+
+            // Send a response.
+            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
+            VirtioVsockHdr {
+                op: VirtioVsockOp::Rw.into(),
+                src_cid: host_cid.into(),
+                dst_cid: guest_cid.into(),
+                src_port: host_port.into(),
+                dst_port: guest_port.into(),
+                len: (hello_from_host.len() as u32).into(),
+                socket_type: SocketType::Stream.into(),
+                flags: 0.into(),
+                buf_alloc: 50.into(),
+                fwd_cnt: (hello_from_guest.len() as u32).into(),
+            }
+            .write_to_prefix(response.as_mut_slice());
+            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
+            state
+                .lock()
+                .unwrap()
+                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
+
+            // Expect a shutdown.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Shutdown.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: (hello_from_host.len() as u32).into(),
+                }
+            );
+        });
+
+        socket.connect(host_address, guest_port).unwrap();
+        assert_eq!(
+            socket.wait_for_event().unwrap(),
+            VsockEvent {
+                source: host_address,
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::Connected,
+                buffer_status: VsockBufferStatus {
+                    buffer_allocation: 50,
+                    forward_count: 0,
+                },
+            }
+        );
+        println!("Guest sending");
+        socket
+            .send(host_address, guest_port, "Hello from guest".as_bytes())
+            .unwrap();
+        println!("Guest waiting to receive.");
+        assert_eq!(
+            socket.wait_for_event().unwrap(),
+            VsockEvent {
+                source: host_address,
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::Received {
+                    length: hello_from_host.len()
+                },
+                buffer_status: VsockBufferStatus {
+                    buffer_allocation: 50,
+                    forward_count: hello_from_guest.len() as u32,
+                },
+            }
+        );
+        println!("Guest getting received data.");
+        let mut buffer = [0u8; 64];
+        assert_eq!(
+            socket.recv(host_address, guest_port, &mut buffer).unwrap(),
+            hello_from_host.len()
+        );
+        assert_eq!(
+            &buffer[0..hello_from_host.len()],
+            hello_from_host.as_bytes()
+        );
+        socket.shutdown(host_address, guest_port).unwrap();
+
+        handle.join().unwrap();
+    }
+
+    #[test]
+    fn incoming_connection() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+        let wrong_guest_port = 4444;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket = VsockConnectionManager::new(
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+        );
+
+        socket.listen(guest_port);
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Send a connection request for a port the guest isn't listening on.
+            println!("Host sending connection request to wrong port");
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: wrong_guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a rejection.
+            println!("Host waiting for rejection");
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rst.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: wrong_guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Send a connection request for a port the guest is listening on.
+            println!("Host sending connection request to right port");
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a response.
+            println!("Host waiting for response");
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            println!("Host finished");
+        });
+
+        // Expect an incoming connection.
+        println!("Guest expecting incoming connection.");
+        assert_eq!(
+            socket.wait_for_event().unwrap(),
+            VsockEvent {
+                source: host_address,
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::ConnectionRequest,
+                buffer_status: VsockBufferStatus {
+                    buffer_allocation: 50,
+                    forward_count: 0,
+                },
+            }
+        );
+
+        handle.join().unwrap();
+    }
+}

+ 447 - 0
src/device/socket/singleconnectionmanager.rs

@@ -0,0 +1,447 @@
+use super::{
+    protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
+    VsockEventType,
+};
+use crate::{transport::Transport, Hal, Result};
+use core::hint::spin_loop;
+use log::debug;
+
+/// A higher level interface for VirtIO socket (vsock) devices.
+///
+/// This can only keep track of a single vsock connection. If you want to support multiple
+/// simultaneous connections, try [`VsockConnectionManager`](super::VsockConnectionManager).
+pub struct SingleConnectionManager<H: Hal, T: Transport> {
+    driver: VirtIOSocket<H, T>,
+    connection_info: Option<ConnectionInfo>,
+}
+
+impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
+    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
+    pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+        Self {
+            driver,
+            connection_info: None,
+        }
+    }
+
+    /// Returns the CID which has been assigned to this guest.
+    pub fn guest_cid(&self) -> u64 {
+        self.driver.guest_cid()
+    }
+
+    /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
+    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        if self.connection_info.is_some() {
+            return Err(SocketError::ConnectionExists.into());
+        }
+
+        let new_connection_info = ConnectionInfo::new(destination, src_port);
+
+        self.driver.connect(&new_connection_info)?;
+        debug!("Connection requested: {:?}", new_connection_info);
+        self.connection_info = Some(new_connection_info);
+        Ok(())
+    }
+
+    /// Sends the buffer to the destination.
+    pub fn send(&mut self, buffer: &[u8]) -> Result {
+        let connection_info = self
+            .connection_info
+            .as_mut()
+            .ok_or(SocketError::NotConnected)?;
+        connection_info.buf_alloc = 0;
+        self.driver.send(buffer, connection_info)
+    }
+
+    /// Polls the vsock device to receive data or other updates.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
+        let Some(connection_info) = &mut self.connection_info else {
+            return Err(SocketError::NotConnected.into());
+        };
+
+        // Tell the peer that we have space to receive some data.
+        connection_info.buf_alloc = buffer.len() as u32;
+        self.driver.credit_update(connection_info)?;
+
+        self.poll_rx_queue(buffer)
+    }
+
+    /// Blocks until we get some event from the vsock device.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
+        loop {
+            if let Some(event) = self.poll_recv(buffer)? {
+                return Ok(event);
+            } else {
+                spin_loop();
+            }
+        }
+    }
+
+    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
+        let guest_cid = self.driver.guest_cid();
+        let self_connection_info = &mut self.connection_info;
+
+        self.driver.poll(|event, borrowed_body| {
+            let Some(connection_info) = self_connection_info else {
+                return Ok(None);
+            };
+
+            // Skip packets which don't match our current connection.
+            if !event.matches_connection(connection_info, guest_cid) {
+                debug!(
+                    "Skipping {:?} as connection is {:?}",
+                    event, connection_info
+                );
+                return Ok(None);
+            }
+
+            // Update stored connection info.
+            connection_info.update_for_event(&event);
+
+            match event.event_type {
+                VsockEventType::ConnectionRequest => {
+                    // TODO: Send Rst or handle incoming connections.
+                }
+                VsockEventType::Connected => {}
+                VsockEventType::Disconnected { .. } => {
+                    *self_connection_info = None;
+                }
+                VsockEventType::Received { length } => {
+                    body.get_mut(0..length)
+                        .ok_or(SocketError::OutputBufferTooShort(length))?
+                        .copy_from_slice(borrowed_body);
+                    connection_info.done_forwarding(length);
+                }
+                VsockEventType::CreditRequest => {
+                    // No point sending a credit update until `poll_recv` is called with a buffer,
+                    // as otherwise buf_alloc would just be 0 anyway.
+                }
+                VsockEventType::CreditUpdate => {}
+            }
+
+            Ok(Some(event))
+        })
+    }
+
+    /// Requests to shut down the connection cleanly.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
+    pub fn shutdown(&mut self) -> Result {
+        let connection_info = self
+            .connection_info
+            .as_mut()
+            .ok_or(SocketError::NotConnected)?;
+        connection_info.buf_alloc = 0;
+
+        self.driver.shutdown(connection_info)
+    }
+
+    /// Forcibly closes the connection without waiting for the peer.
+    pub fn force_close(&mut self) -> Result {
+        let connection_info = self
+            .connection_info
+            .as_mut()
+            .ok_or(SocketError::NotConnected)?;
+        connection_info.buf_alloc = 0;
+
+        self.driver.force_close(connection_info)?;
+        self.connection_info = None;
+        Ok(())
+    }
+
+    /// Blocks until the peer either accepts our connection request (with a
+    /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
+    /// `VIRTIO_VSOCK_OP_RST`).
+    pub fn wait_for_connect(&mut self) -> Result {
+        loop {
+            match self.wait_for_recv(&mut [])?.event_type {
+                VsockEventType::Connected => return Ok(()),
+                VsockEventType::Disconnected { .. } => {
+                    return Err(SocketError::ConnectionFailed.into())
+                }
+                VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()),
+                VsockEventType::ConnectionRequest
+                | VsockEventType::CreditRequest
+                | VsockEventType::CreditUpdate => {}
+            }
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{
+        device::socket::{
+            protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
+            vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
+        },
+        hal::fake::FakeHal,
+        transport::{
+            fake::{FakeTransport, QueueStatus, State},
+            DeviceStatus, DeviceType,
+        },
+        volatile::ReadOnly,
+    };
+    use alloc::{sync::Arc, vec};
+    use core::{mem::size_of, ptr::NonNull};
+    use std::{sync::Mutex, thread};
+    use zerocopy::{AsBytes, FromBytes};
+
+    #[test]
+    fn send_recv() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
+        let hello_from_guest = "Hello from guest";
+        let hello_from_host = "Hello from host";
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket = SingleConnectionManager::new(
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+        );
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Wait for connection request.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Accept connection and give the peer enough credit to send the message.
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a credit update.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect the guest to send some data.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            let request = state
+                .lock()
+                .unwrap()
+                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+            assert_eq!(
+                request.len(),
+                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
+            );
+            assert_eq!(
+                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rw.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: (hello_from_guest.len() as u32).into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+            assert_eq!(
+                &request[size_of::<VirtioVsockHdr>()..],
+                hello_from_guest.as_bytes()
+            );
+
+            // Send a response.
+            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
+            VirtioVsockHdr {
+                op: VirtioVsockOp::Rw.into(),
+                src_cid: host_cid.into(),
+                dst_cid: guest_cid.into(),
+                src_port: host_port.into(),
+                dst_port: guest_port.into(),
+                len: (hello_from_host.len() as u32).into(),
+                socket_type: SocketType::Stream.into(),
+                flags: 0.into(),
+                buf_alloc: 50.into(),
+                fwd_cnt: (hello_from_guest.len() as u32).into(),
+            }
+            .write_to_prefix(response.as_mut_slice());
+            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
+            state
+                .lock()
+                .unwrap()
+                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
+
+            // Expect a credit update.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 64.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect a shutdown.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Shutdown.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: (hello_from_host.len() as u32).into(),
+                }
+            );
+        });
+
+        socket.connect(host_address, guest_port).unwrap();
+        socket.wait_for_connect().unwrap();
+        socket.send(hello_from_guest.as_bytes()).unwrap();
+        let mut buffer = [0u8; 64];
+        let event = socket.wait_for_recv(&mut buffer).unwrap();
+        assert_eq!(
+            event,
+            VsockEvent {
+                source: VsockAddr {
+                    cid: host_cid,
+                    port: host_port,
+                },
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::Received {
+                    length: hello_from_host.len()
+                },
+                buffer_status: VsockBufferStatus {
+                    buffer_allocation: 50,
+                    forward_count: hello_from_guest.len() as u32,
+                },
+            }
+        );
+        assert_eq!(
+            &buffer[0..hello_from_host.len()],
+            hello_from_host.as_bytes()
+        );
+        socket.shutdown().unwrap();
+
+        handle.join().unwrap();
+    }
+}

+ 209 - 464
src/device/socket/vsock.rs

@@ -7,27 +7,26 @@ use crate::hal::Hal;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
-use crate::Result;
+use crate::{Error, Result};
 use alloc::boxed::Box;
-use core::hint::spin_loop;
 use core::mem::size_of;
 use core::ptr::{null_mut, NonNull};
-use log::{debug, info};
+use log::debug;
 use zerocopy::{AsBytes, FromBytes};
 
-const RX_QUEUE_IDX: u16 = 0;
-const TX_QUEUE_IDX: u16 = 1;
+pub(crate) const RX_QUEUE_IDX: u16 = 0;
+pub(crate) const TX_QUEUE_IDX: u16 = 1;
 const EVENT_QUEUE_IDX: u16 = 2;
 
-const QUEUE_SIZE: usize = 8;
+pub(crate) const QUEUE_SIZE: usize = 8;
 
-/// The size in bytes of each buffer used in the RX virtqueue.
+/// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
 const RX_BUFFER_SIZE: usize = 512;
 
 #[derive(Clone, Debug, Default, PartialEq, Eq)]
-struct ConnectionInfo {
-    dst: VsockAddr,
-    src_port: u32,
+pub struct ConnectionInfo {
+    pub dst: VsockAddr,
+    pub src_port: u32,
     /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
     /// bytes it has allocated for packet bodies.
     peer_buf_alloc: u32,
@@ -36,6 +35,9 @@ struct ConnectionInfo {
     peer_fwd_cnt: u32,
     /// The number of bytes of packet bodies which we have sent to the peer.
     tx_cnt: u32,
+    /// The number of bytes of buffer space we have allocated to receive packet bodies from the
+    /// peer.
+    pub buf_alloc: u32,
     /// The number of bytes of packet bodies which we have received from the peer and handled.
     fwd_cnt: u32,
     /// Whether we have recently requested credit from the peer.
@@ -46,6 +48,35 @@ struct ConnectionInfo {
 }
 
 impl ConnectionInfo {
+    pub fn new(destination: VsockAddr, src_port: u32) -> Self {
+        Self {
+            dst: destination,
+            src_port,
+            ..Default::default()
+        }
+    }
+
+    /// Updates this connection info with the peer buffer allocation and forwarded count from the
+    /// given event.
+    pub fn update_for_event(&mut self, event: &VsockEvent) {
+        self.peer_buf_alloc = event.buffer_status.buffer_allocation;
+        self.peer_fwd_cnt = event.buffer_status.forward_count;
+
+        if let VsockEventType::CreditUpdate = event.event_type {
+            self.has_pending_credit_request = false;
+        }
+    }
+
+    /// Increases the forwarded count recorded for this connection by the given number of bytes.
+    ///
+    /// This should be called once received data has been passed to the client, so there is buffer
+    /// space available for more.
+    pub fn done_forwarding(&mut self, length: usize) {
+        self.fwd_cnt += length as u32;
+    }
+
+    /// Returns the number of bytes of RX buffer space the peer has available to receive packet body
+    /// data from us.
     fn peer_free(&self) -> u32 {
         self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
     }
@@ -56,6 +87,7 @@ impl ConnectionInfo {
             dst_cid: self.dst.cid.into(),
             src_port: self.src_port.into(),
             dst_port: self.dst.port.into(),
+            buf_alloc: self.buf_alloc.into(),
             fwd_cnt: self.fwd_cnt.into(),
             ..Default::default()
         }
@@ -69,10 +101,77 @@ pub struct VsockEvent {
     pub source: VsockAddr,
     /// The destination of the event, i.e. the CID and port on our side.
     pub destination: VsockAddr,
+    /// The peer's buffer status for the connection.
+    pub buffer_status: VsockBufferStatus,
     /// The type of event.
     pub event_type: VsockEventType,
 }
 
+impl VsockEvent {
+    /// Returns whether the event matches the given connection.
+    pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
+        self.source == connection_info.dst
+            && self.destination.cid == guest_cid
+            && self.destination.port == connection_info.src_port
+    }
+
+    fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
+        let op = header.op()?;
+        let buffer_status = VsockBufferStatus {
+            buffer_allocation: header.buf_alloc.into(),
+            forward_count: header.fwd_cnt.into(),
+        };
+        let source = header.source();
+        let destination = header.destination();
+
+        let event_type = match op {
+            VirtioVsockOp::Request => {
+                header.check_data_is_empty()?;
+                VsockEventType::ConnectionRequest
+            }
+            VirtioVsockOp::Response => {
+                header.check_data_is_empty()?;
+                VsockEventType::Connected
+            }
+            VirtioVsockOp::CreditUpdate => {
+                header.check_data_is_empty()?;
+                VsockEventType::CreditUpdate
+            }
+            VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
+                header.check_data_is_empty()?;
+                debug!("Disconnected from the peer");
+                let reason = if op == VirtioVsockOp::Rst {
+                    DisconnectReason::Reset
+                } else {
+                    DisconnectReason::Shutdown
+                };
+                VsockEventType::Disconnected { reason }
+            }
+            VirtioVsockOp::Rw => VsockEventType::Received {
+                length: header.len() as usize,
+            },
+            VirtioVsockOp::CreditRequest => {
+                header.check_data_is_empty()?;
+                VsockEventType::CreditRequest
+            }
+            VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
+        };
+
+        Ok(VsockEvent {
+            source,
+            destination,
+            buffer_status,
+            event_type,
+        })
+    }
+}
+
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockBufferStatus {
+    pub buffer_allocation: u32,
+    pub forward_count: u32,
+}
+
 /// The reason why a vsock connection was closed.
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum DisconnectReason {
@@ -86,6 +185,8 @@ pub enum DisconnectReason {
 /// Details of the type of an event received from a VirtIO socket.
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub enum VsockEventType {
+    /// The peer requests to establish a connection with us.
+    ConnectionRequest,
     /// The connection was successfully established.
     Connected,
     /// The connection was closed.
@@ -98,9 +199,16 @@ pub enum VsockEventType {
         /// The length of the data in bytes.
         length: usize,
     },
+    /// The peer requests us to send a credit update.
+    CreditRequest,
+    /// The peer just sent us a credit update with nothing else.
+    CreditUpdate,
 }
 
-/// Driver for a VirtIO socket device.
+/// Low-level driver for a VirtIO socket device.
+///
+/// You probably want to use [`VsockConnectionManager`](super::VsockConnectionManager) rather than
+/// using this directly.
 pub struct VirtIOSocket<H: Hal, T: Transport> {
     transport: T,
     /// Virtqueue to receive packets.
@@ -112,9 +220,6 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
     /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
     guest_cid: u64,
     rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
-
-    /// Currently the device is only allowed to be connected to one destination at a time.
-    connection_info: Option<ConnectionInfo>,
 }
 
 impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
@@ -138,19 +243,19 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     pub fn new(mut transport: T) -> Result<Self> {
         transport.begin_init(|features| {
             let features = Feature::from_bits_truncate(features);
-            info!("Device features: {:?}", features);
+            debug!("Device features: {:?}", features);
             // negotiate these flags only
             let supported_features = Feature::empty();
             (features & supported_features).bits()
         });
 
         let config = transport.config_space::<VirtioVsockConfig>()?;
-        info!("config: {:?}", config);
+        debug!("config: {:?}", config);
         // Safe because config is a valid pointer to the device configuration space.
         let guest_cid = unsafe {
             volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
         };
-        info!("guest cid: {guest_cid:?}");
+        debug!("guest cid: {guest_cid:?}");
 
         let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
         let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
@@ -158,13 +263,13 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
 
         // Allocate and add buffers for the RX queue.
         let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
-        for i in 0..QUEUE_SIZE {
+        for (i, rx_queue_buffer) in rx_queue_buffers.iter_mut().enumerate() {
             let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
             // Safe because the buffer lives as long as the queue, as specified in the function
             // safety requirement, and we don't access it until it is popped.
             let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
             assert_eq!(i, token.into());
-            rx_queue_buffers[i] = Box::into_raw(buffer);
+            *rx_queue_buffer = Box::into_raw(buffer);
         }
         let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
 
@@ -180,7 +285,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             event,
             guest_cid,
             rx_queue_buffers,
-            connection_info: None,
         })
     }
 
@@ -191,45 +295,30 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
 
     /// Sends a request to connect to the given destination.
     ///
-    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// This returns as soon as the request is sent; you should wait until `poll` returns a
     /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
     /// before sending data.
-    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
-        if self.connection_info.is_some() {
-            return Err(SocketError::ConnectionExists.into());
-        }
-        let new_connection_info = ConnectionInfo {
-            dst: destination,
-            src_port,
-            ..Default::default()
-        };
+    pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Request.into(),
-            ..new_connection_info.new_header(self.guest_cid)
+            ..connection_info.new_header(self.guest_cid)
         };
-        // Sends a header only packet to the tx queue to connect the device to the listening
-        // socket at the given destination.
-        self.send_packet_to_tx_queue(&header, &[])?;
-
-        self.connection_info = Some(new_connection_info);
-        debug!("Connection requested: {:?}", self.connection_info);
-        Ok(())
+        // Sends a header only packet to the TX queue to connect the device to the listening socket
+        // at the given destination.
+        self.send_packet_to_tx_queue(&header, &[])
     }
 
-    /// Blocks until the peer either accepts our connection request (with a
-    /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
-    /// `VIRTIO_VSOCK_OP_RST`).
-    pub fn wait_for_connect(&mut self) -> Result {
-        match self.wait_for_recv(&mut [])?.event_type {
-            VsockEventType::Connected => Ok(()),
-            VsockEventType::Disconnected { .. } => Err(SocketError::ConnectionFailed.into()),
-            VsockEventType::Received { .. } => Err(SocketError::InvalidOperation.into()),
-        }
+    /// Accepts the given connection from a peer.
+    pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Response.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])
     }
 
-    /// Requests the peer to send us a credit update for the current connection.
-    fn request_credit(&mut self) -> Result {
-        let connection_info = self.connection_info()?;
+    /// Requests the peer to send us a credit update for the given connection.
+    fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditRequest.into(),
             ..connection_info.new_header(self.guest_cid)
@@ -238,21 +327,16 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
 
     /// Sends the buffer to the destination.
-    pub fn send(&mut self, buffer: &[u8]) -> Result {
-        let mut connection_info = self.connection_info()?;
-
-        let result = self.check_peer_buffer_is_sufficient(&mut connection_info, buffer.len());
-        self.connection_info = Some(connection_info.clone());
-        result?;
+    pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
+        self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
 
         let len = buffer.len() as u32;
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rw.into(),
             len: len.into(),
-            buf_alloc: 0.into(),
             ..connection_info.new_header(self.guest_cid)
         };
-        self.connection_info.as_mut().unwrap().tx_cnt += len;
+        connection_info.tx_cnt += len;
         self.send_packet_to_tx_queue(&header, buffer)
     }
 
@@ -267,59 +351,48 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             // Request an update of the cached peer credit, if we haven't already done so, and tell
             // the caller to try again later.
             if !connection_info.has_pending_credit_request {
-                self.request_credit()?;
+                self.request_credit(connection_info)?;
                 connection_info.has_pending_credit_request = true;
             }
             Err(SocketError::InsufficientBufferSpaceInPeer.into())
         }
     }
 
-    /// Polls the vsock device to receive data or other updates.
-    ///
-    /// A buffer must be provided to put the data in if there is some to
-    /// receive.
-    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
-        let connection_info = self.connection_info()?;
-
-        // Tell the peer that we have space to receive some data.
+    /// Tells the peer how much buffer space we have to receive data.
+    pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditUpdate.into(),
-            buf_alloc: (buffer.len() as u32).into(),
             ..connection_info.new_header(self.guest_cid)
         };
-        self.send_packet_to_tx_queue(&header, &[])?;
+        self.send_packet_to_tx_queue(&header, &[])
+    }
 
-        // Handle entries from the RX virtqueue until we find one that generates an event.
-        let event = self.poll_rx_queue(buffer)?;
+    /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
+    /// it.
+    pub fn poll(
+        &mut self,
+        handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
+    ) -> Result<Option<VsockEvent>> {
+        let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
+            return Ok(None);
+        };
 
-        if self.rx.should_notify() {
-            self.transport.notify(RX_QUEUE_IDX);
-        }
+        let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
 
-        Ok(event)
-    }
-
-    /// Blocks until we get some event from the vsock device.
-    ///
-    /// A buffer must be provided to put the data in if there is some to
-    /// receive.
-    pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
-        loop {
-            if let Some(event) = self.poll_recv(buffer)? {
-                return Ok(event);
-            } else {
-                spin_loop();
-            }
+        unsafe {
+            // TODO: What about if both handler and this give errors?
+            self.add_buffer_to_rx_queue(token)?;
         }
+
+        result
     }
 
-    /// Request to shut down the connection cleanly.
+    /// Requests to shut down the connection cleanly.
     ///
-    /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
+    /// This returns as soon as the request is sent; you should wait until `poll` returns a
     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
     /// shutdown.
-    pub fn shutdown(&mut self) -> Result {
-        let connection_info = self.connection_info()?;
+    pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Shutdown.into(),
             ..connection_info.new_header(self.guest_cid)
@@ -328,14 +401,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
 
     /// Forcibly closes the connection without waiting for the peer.
-    pub fn force_close(&mut self) -> Result {
-        let connection_info = self.connection_info()?;
+    pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rst.into(),
             ..connection_info.new_header(self.guest_cid)
         };
         self.send_packet_to_tx_queue(&header, &[])?;
-        self.connection_info = None;
         Ok(())
     }
 
@@ -348,118 +419,39 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    /// Polls the RX virtqueue until either it is empty, there is an error, or we find a packet
-    /// which generates a `VsockEvent`.
+    /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
     ///
-    /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
-    /// don't result in any events to return.
-    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
-        loop {
-            let mut connection_info = self.connection_info.clone().unwrap_or_default();
-            let Some(header) = self.pop_packet_from_rx_queue(body)? else{
-                return Ok(None);
-            };
-
-            let op = header.op()?;
-
-            // Skip packets which don't match our current connection.
-            if header.source() != connection_info.dst
-                || header.dst_cid.get() != self.guest_cid
-                || header.dst_port.get() != connection_info.src_port
-            {
-                debug!(
-                    "Skipping {:?} as connection is {:?}",
-                    header, connection_info
-                );
-                continue;
-            }
-
-            connection_info.peer_buf_alloc = header.buf_alloc.into();
-            connection_info.peer_fwd_cnt = header.fwd_cnt.into();
-            if self.connection_info.is_some() {
-                self.connection_info = Some(connection_info.clone());
-                debug!("Connection info updated: {:?}", self.connection_info);
-            }
+    /// # Safety
+    ///
+    /// The buffer must not currently be in the RX queue, and no other references to it must exist
+    /// between when this method is called and when it is popped from the queue.
+    unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
+        // Safe because the buffer lives as long as the queue, and the caller guarantees that it's
+        // not currently in the queue or referred to anywhere else until it is popped.
+        unsafe {
+            let buffer = self
+                .rx_queue_buffers
+                .get_mut(usize::from(index))
+                .ok_or(Error::WrongToken)?
+                .as_mut();
+            let new_token = self.rx.add(&[], &mut [buffer])?;
+            // If the RX buffer somehow gets assigned a different token, then our safety assumptions
+            // are broken and we can't safely continue to do anything with the device.
+            assert_eq!(new_token, index);
+        }
 
-            match op {
-                VirtioVsockOp::Request => {
-                    header.check_data_is_empty()?;
-                    // TODO: Send a Rst, or support listening.
-                }
-                VirtioVsockOp::Response => {
-                    header.check_data_is_empty()?;
-                    return Ok(Some(VsockEvent {
-                        source: connection_info.dst,
-                        destination: VsockAddr {
-                            cid: self.guest_cid,
-                            port: connection_info.src_port,
-                        },
-                        event_type: VsockEventType::Connected,
-                    }));
-                }
-                VirtioVsockOp::CreditUpdate => {
-                    header.check_data_is_empty()?;
-                    connection_info.has_pending_credit_request = false;
-                    if self.connection_info.is_some() {
-                        self.connection_info = Some(connection_info.clone());
-                    }
-
-                    // Virtio v1.1 5.10.6.3
-                    // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
-                    // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
-                    // any time a change in buffer space occurs.
-                    continue;
-                }
-                VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
-                    header.check_data_is_empty()?;
-
-                    self.connection_info = None;
-                    info!("Disconnected from the peer");
-
-                    let reason = if op == VirtioVsockOp::Rst {
-                        DisconnectReason::Reset
-                    } else {
-                        DisconnectReason::Shutdown
-                    };
-                    return Ok(Some(VsockEvent {
-                        source: connection_info.dst,
-                        destination: VsockAddr {
-                            cid: self.guest_cid,
-                            port: connection_info.src_port,
-                        },
-                        event_type: VsockEventType::Disconnected { reason },
-                    }));
-                }
-                VirtioVsockOp::Rw => {
-                    self.connection_info.as_mut().unwrap().fwd_cnt += header.len();
-                    return Ok(Some(VsockEvent {
-                        source: connection_info.dst,
-                        destination: VsockAddr {
-                            cid: self.guest_cid,
-                            port: connection_info.src_port,
-                        },
-                        event_type: VsockEventType::Received {
-                            length: header.len() as usize,
-                        },
-                    }));
-                }
-                VirtioVsockOp::CreditRequest => {
-                    header.check_data_is_empty()?;
-                    // TODO: Send a credit update.
-                }
-                VirtioVsockOp::Invalid => {
-                    return Err(SocketError::InvalidOperation.into());
-                }
-            }
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
         }
+
+        Ok(())
     }
 
-    /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
-    /// the body into the given buffer.
+    /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
+    /// reference to the buffer containing the body.
     ///
-    /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
-    /// buffer supplied.
-    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
+    /// Returns `None` if there is no pending packet.
+    fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
         let Some(token) = self.rx.peek_used() else {
             return Ok(None);
         };
@@ -467,54 +459,47 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
         // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
         // buffer back to the RX queue then we don't access it again until next time it is popped.
-        let header = unsafe {
+        let (header, body) = unsafe {
             let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
             let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
 
             // Read the header and body from the buffer. Don't check the result yet, because we need
             // to add the buffer back to the queue either way.
-            let header_result = read_header_and_body(buffer, body);
-
-            // Add the buffer back to the RX queue.
-            let new_token = self.rx.add(&[], &mut [buffer])?;
-            // If the RX buffer somehow gets assigned a different token, then our safety assumptions
-            // are broken and we can't safely continue to do anything with the device.
-            assert_eq!(new_token, token);
+            let header_result = read_header_and_body(buffer);
+            if header_result.is_err() {
+                // If there was an error, add the buffer back immediately. Ignore any errors, as we
+                // need to return the first error.
+                let _ = self.add_buffer_to_rx_queue(token);
+            }
 
             header_result
         }?;
 
         debug!("Received packet {:?}. Op {:?}", header, header.op());
-        Ok(Some(header))
-    }
-
-    fn connection_info(&self) -> Result<ConnectionInfo> {
-        self.connection_info
-            .clone()
-            .ok_or(SocketError::NotConnected.into())
+        Ok(Some((header, body, token)))
     }
 }
 
-fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
-    let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
+fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
+    // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
+    let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap();
     let body_length = header.len() as usize;
+
+    // This could fail if the device returns an unreasonably long body length.
     let data_end = size_of::<VirtioVsockHdr>()
         .checked_add(body_length)
         .ok_or(SocketError::InvalidNumber)?;
+    // This could fail if the device returns a body length longer than the buffer we gave it.
     let data = buffer
         .get(size_of::<VirtioVsockHdr>()..data_end)
         .ok_or(SocketError::BufferTooShort)?;
-    body.get_mut(0..body_length)
-        .ok_or(SocketError::OutputBufferTooShort(body_length))?
-        .copy_from_slice(data);
-    Ok(header)
+    Ok((header, data))
 }
 
 #[cfg(test)]
 mod tests {
     use super::*;
     use crate::{
-        device::socket::protocol::SocketType,
         hal::fake::FakeHal,
         transport::{
             fake::{FakeTransport, QueueStatus, State},
@@ -524,7 +509,7 @@ mod tests {
     };
     use alloc::{sync::Arc, vec};
     use core::ptr::NonNull;
-    use std::{sync::Mutex, thread};
+    use std::sync::Mutex;
 
     #[test]
     fn config() {
@@ -554,244 +539,4 @@ mod tests {
             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
         assert_eq!(socket.guest_cid(), 0x00_0000_0042);
     }
-
-    #[test]
-    fn send_recv() {
-        let host_cid = 2;
-        let guest_cid = 66;
-        let host_port = 1234;
-        let guest_port = 4321;
-        let host_address = VsockAddr {
-            cid: host_cid,
-            port: host_port,
-        };
-        let hello_from_guest = "Hello from guest";
-        let hello_from_host = "Hello from host";
-
-        let mut config_space = VirtioVsockConfig {
-            guest_cid_low: ReadOnly::new(66),
-            guest_cid_high: ReadOnly::new(0),
-        };
-        let state = Arc::new(Mutex::new(State {
-            status: DeviceStatus::empty(),
-            driver_features: 0,
-            guest_page_size: 0,
-            interrupt_pending: false,
-            queues: vec![
-                QueueStatus::default(),
-                QueueStatus::default(),
-                QueueStatus::default(),
-            ],
-        }));
-        let transport = FakeTransport {
-            device_type: DeviceType::Socket,
-            max_queue_size: 32,
-            device_features: 0,
-            config_space: NonNull::from(&mut config_space),
-            state: state.clone(),
-        };
-        let mut socket =
-            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
-
-        // Start a thread to simulate the device.
-        let handle = thread::spawn(move || {
-            // Wait for connection request.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Request.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-
-            // Accept connection and give the peer enough credit to send the message.
-            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
-                RX_QUEUE_IDX,
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Response.into(),
-                    src_cid: host_cid.into(),
-                    dst_cid: guest_cid.into(),
-                    src_port: host_port.into(),
-                    dst_port: guest_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 50.into(),
-                    fwd_cnt: 0.into(),
-                }
-                .as_bytes(),
-            );
-
-            // Expect a credit update.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::CreditUpdate.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-
-            // Expect the guest to send some data.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            let request = state
-                .lock()
-                .unwrap()
-                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
-            assert_eq!(
-                request.len(),
-                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
-            );
-            assert_eq!(
-                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Rw.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: (hello_from_guest.len() as u32).into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-            assert_eq!(
-                &request[size_of::<VirtioVsockHdr>()..],
-                hello_from_guest.as_bytes()
-            );
-
-            // Send a response.
-            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
-            VirtioVsockHdr {
-                op: VirtioVsockOp::Rw.into(),
-                src_cid: host_cid.into(),
-                dst_cid: guest_cid.into(),
-                src_port: host_port.into(),
-                dst_port: guest_port.into(),
-                len: (hello_from_host.len() as u32).into(),
-                socket_type: SocketType::Stream.into(),
-                flags: 0.into(),
-                buf_alloc: 50.into(),
-                fwd_cnt: (hello_from_guest.len() as u32).into(),
-            }
-            .write_to_prefix(response.as_mut_slice());
-            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
-            state
-                .lock()
-                .unwrap()
-                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
-
-            // Expect a credit update.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::CreditUpdate.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 64.into(),
-                    fwd_cnt: 0.into(),
-                }
-            );
-
-            // Expect a shutdown.
-            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
-            assert_eq!(
-                VirtioVsockHdr::read_from(
-                    state
-                        .lock()
-                        .unwrap()
-                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
-                        .as_slice()
-                )
-                .unwrap(),
-                VirtioVsockHdr {
-                    op: VirtioVsockOp::Shutdown.into(),
-                    src_cid: guest_cid.into(),
-                    dst_cid: host_cid.into(),
-                    src_port: guest_port.into(),
-                    dst_port: host_port.into(),
-                    len: 0.into(),
-                    socket_type: SocketType::Stream.into(),
-                    flags: 0.into(),
-                    buf_alloc: 0.into(),
-                    fwd_cnt: (hello_from_host.len() as u32).into(),
-                }
-            );
-        });
-
-        socket.connect(host_address, guest_port).unwrap();
-        socket.wait_for_connect().unwrap();
-        socket.send(hello_from_guest.as_bytes()).unwrap();
-        let mut buffer = [0u8; 64];
-        let event = socket.wait_for_recv(&mut buffer).unwrap();
-        assert_eq!(
-            event,
-            VsockEvent {
-                source: VsockAddr {
-                    cid: host_cid,
-                    port: host_port,
-                },
-                destination: VsockAddr {
-                    cid: guest_cid,
-                    port: guest_port,
-                },
-                event_type: VsockEventType::Received {
-                    length: hello_from_host.len()
-                }
-            }
-        );
-        assert_eq!(
-            &buffer[0..hello_from_host.len()],
-            hello_from_host.as_bytes()
-        );
-        socket.shutdown().unwrap();
-
-        handle.join().unwrap();
-    }
 }