Browse Source

Keep shadow copy of descriptor table.

This means we don't have to read back from the memory shared with the
device, which is better in case a malicious device changes it.
Andrew Walbran 2 years ago
parent
commit
ea3de12e98
1 changed files with 53 additions and 31 deletions
  1. 53 31
      src/queue.rs

+ 53 - 31
src/queue.rs

@@ -10,8 +10,9 @@ use core::hint::spin_loop;
 use core::mem::size_of;
 #[cfg(test)]
 use core::ptr;
-use core::ptr::{addr_of_mut, NonNull};
+use core::ptr::NonNull;
 use core::sync::atomic::{fence, Ordering};
+use zerocopy::FromBytes;
 
 /// The mechanism for bulk data transport on virtio devices.
 ///
@@ -24,8 +25,16 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> {
     /// DMA guard
     layout: VirtQueueLayout<H>,
     /// Descriptor table
+    ///
+    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
+    /// trust values read back from it. Use `desc_shadow` instead to keep track of what we wrote to
+    /// it.
     desc: NonNull<[Descriptor]>,
     /// Available ring
+    ///
+    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
+    /// trust values read back from it. The only field we need to read currently is `idx`, so we
+    /// have `avail_idx` below to use instead.
     avail: NonNull<AvailRing<SIZE>>,
     /// Used ring
     used: NonNull<UsedRing<SIZE>>,
@@ -36,6 +45,9 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> {
     num_used: u16,
     /// The head desc index of the free list.
     free_head: u16,
+    /// Our trusted copy of `desc` that the device can't access.
+    desc_shadow: [Descriptor; SIZE],
+    /// Our trusted copy of `avail.idx`.
     avail_idx: u16,
     last_used_idx: u16,
 }
@@ -73,8 +85,10 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         let avail = layout.avail_vaddr().cast();
         let used = layout.used_vaddr().cast();
 
+        let mut desc_shadow: [Descriptor; SIZE] = FromBytes::new_zeroed();
         // Link descriptors together.
         for i in 0..(size - 1) {
+            desc_shadow[i as usize].next = i + 1;
             // Safe because `desc` is properly aligned, dereferenceable, initialised, and the device
             // won't access the descriptors for the duration of this unsafe block.
             unsafe {
@@ -90,6 +104,7 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             queue_idx: idx,
             num_used: 0,
             free_head: 0,
+            desc_shadow,
             avail_idx: 0,
             last_used_idx: 0,
         })
@@ -114,19 +129,22 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         let head = self.free_head;
         let mut last = self.free_head;
 
-        // Safe because self.desc is properly aligned, dereferenceable and initialised, and nothing
-        // else reads or writes the free descriptors during this block.
-        unsafe {
-            for (buffer, direction) in input_output_iter(inputs, outputs) {
-                let desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(buffer, direction, DescFlags::NEXT);
-                last = self.free_head;
-                self.free_head = (*desc).next;
-            }
+        for (buffer, direction) in input_output_iter(inputs, outputs) {
+            // Write to desc_shadow then copy.
+            let desc = &mut self.desc_shadow[usize::from(self.free_head)];
+            desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
+            last = self.free_head;
+            self.free_head = desc.next;
 
-            // set last_elem.next = NULL
-            (*self.desc_ptr(last)).flags.remove(DescFlags::NEXT);
+            self.write_desc(last);
         }
+
+        // set last_elem.next = NULL
+        self.desc_shadow[usize::from(last)]
+            .flags
+            .remove(DescFlags::NEXT);
+        self.write_desc(last);
+
         self.num_used += (inputs.len() + outputs.len()) as u16;
 
         let avail_slot = self.avail_idx & (SIZE as u16 - 1);
@@ -192,10 +210,15 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 }
     }
 
-    /// Returns a non-null pointer to the descriptor at the given index.
-    fn desc_ptr(&mut self, index: u16) -> *mut Descriptor {
-        // Safe because self.desc is properly aligned and dereferenceable.
-        unsafe { addr_of_mut!((*self.desc.as_ptr())[index as usize]) }
+    /// Copies the descriptor at the given index from `desc_shadow` to `desc`, so it can be seen by
+    /// the device.
+    fn write_desc(&mut self, index: u16) {
+        let index = usize::from(index);
+        // Safe because self.desc is properly aligned, dereferenceable and initialised, and nothing
+        // else reads or writes the descriptor during this block.
+        unsafe {
+            (*self.desc.as_ptr())[index] = self.desc_shadow[index].clone();
+        }
     }
 
     /// Returns whether there is a used element that can be popped.
@@ -237,20 +260,18 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         let mut next = Some(head);
 
         for (buffer, direction) in input_output_iter(inputs, outputs) {
-            let desc = self.desc_ptr(next.expect("Descriptor chain was shorter than expected."));
-
-            // Safe because self.desc is properly aligned, dereferenceable and initialised, and
-            // nothing else reads or writes the descriptor during this block.
-            let paddr = unsafe {
-                let paddr = (*desc).addr;
-                (*desc).unset_buf();
-                self.num_used -= 1;
-                next = (*desc).next();
-                if next.is_none() {
-                    (*desc).next = original_free_head;
-                }
-                paddr
-            };
+            let desc_index = next.expect("Descriptor chain was shorter than expected.");
+            let desc = &mut self.desc_shadow[usize::from(desc_index)];
+
+            let paddr = desc.addr;
+            desc.unset_buf();
+            self.num_used -= 1;
+            next = desc.next();
+            if next.is_none() {
+                desc.next = original_free_head;
+            }
+
+            self.write_desc(desc_index);
 
             // Unshare the buffer (and perhaps copy its contents back to the original buffer).
             H::unshare(paddr as usize, buffer, direction);
@@ -447,7 +468,7 @@ fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) {
 }
 
 #[repr(C, align(16))]
-#[derive(Debug)]
+#[derive(Clone, Debug, FromBytes)]
 pub(crate) struct Descriptor {
     addr: u64,
     len: u32,
@@ -500,6 +521,7 @@ impl Descriptor {
 
 bitflags! {
     /// Descriptor flags
+    #[derive(FromBytes)]
     struct DescFlags: u16 {
         const NEXT = 1;
         const WRITE = 2;