Jelajahi Sumber

Merge pull request #90 from rcore-os/vsockcleanup

Minor improvements to vsock driver
chyyuu 1 tahun lalu
induk
melakukan
4d7038f214
3 mengubah file dengan 47 tambahan dan 81 penghapusan
  1. 8 2
      examples/aarch64/src/main.rs
  2. 1 0
      src/device/socket/mod.rs
  3. 38 79
      src/device/socket/vsock.rs

+ 8 - 2
examples/aarch64/src/main.rs

@@ -30,7 +30,7 @@ use virtio_drivers::{
         blk::VirtIOBlk,
         console::VirtIOConsole,
         gpu::VirtIOGpu,
-        socket::{VirtIOSocket, VsockEventType},
+        socket::{VirtIOSocket, VsockAddr, VsockEventType},
     },
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
@@ -209,7 +209,13 @@ fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
     let host_cid = 2;
     let port = 1221;
     info!("Connecting to host on port {port}...");
-    socket.connect(host_cid, port, port)?;
+    socket.connect(
+        VsockAddr {
+            cid: host_cid,
+            port,
+        },
+        port,
+    )?;
     socket.wait_for_connect()?;
     info!("Connected to the host");
 

+ 1 - 0
src/device/socket/mod.rs

@@ -6,5 +6,6 @@ mod protocol;
 mod vsock;
 
 pub use error::SocketError;
+pub use protocol::VsockAddr;
 #[cfg(feature = "alloc")]
 pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};

+ 38 - 79
src/device/socket/vsock.rs

@@ -7,11 +7,11 @@ use crate::hal::Hal;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
-use crate::{Result, PAGE_SIZE};
+use crate::Result;
 use alloc::boxed::Box;
 use core::hint::spin_loop;
 use core::mem::size_of;
-use core::ptr::NonNull;
+use core::ptr::{null_mut, NonNull};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 
@@ -21,6 +21,9 @@ const EVENT_QUEUE_IDX: u16 = 2;
 
 const QUEUE_SIZE: usize = 8;
 
+/// The size in bytes of each buffer used in the RX virtqueue.
+const RX_BUFFER_SIZE: usize = 512;
+
 #[derive(Clone, Debug, Default, PartialEq, Eq)]
 struct ConnectionInfo {
     dst: VsockAddr,
@@ -108,7 +111,7 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
     /// The guest_cid field contains the guest’s context ID, which uniquely identifies
     /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
     guest_cid: u64,
-    rx_buf: NonNull<[u8; PAGE_SIZE]>,
+    rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
 
     /// Currently the device is only allowed to be connected to one destination at a time.
     connection_info: Option<ConnectionInfo>,
@@ -122,9 +125,11 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
         self.transport.queue_unset(TX_QUEUE_IDX);
         self.transport.queue_unset(EVENT_QUEUE_IDX);
 
-        // Safe because we obtained the rx_buf pointer from Box::into_raw, and it won't be used
-        // anywhere else after the driver is destroyed.
-        unsafe { drop(Box::from_raw(self.rx_buf.as_ptr())) };
+        for buffer in self.rx_queue_buffers {
+            // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
+            // used anywhere else after the driver is destroyed.
+            unsafe { drop(Box::from_raw(buffer.as_ptr())) };
+        }
     }
 }
 
@@ -151,14 +156,22 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
         let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
 
-        // Allocates 4 KiB memory for the RX buffer.
-        let rx_buf: NonNull<[u8; PAGE_SIZE]> =
-            NonNull::new(Box::into_raw(FromBytes::new_box_zeroed())).unwrap();
-        // Safe because `rx_buf` lives as long as the `rx` queue.
-        unsafe {
-            Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
+        // Allocate and add buffers for the RX queue.
+        let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
+        for i in 0..QUEUE_SIZE {
+            let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
+            // Safe because the buffer lives as long as the queue, as specified in the function
+            // safety requirement, and we don't access it until it is popped.
+            let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
+            assert_eq!(i, token.into());
+            rx_queue_buffers[i] = Box::into_raw(buffer);
         }
+        let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
+
         transport.finish_init();
+        if rx.should_notify() {
+            transport.notify(RX_QUEUE_IDX);
+        }
 
         Ok(Self {
             transport,
@@ -166,39 +179,14 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             tx,
             event,
             guest_cid,
-            rx_buf,
+            rx_queue_buffers,
             connection_info: None,
         })
     }
 
