Browse Source

Add support for indirect descriptors in VirtQueue.

Co-authored-by: Arsering <[email protected]>
Andrew Walbran 1 year ago
parent
commit
5934187dd0
7 changed files with 252 additions and 61 deletions
  1. 1 1
      src/device/blk.rs
  2. 2 2
      src/device/console.rs
  3. 2 2
      src/device/gpu.rs
  4. 2 2
      src/device/input.rs
  5. 2 2
      src/device/net.rs
  6. 3 3
      src/device/socket/vsock.rs
  7. 240 49
      src/queue.rs

+ 1 - 1
src/device/blk.rs

@@ -68,7 +68,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         info!("found a block device of size {}KB", capacity / 2);
 
-        let queue = VirtQueue::new(&mut transport, QUEUE)?;
+        let queue = VirtQueue::new(&mut transport, QUEUE, false)?;
         transport.finish_init();
 
         Ok(VirtIOBlk {

+ 2 - 2
src/device/console.rs

@@ -72,8 +72,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<H, T> {
             (features & supported_features).bits()
         });
         let config_space = transport.config_space::<Config>()?;
-        let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0)?;
-        let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0)?;
+        let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, false)?;
+        let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, false)?;
 
         // Safe because no alignment or initialisation is required for [u8], the DMA buffer is
         // dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer

+ 2 - 2
src/device/gpu.rs

@@ -57,8 +57,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<H, T> {
             );
         }
 
-        let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
-        let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR)?;
+        let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, false)?;
+        let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, false)?;
 
         let queue_buf_send = FromBytes::new_box_slice_zeroed(PAGE_SIZE);
         let queue_buf_recv = FromBytes::new_box_slice_zeroed(PAGE_SIZE);

+ 2 - 2
src/device/input.rs

@@ -38,8 +38,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
 
         let config = transport.config_space::<Config>()?;
 
-        let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT)?;
-        let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?;
+        let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, false)?;
+        let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, false)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
             let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? };

+ 2 - 2
src/device/net.rs

@@ -139,8 +139,8 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE>
             return Err(Error::InvalidParam);
         }
 
-        let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
-        let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
+        let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, false)?;
+        let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE, false)?;
 
         const NONE_BUF: Option<RxBuffer> = None;
         let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];

+ 3 - 3
src/device/socket/vsock.rs

@@ -257,9 +257,9 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         };
         debug!("guest cid: {guest_cid:?}");
 
-        let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
-        let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
-        let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
+        let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX, false)?;
+        let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX, false)?;
+        let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX, false)?;
 
         // Allocate and add buffers for the RX queue.
         let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];

+ 240 - 49
src/queue.rs

@@ -3,6 +3,8 @@
 use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
 use crate::transport::Transport;
 use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};
+#[cfg(feature = "alloc")]
+use alloc::boxed::Box;
 use bitflags::bitflags;
 #[cfg(test)]
 use core::cmp::min;
@@ -12,7 +14,7 @@ use core::mem::{size_of, take};
 use core::ptr;
 use core::ptr::NonNull;
 use core::sync::atomic::{fence, Ordering};
-use zerocopy::FromBytes;
+use zerocopy::{AsBytes, FromBytes};
 
 /// The mechanism for bulk data transport on virtio devices.
 ///
@@ -50,11 +52,15 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> {
     /// Our trusted copy of `avail.idx`.
     avail_idx: u16,
     last_used_idx: u16,
+    #[cfg(feature = "alloc")]
+    indirect: bool,
+    #[cfg(feature = "alloc")]
+    indirect_lists: [Option<NonNull<[Descriptor]>>; SIZE],
 }
 
 impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// Create a new VirtQueue.
-    pub fn new<T: Transport>(transport: &mut T, idx: u16) -> Result<Self> {
+    pub fn new<T: Transport>(transport: &mut T, idx: u16, indirect: bool) -> Result<Self> {
         if transport.queue_used(idx) {
             return Err(Error::AlreadyUsed);
         }
@@ -96,6 +102,8 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             }
         }
 
