浏览代码

Add fake transport and a test for console driver. (#17)

* Assert that token assumption holds.

* Add fake transport and a test for console driver.

* Add test helper to queue to avoid exposing types and fields.
Andrew Walbran 2 年之前
父节点
当前提交
62f3e4f262
共有 6 个文件被更改,包括 246 次插入3 次删除
  1. 60 0
      src/console.rs
  2. 5 1
      src/input.rs
  3. 1 1
      src/lib.rs
  4. 61 0
      src/queue.rs
  5. 115 0
      src/transport/fake.rs
  6. 4 1
      src/transport/mod.rs

+ 60 - 0
src/console.rs

@@ -143,3 +143,63 @@ bitflags! {
         const NOTIFICATION_DATA     = 1 << 38;
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::{
+        hal::fake::FakeHal,
+        transport::fake::{FakeTransport, QueueStatus, State},
+    };
+    use alloc::{sync::Arc, vec};
+    use core::ptr::NonNull;
+    use std::sync::Mutex;
+
+    #[test]
+    fn receive() {
+        let mut config_space = Config {
+            cols: ReadOnly::new(0),
+            rows: ReadOnly::new(0),
+            max_nr_ports: ReadOnly::new(0),
+            emerg_wr: WriteOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![QueueStatus::default(); 2],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Console,
+            max_queue_size: 2,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space).cast(),
+            state: state.clone(),
+        };
+        let mut console = VirtIOConsole::<FakeHal, FakeTransport>::new(transport).unwrap();
+
+        // Nothing is available to receive.
+        assert_eq!(console.recv(false).unwrap(), None);
+        assert_eq!(console.recv(true).unwrap(), None);
+
+        // Still nothing after a spurious interrupt.
+        assert_eq!(console.ack_interrupt(), Ok(false));
+        assert_eq!(console.recv(false).unwrap(), None);
+
+        // Make a character available, and simulate an interrupt.
+        {
+            let mut state = state.lock().unwrap();
+            state.write_to_queue(QUEUE_SIZE, QUEUE_RECEIVEQ_PORT_0, &[42]);
+
+            state.interrupt_pending = true;
+        }
+        assert_eq!(console.ack_interrupt(), Ok(true));
+        assert_eq!(state.lock().unwrap().interrupt_pending, false);
+
+        // Receive the character. If we don't pop it it is still there to read again.
+        assert_eq!(console.recv(false).unwrap(), Some(42));
+        assert_eq!(console.recv(true).unwrap(), Some(42));
+        assert_eq!(console.recv(true).unwrap(), None);
+    }
+}

+ 5 - 1
src/input.rs

