Browse Source

Unshare buffers in pop_used.

This requires callers to pass the buffers in again, which is a good
thing anyway as it ensures they are still alive at that point.
Andrew Walbran 2 years ago
parent
commit
90d508a8aa
4 changed files with 97 additions and 27 deletions
  1. 51 12
      src/device/blk.rs
  2. 4 1
      src/device/console.rs
  3. 3 1
      src/device/input.rs
  4. 39 13
      src/queue.rs

+ 51 - 12
src/device/blk.rs

@@ -125,7 +125,8 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// * `block_id` - The identifier of the block to read.
     /// * `req` - A buffer which the driver can use for the request to send to the device. The
     ///   contents don't matter as `read_block_nb` will initialise it, but like the other buffers it
-    ///   needs to be valid (and not otherwise used) until the corresponding `pop_used` call.
+    ///   needs to be valid (and not otherwise used) until the corresponding `complete_read_block`
+    ///   call.
     /// * `buf` - The buffer in memory into which the block should be read.
     /// * `resp` - A mutable reference to a variable provided by the caller
     ///   to contain the status of the request. The caller can safely
@@ -137,8 +138,9 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// the position of the first Descriptor in the chain. If there are not enough
     /// Descriptors to allocate, then it returns [`Error::QueueFull`].
     ///
-    /// The caller can then call `pop_used` to check whether the device has finished handling the
-    /// request. Once it has, the caller can then read the response and dispose of the buffers.
+    /// The caller can then call `peek_used` with the returned token to check whether the device has
+    /// finished handling the request. Once it has, the caller must call `complete_read_block` with
+    /// the same buffers before reading the response.
     ///
     /// ```
     /// # use virtio_drivers::{Error, Hal};
@@ -153,8 +155,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// let token = unsafe { blk.read_block_nb(42, &mut request, &mut buffer, &mut response) }?;
     ///
     /// // Wait for an interrupt to tell us that the request completed...
+    /// assert_eq!(blk.peek_used(), Some(token));
     ///
-    /// assert_eq!(blk.pop_used()?, token);
+    /// unsafe {
+    ///   blk.complete_read_block(token, &request, &mut buffer, &mut response)?;
+    /// }
     /// if response.status() == RespStatus::OK {
     ///   println!("Successfully read block.");
     /// } else {
@@ -189,6 +194,24 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         Ok(token)
     }
 
+    /// Completes a read operation which was started by `read_block_nb`.
+    ///
+    /// # Safety
+    ///
+    /// The same buffers must be passed in again as were passed to `read_block_nb` when it returned
+    /// the token.
+    pub unsafe fn complete_read_block(
+        &mut self,
+        token: u16,
+        req: &BlkReq,
+        buf: &mut [u8],
+        resp: &mut BlkResp,
+    ) -> Result<()> {
+        self.queue
+            .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+        Ok(())
+    }
+
     /// Writes the contents of the given buffer to a block.
     ///
     /// Blocks until the write is complete or there is an error.
@@ -219,7 +242,8 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// * `block_id` - The identifier of the block to write.
     /// * `req` - A buffer which the driver can use for the request to send to the device. The
     ///   contents don't matter as `read_block_nb` will initialise it, but like the other buffers it
-    ///   needs to be valid (and not otherwise used) until the corresponding `pop_used` call.
+    ///   needs to be valid (and not otherwise used) until the corresponding `complete_read_block`
+    ///   call.
     /// * `buf` - The buffer in memory containing the data to write to the block.
     /// * `resp` - A mutable reference to a variable provided by the caller
     ///   to contain the status of the request. The caller can safely
@@ -252,13 +276,28 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         Ok(token)
     }
 