+        #[cfg(feature = "alloc")]
+        const NONE: Option<NonNull<[Descriptor]>> = None;
         Ok(VirtQueue {
             layout,
             desc,
@@ -107,6 +115,10 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             desc_shadow,
             avail_idx: 0,
             last_used_idx: 0,
+            #[cfg(feature = "alloc")]
+            indirect,
+            #[cfg(feature = "alloc")]
+            indirect_lists: [NONE; SIZE],
         })
     }
 
@@ -128,10 +140,58 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
         }
-        if inputs.len() + outputs.len() + self.num_used as usize > SIZE {
+        let descriptors_needed = inputs.len() + outputs.len();
+        // Only consider indirect descriptors if the alloc feature is enabled, as they require
+        // allocation.
+        #[cfg(feature = "alloc")]
+        if self.num_used as usize + 1 > SIZE
+            || descriptors_needed > SIZE
+            || (!self.indirect && self.num_used as usize + descriptors_needed > SIZE)
+        {
+            return Err(Error::QueueFull);
+        }
+        #[cfg(not(feature = "alloc"))]
+        if self.num_used as usize + descriptors_needed > SIZE {
             return Err(Error::QueueFull);
         }
 
+        #[cfg(feature = "alloc")]
+        let head = if self.indirect && descriptors_needed > 1 {
+            self.add_indirect(inputs, outputs)
+        } else {
+            self.add_direct(inputs, outputs)
+        };
+        #[cfg(not(feature = "alloc"))]
+        let head = self.add_direct(inputs, outputs);
+
+        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
+        // Safe because self.avail is properly aligned, dereferenceable and initialised.
+        unsafe {
+            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
+        }
+
+        // Write barrier so that device sees changes to descriptor table and available ring before
+        // change to available index.
+        fence(Ordering::SeqCst);
+
+        // increase head of avail ring
+        self.avail_idx = self.avail_idx.wrapping_add(1);
+        // Safe because self.avail is properly aligned, dereferenceable and initialised.
+        unsafe {
+            (*self.avail.as_ptr()).idx = self.avail_idx;
+        }
+
+        // Write barrier so that device can see change to available index after this method returns.
+        fence(Ordering::SeqCst);
+
+        Ok(head)
+    }
+
+    fn add_direct<'a, 'b>(
+        &mut self,
+        inputs: &'a [&'b [u8]],
+        outputs: &'a mut [&'b mut [u8]],
+    ) -> u16 {
         // allocate descriptors from free list
         let head = self.free_head;
         let mut last = self.free_head;
@@ -160,27 +220,55 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
 
         self.num_used += (inputs.len() + outputs.len()) as u16;
 
-        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
-        // Safe because self.avail is properly aligned, dereferenceable and initialised.
-        unsafe {
-            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
+        head
+    }
+
+    #[cfg(feature = "alloc")]
+    fn add_indirect<'a, 'b>(
+        &mut self,
+        inputs: &'a [&'b [u8]],
+        outputs: &'a mut [&'b mut [u8]],
+    ) -> u16 {
+        let head = self.free_head;
+
+        // Allocate and fill in indirect descriptor list.
+        let mut indirect_list = Descriptor::new_box_slice_zeroed(inputs.len() + outputs.len());
+        for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
+            let desc = &mut indirect_list[i];
+            // Safe because our caller promises that the buffers live at least until `pop_used`
+            // returns them.
+            unsafe {
+                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
+            }
+            desc.next = (i + 1) as u16;
         }
+        indirect_list
+            .last_mut()
+            .unwrap()
+            .flags
+            .remove(DescFlags::NEXT);
 
-        // Write barrier so that device sees changes to descriptor table and available ring before
-        // change to available index.
-        fence(Ordering::SeqCst);
+        // Need to store pointer to indirect_list too, because direct_desc.set_buf will only store
+        // the physical DMA address which might be different.
+        assert!(self.indirect_lists[usize::from(head)].is_none());
+        self.indirect_lists[usize::from(head)] = Some(indirect_list.as_mut().into());
 
-        // increase head of avail ring
-        self.avail_idx = self.avail_idx.wrapping_add(1);
-        // Safe because self.avail is properly aligned, dereferenceable and initialised.
+        // Write a descriptor pointing to indirect descriptor list. We use Box::leak to prevent the
+        // indirect list from being freed when this function returns; recycle_descriptors is instead
+        // responsible for freeing the memory after the buffer chain is popped.
+        let direct_desc = &mut self.desc_shadow[usize::from(head)];
+        self.free_head = direct_desc.next;
         unsafe {
-            (*self.avail.as_ptr()).idx = self.avail_idx;
+            direct_desc.set_buf::<H>(
+                Box::leak(indirect_list).as_bytes().into(),
+                BufferDirection::DriverToDevice,
+                DescFlags::INDIRECT,
+            );
         }
+        self.write_desc(head);
+        self.num_used += 1;
 
-        // Write barrier so that device can see change to available index after this method returns.
-        fence(Ordering::SeqCst);
-
-        Ok(head)
+        head
     }
 
     /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
@@ -263,7 +351,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
 
     /// Returns the number of free descriptors.
     pub fn available_desc(&self) -> usize {
-        SIZE - self.num_used as usize
+        #[cfg(feature = "alloc")]
+        if self.indirect {
+            return if usize::from(self.num_used) == SIZE {
+                0
+            } else {
+                SIZE
+            };
+        }
+
+        SIZE - usize::from(self.num_used)
     }
 
     /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
@@ -284,34 +381,75 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ) {
         let original_free_head = self.free_head;
         self.free_head = head;
-        let mut next = Some(head);
 
-        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
-            assert_ne!(buffer.len(), 0);
-
-            let desc_index = next.expect("Descriptor chain was shorter than expected.");
-            let desc = &mut self.desc_shadow[usize::from(desc_index)];
+        let head_desc = &mut self.desc_shadow[usize::from(head)];
+        if head_desc.flags.contains(DescFlags::INDIRECT) {
+            #[cfg(feature = "alloc")]
+            {
+                // Find the indirect descriptor list, unshare it and move its descriptor to the free
+                // list.
+                let indirect_list = self.indirect_lists[usize::from(head)].take().unwrap();
+                // SAFETY: We allocated the indirect list in `add_indirect`, and the device has
+                // finished accessing it by this point.
+                let mut indirect_list = unsafe { Box::from_raw(indirect_list.as_ptr()) };
+                let paddr = head_desc.addr;
+                head_desc.unset_buf();
+                self.num_used -= 1;
+                head_desc.next = original_free_head;
+
+                unsafe {
+                    H::unshare(
+                        paddr as usize,
+                        indirect_list.as_bytes_mut().into(),
+                        BufferDirection::DriverToDevice,
+                    );
+                }
 
-            let paddr = desc.addr;
-            desc.unset_buf();
-            self.num_used -= 1;
-            next = desc.next();
-            if next.is_none() {
-                desc.next = original_free_head;
+                // Unshare the buffers in the indirect descriptor list, and free it.
+                assert_eq!(indirect_list.len(), inputs.len() + outputs.len());
+                for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
+                    assert_ne!(buffer.len(), 0);
+
+                    // SAFETY: The caller ensures that the buffer is valid and matches the
+                    // descriptor from which we got `paddr`.
+                    unsafe {
+                        // Unshare the buffer (and perhaps copy its contents back to the original
+                        // buffer).
+                        H::unshare(indirect_list[i].addr as usize, buffer, direction);
+                    }
+                }
+                drop(indirect_list);
             }
+        } else {
+            let mut next = Some(head);
 
-            self.write_desc(desc_index);
+            for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
+                assert_ne!(buffer.len(), 0);
 
-            // Safe because the caller ensures that the buffer is valid and matches the descriptor
-            // from which we got `paddr`.
-            unsafe {
-                // Unshare the buffer (and perhaps copy its contents back to the original buffer).
-                H::unshare(paddr as usize, buffer, direction);
+                let desc_index = next.expect("Descriptor chain was shorter than expected.");
+                let desc = &mut self.desc_shadow[usize::from(desc_index)];
+
+                let paddr = desc.addr;
+                desc.unset_buf();
+                self.num_used -= 1;
+                next = desc.next();
+                if next.is_none() {
+                    desc.next = original_free_head;
+                }
+
+                self.write_desc(desc_index);
+
+                // SAFETY: The caller ensures that the buffer is valid and matches the descriptor
+                // from which we got `paddr`.
+                unsafe {
+                    // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+                    H::unshare(paddr as usize, buffer, direction);
+                }
             }
-        }
 
-        if next.is_some() {
-            panic!("Descriptor chain was longer than expected.");
+            if next.is_some() {
+                panic!("Descriptor chain was longer than expected.");
+            }
         }
     }
 
