Quellcode durchsuchen

[vsock] Implement credit request/shutdown/send/recv (#71)

* [vsock] Implement credit request and shutdown

* [vsock] Implement send and recv

recv works; send doesn't work yet.

* Update send()

* Send works OK

* Add vsock_server README.md

* Adjust shutdown()

* Add wait_one()

* Fix clippy

* Adjust error in recv/send

* Change pop to poll_and_filter

* Adjust vsock_server

* Remove cargo clean

* Minor adjustment

* Post the local buffer space to host when connecting

* Fix cargo fmt

* Do not check local buffer when transmitting data

* Make Addr derive Eq, PartialEq

* Don't expose request_credit.

* Fix flow control.

Connection RX buffers are not the same as VirtIO RX queue buffers. We
also need to keep track of the bytes we have sent and received.

* Only request credit if we don't already have enough.

If there's still no space then return an error.

* ConnectionInfo shouldn't be Copy.

* Skip packets which don't match our current connection.

* Add helper method to make header for connection.

* Recycle rx buffer

* Remove rx_buf size from new() as recv takes additional buffer

* Fix clippy suggestion

* Recycle the rx_buffer even if the packet doesn't match

* Don't keep references to buffers from the RX queue around.

This lets us recycle RX buffers soundly in pop_packet_from_rx_queue, and
simplifies lifetimes in other places too.

* Simplify error handling.

* Use debug rather than trace for warning about skipping.

* TODO

---------

Co-authored-by: Alice Wang <[email protected]>
Co-authored-by: Andrew Walbran <[email protected]>
Alice Wang vor 1 Jahr
Ursprung
Commit
8a19fdffc2

+ 15 - 0
examples/aarch64/Makefile

@@ -56,6 +56,21 @@ header: kernel
 clean:
 	cargo clean
 
+# This target is used to test the vsock driver manually. See vsock_server/README.md
+# for more information.
+qemu-vsock: $(kernel_qemu_bin) $(img)
+	qemu-system-aarch64 \
+	  $(QEMU_ARGS) \
+		-machine virt \
+		-cpu max \
+		-serial chardev:char0 \
+		-kernel $(kernel_qemu_bin) \
+		-global virtio-mmio.force-legacy=false \
+		-nic none \
+		-drive file=$(img),if=none,format=raw,id=x0 \
+		-device vhost-vsock-device,id=virtiosocket0,guest-cid=102 \
+		-chardev stdio,id=char0,mux=on
+
 qemu: $(kernel_qemu_bin) $(img) $(vsock_server_bin)
 	$(vsock_server_bin) &
 	qemu-system-aarch64 \

+ 26 - 7
examples/aarch64/src/main.rs

@@ -116,7 +116,10 @@ fn virtio_device(transport: impl Transport) {
         DeviceType::GPU => virtio_gpu(transport),
         // DeviceType::Network => virtio_net(transport), // currently is unsupported without alloc
         DeviceType::Console => virtio_console(transport),
-        DeviceType::Socket => virtio_socket(transport),
+        DeviceType::Socket => match virtio_socket(transport) {
+            Ok(()) => info!("virtio-socket test finished successfully"),
+            Err(e) => error!("virtio-socket test finished with error '{e:?}'"),
+        },
         t => warn!("Unrecognized virtio device: {:?}", t),
     }
 }
@@ -179,18 +182,34 @@ fn virtio_console<T: Transport>(transport: T) {
     info!("virtio-console test finished");
 }
 
