Răsfoiți Sursa

Allocate RX buffers individually rather than splitting one big buffer.

Andrew Walbran 1 an în urmă
părinte
comite
9a5cb195c6
1 a modificat fișierele cu 27 adăugiri și 74 ștergeri
  1. 27 74
      src/device/socket/vsock.rs

+ 27 - 74
src/device/socket/vsock.rs

@@ -8,11 +8,11 @@ use crate::hal::Hal;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
-use crate::{Result, PAGE_SIZE};
+use crate::Result;
 use alloc::boxed::Box;
 use core::hint::spin_loop;
 use core::mem::size_of;
-use core::ptr::NonNull;
+use core::ptr::{null_mut, NonNull};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 
@@ -22,6 +22,9 @@ const EVENT_QUEUE_IDX: u16 = 2;
 
 const QUEUE_SIZE: usize = 8;
 
+/// The size in bytes of each buffer used in the RX virtqueue.
+const RX_BUFFER_SIZE: usize = 512;
+
 #[derive(Clone, Debug, Default, PartialEq, Eq)]
 struct ConnectionInfo {
     dst: VsockAddr,
@@ -109,7 +112,7 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
     /// The guest_cid field contains the guest’s context ID, which uniquely identifies
     /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
     guest_cid: u64,
-    rx_buf: NonNull<[u8; PAGE_SIZE]>,
+    rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
 
     /// Currently the device is only allowed to be connected to one destination at a time.
     connection_info: Option<ConnectionInfo>,
@@ -123,9 +126,11 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
         self.transport.queue_unset(TX_QUEUE_IDX);
         self.transport.queue_unset(EVENT_QUEUE_IDX);
 
-        // Safe because we obtained the rx_buf pointer from Box::into_raw, and it won't be used
-        // anywhere else after the driver is destroyed.
-        unsafe { drop(Box::from_raw(self.rx_buf.as_ptr())) };
+        for buffer in self.rx_queue_buffers {
+            // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
+            // used anywhere else after the driver is destroyed.
+            unsafe { drop(Box::from_raw(buffer.as_ptr())) };
+        }
     }
 }
 
@@ -152,14 +157,22 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
         let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
 
-        // Allocates 4 KiB memory for the RX buffer.
-        let rx_buf: NonNull<[u8; PAGE_SIZE]> =
-            NonNull::new(Box::into_raw(FromBytes::new_box_zeroed())).unwrap();
-        // Safe because `rx_buf` lives as long as the `rx` queue.
-        unsafe {
-            Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
+        // Allocate and add buffers for the RX queue.
+        let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
+        for i in 0..QUEUE_SIZE {
+            let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
+            // Safe because the buffer lives as long as the queue, as specified in the function
+            // safety requirement, and we don't access it until it is popped.
+            let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
+            assert_eq!(i, token.into());
+            rx_queue_buffers[i] = Box::into_raw(buffer);
         }
+        let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
+
         transport.finish_init();
+        if rx.should_notify() {
+            transport.notify(RX_QUEUE_IDX);
+        }
 
         Ok(Self {
             transport,
@@ -167,41 +180,11 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             tx,
             event,
             guest_cid,
-            rx_buf,
+            rx_queue_buffers,
             connection_info: None,
         })
     }
 
-    /// Fills the `rx` queue with the buffer `rx_buf`.
-    ///
-    /// # Safety
-    ///
-    /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
-    /// in the queue must not be used anywhere else at the same time.
-    unsafe fn fill_rx_queue(
-        rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
-        rx_buf: NonNull<[u8]>,
-        transport: &mut T,
-    ) -> Result {
-        if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
-            return Err(SocketError::BufferTooShort.into());
-        }
-        for i in 0..QUEUE_SIZE {
-            // Safe because the buffer lives as long as the queue, as specified in the function
-            // safety requirement, and we don't access it until it is popped.
-            unsafe {
-                let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
-                let token = rx.add(&[], &mut [buffer])?;
-                assert_eq!(i, token.into());
-            }
-        }
-
-        if rx.should_notify() {
-            transport.notify(RX_QUEUE_IDX);
-        }
-        Ok(())
-    }
-
     /// Sends a request to connect to the given destination.
     ///
     /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
@@ -481,7 +464,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
         // buffer back to the RX queue then we don't access it again until next time it is popped.
         let header = unsafe {
-            let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf, token.into())?;
+            let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
             let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
 
             // Read the header and body from the buffer. Don't check the result yet, because we need
@@ -506,36 +489,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             .clone()
             .ok_or(SocketError::NotConnected.into())
     }
-
-    /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
-    /// virtqueue.
-    ///
-    /// # Safety
-    ///
-    /// `rx_buf` must be a valid dereferenceable pointer.
-    /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
-    /// any other references to the same subslice of the RX buffer or outlive the buffer.
-    unsafe fn as_mut_sub_rx_buffer<'a>(
-        mut rx_buf: NonNull<[u8]>,
-        i: usize,
-    ) -> Result<&'a mut [u8]> {
-        let buffer_size = rx_buf.len() / QUEUE_SIZE;
-        let start = buffer_size
-            .checked_mul(i)
-            .ok_or(SocketError::InvalidNumber)?;
-        let end = start
-            .checked_add(buffer_size)
-            .ok_or(SocketError::InvalidNumber)?;
-        // Safe because no alignment or initialisation is required for [u8], and our caller assures
-        // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
-        // won't overlap with any other references to the same slice or outlive it.
-        unsafe {
-            rx_buf
-                .as_mut()
-                .get_mut(start..end)
-                .ok_or(SocketError::BufferTooShort.into())
-        }
-    }
 }
 
 fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {