浏览代码

Make queue size a const generic parameter.

This fixes the issue where we were using the wrong size in AvailRing and
UsedRing.
Andrew Walbran 2 年之前
父节点
当前提交
72e5aaee34
共有 7 个文件被更改,包括 66 次插入71 次删除
  1. 4 3
      src/device/blk.rs
  2. 6 6
      src/device/console.rs
  3. 6 4
      src/device/gpu.rs
  4. 4 4
      src/device/input.rs
  5. 6 5
      src/device/net.rs
  6. 38 46
      src/queue.rs
  7. 2 3
      src/transport/fake.rs

+ 4 - 3
src/device/blk.rs

@@ -10,6 +10,7 @@ use log::info;
 use zerocopy::{AsBytes, FromBytes};
 
 const QUEUE: u16 = 0;
+const QUEUE_SIZE: u16 = 16;
 
 /// Driver for a VirtIO block device.
 ///
@@ -39,7 +40,7 @@ const QUEUE: u16 = 0;
 /// ```
 pub struct VirtIOBlk<H: Hal, T: Transport> {
     transport: T,
-    queue: VirtQueue<H>,
+    queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
     capacity: u64,
     readonly: bool,
 }
@@ -67,7 +68,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         info!("found a block device of size {}KB", capacity / 2);
 
-        let queue = VirtQueue::new(&mut transport, QUEUE, 16)?;
+        let queue = VirtQueue::new(&mut transport, QUEUE)?;
         transport.finish_init();
 
         Ok(VirtIOBlk {
@@ -298,7 +299,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     ///
     /// This can be used to tell the caller how many channels to monitor on.
     pub fn virt_queue_size(&self) -> u16 {
-        self.queue.size()
+        QUEUE_SIZE
     }
 }
 

+ 6 - 6
src/device/console.rs

@@ -11,7 +11,7 @@ use log::info;
 
 const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
 const QUEUE_TRANSMITQ_PORT_0: u16 = 1;
-const QUEUE_SIZE: u16 = 2;
+const QUEUE_SIZE: usize = 2;
 
 /// Driver for a VirtIO console device.
 ///
@@ -41,8 +41,8 @@ const QUEUE_SIZE: u16 = 2;
 pub struct VirtIOConsole<'a, H: Hal, T: Transport> {
     transport: T,
     config_space: NonNull<Config>,
-    receiveq: VirtQueue<H>,
-    transmitq: VirtQueue<H>,
+    receiveq: VirtQueue<H, QUEUE_SIZE>,
+    transmitq: VirtQueue<H, QUEUE_SIZE>,
     queue_buf_dma: Dma<H>,
     queue_buf_rx: &'a mut [u8],
     cursor: usize,
@@ -72,8 +72,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
             (features & supported_features).bits()
         });
         let config_space = transport.config_space::<Config>()?;
-        let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, QUEUE_SIZE)?;
-        let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, QUEUE_SIZE)?;
+        let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0)?;
+        let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0)?;
         let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?;
 
         // Safe because no alignment or initialisation is required for [u8], the DMA buffer is
@@ -270,7 +270,7 @@ mod tests {
         // Make a character available, and simulate an interrupt.
         {
             let mut state = state.lock().unwrap();
-            state.write_to_queue(QUEUE_SIZE, QUEUE_RECEIVEQ_PORT_0, &[42]);
+            state.write_to_queue::<QUEUE_SIZE>(QUEUE_RECEIVEQ_PORT_0, &[42]);
 
             state.interrupt_pending = true;
         }

+ 6 - 4
src/device/gpu.rs

@@ -8,6 +8,8 @@ use crate::{pages, Error, Result};
 use bitflags::bitflags;
 use log::info;
 
+const QUEUE_SIZE: u16 = 2;
+
 /// A virtio based graphics adapter.
 ///
 /// It can operate in 2D mode and in 3D (virgl) mode.
@@ -23,9 +25,9 @@ pub struct VirtIOGpu<'a, H: Hal, T: Transport> {
     /// DMA area of cursor image buffer.
     cursor_buffer_dma: Option<Dma<H>>,
     /// Queue for sending control commands.
-    control_queue: VirtQueue<H>,
+    control_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
     /// Queue for sending cursor commands.
-    cursor_queue: VirtQueue<H>,
+    cursor_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
     /// DMA region for sending data to the device.
     dma_send: Dma<H>,
     /// DMA region for receiving data from the device.
@@ -57,8 +59,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
             );
         }
 
