Browse Source

Use atomic reads and writes for available and used ring. (#2)

This is necessary to correctly synchronise with the host. The previous
fences were not correct.

Co-authored-by: Andrew Walbran <qwandor@google.com>
LoGin 1 year ago
parent
commit
448a7811b8
1 changed files with 43 additions and 33 deletions
  1. 43 33
      src/queue.rs

+ 43 - 33
src/queue.rs

@@ -13,7 +13,7 @@ use core::mem::{size_of, take};
 #[cfg(test)]
 use core::ptr;
 use core::ptr::NonNull;
-use core::sync::atomic::{fence, Ordering};
+use core::sync::atomic::{fence, AtomicU16, Ordering};
 use zerocopy::{AsBytes, FromBytes, FromZeroes};
 
 /// The mechanism for bulk data transport on virtio devices.
@@ -192,12 +192,11 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         self.avail_idx = self.avail_idx.wrapping_add(1);
         // Safe because self.avail is properly aligned, dereferenceable and initialised.
         unsafe {
-            (*self.avail.as_ptr()).idx = self.avail_idx;
+            (*self.avail.as_ptr())
+                .idx
+                .store(self.avail_idx, Ordering::Release);
         }
 
-        // Write barrier so that device can see change to available index after this method returns.
-        fence(Ordering::SeqCst);
-
         Ok(head)
     }
 
@@ -324,10 +323,12 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         if !self.event_idx {
             // Safe because self.avail points to a valid, aligned, initialised, dereferenceable, readable
             // instance of AvailRing.
-            unsafe { (*self.avail.as_ptr()).flags = avail_ring_flags }
+            unsafe {
+                (*self.avail.as_ptr())
+                    .flags
+                    .store(avail_ring_flags, Ordering::Release)
+            }
         }
-        // Write barrier so that device can see change to available index after this method returns.
-        fence(Ordering::SeqCst);
     }
 
     /// Returns whether the driver should notify the device after adding a new buffer to the
@@ -335,18 +336,15 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ///
     /// This will be false if the device has supressed notifications.
     pub fn should_notify(&self) -> bool {
-        // Read barrier, so we read a fresh value from the device.
-        fence(Ordering::SeqCst);
-
         if self.event_idx {
             // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
             // instance of UsedRing.
-            let avail_event = unsafe { (*self.used.as_ptr()).avail_event };
+            let avail_event = unsafe { (*self.used.as_ptr()).avail_event.load(Ordering::Acquire) };
             self.avail_idx >= avail_event.wrapping_add(1)
         } else {
             // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
             // instance of UsedRing.
-            unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 }
+            unsafe { (*self.used.as_ptr()).flags.load(Ordering::Acquire) & 0x0001 == 0 }
         }
     }
 
@@ -363,12 +361,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
 
     /// Returns whether there is a used element that can be popped.
     pub fn can_pop(&self) -> bool {
-        // Read barrier, so we read a fresh value from the device.
-        fence(Ordering::SeqCst);
-
         // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
         // instance of UsedRing.
-        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx }
+        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx.load(Ordering::Acquire) }
     }
 
     /// Returns the descriptor index (a.k.a. token) of the next used element without popping it, or
@@ -506,7 +501,6 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         if !self.can_pop() {
             return Err(Error::NotReady);
         }
-        // Read barrier not necessary, as can_pop already has one.
 
         // Get the index of the start of the descriptor chain for the next element in the used ring.
         let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
@@ -532,7 +526,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
 
         if self.event_idx {
             unsafe {
-                (*self.avail.as_ptr()).used_event = self.last_used_idx;
+                (*self.avail.as_ptr())
+                    .used_event
+                    .store(self.last_used_idx, Ordering::Release);
             }
         }
 
@@ -761,12 +757,12 @@ bitflags! {
 #[repr(C)]
 #[derive(Debug)]
 struct AvailRing<const SIZE: usize> {
-    flags: u16,
+    flags: AtomicU16,
     /// A driver MUST NOT decrement the idx.
-    idx: u16,
+    idx: AtomicU16,
     ring: [u16; SIZE],
     /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
-    used_event: u16,
+    used_event: AtomicU16,
 }
 
 /// The used ring is where the device returns buffers once it is done with them:
@@ -774,11 +770,11 @@ struct AvailRing<const SIZE: usize> {
 #[repr(C)]
 #[derive(Debug)]
 struct UsedRing<const SIZE: usize> {
-    flags: u16,
-    idx: u16,
+    flags: AtomicU16,
+    idx: AtomicU16,
     ring: [UsedElem; SIZE],
     /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
-    avail_event: u16,
+    avail_event: AtomicU16,
 }
 
 #[repr(C)]
@@ -847,10 +843,13 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
     // nothing else accesses them during this block.
     unsafe {
         // Make sure there is actually at least one descriptor available to read from.
-        assert_ne!((*available_ring).idx, (*used_ring).idx);
+        assert_ne!(
+            (*available_ring).idx.load(Ordering::Acquire),
+            (*used_ring).idx.load(Ordering::Acquire)
+        );
         // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
         // `used_ring.idx` marks the next descriptor we should take from the available ring.
-        let next_slot = (*used_ring).idx & (QUEUE_SIZE as u16 - 1);
+        let next_slot = (*used_ring).idx.load(Ordering::Acquire) & (QUEUE_SIZE as u16 - 1);
         let head_descriptor_index = (*available_ring).ring[next_slot as usize];
         let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
 
@@ -951,7 +950,7 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
         // Mark the buffer as used.
         (*used_ring).ring[next_slot as usize].id = head_descriptor_index as u32;
         (*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
-        (*used_ring).idx += 1;
+        (*used_ring).idx.fetch_add(1, Ordering::AcqRel);
     }
 }
 
@@ -1156,17 +1155,26 @@ mod tests {
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
 
         // Check that the avail ring's flag is zero by default.
-        assert_eq!(unsafe { (*queue.avail.as_ptr()).flags }, 0x0);
+        assert_eq!(
+            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
+            0x0
+        );
 
         queue.set_dev_notify(false);
 
         // Check that the avail ring's flag is 1 after `disable_dev_notify`.
-        assert_eq!(unsafe { (*queue.avail.as_ptr()).flags }, 0x1);
+        assert_eq!(
+            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
+            0x1
+        );
 
         queue.set_dev_notify(true);
 
         // Check that the avail ring's flag is 0 after `enable_dev_notify`.
-        assert_eq!(unsafe { (*queue.avail.as_ptr()).flags }, 0x0);
+        assert_eq!(
+            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
+            0x0
+        );
     }
 
     /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
@@ -1197,7 +1205,7 @@ mod tests {
         // initialised, and nothing else is accessing them at the same time.
         unsafe {
             // Suppress notifications.
-            (*queue.used.as_ptr()).flags = 0x01;
+            (*queue.used.as_ptr()).flags.store(0x01, Ordering::Release);
         }
 
         // Check that the transport would not be notified.
@@ -1232,7 +1240,9 @@ mod tests {
         // initialised, and nothing else is accessing them at the same time.
         unsafe {
             // Suppress notifications.
-            (*queue.used.as_ptr()).avail_event = 1;
+            (*queue.used.as_ptr())
+                .avail_event
+                .store(1, Ordering::Release);
         }
 
         // Check that the transport would not be notified.