-fn virtio_socket<T: Transport>(transport: T) {
+fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
     let mut socket =
         VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver");
     let host_cid = 2;
     let port = 1221;
     info!("Connecting to host on port {port}...");
-    if let Err(e) = socket.connect(host_cid, port, port) {
-        error!("Failed to connect to host: {:?}", e);
-    } else {
-        info!("Connected to host on port {port} successfully.")
+    socket.connect(host_cid, port, port)?;
+    info!("Connected to the host");
+
+    const EXCHANGE_NUM: usize = 2;
+    let messages = ["0-Ack. Hello from guest.", "1-Ack. Received again."];
+    for k in 0..EXCHANGE_NUM {
+        let mut buffer = [0u8; 24];
+        let len = socket.recv(&mut buffer)?;
+        info!(
+            "Received message: {:?}({:?}), len: {:?}",
+            buffer,
+            core::str::from_utf8(&buffer[..len]),
+            len
+        );
+
+        let message = messages[k % messages.len()];
+        socket.send(message.as_bytes())?;
+        info!("Sent message: {:?}", message);
     }
-    info!("VirtIO socket test finished");
+    socket.shutdown()?;
+    info!("Shutdown the connection");
+    Ok(())
 }
 
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]

+ 48 - 0
examples/vsock_server/README.md

@@ -0,0 +1,48 @@
+# Running Virtio Vsock Example
+
+The binary `vsock_server` sets up a vsock server on the host. It can be used to run the virtio vsock example in `examples/aarch64`
+
+## Build and Run the Example
+
+Run the server on the host:
+```bash
+examples/vsock_server$ cargo run
+```
+
+Run the guest:
+```bash
+examples/aarch64$ make qemu-vsock
+```
+
+## Sample Log
+
+The example demonstrates two rounds of message exchange between the host and the guest.
+
+Host:
+```
+[Host] Setting up listening socket on port 1221
+[Host] Accept connection: VsockStream { socket: 4 }, peer addr: Ok(cid: 102 port: 1221), local addr: Ok(cid: 2 port: 1221)
+[Host] Sent message: "0-Hello from host".
+[Host] Flushed.
+[Host] Received message: [48, 45, 65, 99, 107, 46, 32, 72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 103, 117, 101, 115, 116, 46, 0, 0, 0, 0, 0, 0](Ok("0-Ack. Hello from guest.")), len: 24
+[Host] Sent message: "1-Hello from host".
+[Host] Flushed.
+[Host] Received message: [49, 45, 65, 99, 107, 46, 32, 82, 101, 99, 101, 105, 118, 101, 100, 32, 97, 103, 97, 105, 110, 46, 0, 0, 0, 0, 0, 0, 0, 0](Ok("1-Ack. Received again.")), len: 22
+[Host] End.
+```
+
+Guest:
+```
+[INFO] guest cid: 102
+[INFO] Connecting to host on port 1221...
+[DEBUG] Connection established: Some(ConnectionInfo { dst: VsockAddr { cid: 2, port: 1221 }, src_port: 1221, peer_buf_alloc: 0, peer_fwd_cnt: 0 })
+[INFO] Connected to the host
+...
+[INFO] Received message: [48, 45, 72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 104, 111, 115, 116, 0, 0, 0, 0, 0, 0, 0](Ok("0-Hello from host")), len: 17
+[DEBUG] Connection info updated: Some(ConnectionInfo { dst: VsockAddr { cid: 2, port: 1221 }, src_port: 1221, peer_buf_alloc: 262144, peer_fwd_cnt: 0 })
+[INFO] Sent message: "0-Ack. Hello from guest."
+[INFO] Received message: [49, 45, 72, 101, 108, 108, 111, 32, 102, 114, 111, 109, 32, 104, 111, 115, 116, 0, 0, 0, 0, 0, 0, 0](Ok("1-Hello from host")), len: 17
+[INFO] Sent message: "1-Ack. Received again."
+[INFO] Disconnected from the peer
+[INFO] Shutdown the connection
+```

+ 50 - 4
examples/vsock_server/src/main.rs

