浏览代码

Try taking buffers by reference rather than pointer.

Andrew Walbran 2 年之前
父节点
当前提交
62fbbdc164
共有 6 个文件被更改,包括 90 次插入67 次删除
  1. 6 6
      src/device/blk.rs
  2. 3 4
      src/device/console.rs
  3. 1 1
      src/device/gpu.rs
  4. 4 3
      src/device/input.rs
  5. 10 5
      src/device/net.rs
  6. 66 48
      src/queue.rs

+ 6 - 6
src/device/blk.rs

@@ -109,7 +109,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
             &[req.as_bytes()],
-            &[buf, resp.as_bytes_mut()],
+            &mut [buf, resp.as_bytes_mut()],
             &mut self.transport,
         )?;
         resp.status.into()
@@ -187,7 +187,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         let token = self
             .queue
-            .add(&[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+            .add(&[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
         if self.queue.should_notify() {
             self.transport.notify(QUEUE);
         }
@@ -208,7 +208,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         resp: &mut BlkResp,
     ) -> Result<()> {
         self.queue
-            .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+            .pop_used(token, &[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
         resp.status.into()
     }
 
@@ -225,7 +225,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
             &[req.as_bytes(), buf],
-            &[resp.as_bytes_mut()],
+            &mut [resp.as_bytes_mut()],
             &mut self.transport,
         )?;
         resp.status.into()
@@ -268,7 +268,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         let token = self
             .queue
-            .add(&[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+            .add(&[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
         if self.queue.should_notify() {
             self.transport.notify(QUEUE);
         }
@@ -289,7 +289,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         resp: &mut BlkResp,
     ) -> Result<()> {
         self.queue
-            .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+            .pop_used(token, &[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
         resp.status.into()
     }
 

+ 3 - 4
src/device/console.rs

@@ -118,7 +118,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
         if self.receive_token.is_none() && self.cursor == self.pending_len {
             // Safe because the buffer lasts at least as long as the queue, and there are no other
             // outstanding requests using the buffer.
-            self.receive_token = Some(unsafe { self.receiveq.add(&[], &[self.queue_buf_rx]) }?);
+            self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?);
             if self.receiveq.should_notify() {
                 self.transport.notify(QUEUE_RECEIVEQ_PORT_0);
             }
@@ -149,7 +149,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
                 // `poll_retrieve` and it is still valid.
                 let len = unsafe {
                     self.receiveq
-                        .pop_used(receive_token, &[], &[self.queue_buf_rx])?
+                        .pop_used(receive_token, &[], &mut [self.queue_buf_rx])?
                 };
                 flag = true;
                 assert_ne!(len, 0);
@@ -179,9 +179,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     /// Sends a character to the console.
     pub fn send(&mut self, chr: u8) -> Result<()> {
         let buf: [u8; 1] = [chr];
-        // Safe because the buffer is valid until we pop_used below.
         self.transmitq
-            .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
+            .add_notify_wait_pop(&[&buf], &mut [], &mut self.transport)?;
         Ok(())
     }
 }

+ 1 - 1
src/device/gpu.rs

@@ -178,7 +178,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
         self.control_queue.add_notify_wait_pop(
             &[self.queue_buf_send],
-            &[self.queue_buf_recv],
+            &mut [self.queue_buf_recv],
             &mut self.transport,
         )?;
         Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap())

+ 4 - 3
src/device/input.rs

@@ -42,7 +42,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
-            let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? };
+            let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? };
             assert_eq!(token, i as u16);
         }
         if event_queue.should_notify() {
@@ -73,12 +73,13 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
             // is still valid.
             unsafe {
                 self.event_queue
-                    .pop_used(token, &[], &[event.as_bytes_mut()])
+                    .pop_used(token, &[], &mut [event.as_bytes_mut()])
                     .ok()?;
             }
             // requeue
             // Safe because buffer lasts as long as the queue.
-            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {
+            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &mut [event.as_bytes_mut()]) }
+            {
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // was just freed by `pop_used`.

+ 10 - 5
src/device/net.rs

@@ -81,9 +81,11 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
         let mut header = MaybeUninit::<Header>::uninit();
         let header_buf = unsafe { (*header.as_mut_ptr()).as_bytes_mut() };
-        let len =
-            self.recv_queue
-                .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
+        let len = self.recv_queue.add_notify_wait_pop(
+            &[],
+            &mut [header_buf, buf],
+            &mut self.transport,
+        )?;
         // let header = unsafe { header.assume_init() };
         Ok(len as usize - size_of::<Header>())
     }
@@ -91,8 +93,11 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Send a packet.
     pub fn send(&mut self, buf: &[u8]) -> Result {
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
-        self.send_queue
-            .add_notify_wait_pop(&[header.as_bytes(), buf], &[], &mut self.transport)?;
+        self.send_queue.add_notify_wait_pop(
+            &[header.as_bytes(), buf],
+            &mut [],
+            &mut self.transport,
+        )?;
         Ok(())
     }
 }

+ 66 - 48
src/queue.rs

@@ -5,7 +5,7 @@ use bitflags::bitflags;
 #[cfg(test)]
 use core::cmp::min;
 use core::hint::spin_loop;
-use core::mem::size_of;
+use core::mem::{size_of, take};
 #[cfg(test)]
 use core::ptr;
 use core::ptr::NonNull;
@@ -116,7 +116,11 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ///
     /// The input and output buffers must remain valid and not be accessed until a call to
     /// `pop_used` with the returned token succeeds.
-    pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
+    pub unsafe fn add<'a, 'b>(
+        &mut self,
+        inputs: &'a [&'b [u8]],
+        outputs: &'a mut [&'b mut [u8]],
+    ) -> Result<u16> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
         }