@@ -56,7 +56,11 @@ impl<H: Hal, T: Transport> VirtIOInput<'_, H, T> {
         if let Ok((token, _)) = self.event_queue.pop_used() {
             let event = &mut self.event_buf[token as usize];
             // requeue
-            if self.event_queue.add(&[], &[event.as_buf_mut()]).is_ok() {
+            if let Ok(new_token) = self.event_queue.add(&[], &[event.as_buf_mut()]) {
+                // This only works because nothing happen between `pop_used` and `add` that affects
+                // the list of free descriptors in the queue, so `add` reuses the descriptor which
+                // was just freed by `pop_used`.
+                assert_eq!(new_token, token);
                 return Some(*event);
             }
         }

+ 1 - 1
src/lib.rs

@@ -1,6 +1,6 @@
 //! VirtIO guest drivers.
 
-#![no_std]
+#![cfg_attr(not(test), no_std)]
 #![deny(unused_must_use, missing_docs)]
 #![allow(clippy::identity_op)]
 #![allow(dead_code)]

+ 61 - 0
src/queue.rs

@@ -1,6 +1,8 @@
 use core::mem::size_of;
 use core::slice;
 use core::sync::atomic::{fence, Ordering};
+#[cfg(test)]
+use core::{cmp::min, ptr};
 
 use super::*;
 use crate::transport::Transport;
@@ -270,6 +272,65 @@ struct UsedElem {
     len: Volatile<u32>,
 }
 
+/// Simulates the device writing to a VirtIO queue, for use in tests.
+///
+/// The fake device always uses descriptors in order.
+#[cfg(test)]
+pub(crate) fn fake_write_to_queue(
+    queue_size: u16,
+    receive_queue_descriptors: *const Descriptor,
+    receive_queue_driver_area: VirtAddr,
+    receive_queue_device_area: VirtAddr,
+    data: &[u8],
+) {
+    let descriptors =
+        unsafe { slice::from_raw_parts(receive_queue_descriptors, queue_size as usize) };
+    let available_ring = receive_queue_driver_area as *const AvailRing;
+    let used_ring = receive_queue_device_area as *mut UsedRing;
+    unsafe {
+        // Make sure there is actually at least one descriptor available to write to.
+        assert_ne!((*available_ring).idx.read(), (*used_ring).idx.read());
+        // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
+        // `used_ring.idx` marks the next descriptor we should take from the available ring.
+        let next_slot = (*used_ring).idx.read() & (queue_size - 1);
+        let head_descriptor_index = (*available_ring).ring[next_slot as usize].read();
+        let mut descriptor = &descriptors[head_descriptor_index as usize];
+
+        // Loop through all descriptors in the chain, writing data to them.
+        let mut remaining_data = data;
+        loop {
+            // Check the buffer and write to it.
+            let flags = descriptor.flags.read();
+            assert!(flags.contains(DescFlags::WRITE));
+            let buffer_length = descriptor.len.read() as usize;
+            let length_to_write = min(remaining_data.len(), buffer_length);
+            ptr::copy(
+                remaining_data.as_ptr(),
+                descriptor.addr.read() as *mut u8,
+                length_to_write,
+            );
+            remaining_data = &remaining_data[length_to_write..];
+
+            if flags.contains(DescFlags::NEXT) {
+                let next = descriptor.next.read() as usize;
+                descriptor = &descriptors[next];
+            } else {
+                assert_eq!(remaining_data.len(), 0);
+                break;
+            }
+        }
+
+        // Mark the buffer as used.
+        (*used_ring).ring[next_slot as usize]
+            .id
+            .write(head_descriptor_index as u32);
+        (*used_ring).ring[next_slot as usize]
+            .len
+            .write(data.len() as u32);
+        (*used_ring).idx.update(|idx| *idx += 1);
+    }
+}
+
 #[cfg(test)]
 mod tests {
     use super::*;

+ 115 - 0
src/transport/fake.rs

@@ -0,0 +1,115 @@
+use super::{DeviceStatus, Transport};
+use crate::{
+    queue::{fake_write_to_queue, Descriptor},
+    DeviceType, PhysAddr,
+};
+use alloc::{sync::Arc, vec::Vec};
+use core::ptr::NonNull;
+use std::sync::Mutex;
+
+/// A fake implementation of [`Transport`] for unit tests.
+#[derive(Debug)]
+pub struct FakeTransport {
+    pub device_type: DeviceType,
+    pub max_queue_size: u32,
+    pub device_features: u64,
+    pub config_space: NonNull<u64>,
+    pub state: Arc<Mutex<State>>,
+}
+
+impl Transport for FakeTransport {
+    fn device_type(&self) -> DeviceType {
+        self.device_type
+    }
+
+    fn read_device_features(&mut self) -> u64 {
+        self.device_features
+    }
+
+    fn write_driver_features(&mut self, driver_features: u64) {
+        self.state.lock().unwrap().driver_features = driver_features;
+    }
+
+    fn max_queue_size(&self) -> u32 {
+        self.max_queue_size
+    }
+
+    fn notify(&mut self, queue: u32) {
+        self.state.lock().unwrap().queues[queue as usize].notified = true;
+    }
+
+    fn set_status(&mut self, status: DeviceStatus) {
+        self.state.lock().unwrap().status = status;
+    }
+
+    fn set_guest_page_size(&mut self, guest_page_size: u32) {
+        self.state.lock().unwrap().guest_page_size = guest_page_size;
+    }
+
+    fn queue_set(
+        &mut self,
+        queue: u32,
+        size: u32,
+        descriptors: PhysAddr,
+        driver_area: PhysAddr,
+        device_area: PhysAddr,
+    ) {
+        let mut state = self.state.lock().unwrap();
+        state.queues[queue as usize].size = size;
+        state.queues[queue as usize].descriptors = descriptors;
+        state.queues[queue as usize].driver_area = driver_area;
+        state.queues[queue as usize].device_area = device_area;
+    }
+
+    fn queue_used(&mut self, queue: u32) -> bool {
+        self.state.lock().unwrap().queues[queue as usize].descriptors != 0
+    }
+
+    fn ack_interrupt(&mut self) -> bool {
+        let mut state = self.state.lock().unwrap();
+        let pending = state.interrupt_pending;
+        if pending {
+            state.interrupt_pending = false;
+        }
+        pending
+    }
+
+    fn config_space(&self) -> NonNull<u64> {
+        self.config_space
+    }
+}
+
+#[derive(Debug, Default)]
+pub struct State {
+    pub status: DeviceStatus,
+    pub driver_features: u64,
+    pub guest_page_size: u32,
+    pub interrupt_pending: bool,
+    pub queues: Vec<QueueStatus>,
+}
+
+impl State {
+    /// Simulates the device writing to the given queue.
+    ///
+    /// The fake device always uses descriptors in order.
+    pub fn write_to_queue(&mut self, queue_size: u16, queue_index: usize, data: &[u8]) {
+        let receive_queue = &self.queues[queue_index];
+        assert_ne!(receive_queue.descriptors, 0);
+        fake_write_to_queue(
+            queue_size,
+            receive_queue.descriptors as *const Descriptor,
+            receive_queue.driver_area,
+            receive_queue.device_area,
+            data,
+        );
+    }
+}
+
+#[derive(Clone, Debug, Default, Eq, PartialEq)]
+pub struct QueueStatus {
+    pub size: u32,
+    pub descriptors: PhysAddr,
+    pub driver_area: PhysAddr,
+    pub device_area: PhysAddr,
+    pub notified: bool,
+}

+ 4 - 1
src/transport/mod.rs

@@ -1,3 +1,5 @@
+#[cfg(test)]
+pub mod fake;
 pub mod mmio;
 
 use crate::{PhysAddr, PAGE_SIZE};
@@ -70,6 +72,7 @@ pub trait Transport {
 
 bitflags! {
     /// The device status field.
+    #[derive(Default)]
     pub struct DeviceStatus: u32 {
         /// Indicates that the guest OS has found the device and recognized it
         /// as a valid virtio device.
@@ -99,7 +102,7 @@ bitflags! {
 
 /// Types of virtio devices.
 #[repr(u8)]
-#[derive(Debug, Eq, PartialEq)]
+#[derive(Clone, Copy, Debug, Eq, PartialEq)]
 #[allow(missing_docs)]
 pub enum DeviceType {
     Invalid = 0,