-    /// Fills the `rx` queue with the buffer `rx_buf`.
-    ///
-    /// # Safety
-    ///
-    /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
-    /// in the queue must not be used anywhere else at the same time.
-    unsafe fn fill_rx_queue(
-        rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
-        rx_buf: NonNull<[u8]>,
-        transport: &mut T,
-    ) -> Result {
-        if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
-            return Err(SocketError::BufferTooShort.into());
-        }
-        for i in 0..QUEUE_SIZE {
-            // Safe because the buffer lives as long as the queue, as specified in the function
-            // safety requirement, and we don't access it until it is popped.
-            unsafe {
-                let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
-                let token = rx.add(&[], &mut [buffer])?;
-                assert_eq!(i, token.into());
-            }
-        }
-
-        if rx.should_notify() {
-            transport.notify(RX_QUEUE_IDX);
-        }
-        Ok(())
+    /// Returns the CID which has been assigned to this guest.
+    pub fn guest_cid(&self) -> u64 {
+        self.guest_cid
     }
 
     /// Sends a request to connect to the given destination.
@@ -206,15 +194,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
     /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
     /// before sending data.
-    pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
+    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
         if self.connection_info.is_some() {
             return Err(SocketError::ConnectionExists.into());
         }
         let new_connection_info = ConnectionInfo {
-            dst: VsockAddr {
-                cid: dst_cid,
-                port: dst_port,
-            },
+            dst: destination,
             src_port,
             ..Default::default()
         };
@@ -483,7 +468,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
         // buffer back to the RX queue then we don't access it again until next time it is popped.
         let header = unsafe {
-            let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf, token.into())?;
+            let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
             let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
 
             // Read the header and body from the buffer. Don't check the result yet, because we need
@@ -508,36 +493,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
             .clone()
             .ok_or(SocketError::NotConnected.into())
     }
-
-    /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
-    /// virtqueue.
-    ///
-    /// # Safety
-    ///
-    /// `rx_buf` must be a valid dereferenceable pointer.
-    /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
-    /// any other references to the same subslice of the RX buffer or outlive the buffer.
-    unsafe fn as_mut_sub_rx_buffer<'a>(
-        mut rx_buf: NonNull<[u8]>,
-        i: usize,
-    ) -> Result<&'a mut [u8]> {
-        let buffer_size = rx_buf.len() / QUEUE_SIZE;
-        let start = buffer_size
-            .checked_mul(i)
-            .ok_or(SocketError::InvalidNumber)?;
-        let end = start
-            .checked_add(buffer_size)
-            .ok_or(SocketError::InvalidNumber)?;
-        // Safe because no alignment or initialisation is required for [u8], and our caller assures
-        // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
-        // won't overlap with any other references to the same slice or outlive it.
-        unsafe {
-            rx_buf
-                .as_mut()
-                .get_mut(start..end)
-                .ok_or(SocketError::BufferTooShort.into())
-        }
-    }
 }
 
 fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
@@ -597,7 +552,7 @@ mod tests {
         };
         let socket =
             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
-        assert_eq!(socket.guest_cid, 0x00_0000_0042);
+        assert_eq!(socket.guest_cid(), 0x00_0000_0042);
     }
 
     #[test]
@@ -606,6 +561,10 @@ mod tests {
         let guest_cid = 66;
         let host_port = 1234;
         let guest_port = 4321;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
         let hello_from_guest = "Hello from guest";
         let hello_from_host = "Hello from host";
 
@@ -806,7 +765,7 @@ mod tests {
             );
         });
 
-        socket.connect(host_cid, guest_port, host_port).unwrap();
+        socket.connect(host_address, guest_port).unwrap();
         socket.wait_for_connect().unwrap();
         socket.send(hello_from_guest.as_bytes()).unwrap();
         let mut buffer = [0u8; 64];