Parcourir la source

Add tests for VirtQueue (#12)

* Run unit tests on GitHub Actions.

* Add some tests for VirtQueue initialisation.

This required adding a fake HAL implementation, and a method to
construct a fake header.

* Test invalid cases of adding to queue.

* Test adding buffers to queue.

* VirtQueue struct doesn't need to be repr(C).

It isn't shared with the host, only the DMA region (containing the
descriptors, AvailRing and UsedRing) is.
Andrew Walbran il y a 2 ans
Parent
commit
4ee80e50ba
5 fichiers modifiés avec 220 ajouts et 19 suppressions
  1. 19 14
      .github/workflows/main.yml
  2. 4 0
      src/hal.rs
  3. 41 0
      src/hal/fake.rs
  4. 50 1
      src/header.rs
  5. 106 4
      src/queue.rs

+ 19 - 14
.github/workflows/main.yml

@@ -25,17 +25,22 @@ jobs:
   build:
     runs-on: ubuntu-latest
     steps:
-    - uses: actions/checkout@v2
-    - uses: actions-rs/toolchain@v1
-      with:
-        profile: minimal
-        toolchain: stable
-    - name: Build
-      uses: actions-rs/cargo@v1
-      with:
-        command: build
-        args: --all-features
-    - name: Docs
-      uses: actions-rs/cargo@v1
-      with:
-        command: doc
+      - uses: actions/checkout@v2
+      - uses: actions-rs/toolchain@v1
+        with:
+          profile: minimal
+          toolchain: stable
+      - name: Build
+        uses: actions-rs/cargo@v1
+        with:
+          command: build
+          args: --all-features
+      - name: Docs
+        uses: actions-rs/cargo@v1
+        with:
+          command: doc
+      - name: Test
+        uses: actions-rs/cargo@v1
+        with:
+          command: test
+          args: --all-features

+ 4 - 0
src/hal.rs

@@ -1,3 +1,6 @@
+#[cfg(test)]
+pub mod fake;
+
 use super::*;
 use core::marker::PhantomData;
 
@@ -8,6 +11,7 @@ pub type VirtAddr = usize;
 pub type PhysAddr = usize;
 
 /// A region of contiguous physical memory used for DMA.
+#[derive(Debug)]
 pub struct DMA<H: Hal> {
     paddr: usize,
     pages: usize,

+ 41 - 0
src/hal/fake.rs

@@ -0,0 +1,41 @@
+//! Fake HAL implementation for tests.
+
+use crate::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error};
+use core::alloc::Layout;
+
+#[derive(Debug)]
+pub struct FakeHal;
+
+/// Fake HAL implementation for use in unit tests.
+impl Hal for FakeHal {
+    fn dma_alloc(pages: usize) -> PhysAddr {
+        assert_ne!(pages, 0);
+        let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
+        // Safe because the size and alignment of the layout are non-zero.
+        let ptr = unsafe { alloc_zeroed(layout) };
+        if ptr.is_null() {
+            handle_alloc_error(layout);
+        }
+        ptr as PhysAddr
+    }
+
+    fn dma_dealloc(paddr: PhysAddr, pages: usize) -> i32 {
+        assert_ne!(pages, 0);
+        let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
+        // Safe because the layout is the same as was used when the memory was allocated by
+        // `dma_alloc` above.
+        unsafe {
+            dealloc(paddr as *mut u8, layout);
+        }
+        0
+    }
+
+    fn phys_to_virt(paddr: PhysAddr) -> VirtAddr {
+        paddr
+    }
+
+    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+        vaddr
+    }
+}

+ 50 - 1
src/header.rs

@@ -2,6 +2,8 @@ use crate::PAGE_SIZE;
 use bitflags::*;
 use volatile::{ReadOnly, Volatile, WriteOnly};
 
+const MAGIC_VALUE: u32 = 0x7472_6976;
+
 /// MMIO Device Legacy Register Interface.
 ///
 /// Ref: 4.2.4 Legacy interface