-        let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, 2)?;
-        let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, 2)?;
+        let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
+        let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR)?;
 
         let dma_send = Dma::new(1, BufferDirection::DriverToDevice)?;
         let dma_recv = Dma::new(1, BufferDirection::DeviceToDriver)?;

+ 4 - 4
src/device/input.rs

@@ -18,8 +18,8 @@ use zerocopy::{AsBytes, FromBytes};
 /// making pass-through implementations on top of evdev easy.
 pub struct VirtIOInput<H: Hal, T: Transport> {
     transport: T,
-    event_queue: VirtQueue<H>,
-    status_queue: VirtQueue<H>,
+    event_queue: VirtQueue<H, QUEUE_SIZE>,
+    status_queue: VirtQueue<H, QUEUE_SIZE>,
     event_buf: Box<[InputEvent; 32]>,
     config: NonNull<Config>,
 }
@@ -38,8 +38,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
 
         let config = transport.config_space::<Config>()?;
 
-        let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, QUEUE_SIZE as u16)?;
-        let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
+        let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT)?;
+        let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
             let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? };

+ 6 - 5
src/device/net.rs

@@ -10,6 +10,8 @@ use core::mem::{size_of, MaybeUninit};
 use log::{debug, info};
 use zerocopy::{AsBytes, FromBytes};
 
+const QUEUE_SIZE: u16 = 2;
+
 /// The virtio network device is a virtual ethernet card.
 ///
 /// It has enhanced rapidly and demonstrates clearly how support for new
@@ -20,8 +22,8 @@ use zerocopy::{AsBytes, FromBytes};
 pub struct VirtIONet<H: Hal, T: Transport> {
     transport: T,
     mac: EthernetAddress,
-    recv_queue: VirtQueue<H>,
-    send_queue: VirtQueue<H>,
+    recv_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
+    send_queue: VirtQueue<H, { QUEUE_SIZE as usize }>,
 }
 
 impl<H: Hal, T: Transport> VirtIONet<H, T> {
@@ -42,9 +44,8 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
             debug!("Got MAC={:?}, status={:?}", mac, volread!(config, status));
         }
 
-        let queue_num = 2; // for simplicity
-        let recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE, queue_num)?;
-        let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, queue_num)?;
+        let recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
+        let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
 
         transport.finish_init();
 

+ 38 - 46
src/queue.rs

@@ -16,24 +16,22 @@ use core::sync::atomic::{fence, Ordering};
 /// The mechanism for bulk data transport on virtio devices.
 ///
 /// Each device can have zero or more virtqueues.
+///
+/// * `SIZE`: The size of the queue. This is both the number of descriptors, and the number of slots
+///   in the available and used rings.
 #[derive(Debug)]
