浏览代码

Implement fake reads and writes in terms of a shared function.

This also allows for more complex devices to be simulated.
Andrew Walbran 2 年之前
父节点
当前提交
280e5f9889
共有 2 个文件被更改,包括 70 次插入71 次删除
  1. 35 67
      src/queue.rs
  2. 35 4
      src/transport/fake.rs

+ 35 - 67
src/queue.rs

@@ -523,70 +523,17 @@ struct UsedElem {
     len: u32,
 }
 
-/// Simulates the device writing to a VirtIO queue, for use in tests.
+/// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
 ///
 /// The fake device always uses descriptors in order.
 #[cfg(test)]
-pub(crate) fn fake_write_to_queue<const QUEUE_SIZE: usize>(
+pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
     queue_descriptors: *const Descriptor,
     queue_driver_area: VirtAddr,
     queue_device_area: VirtAddr,
-    data: &[u8],
+    handler: impl FnOnce(Vec<u8>) -> Vec<u8>,
 ) {
-    let descriptors = ptr::slice_from_raw_parts(queue_descriptors, QUEUE_SIZE);
-    let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>;
-    let used_ring = queue_device_area as *mut UsedRing<QUEUE_SIZE>;
-    // Safe because the various pointers are properly aligned, dereferenceable, initialised, and
-    // nothing else accesses them during this block.
-    unsafe {
-        // Make sure there is actually at least one descriptor available to write to.
-        assert_ne!((*available_ring).idx, (*used_ring).idx);
-        // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
-        // `used_ring.idx` marks the next descriptor we should take from the available ring.
-        let next_slot = (*used_ring).idx & (QUEUE_SIZE as u16 - 1);
-        let head_descriptor_index = (*available_ring).ring[next_slot as usize];
-        let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
-
-        // Loop through all descriptors in the chain, writing data to them.
-        let mut remaining_data = data;
-        loop {
-            // Check the buffer and write to it.
-            let flags = descriptor.flags;
-            assert!(flags.contains(DescFlags::WRITE));
-            let buffer_length = descriptor.len as usize;
-            let length_to_write = min(remaining_data.len(), buffer_length);
-            ptr::copy(
-                remaining_data.as_ptr(),
-                descriptor.addr as *mut u8,
-                length_to_write,
-            );
-            remaining_data = &remaining_data[length_to_write..];
-
-            if let Some(next) = descriptor.next() {
-                descriptor = &(*descriptors)[next as usize];
-            } else {
-                assert_eq!(remaining_data.len(), 0);
-                break;
-            }
-        }
-
-        // Mark the buffer as used.
-        (*used_ring).ring[next_slot as usize].id = head_descriptor_index as u32;
-        (*used_ring).ring[next_slot as usize].len = data.len() as u32;
-        (*used_ring).idx += 1;
-    }
-}
-
-/// Simulates the device reading from a VirtIO queue, for use in tests.
-///
-/// The fake device always uses descriptors in order.
-#[cfg(test)]
-pub(crate) fn fake_read_from_queue<const QUEUE_SIZE: usize>(
-    queue_descriptors: *const Descriptor,
-    queue_driver_area: VirtAddr,
-    queue_device_area: VirtAddr,
-) -> Vec<u8> {
-    use core::slice;
+    use core::{ops::Deref, slice};
 
     let descriptors = ptr::slice_from_raw_parts(queue_descriptors, QUEUE_SIZE);
     let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>;
@@ -603,16 +550,12 @@ pub(crate) fn fake_read_from_queue<const QUEUE_SIZE: usize>(
         let head_descriptor_index = (*available_ring).ring[next_slot as usize];
         let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
 
-        // Loop through all descriptors in the chain, reading data from them.
+        // Loop through all input descriptors in the chain, reading data from them.
         let mut input = Vec::new();
-        loop {
-            // Check the buffer and read from it.
-            let flags = descriptor.flags;
-            assert!(!flags.contains(DescFlags::WRITE));
-            let buffer_length = descriptor.len as usize;
+        while !descriptor.flags.contains(DescFlags::WRITE) {
             input.extend_from_slice(slice::from_raw_parts(
                 descriptor.addr as *const u8,
-                buffer_length,
+                descriptor.len as usize,
             ));
 
             if let Some(next) = descriptor.next() {
@@ -621,13 +564,38 @@ pub(crate) fn fake_read_from_queue<const QUEUE_SIZE: usize>(
                 break;
             }
         }
+        let input_length = input.len();
+
+        // Let the test handle the request.
+        let output = handler(input);
+
+        // Write the response to the remaining descriptors.
+        let mut remaining_output = output.deref();
+        if descriptor.flags.contains(DescFlags::WRITE) {
+            loop {
+                assert!(descriptor.flags.contains(DescFlags::WRITE));
+
+                let length_to_write = min(remaining_output.len(), descriptor.len as usize);
+                ptr::copy(
+                    remaining_output.as_ptr(),
+                    descriptor.addr as *mut u8,
+                    length_to_write,
+                );
+                remaining_output = &remaining_output[length_to_write..];
+
+                if let Some(next) = descriptor.next() {
+                    descriptor = &(*descriptors)[next as usize];
+                } else {
+                    break;
+                }
+            }
+        }
+        assert_eq!(remaining_output.len(), 0);
 
         // Mark the buffer as used.
         (*used_ring).ring[next_slot as usize].id = head_descriptor_index as u32;
-        (*used_ring).ring[next_slot as usize].len = input.len() as u32;
+        (*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
         (*used_ring).idx += 1;
-
-        input
     }
 }
 

+ 35 - 4
src/transport/fake.rs

@@ -1,6 +1,6 @@
 use super::{DeviceStatus, DeviceType, Transport};
 use crate::{
-    queue::{fake_read_from_queue, fake_write_to_queue, Descriptor},
+    queue::{fake_read_write_queue, Descriptor},
     PhysAddr, Result,
 };
 use alloc::{sync::Arc, vec::Vec};
@@ -111,11 +111,14 @@ impl State {
     pub fn write_to_queue<const QUEUE_SIZE: usize>(&mut self, queue_index: u16, data: &[u8]) {
         let queue = &self.queues[queue_index as usize];
         assert_ne!(queue.descriptors, 0);
-        fake_write_to_queue::<QUEUE_SIZE>(
+        fake_read_write_queue::<QUEUE_SIZE>(
             queue.descriptors as *const Descriptor,
             queue.driver_area,
             queue.device_area,
-            data,
+            |input| {
+                assert_eq!(input, Vec::new());
+                data.to_owned()
+            },
         );
     }
 
@@ -127,10 +130,38 @@ impl State {
     pub fn read_from_queue<const QUEUE_SIZE: usize>(&mut self, queue_index: u16) -> Vec<u8> {
         let queue = &self.queues[queue_index as usize];
         assert_ne!(queue.descriptors, 0);
-        fake_read_from_queue::<QUEUE_SIZE>(
+
+        let mut ret = None;
+
+        // Read data from the queue but don't write any response.
+        fake_read_write_queue::<QUEUE_SIZE>(
+            queue.descriptors as *const Descriptor,
+            queue.driver_area,
+            queue.device_area,
+            |input| {
+                ret = Some(input);
+                Vec::new()
+            },
+        );
+
+        ret.unwrap()
+    }
+
+    /// Simulates the device reading data from the given queue and then writing a response back.
+    ///
+    /// The fake device always uses descriptors in order.
+    pub fn read_write_queue<const QUEUE_SIZE: usize>(
+        &mut self,
+        queue_index: u16,
+        handler: impl FnOnce(Vec<u8>) -> Vec<u8>,
+    ) {
+        let queue = &self.queues[queue_index as usize];
+        assert_ne!(queue.descriptors, 0);
+        fake_read_write_queue::<QUEUE_SIZE>(
             queue.descriptors as *const Descriptor,
             queue.driver_area,
             queue.device_area,
+            handler,
         )
     }
 }