@@ -1,13 +1,59 @@
-// Sets a listening socket on host.
+//! Sets up a listening socket on host.
+use std::{
+    io::{Read, Write},
+    time::Duration,
+};
 use vsock::{VsockAddr, VsockListener, VMADDR_CID_HOST};
 
 const PORT: u32 = 1221;
 
 fn main() {
-    println!("Setting up listening socket on port {PORT}");
+    println!("[Host] Setting up listening socket on port {PORT}");
     let listener = VsockListener::bind(&VsockAddr::new(VMADDR_CID_HOST, PORT))
         .expect("Failed to set up listening port");
-    for incoming in listener.incoming() {
-        println!("Accept connection: {incoming:?}");
+
+    let Some(Ok(mut vsock_stream)) = listener.incoming().next() else {
+        println!("[Host] Failed to get vsock_stream");
+        return;
+    };
+    println!(
+        "[Host] Accept connection: {:?}, peer addr: {:?}, local addr: {:?}",
+        vsock_stream,
+        vsock_stream.peer_addr(),
+        vsock_stream.local_addr()
+    );
+
+    const EXCHANGE_NUM: usize = 2;
+    for k in 0..EXCHANGE_NUM {
+        let message = &format!("{k}-Hello from host");
+        vsock_stream
+            .write_all(message.as_bytes())
+            .expect("write_all");
+        println!("[Host] Sent message: {:?}.", message);
+        vsock_stream.flush().expect("flush");
+        println!("[Host] Flushed.");
+
+        let mut message = vec![0u8; 30];
+        vsock_stream
+            .set_read_timeout(Some(Duration::from_millis(3_000)))
+            .expect("set_read_timeout");
+        for i in 0..10 {
+            match vsock_stream.read(&mut message) {
+                Ok(len) => {
+                    println!(
+                        "[Host] Received message: {:?}({:?}), len: {:?}",
+                        message,
+                        std::str::from_utf8(&message[..len]),
+                        len,
+                    );
+                    break;
+                }
+                Err(e) => {
+                    println!("{i} {e:?}");
+                    std::thread::sleep(Duration::from_millis(200))
+                }
+            }
+        }
     }
+    println!("[Host] End.");
 }

+ 36 - 1
src/device/socket/error.rs

@@ -5,28 +5,63 @@ use core::{fmt, result};
 /// The error type of VirtIO socket driver.
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum SocketError {
+    /// There is an existing connection.
+    ConnectionExists,
     /// Failed to establish the connection.
     ConnectionFailed,
+    /// The device is not connected to any peer.
+    NotConnected,
+    /// Peer socket is shutdown.
+    PeerSocketShutdown,
     /// No response received.
     NoResponseReceived,
     /// The given buffer is shorter than expected.
     BufferTooShort,
+    /// The given buffer for output is shorter than expected.
+    OutputBufferTooShort(usize),
+    /// The given buffer has exceeded the maximum buffer size.
+    BufferTooLong(usize, usize),
     /// Unknown operation.
     UnknownOperation(u16),
     /// Invalid operation,
     InvalidOperation,
+    /// Invalid number.
+    InvalidNumber,
+    /// Unexpected data in packet.
+    UnexpectedDataInPacket,
+    /// Peer has insufficient buffer space, try again later.
+    InsufficientBufferSpaceInPeer,
+    /// Recycled a wrong buffer.
+    RecycledWrongBuffer,
 }
 
 impl fmt::Display for SocketError {
     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
         match self {
-            Self::ConnectionFailed => write!(f, "Failed to establish the connection"),
+            Self::ConnectionExists => write!(
+                f,
+                "There is an existing connection. Please close the current connection before attempting to connect again."),
+            Self::ConnectionFailed => write!(
+                f, "Failed to establish the connection. The packet sent may have an unknown type value"
+            ),
+            Self::NotConnected => write!(f, "The device is not connected to any peer. Please connect it to a peer first."),
+            Self::PeerSocketShutdown => write!(f, "The peer socket is shutdown."),
             Self::NoResponseReceived => write!(f, "No response received"),
             Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"),
+            Self::BufferTooLong(actual, max) => {
+                write!(f, "The given buffer length '{actual}' has exceeded the maximum allowed buffer length '{max}'")
+            }
+            Self::OutputBufferTooShort(expected) => {
+                write!(f, "The given output buffer is too short. '{expected}' bytes is needed for the output buffer.")
+            }
             Self::UnknownOperation(op) => {
                 write!(f, "The operation code '{op}' is unknown")
             }
             Self::InvalidOperation => write!(f, "Invalid operation"),
+            Self::InvalidNumber => write!(f, "Invalid number"),
+            Self::UnexpectedDataInPacket => write!(f, "No data is expected in the packet"),
+            Self::InsufficientBufferSpaceInPeer => write!(f, "Peer has insufficient buffer space, try again later"),
+            Self::RecycledWrongBuffer => write!(f, "Recycled a wrong buffer"),
         }
     }
 }