@@ -509,7 +647,7 @@ fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) {
 }
 
 #[repr(C, align(16))]
-#[derive(Clone, Debug, FromBytes)]
+#[derive(AsBytes, Clone, Debug, FromBytes)]
 pub(crate) struct Descriptor {
     addr: u64,
     len: u32,
@@ -564,7 +702,7 @@ impl Descriptor {
 }
 
 /// Descriptor flags
-#[derive(Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)]
+#[derive(AsBytes, Copy, Clone, Debug, Default, Eq, FromBytes, PartialEq)]
 #[repr(transparent)]
 struct DescFlags(u16);
 
@@ -737,7 +875,7 @@ mod tests {
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         // Size not a power of 2.
         assert_eq!(
-            VirtQueue::<FakeHal, 3>::new(&mut transport, 0).unwrap_err(),
+            VirtQueue::<FakeHal, 3>::new(&mut transport, 0, false).unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -747,7 +885,7 @@ mod tests {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal, 8>::new(&mut transport, 0).unwrap_err(),
+            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false).unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -756,9 +894,9 @@ mod tests {
     fn queue_already_used() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
+        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap_err(),
+            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap_err(),
             Error::AlreadyUsed
         );
     }
@@ -767,7 +905,7 @@ mod tests {
     fn add_empty() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
         assert_eq!(
             unsafe { queue.add(&[], &mut []) }.unwrap_err(),
             Error::InvalidParam
@@ -778,7 +916,7 @@ mod tests {
     fn add_too_many() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
             unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
@@ -790,7 +928,7 @@ mod tests {
     fn add_buffers() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
         assert_eq!(queue.available_desc(), 4);
 
         // Add a buffer chain consisting of two device-readable parts followed by two
@@ -845,4 +983,57 @@ mod tests {
             );
         }
     }
+
+    #[cfg(feature = "alloc")]
+    #[test]
+    fn add_buffers_indirect() {
+        use core::ptr::slice_from_raw_parts;
+
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true).unwrap();
+        assert_eq!(queue.available_desc(), 4);
+
+        // Add a buffer chain consisting of two device-readable parts followed by two
+        // device-writable parts.
+        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
+
+        assert_eq!(queue.available_desc(), 4);
+        assert!(!queue.can_pop());
+
+        // Safe because the various parts of the queue are properly aligned, dereferenceable and
+        // initialised, and nothing else is accessing them at the same time.
+        unsafe {
+            let indirect_descriptor_index = (*queue.avail.as_ptr()).ring[0];
+            assert_eq!(indirect_descriptor_index, token);
+            assert_eq!(
+                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].len as usize,
+                4 * size_of::<Descriptor>()
+            );
+            assert_eq!(
+                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].flags,
+                DescFlags::INDIRECT
+            );
+
+            let indirect_descriptors = slice_from_raw_parts(
+                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].addr
+                    as *const Descriptor,
+                4,
+            );
+            assert_eq!((*indirect_descriptors)[0].len, 2);
+            assert_eq!((*indirect_descriptors)[0].flags, DescFlags::NEXT);
+            assert_eq!((*indirect_descriptors)[0].next, 1);
+            assert_eq!((*indirect_descriptors)[1].len, 1);
+            assert_eq!((*indirect_descriptors)[1].flags, DescFlags::NEXT);
+            assert_eq!((*indirect_descriptors)[1].next, 2);
+            assert_eq!((*indirect_descriptors)[2].len, 2);
+            assert_eq!(
+                (*indirect_descriptors)[2].flags,
+                DescFlags::NEXT | DescFlags::WRITE
+            );
+            assert_eq!((*indirect_descriptors)[2].next, 3);
+            assert_eq!((*indirect_descriptors)[3].len, 1);
+            assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE);
+        }
+    }
 }