@@ -148,7 +150,7 @@ pub struct VirtIOHeader {
 impl VirtIOHeader {
     /// Verify a valid header.
     pub fn verify(&self) -> bool {
-        self.magic.read() == 0x7472_6976 && self.version.read() == 1 && self.device_id.read() != 0
+        self.magic.read() == MAGIC_VALUE && self.version.read() == 1 && self.device_id.read() != 0
     }
 
     /// Get the device type.
@@ -244,6 +246,53 @@ impl VirtIOHeader {
     pub fn config_space(&self) -> *mut u64 {
         (self as *const _ as usize + CONFIG_SPACE_OFFSET) as _
     }
+
+    /// Constructs a fake virtio header for use in unit tests.
+    #[cfg(test)]
+    pub fn make_fake_header(
+        device_id: u32,
+        vendor_id: u32,
+        device_features: u32,
+        queue_num_max: u32,
+    ) -> Self {
+        Self {
+            magic: ReadOnly::new(MAGIC_VALUE),
+            version: ReadOnly::new(1),
+            device_id: ReadOnly::new(device_id),
+            vendor_id: ReadOnly::new(vendor_id),
+            device_features: ReadOnly::new(device_features),
+            device_features_sel: WriteOnly::default(),
+            __r1: Default::default(),
+            driver_features: Default::default(),
+            driver_features_sel: Default::default(),
+            guest_page_size: Default::default(),
+            __r2: Default::default(),
+            queue_sel: Default::default(),
+            queue_num_max: ReadOnly::new(queue_num_max),
+            queue_num: Default::default(),
+            queue_align: Default::default(),
+            queue_pfn: Default::default(),
+            queue_ready: Default::default(),
+            __r3: Default::default(),
+            queue_notify: Default::default(),
+            __r4: Default::default(),
+            interrupt_status: Default::default(),
+            interrupt_ack: Default::default(),
+            __r5: Default::default(),
+            status: Volatile::new(DeviceStatus::empty()),
+            __r6: Default::default(),
+            queue_desc_low: Default::default(),
+            queue_desc_high: Default::default(),
+            __r7: Default::default(),
+            queue_avail_low: Default::default(),
+            queue_avail_high: Default::default(),
+            __r8: Default::default(),
+            queue_used_low: Default::default(),
+            queue_used_high: Default::default(),
+            __r9: Default::default(),
+            config_generation: Default::default(),
+        }
+    }
 }
 
 bitflags! {

+ 106 - 4
src/queue.rs

@@ -11,7 +11,7 @@ use volatile::Volatile;
 /// The mechanism for bulk data transport on virtio devices.
 ///
 /// Each device can have zero or more virtqueues.
-#[repr(C)]
+#[derive(Debug)]
 pub struct VirtQueue<'a, H: Hal> {
     /// DMA guard
     dma: DMA<H>,
@@ -24,7 +24,10 @@ pub struct VirtQueue<'a, H: Hal> {
 
     /// The index of queue
     queue_idx: u32,
-    /// The size of queue
+    /// The size of the queue.
+    ///
+    /// This is both the number of descriptors, and the number of slots in the available and used
+    /// rings.
     queue_size: u16,
     /// The number of used queues.
     num_used: u16,
@@ -44,7 +47,7 @@ impl<H: Hal> VirtQueue<'_, H> {
             return Err(Error::InvalidParam);
         }
         let layout = VirtQueueLayout::new(size);
-        // alloc continuous pages
+        // Allocate contiguous pages.
         let dma = DMA::new(layout.size / PAGE_SIZE)?;
 
         header.queue_set(idx as u32, size as u32, PAGE_SIZE as u32, dma.pfn());
@@ -54,7 +57,7 @@ impl<H: Hal> VirtQueue<'_, H> {
         let avail = unsafe { &mut *((dma.vaddr() + layout.avail_offset) as *mut AvailRing) };
         let used = unsafe { &mut *((dma.vaddr() + layout.used_offset) as *mut UsedRing) };
 
-        // link descriptors together
+        // Link descriptors together.
         for i in 0..(size - 1) {
             desc[i as usize].next.write(i + 1);
         }
@@ -260,3 +263,102 @@ struct UsedElem {
     id: Volatile<u32>,
     len: Volatile<u32>,
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::hal::fake::FakeHal;
+    use core::mem::zeroed;
+
+    #[test]
+    fn invalid_queue_size() {
+        let mut header = unsafe { zeroed() };
+        // Size not a power of 2.
+        assert_eq!(
+            VirtQueue::<FakeHal>::new(&mut header, 0, 3).unwrap_err(),
+            Error::InvalidParam
+        );
+    }
+
+    #[test]
+    fn queue_too_big() {
+        let mut header = VirtIOHeader::make_fake_header(0, 0, 0, 4);
+        assert_eq!(
+            VirtQueue::<FakeHal>::new(&mut header, 0, 5).unwrap_err(),
+            Error::InvalidParam
+        );
+    }
+
+    #[test]
+    fn queue_already_used() {
+        let mut header = VirtIOHeader::make_fake_header(0, 0, 0, 4);
+        VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        assert_eq!(
+            VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap_err(),
+            Error::AlreadyUsed
+        );
+    }
+
+    #[test]
+    fn add_empty() {
+        let mut header = VirtIOHeader::make_fake_header(0, 0, 0, 4);
+        let mut queue = VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        assert_eq!(queue.add(&[], &[]).unwrap_err(), Error::InvalidParam);
+    }
+
+    #[test]
+    fn add_too_big() {
+        let mut header = VirtIOHeader::make_fake_header(0, 0, 0, 4);
+        let mut queue = VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        assert_eq!(queue.available_desc(), 4);
+        assert_eq!(
+            queue
+                .add(&[&[], &[], &[]], &[&mut [], &mut []])
+                .unwrap_err(),
+            Error::BufferTooSmall
+        );
+    }
+
+    #[test]
+    fn add_buffers() {
+        let mut header = VirtIOHeader::make_fake_header(0, 0, 0, 4);
+        let mut queue = VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        assert_eq!(queue.size(), 4);
+        assert_eq!(queue.available_desc(), 4);
+
+        // Add a buffer chain consisting of two device-readable parts followed by two
+        // device-writable parts.
+        let token = queue
+            .add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]])
+            .unwrap();
+
+        assert_eq!(queue.available_desc(), 0);
+        assert!(!queue.can_pop());
+
+        let first_descriptor_index = queue.avail.ring[0].read();
+        assert_eq!(first_descriptor_index, token);
+        assert_eq!(queue.desc[first_descriptor_index as usize].len.read(), 2);
+        assert_eq!(
+            queue.desc[first_descriptor_index as usize].flags.read(),
+            DescFlags::NEXT
+        );
+        let second_descriptor_index = queue.desc[first_descriptor_index as usize].next.read();
+        assert_eq!(queue.desc[second_descriptor_index as usize].len.read(), 1);
+        assert_eq!(
+            queue.desc[second_descriptor_index as usize].flags.read(),
+            DescFlags::NEXT
+        );
+        let third_descriptor_index = queue.desc[second_descriptor_index as usize].next.read();
+        assert_eq!(queue.desc[third_descriptor_index as usize].len.read(), 2);
+        assert_eq!(
+            queue.desc[third_descriptor_index as usize].flags.read(),
+            DescFlags::NEXT | DescFlags::WRITE
+        );
+        let fourth_descriptor_index = queue.desc[third_descriptor_index as usize].next.read();
+        assert_eq!(queue.desc[fourth_descriptor_index as usize].len.read(), 1);
+        assert_eq!(
+            queue.desc[fourth_descriptor_index as usize].flags.read(),
+            DescFlags::WRITE
+        );
+    }
+}