+ 38 - 16
src/device/socket/protocol.rs

@@ -5,7 +5,6 @@ use crate::volatile::ReadOnly;
 use core::{
     convert::{TryFrom, TryInto},
     fmt,
-    mem::size_of,
 };
 use zerocopy::{
     byteorder::{LittleEndian, U16, U32, U64},
@@ -51,7 +50,9 @@ pub struct VirtioVsockHdr {
     pub socket_type: U16<LittleEndian>,
     pub op: U16<LittleEndian>,
     pub flags: U32<LittleEndian>,
+    /// Total receive buffer space for this socket. This includes both free and in-use buffers.
     pub buf_alloc: U32<LittleEndian>,
+    /// Free-running bytes received counter.
     pub fwd_cnt: U32<LittleEndian>,
 }
 
@@ -72,25 +73,46 @@ impl Default for VirtioVsockHdr {
     }
 }
 
-#[derive(Clone, Debug)]
-pub struct VirtioVsockPacket<'a> {
-    pub hdr: VirtioVsockHdr,
-    pub data: &'a [u8],
-}
-
-impl<'a> VirtioVsockPacket<'a> {
-    pub fn read_from(buffer: &'a [u8]) -> error::Result<Self> {
-        let hdr = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
-        let data_end = size_of::<VirtioVsockHdr>() + (hdr.len.get() as usize);
-        let data = buffer
-            .get(size_of::<VirtioVsockHdr>()..data_end)
-            .ok_or(SocketError::BufferTooShort)?;
-        Ok(Self { hdr, data })
+impl VirtioVsockHdr {
+    /// Returns the length of the data.
+    pub fn len(&self) -> u32 {
+        u32::from(self.len)
     }
 
     pub fn op(&self) -> error::Result<VirtioVsockOp> {
-        self.hdr.op.try_into()
+        self.op.try_into()
     }
+
+    pub fn source(&self) -> VsockAddr {
+        VsockAddr {
+            cid: self.src_cid.get(),
+            port: self.src_port.get(),
+        }
+    }
+
+    pub fn destination(&self) -> VsockAddr {
+        VsockAddr {
+            cid: self.dst_cid.get(),
+            port: self.dst_port.get(),
+        }
+    }
+
+    pub fn check_data_is_empty(&self) -> error::Result<()> {
+        if self.len() == 0 {
+            Ok(())
+        } else {
+            Err(SocketError::UnexpectedDataInPacket)
+        }
+    }
+}
+
+/// Socket address.
+#[derive(Copy, Clone, Debug, Default, PartialEq, Eq)]
+pub struct VsockAddr {
+    /// Context Identifier.
+    pub cid: u64,
+    /// Port number.
+    pub port: u32,
 }
 
 /// An event sent to the event queue

+ 385 - 37
src/device/socket/vsock.rs

@@ -1,24 +1,60 @@
 //! Driver for VirtIO socket devices.
+#![deny(unsafe_op_in_unsafe_fn)]
 
 use super::error::SocketError;
-use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VirtioVsockPacket};
+use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
 use crate::device::common::Feature;
 use crate::hal::{BufferDirection, Dma, Hal};
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
 use crate::Result;
-use log::{info, trace};
-use zerocopy::AsBytes;
+use core::ptr::NonNull;
+use core::{convert::TryFrom, mem::size_of};
+use log::{debug, info};
+use zerocopy::{AsBytes, FromBytes};
 
 const RX_QUEUE_IDX: u16 = 0;
 const TX_QUEUE_IDX: u16 = 1;
 const EVENT_QUEUE_IDX: u16 = 2;
 
-const QUEUE_SIZE: usize = 2;
+const QUEUE_SIZE: usize = 8;
+
+#[derive(Clone, Debug, Default, PartialEq, Eq)]
+struct ConnectionInfo {
+    dst: VsockAddr,
+    src_port: u32,
+    /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
+    /// bytes it has allocated for packet bodies.
+    peer_buf_alloc: u32,
+    /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
+    /// has finished processing.
+    peer_fwd_cnt: u32,
+    /// The number of bytes of packet bodies which we have sent to the peer.
+    tx_cnt: u32,
+    /// The number of bytes of packet bodies which we have received from the peer and handled.
+    fwd_cnt: u32,
+}
+
+impl ConnectionInfo {
+    fn peer_free(&self) -> u32 {
+        self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
+    }
+
+    fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
+        VirtioVsockHdr {
+            src_cid: src_cid.into(),
+            dst_cid: self.dst.cid.into(),
+            src_port: self.src_port.into(),
+            dst_port: self.dst.port.into(),
+            fwd_cnt: self.fwd_cnt.into(),
+            ..Default::default()
+        }
+    }
+}
 
 /// Driver for a VirtIO socket device.