-pub struct VirtQueue<H: Hal> {
+pub struct VirtQueue<H: Hal, const SIZE: usize> {
     /// DMA guard
     layout: VirtQueueLayout<H>,
     /// Descriptor table
     desc: NonNull<[Descriptor]>,
     /// Available ring
-    avail: NonNull<AvailRing>,
+    avail: NonNull<AvailRing<SIZE>>,
     /// Used ring
-    used: NonNull<UsedRing>,
+    used: NonNull<UsedRing<SIZE>>,
 
     /// The index of queue
     queue_idx: u16,
-    /// The size of the queue.
-    ///
-    /// This is both the number of descriptors, and the number of slots in the available and used
-    /// rings.
-    queue_size: u16,
     /// The number of descriptors currently in use.
     num_used: u16,
     /// The head desc index of the free list.
@@ -42,15 +40,19 @@ pub struct VirtQueue<H: Hal> {
     last_used_idx: u16,
 }
 
-impl<H: Hal> VirtQueue<H> {
+impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// Create a new VirtQueue.
-    pub fn new<T: Transport>(transport: &mut T, idx: u16, size: u16) -> Result<Self> {
+    pub fn new<T: Transport>(transport: &mut T, idx: u16) -> Result<Self> {
         if transport.queue_used(idx) {
             return Err(Error::AlreadyUsed);
         }
-        if !size.is_power_of_two() || transport.max_queue_size() < size as u32 {
+        if !SIZE.is_power_of_two()
+            || SIZE > u16::MAX.into()
+            || transport.max_queue_size() < SIZE as u32
+        {
             return Err(Error::InvalidParam);
         }
+        let size = SIZE as u16;
 
         let layout = if transport.requires_legacy_layout() {
             VirtQueueLayout::allocate_legacy(size)?
@@ -60,16 +62,14 @@ impl<H: Hal> VirtQueue<H> {
 
         transport.queue_set(
             idx,
-            size as u32,
+            size.into(),
             layout.descriptors_paddr(),
             layout.driver_area_paddr(),
             layout.device_area_paddr(),
         );
 
-        let desc = nonnull_slice_from_raw_parts(
-            layout.descriptors_vaddr().cast::<Descriptor>(),
-            size as usize,
-        );
+        let desc =
+            nonnull_slice_from_raw_parts(layout.descriptors_vaddr().cast::<Descriptor>(), SIZE);
         let avail = layout.avail_vaddr().cast();
         let used = layout.used_vaddr().cast();
 
@@ -87,7 +87,6 @@ impl<H: Hal> VirtQueue<H> {
             desc,
             avail,
             used,
-            queue_size: size,
             queue_idx: idx,
             num_used: 0,
             free_head: 0,
@@ -107,7 +106,7 @@ impl<H: Hal> VirtQueue<H> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
         }
-        if inputs.len() + outputs.len() + self.num_used as usize > self.queue_size as usize {
+        if inputs.len() + outputs.len() + self.num_used as usize > SIZE {
             return Err(Error::QueueFull);
         }
 
@@ -130,7 +129,7 @@ impl<H: Hal> VirtQueue<H> {
         }
         self.num_used += (inputs.len() + outputs.len()) as u16;
 
-        let avail_slot = self.avail_idx & (self.queue_size - 1);
+        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
         // Safe because self.avail is properly aligned, dereferenceable and initialised.
         unsafe {
             (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
@@ -198,7 +197,7 @@ impl<H: Hal> VirtQueue<H> {
     /// `None` if the used ring is empty.
     pub fn peek_used(&self) -> Option<u16> {
         if self.can_pop() {
-            let last_used_slot = self.last_used_idx & (self.queue_size - 1);
+            let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
             // Safe because self.used points to a valid, aligned, initialised, dereferenceable,
             // readable instance of UsedRing.
             Some(unsafe { (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16 })
@@ -209,7 +208,7 @@ impl<H: Hal> VirtQueue<H> {
 
     /// Returns the number of free descriptors.
     pub fn available_desc(&self) -> usize {
-        (self.queue_size - self.num_used) as usize
+        SIZE - self.num_used as usize
     }
 
     /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
@@ -263,7 +262,7 @@ impl<H: Hal> VirtQueue<H> {
         // Read barrier not necessary, as can_pop already has one.
 
         // Get the index of the start of the descriptor chain for the next element in the used ring.
-        let last_used_slot = self.last_used_idx & (self.queue_size - 1);
+        let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
         let index;
         let len;
         // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
@@ -283,11 +282,6 @@ impl<H: Hal> VirtQueue<H> {
 
         Ok(len)
     }
-
-    /// Return size of the queue.
-    pub fn size(&self) -> u16 {
-        self.queue_size
-    }
 }
 
 /// The inner layout of a VirtQueue.
@@ -503,11 +497,11 @@ bitflags! {
 /// It is only written by the driver and read by the device.
 #[repr(C)]
 #[derive(Debug)]
-struct AvailRing {
+struct AvailRing<const SIZE: usize> {
     flags: u16,
     /// A driver MUST NOT decrement the idx.
     idx: u16,
-    ring: [u16; 32], // actual size: queue_size
+    ring: [u16; SIZE],
     used_event: u16, // unused
 }
 
@@ -515,11 +509,11 @@ struct AvailRing {
 /// it is only written to by the device, and read by the driver.
 #[repr(C)]
 #[derive(Debug)]
-struct UsedRing {
+struct UsedRing<const SIZE: usize> {
     flags: u16,
     idx: u16,
-    ring: [UsedElem; 32], // actual size: queue_size
-    avail_event: u16,     // unused
+    ring: [UsedElem; SIZE],
+    avail_event: u16, // unused
 }
 
 #[repr(C)]
@@ -533,16 +527,15 @@ struct UsedElem {
 ///
 /// The fake device always uses descriptors in order.
 #[cfg(test)]
-pub(crate) fn fake_write_to_queue(
-    queue_size: u16,
+pub(crate) fn fake_write_to_queue<const QUEUE_SIZE: usize>(
     receive_queue_descriptors: *const Descriptor,
     receive_queue_driver_area: VirtAddr,
     receive_queue_device_area: VirtAddr,
     data: &[u8],
 ) {
-    let descriptors = ptr::slice_from_raw_parts(receive_queue_descriptors, queue_size as usize);
-    let available_ring = receive_queue_driver_area as *const AvailRing;
-    let used_ring = receive_queue_device_area as *mut UsedRing;
+    let descriptors = ptr::slice_from_raw_parts(receive_queue_descriptors, QUEUE_SIZE);
+    let available_ring = receive_queue_driver_area as *const AvailRing<QUEUE_SIZE>;
+    let used_ring = receive_queue_device_area as *mut UsedRing<QUEUE_SIZE>;
     // Safe because the various pointers are properly aligned, dereferenceable, initialised, and
     // nothing else accesses them during this block.
     unsafe {
@@ -550,7 +543,7 @@ pub(crate) fn fake_write_to_queue(
         assert_ne!((*available_ring).idx, (*used_ring).idx);
         // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
         // `used_ring.idx` marks the next descriptor we should take from the available ring.
-        let next_slot = (*used_ring).idx & (queue_size - 1);
+        let next_slot = (*used_ring).idx & (QUEUE_SIZE as u16 - 1);
         let head_descriptor_index = (*available_ring).ring[next_slot as usize];
         let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
 
@@ -599,7 +592,7 @@ mod tests {
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         // Size not a power of 2.
         assert_eq!(
-            VirtQueue::<FakeHal>::new(&mut transport, 0, 3).unwrap_err(),
+            VirtQueue::<FakeHal, 3>::new(&mut transport, 0).unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -609,7 +602,7 @@ mod tests {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal>::new(&mut transport, 0, 5).unwrap_err(),
+            VirtQueue::<FakeHal, 8>::new(&mut transport, 0).unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -618,9 +611,9 @@ mod tests {
     fn queue_already_used() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
+        VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap_err(),
+            VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap_err(),
             Error::AlreadyUsed
         );
     }
@@ -629,7 +622,7 @@ mod tests {
     fn add_empty() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(
             unsafe { queue.add(&[], &[]) }.unwrap_err(),
             Error::InvalidParam
@@ -640,7 +633,7 @@ mod tests {
     fn add_too_many() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
             unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
@@ -652,8 +645,7 @@ mod tests {
     fn add_buffers() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
-        assert_eq!(queue.size(), 4);
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(queue.available_desc(), 4);
 
         // Add a buffer chain consisting of two device-readable parts followed by two

+ 2 - 3
src/transport/fake.rs

@@ -108,11 +108,10 @@ impl State {
     /// Simulates the device writing to the given queue.
     ///
     /// The fake device always uses descriptors in order.
-    pub fn write_to_queue(&mut self, queue_size: u16, queue_index: u16, data: &[u8]) {
+    pub fn write_to_queue<const QUEUE_SIZE: usize>(&mut self, queue_index: u16, data: &[u8]) {
         let receive_queue = &self.queues[queue_index as usize];
         assert_ne!(receive_queue.descriptors, 0);
-        fake_write_to_queue(
-            queue_size,
+        fake_write_to_queue::<QUEUE_SIZE>(
             receive_queue.descriptors as *const Descriptor,
             receive_queue.driver_area,
             receive_queue.device_area,