Browse Source

VirtQueue::pop_used should also be marked as unsafe.

Andrew Walbran 2 years ago
parent
commit
2caa561a86
3 changed files with 25 additions and 9 deletions
  1. 6 3
      src/device/console.rs
  2. 7 3
      src/device/input.rs
  3. 12 3
      src/queue.rs

+ 6 - 3
src/device/console.rs

@@ -145,9 +145,12 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
         let mut flag = false;
         if let Some(receive_token) = self.receive_token {
             if self.receive_token == self.receiveq.peek_used() {
-                let len = self
-                    .receiveq
-                    .pop_used(receive_token, &[], &[self.queue_buf_rx])?;
+                // Safe because we are passing the same buffer as we passed to `VirtQueue::add` in
+                // `poll_retrieve` and it is still valid.
+                let len = unsafe {
+                    self.receiveq
+                        .pop_used(receive_token, &[], &[self.queue_buf_rx])?
+                };
                 flag = true;
                 assert_ne!(len, 0);
                 self.cursor = 0;

+ 7 - 3
src/device/input.rs

@@ -69,9 +69,13 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
     pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
         if let Some(token) = self.event_queue.peek_used() {
             let event = &mut self.event_buf[token as usize];
-            self.event_queue
-                .pop_used(token, &[], &[event.as_bytes_mut()])
-                .ok()?;
+            // Safe because we are passing the same buffer as we passed to `VirtQueue::add` and it
+            // is still valid.
+            unsafe {
+                self.event_queue
+                    .pop_used(token, &[], &[event.as_bytes_mut()])
+                    .ok()?;
+            }
             // requeue
             // Safe because buffer lasts as long as the queue.
             if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {

+ 12 - 3
src/queue.rs

@@ -114,7 +114,8 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ///
     /// # Safety
     ///
-    /// The input and output buffers must remain valid until the token is returned by `pop_used`.
+    /// The input and output buffers must remain valid and not be accessed until a call to
+    /// `pop_used` with the returned token succeeds.
     pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
@@ -192,7 +193,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             spin_loop();
         }
 
-        self.pop_used(token, inputs, outputs)
+        // Safe because these are the same buffers as we passed to `add` above and they are still
+        // valid.
+        unsafe { self.pop_used(token, inputs, outputs) }
     }
 
     /// Returns whether the driver should notify the device after adding a new buffer to the
@@ -299,7 +302,13 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// length which was used (written) by the device.
     ///
     /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
-    pub fn pop_used(
+    ///
+    /// # Safety
+    ///
+    /// The buffers in `inputs` and `outputs` must be valid pointers to memory which is not accessed
+    /// by any other thread for the duration of this method call, and must match the set of buffers
+    /// originally added to the queue by `add`.
+    pub unsafe fn pop_used(
         &mut self,
         token: u16,
         inputs: &[*const [u8]],