-pub struct VirtIOSocket<'a, H: Hal, T: Transport> {
+pub struct VirtIOSocket<H: Hal, T: Transport> {
     transport: T,
     /// Virtqueue to receive packets.
     rx: VirtQueue<H, { QUEUE_SIZE }>,
@@ -28,11 +64,13 @@ pub struct VirtIOSocket<'a, H: Hal, T: Transport> {
     /// The guest_cid field contains the guest’s context ID, which uniquely identifies
     /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
     guest_cid: u64,
-    queue_buf_dma: Dma<H>,
-    queue_buf_rx: &'a mut [u8],
+    rx_buf_dma: Dma<H>,
+
+    /// Currently the device is only allowed to be connected to one destination at a time.
+    connection_info: Option<ConnectionInfo>,
 }
 
-impl<'a, H: Hal, T: Transport> Drop for VirtIOSocket<'a, H, T> {
+impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
     fn drop(&mut self) {
         // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
         // after they have been freed.
@@ -42,7 +80,7 @@ impl<'a, H: Hal, T: Transport> Drop for VirtIOSocket<'a, H, T> {
     }
 }
 
-impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
+impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// Create a new VirtIO Vsock driver.
     pub fn new(mut transport: T) -> Result<Self> {
         transport.begin_init(|features| {
@@ -65,18 +103,15 @@ impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
         let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
         let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
 
-        let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?;
-
-        // Safe because no alignment or initialisation is required for [u8], the DMA buffer is
-        // dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer
-        // (which we don't otherwise access).
-        let queue_buf_rx = unsafe { queue_buf_dma.raw_slice().as_mut() };
-
-        // Safe because the buffer lives as long as the queue.
-        let _token = unsafe { rx.add(&[], &mut [queue_buf_rx])? };
-
-        if rx.should_notify() {
-            transport.notify(RX_QUEUE_IDX);
+        // Allocates 4 KiB memory as the rx buffer.
+        let rx_buf_dma = Dma::new(
+            1, // pages
+            BufferDirection::DeviceToDriver,
+        )?;
+        let rx_buf = rx_buf_dma.raw_slice();
+        // Safe because `rx_buf` lives as long as the `rx` queue.
+        unsafe {
+            Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
         }
         transport.finish_init();
 
@@ -86,13 +121,46 @@ impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
             tx,
             event,
             guest_cid,
-            queue_buf_dma,
-            queue_buf_rx,
+            rx_buf_dma,
+            connection_info: None,
         })
     }
 
-    /// Connect to the destination.
+    /// Fills the `rx` queue with the buffer `rx_buf`.
+    ///
+    /// # Safety
+    ///
+    /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
+    /// in the queue must not be used anywhere else at the same time.
+    unsafe fn fill_rx_queue(
+        rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
+        rx_buf: NonNull<[u8]>,
+        transport: &mut T,
+    ) -> Result {
+        if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
+            return Err(SocketError::BufferTooShort.into());
+        }
+        for i in 0..QUEUE_SIZE {
+            // Safe because the buffer lives as long as the queue, as specified in the function
+            // safety requirement, and we don't access it until it is popped.
+            unsafe {
+                let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
+                let token = rx.add(&[], &mut [buffer])?;
+                assert_eq!(i, token.into());
+            }
+        }
+
+        if rx.should_notify() {
+            transport.notify(RX_QUEUE_IDX);
+        }
+        Ok(())
+    }
+
+    /// Connects to the destination.
     pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
+        if self.connection_info.is_some() {
+            return Err(SocketError::ConnectionExists.into());
+        }
         let header = VirtioVsockHdr {
             src_cid: self.guest_cid.into(),
             dst_cid: dst_cid.into(),
@@ -101,26 +169,306 @@ impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
             op: VirtioVsockOp::Request.into(),
             ..Default::default()
         };
-        self.tx
-            .add_notify_wait_pop(&[header.as_bytes(), &[]], &mut [], &mut self.transport)?;
-        let token = if let Some(token) = self.rx.peek_used() {
-            token // TODO: Use let else after updating Rust
+        // Sends a header only packet to the tx queue to connect the device to the listening
+        // socket at the given destination.
+        self.send_packet_to_tx_queue(&header, &[])?;
+
+        let dst = VsockAddr {
+            cid: dst_cid,
+            port: dst_port,
+        };
+        self.connection_info.replace(ConnectionInfo {
+            dst,
+            src_port,
+            ..Default::default()
+        });
+        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Response], &mut [], |header| {
+            header.check_data_is_empty().map_err(|e| e.into())
+        })?;
+        debug!("Connection established: {:?}", self.connection_info);
+        Ok(())
+    }
+
+    /// Requests the credit and updates the peer credit in the current connection info.
+    fn request_credit(&mut self) -> Result {
+        let connection_info = self.connection_info()?;
+        let header = VirtioVsockHdr {
+            src_cid: self.guest_cid.into(),
+            dst_cid: connection_info.dst.cid.into(),
+            src_port: connection_info.src_port.into(),
+            dst_port: connection_info.dst.port.into(),
+            op: VirtioVsockOp::CreditRequest.into(),
+            ..Default::default()
+        };
+        self.send_packet_to_tx_queue(&header, &[])?;
+        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::CreditUpdate], &mut [], |_| {
+            Ok(())
+        })
+    }
+
+    /// Sends the buffer to the destination.
+    pub fn send(&mut self, buffer: &[u8]) -> Result {
+        self.check_peer_buffer_is_sufficient(buffer.len())?;
+
+        let connection_info = self.connection_info()?;
+        let len = buffer.len() as u32;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Rw.into(),
+            len: len.into(),
+            buf_alloc: 0.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.connection_info.as_mut().unwrap().tx_cnt += len;
+        self.send_packet_to_tx_queue(&header, buffer)
+    }
+
+    fn check_peer_buffer_is_sufficient(&mut self, buffer_len: usize) -> Result {
+        if usize::try_from(self.connection_info()?.peer_free())
+            .map_err(|_| SocketError::InvalidNumber)?
+            >= buffer_len
+        {
+            Ok(())
         } else {
+            // Update cached peer credit and try again.
+            self.request_credit()?;
+            if usize::try_from(self.connection_info()?.peer_free())
+                .map_err(|_| SocketError::InvalidNumber)?
+                >= buffer_len
+            {
+                Ok(())
+            } else {
+                Err(SocketError::InsufficientBufferSpaceInPeer.into())
+            }
+        }
+    }
+
+    /// Receives the buffer from the destination.
+    /// Returns the actual size of the message.
+    pub fn recv(&mut self, buffer: &mut [u8]) -> Result<usize> {
+        let connection_info = self.connection_info()?;
+
+        // Tell the peer that we have space to recieve some data.
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::CreditUpdate.into(),
+            buf_alloc: (buffer.len() as u32).into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])?;
+
+        // Wait to receive some data.
+        let mut len: u32 = 0;
+        self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Rw], buffer, |header| {
+            len = header.len();
+            Ok(())
+        })?;
+        self.connection_info.as_mut().unwrap().fwd_cnt += len;
+        Ok(len as usize)
+    }
+
+    /// Shuts down the connection.
+    pub fn shutdown(&mut self) -> Result {
+        let connection_info = self.connection_info()?;
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Shutdown.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])?;
+        self.poll_and_filter_packet_from_rx_queue(
+            &[VirtioVsockOp::Rst, VirtioVsockOp::Shutdown],
+            &mut [],
+            |_| Ok(()),
+        )?;
+        Ok(())
+    }
+
+    fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
+        // TODO: Virtio v1.1 5.10.6.1.1 The rx virtqueue MUST be processed even when the tx virtqueue is full.
+        let _len = self.tx.add_notify_wait_pop(
+            &[header.as_bytes(), buffer],
+            &mut [],
+            &mut self.transport,
+        )?;
+        Ok(())
+    }
+
+    fn poll_and_filter_packet_from_rx_queue<F>(
+        &mut self,
+        accepted_ops: &[VirtioVsockOp],
+        body: &mut [u8],
+        f: F,
+    ) -> Result
+    where
+        F: FnOnce(&VirtioVsockHdr) -> Result,
+    {
+        let our_cid = self.guest_cid;
+        let mut result = Ok(());
+        loop {
+            self.wait_one_in_rx_queue();
+            let mut connection_info = self.connection_info.clone().unwrap_or_default();
+            let header = self.pop_packet_from_rx_queue(body)?;
+            let op = header.op()?;
+
+            // Skip packets which don't match our current connection.
+            if header.source() != connection_info.dst
+                || header.dst_cid.get() != our_cid
+                || header.dst_port.get() != connection_info.src_port
+            {
+                debug!(
+                    "Skipping {:?} as connection is {:?}",
+                    header, connection_info
+                );
+                continue;
+            }
+
+            match op {
+                VirtioVsockOp::CreditUpdate => {
+                    header.check_data_is_empty()?;
+
+                    connection_info.peer_buf_alloc = header.buf_alloc.into();
+                    connection_info.peer_fwd_cnt = header.fwd_cnt.into();
+                    self.connection_info.replace(connection_info);
+                    debug!("Connection info updated: {:?}", self.connection_info);
+
+                    if accepted_ops.contains(&op) {
+                        break;
+                    } else {
+                        // Virtio v1.1 5.10.6.3
+                        // The driver can also receive a VIRTIO_VSOCK_OP_CREDIT_UPDATE packet without previously
+                        // sending a VIRTIO_VSOCK_OP_CREDIT_REQUEST packet. This allows communicating updates
+                        // any time a change in buffer space occurs.
+                        continue;
+                    }
+                }
+                VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
+                    header.check_data_is_empty()?;
+
+                    self.connection_info.take();
+                    info!("Disconnected from the peer");
+                    if accepted_ops.contains(&op) {
+                    } else if op == VirtioVsockOp::Rst {
+                        result = Err(SocketError::ConnectionFailed.into());
+                    } else {
+                        assert_eq!(VirtioVsockOp::Shutdown, op);
+                        result = Err(SocketError::PeerSocketShutdown.into());
+                    }
+                    break;
+                }
+                // TODO: Update peer_buf_alloc and peer_fwd_cnt for other packets too.
+                x if accepted_ops.contains(&x) => {
+                    f(&header)?;
+                    break;
+                }
+                _ => {
+                    result = Err(SocketError::InvalidOperation.into());
+                    break;
+                }
+            };
+        }
+
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
+        }
+        result
+    }
+
+    /// Waits until there is at least one element to pop in rx queue.
+    fn wait_one_in_rx_queue(&mut self) {
+        const TIMEOUT_N: usize = 1000000;
+        for _ in 0..TIMEOUT_N {
+            if self.rx.can_pop() {
+                break;
+            } else {
+                core::hint::spin_loop();
+            }
+        }
+    }
+
+    /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
+    /// the body into the given buffer.
+    ///
+    /// Returns an error if there is no pending packet, or the body is bigger than the buffer
+    /// supplied.
+    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<VirtioVsockHdr> {
+        let Some(token) = self.rx.peek_used() else {
             return Err(SocketError::NoResponseReceived.into());
         };
-        // Safe because we are passing the same buffer as we passed to `VirtQueue::add`.
-        let _len = unsafe { self.rx.pop_used(token, &[], &mut [self.queue_buf_rx])? };
-        let packet_rx = VirtioVsockPacket::read_from(self.queue_buf_rx)?;
-        trace!("Received packet {:?}. Op {:?}", packet_rx, packet_rx.op());
-        match packet_rx.op()? {
-            VirtioVsockOp::Response => Ok(()),
-            VirtioVsockOp::Rst => Err(SocketError::ConnectionFailed.into()),
-            VirtioVsockOp::Invalid => Err(SocketError::InvalidOperation.into()),
-            _ => todo!(),
+
+        // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
+        // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
+        // buffer back to the RX queue then we don't access it again until next time it is popped.
+        let header = unsafe {
+            let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf_dma.raw_slice(), token.into())?;
+            let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
+
+            // Read the header and body from the buffer. Don't check the result yet, because we need
+            // to add the buffer back to the queue either way.
+            let header_result = read_header_and_body(buffer, body);
+
+            // Add the buffer back to the RX queue.
+            let new_token = self.rx.add(&[], &mut [buffer])?;
+            // If the RX buffer somehow gets assigned a different token, then our safety assumptions
+            // are broken and we can't safely continue to do anything with the device.
+            assert_eq!(new_token, token);
+
+            header_result
+        }?;
+
+        debug!("Received packet {:?}. Op {:?}", header, header.op());
+        Ok(header)
+    }
+
+    fn connection_info(&self) -> Result<ConnectionInfo> {
+        self.connection_info
+            .clone()
+            .ok_or(SocketError::NotConnected.into())
+    }
+
+    /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
+    /// virtqueue.
+    ///
+    /// # Safety
+    ///
+    /// `rx_buf` must be a valid dereferenceable pointer.
+    /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
+    /// any other references to the same subslice of the RX buffer or outlive the buffer.
+    unsafe fn as_mut_sub_rx_buffer<'a>(
+        mut rx_buf: NonNull<[u8]>,
+        i: usize,
+    ) -> Result<&'a mut [u8]> {
+        let buffer_size = rx_buf.len() / QUEUE_SIZE;
+        let start = buffer_size
+            .checked_mul(i)
+            .ok_or(SocketError::InvalidNumber)?;
+        let end = start
+            .checked_add(buffer_size)
+            .ok_or(SocketError::InvalidNumber)?;
+        // Safe because no alignment or initialisation is required for [u8], and our caller assures
+        // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
+        // won't overlap with any other references to the same slice or outlive it.
+        unsafe {
+            rx_buf
+                .as_mut()
+                .get_mut(start..end)
+                .ok_or(SocketError::BufferTooShort.into())
         }
     }
 }
 
+fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
+    let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
+    let body_length = header.len() as usize;
+    let data_end = size_of::<VirtioVsockHdr>()
+        .checked_add(body_length)
+        .ok_or(SocketError::InvalidNumber)?;
+    let data = buffer
+        .get(size_of::<VirtioVsockHdr>()..data_end)
+        .ok_or(SocketError::BufferTooShort)?;
+    body.get_mut(0..body_length)
+        .ok_or(SocketError::OutputBufferTooShort(body_length))?
+        .copy_from_slice(data);
+    Ok(header)
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;