-    /// During an interrupt, it fetches a token of a completed request from the used
-    /// ring and return it. If all completed requests have already been fetched, return
-    /// Err(Error::NotReady).
-    pub fn pop_used(&mut self) -> Result<u16> {
-        let token = self.queue.peek_used().ok_or(Error::NotReady)?;
-        self.queue.pop_used(token)?;
-        Ok(token)
+    /// Completes a write operation which was started by `write_block_nb`.
+    ///
+    /// # Safety
+    ///
+    /// The same buffers must be passed in again as were passed to `write_block_nb` when it returned
+    /// the token.
+    pub unsafe fn complete_write_block(
+        &mut self,
+        token: u16,
+        req: &BlkReq,
+        buf: &[u8],
+        resp: &mut BlkResp,
+    ) -> Result<()> {
+        self.queue
+            .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+        Ok(())
+    }
+
+    /// Fetches the token of the next completed request from the used ring and returns it, without
+    /// removing it from the used ring. If there are no pending completed requests returns `None`.
+    pub fn peek_used(&mut self) -> Option<u16> {
+        self.queue.peek_used()
     }
 
     /// Returns the size of the device's VirtQueue.

+ 4 - 1
src/device/console.rs

@@ -136,7 +136,10 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     fn finish_receive(&mut self) -> bool {
         let mut flag = false;
         if let Some(receive_token) = self.receive_token {
-            if let Ok(len) = self.receiveq.pop_used(receive_token) {
+            if let Ok(len) = self
+                .receiveq
+                .pop_used(receive_token, &[], &[self.queue_buf_rx])
+            {
                 flag = true;
                 assert_ne!(len, 0);
                 self.cursor = 0;

+ 3 - 1
src/device/input.rs

@@ -65,8 +65,10 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
     /// Pop the pending event.
     pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
         if let Some(token) = self.event_queue.peek_used() {
-            self.event_queue.pop_used(token).ok()?;
             let event = &mut self.event_buf[token as usize];
+            self.event_queue
+                .pop_used(token, &[], &[event.as_bytes_mut()])
+                .ok()?;
             // requeue
             // Safe because buffer lasts as long as the queue.
             if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {

+ 39 - 13
src/queue.rs

@@ -171,7 +171,7 @@ impl<H: Hal> VirtQueue<H> {
             spin_loop();
         }
 
-        self.pop_used(token)
+        self.pop_used(token, inputs, outputs)
     }
 
     /// Returns a non-null pointer to the descriptor at the given index.
@@ -208,25 +208,38 @@ impl<H: Hal> VirtQueue<H> {
         (self.queue_size - self.num_used) as usize
     }
 
-    /// Recycle descriptors in the list specified by head.
+    /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
+    /// list. Unsharing may involve copying data back to the original buffers, so they must be
+    /// passed in too.
     ///
     /// This will push all linked descriptors at the front of the free list.
-    fn recycle_descriptors(&mut self, mut head: u16) {
+    fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) {
         let original_free_head = self.free_head;
         self.free_head = head;
-        loop {
-            let desc = self.desc_ptr(head);
+        let mut next = Some(head);
+
+        for (buffer, direction) in input_output_iter(inputs, outputs) {
+            let desc = self.desc_ptr(next.expect("Descriptor chain was shorter than expected."));
+
             // Safe because self.desc is properly aligned, dereferenceable and initialised, and
             // nothing else reads or writes the descriptor during this block.
-            unsafe {
+            let paddr = unsafe {
+                let paddr = (*desc).addr;
+                (*desc).unset_buf();
                 self.num_used -= 1;
-                if let Some(next) = (*desc).next() {
-                    head = next;
-                } else {
+                next = (*desc).next();
+                if next.is_none() {
                     (*desc).next = original_free_head;
-                    return;
                 }
-            }
+                paddr
+            };
+
+            // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+            H::unshare(paddr as usize, buffer, direction);
+        }
+
+        if next.is_some() {
+            panic!("Descriptor chain was longer than expected.");
         }
     }
 
@@ -234,7 +247,12 @@ impl<H: Hal> VirtQueue<H> {
     /// length which was used (written) by the device.
     ///
     /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
-    pub fn pop_used(&mut self, token: u16) -> Result<u32> {
+    pub fn pop_used(
+        &mut self,
+        token: u16,
+        inputs: &[*const [u8]],
+        outputs: &[*mut [u8]],
+    ) -> Result<u32> {
         if !self.can_pop() {
             return Err(Error::NotReady);
         }
@@ -256,7 +274,7 @@ impl<H: Hal> VirtQueue<H> {
             return Err(Error::WrongToken);
         }
 
-        self.recycle_descriptors(index);
+        self.recycle_descriptors(index, inputs, outputs);
         self.last_used_idx = self.last_used_idx.wrapping_add(1);
 
         Ok(len)
@@ -325,6 +343,14 @@ impl Descriptor {
             };
     }
 
+    /// Sets the buffer address and length to 0.
+    ///
+    /// This must only be called once the device has finished using the descriptor.
+    fn unset_buf(&mut self) {
+        self.addr = 0;
+        self.len = 0;
+    }
+
     /// Returns the index of the next descriptor in the chain if the `NEXT` flag is set, or `None`
     /// if it is not (and thus this descriptor is the end of the chain).
     fn next(&self) -> Option<u16> {