@@ -128,7 +132,7 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         let head = self.free_head;
         let mut last = self.free_head;
 
-        for (buffer, direction) in input_output_iter(inputs, outputs) {
+        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
             // Write to desc_shadow then copy.
             let desc = &mut self.desc_shadow[usize::from(self.free_head)];
             desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
@@ -173,14 +177,14 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// them, then pops them.
     ///
     /// This assumes that the device isn't processing any other buffers at the same time.
-    pub fn add_notify_wait_pop(
+    pub fn add_notify_wait_pop<'a>(
         &mut self,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
         transport: &mut impl Transport,
     ) -> Result<u32> {
-        // Safe because we don't return until the same token has been popped, so they remain valid
-        // until then.
+        // Safe because we don't return until the same token has been popped, so the buffers remain
+        // valid and are not otherwise accessed until then.
         let token = unsafe { self.add(inputs, outputs) }?;
 
         // Notify the queue.
@@ -258,20 +262,19 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ///
     /// # Safety
     ///
-    /// The buffers in `inputs` and `outputs` must be valid pointers to memory which is not accessed
-    /// by any other thread for the duration of this method call, and must match the set of buffers
-    /// originally added to the queue by `add`.
-    unsafe fn recycle_descriptors(
+    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+    /// queue by `add`.
+    unsafe fn recycle_descriptors<'a>(
         &mut self,
         head: u16,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
     ) {
         let original_free_head = self.free_head;
         self.free_head = head;
         let mut next = Some(head);
 
-        for (buffer, direction) in input_output_iter(inputs, outputs) {
+        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
             let desc_index = next.expect("Descriptor chain was shorter than expected.");
             let desc = &mut self.desc_shadow[usize::from(desc_index)];
 
@@ -305,14 +308,13 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ///
     /// # Safety
     ///
-    /// The buffers in `inputs` and `outputs` must be valid pointers to memory which is not accessed
-    /// by any other thread for the duration of this method call, and must match the set of buffers
-    /// originally added to the queue by `add`.
-    pub unsafe fn pop_used(
+    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+    /// queue by `add` when it returned the token being passed in here.
+    pub unsafe fn pop_used<'a>(
         &mut self,
         token: u16,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
     ) -> Result<u32> {
         if !self.can_pop() {
             return Err(Error::NotReady);
@@ -585,6 +587,46 @@ struct UsedElem {
     len: u32,
 }
 
+struct InputOutputIter<'a, 'b> {
+    inputs: &'a [&'b [u8]],
+    outputs: &'a mut [&'b mut [u8]],
+}
+
+impl<'a, 'b> InputOutputIter<'a, 'b> {
+    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
+        Self { inputs, outputs }
+    }
+}
+
+impl<'a, 'b> Iterator for InputOutputIter<'a, 'b> {
+    type Item = (NonNull<[u8]>, BufferDirection);
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if let Some(input) = take_first(&mut self.inputs) {
+            Some(((*input).into(), BufferDirection::DriverToDevice))
+        } else {
+            let output = take_first_mut(&mut self.outputs)?;
+            Some(((*output).into(), BufferDirection::DeviceToDriver))
+        }
+    }
+}
+
+// TODO: Use `slice::take_first` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
+    let (first, rem) = slice.split_first()?;
+    *slice = rem;
+    Some(first)
+}
+
+// TODO: Use `slice::take_first_mut` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
+    let (first, rem) = take(slice).split_first_mut()?;
+    *slice = rem;
+    Some(first)
+}
+
 /// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
 ///
 /// The fake device always uses descriptors in order.
@@ -707,7 +749,7 @@ mod tests {
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(
-            unsafe { queue.add(&[], &[]) }.unwrap_err(),
+            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -719,7 +761,7 @@ mod tests {
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
-            unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
+            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
             Error::QueueFull
         );
     }
@@ -733,7 +775,7 @@ mod tests {
 
         // Add a buffer chain consisting of two device-readable parts followed by two
         // device-writable parts.
-        let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
+        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
 
         assert_eq!(queue.available_desc(), 0);
         assert!(!queue.can_pop());
@@ -784,27 +826,3 @@ mod tests {
         }
     }
 }
-
-/// Returns an iterator over the buffers of first `inputs` and then `outputs`, paired with the
-/// corresponding `BufferDirection`.
-///
-/// Panics if any of the buffer pointers is null.
-fn input_output_iter<'a>(
-    inputs: &'a [*const [u8]],
-    outputs: &'a [*mut [u8]],
-) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a {
-    inputs
-        .iter()
-        .map(|input| {
-            (
-                NonNull::new(*input as *mut [u8]).unwrap(),
-                BufferDirection::DriverToDevice,
-            )
-        })
-        .chain(outputs.iter().map(|output| {
-            (
-                NonNull::new(*output).unwrap(),
-                BufferDirection::DeviceToDriver,
-            )
-        }))
-}