Browse Source

Merge pull request #60 from rcore-os/safety

Mark parts of the `Hal` trait unsafe
Andrew Walbran 2 years ago
parent
commit
67183ee262

+ 5 - 5
examples/aarch64/src/hal.rs

@@ -17,7 +17,7 @@ lazy_static! {
 
 pub struct HalImpl;
 
-impl Hal for HalImpl {
+unsafe impl Hal for HalImpl {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
@@ -25,22 +25,22 @@ impl Hal for HalImpl {
         (paddr, vaddr)
     }
 
-    fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
         trace!("dealloc DMA: paddr={:#x}, pages={}", paddr, pages);
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }

+ 5 - 5
examples/riscv/src/virtio_impl.rs

@@ -16,7 +16,7 @@ lazy_static! {
 
 pub struct HalImpl;
 
-impl Hal for HalImpl {
+unsafe impl Hal for HalImpl {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
@@ -24,22 +24,22 @@ impl Hal for HalImpl {
         (paddr, vaddr)
     }
 
-    fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
         trace!("dealloc DMA: paddr={:#x}, pages={}", paddr, pages);
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }

+ 6 - 6
src/device/blk.rs

@@ -109,7 +109,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
             &[req.as_bytes()],
-            &[buf, resp.as_bytes_mut()],
+            &mut [buf, resp.as_bytes_mut()],
             &mut self.transport,
         )?;
         resp.status.into()
@@ -187,7 +187,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         let token = self
             .queue
-            .add(&[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+            .add(&[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
         if self.queue.should_notify() {
             self.transport.notify(QUEUE);
         }
@@ -208,7 +208,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         resp: &mut BlkResp,
     ) -> Result<()> {
         self.queue
-            .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+            .pop_used(token, &[req.as_bytes()], &mut [buf, resp.as_bytes_mut()])?;
         resp.status.into()
     }
 
@@ -225,7 +225,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
             &[req.as_bytes(), buf],
-            &[resp.as_bytes_mut()],
+            &mut [resp.as_bytes_mut()],
             &mut self.transport,
         )?;
         resp.status.into()
@@ -268,7 +268,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         let token = self
             .queue
-            .add(&[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+            .add(&[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
         if self.queue.should_notify() {
             self.transport.notify(QUEUE);
         }
@@ -289,7 +289,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         resp: &mut BlkResp,
     ) -> Result<()> {
         self.queue
-            .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+            .pop_used(token, &[req.as_bytes(), buf], &mut [resp.as_bytes_mut()])?;
         resp.status.into()
     }
 

+ 8 - 6
src/device/console.rs

@@ -118,7 +118,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
         if self.receive_token.is_none() && self.cursor == self.pending_len {
             // Safe because the buffer lasts at least as long as the queue, and there are no other
             // outstanding requests using the buffer.
-            self.receive_token = Some(unsafe { self.receiveq.add(&[], &[self.queue_buf_rx]) }?);
+            self.receive_token = Some(unsafe { self.receiveq.add(&[], &mut [self.queue_buf_rx]) }?);
             if self.receiveq.should_notify() {
                 self.transport.notify(QUEUE_RECEIVEQ_PORT_0);
             }
@@ -145,9 +145,12 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
         let mut flag = false;
         if let Some(receive_token) = self.receive_token {
             if self.receive_token == self.receiveq.peek_used() {
-                let len = self
-                    .receiveq
-                    .pop_used(receive_token, &[], &[self.queue_buf_rx])?;
+                // Safe because we are passing the same buffer as we passed to `VirtQueue::add` in
+                // `poll_retrieve` and it is still valid.
+                let len = unsafe {
+                    self.receiveq
+                        .pop_used(receive_token, &[], &mut [self.queue_buf_rx])?
+                };
                 flag = true;
                 assert_ne!(len, 0);
                 self.cursor = 0;
@@ -176,9 +179,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     /// Sends a character to the console.
     pub fn send(&mut self, chr: u8) -> Result<()> {
         let buf: [u8; 1] = [chr];
-        // Safe because the buffer is valid until we pop_used below.
         self.transmitq
-            .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
+            .add_notify_wait_pop(&[&buf], &mut [], &mut self.transport)?;
         Ok(())
     }
 }

+ 1 - 1
src/device/gpu.rs

@@ -178,7 +178,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         req.write_to_prefix(&mut *self.queue_buf_send).unwrap();
         self.control_queue.add_notify_wait_pop(
             &[self.queue_buf_send],
-            &[self.queue_buf_recv],
+            &mut [self.queue_buf_recv],
             &mut self.transport,
         )?;
         Ok(Rsp::read_from_prefix(&*self.queue_buf_recv).unwrap())

+ 10 - 5
src/device/input.rs

@@ -42,7 +42,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
-            let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? };
+            let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? };
             assert_eq!(token, i as u16);
         }
         if event_queue.should_notify() {
@@ -69,12 +69,17 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
     pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
         if let Some(token) = self.event_queue.peek_used() {
             let event = &mut self.event_buf[token as usize];
-            self.event_queue
-                .pop_used(token, &[], &[event.as_bytes_mut()])
-                .ok()?;
+            // Safe because we are passing the same buffer as we passed to `VirtQueue::add` and it
+            // is still valid.
+            unsafe {
+                self.event_queue
+                    .pop_used(token, &[], &mut [event.as_bytes_mut()])
+                    .ok()?;
+            }
             // requeue
             // Safe because buffer lasts as long as the queue.
-            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {
+            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &mut [event.as_bytes_mut()]) }
+            {
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // was just freed by `pop_used`.

+ 10 - 5
src/device/net.rs

@@ -81,9 +81,11 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
         let mut header = MaybeUninit::<Header>::uninit();
         let header_buf = unsafe { (*header.as_mut_ptr()).as_bytes_mut() };
-        let len =
-            self.recv_queue
-                .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
+        let len = self.recv_queue.add_notify_wait_pop(
+            &[],
+            &mut [header_buf, buf],
+            &mut self.transport,
+        )?;
         // let header = unsafe { header.assume_init() };
         Ok(len as usize - size_of::<Header>())
     }
@@ -91,8 +93,11 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Send a packet.
     pub fn send(&mut self, buf: &[u8]) -> Result {
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
-        self.send_queue
-            .add_notify_wait_pop(&[header.as_bytes(), buf], &[], &mut self.transport)?;
+        self.send_queue.add_notify_wait_pop(
+            &[header.as_bytes(), buf],
+            &mut [],
+            &mut self.transport,
+        )?;
         Ok(())
     }
 }

+ 52 - 6
src/hal.rs

@@ -53,35 +53,81 @@ impl<H: Hal> Dma<H> {
 
 impl<H: Hal> Drop for Dma<H> {
     fn drop(&mut self) {
-        let err = H::dma_dealloc(self.paddr, self.vaddr, self.pages);
+        // Safe because the memory was previously allocated by `dma_alloc` in `Dma::new`, not yet
+        // deallocated, and we are passing the values from then.
+        let err = unsafe { H::dma_dealloc(self.paddr, self.vaddr, self.pages) };
         assert_eq!(err, 0, "failed to deallocate DMA");
     }
 }
 
 /// The interface which a particular hardware implementation must implement.
-pub trait Hal {
+///
+/// # Safety
+///
+/// Implementations of this trait must follow the "implementation safety" requirements documented
+/// for each method. Callers must follow the safety requirements documented for the unsafe methods.
+pub unsafe trait Hal {
     /// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use.
     ///
     /// Returns both the physical address which the device can use to access the memory, and a
     /// pointer to the start of it which the driver can use to access it.
+    ///
+    /// # Implementation safety
+    ///
+    /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+    /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to
+    /// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it
+    /// is deallocated by `dma_dealloc`.
     fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>);
+
     /// Deallocates the given contiguous physical DMA memory pages.
-    fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+    ///
+    /// # Safety
+    ///
+    /// The memory must have been allocated by `dma_alloc` on the same `Hal` implementation, and not
+    /// yet deallocated. `pages` must be the same number passed to `dma_alloc` originally, and both
+    /// `paddr` and `vaddr` must be the values returned by `dma_alloc`.
+    unsafe fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+
     /// Converts a physical address used for MMIO to a virtual address which the driver can access.
     ///
     /// This is only used for MMIO addresses within BARs read from the device, for the PCI
     /// transport. It may check that the address range up to the given size is within the region
     /// expected for MMIO.
-    fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+    ///
+    /// # Implementation safety
+    ///
+    /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+    /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, and won't alias any
+    /// other allocations or references in the program.
+    ///
+    /// # Safety
+    ///
+    /// The `paddr` and `size` must describe a valid MMIO region. The implementation may validate it
+    /// in some way (and panic if it is invalid) but is not guaranteed to.
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+
     /// Shares the given memory range with the device, and returns the physical address that the
     /// device can use to access it.
     ///
     /// This may involve mapping the buffer into an IOMMU, giving the host permission to access the
     /// memory, or copying it to a special region where it can be accessed.
-    fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+    ///
+    /// # Safety
+    ///
+    /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+    /// for the duration of this method call.
+    unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+
     /// Unshares the given memory range from the device and (if necessary) copies it back to the
     /// original buffer.
-    fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
+    ///
+    /// # Safety
+    ///
+    /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+    /// for the duration of this method call. The `paddr` must be the value previously returned by
+    /// the corresponding `share` call.
+    unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
 }
 
 /// The direction in which a buffer is passed.

+ 5 - 5
src/hal/fake.rs

@@ -8,7 +8,7 @@ use core::{alloc::Layout, ptr::NonNull};
 pub struct FakeHal;
 
 /// Fake HAL implementation for use in unit tests.
-impl Hal for FakeHal {
+unsafe impl Hal for FakeHal {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
@@ -21,7 +21,7 @@ impl Hal for FakeHal {
         }
     }
 
-    fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
         // Safe because the layout is the same as was used when the memory was allocated by
@@ -32,17 +32,17 @@ impl Hal for FakeHal {
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }

+ 90 - 45
src/queue.rs

@@ -5,7 +5,7 @@ use bitflags::bitflags;
 #[cfg(test)]
 use core::cmp::min;
 use core::hint::spin_loop;
-use core::mem::size_of;
+use core::mem::{size_of, take};
 #[cfg(test)]
 use core::ptr;
 use core::ptr::NonNull;
@@ -114,8 +114,13 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     ///
     /// # Safety
     ///
-    /// The input and output buffers must remain valid until the token is returned by `pop_used`.
-    pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
+    /// The input and output buffers must remain valid and not be accessed until a call to
+    /// `pop_used` with the returned token succeeds.
+    pub unsafe fn add<'a, 'b>(
+        &mut self,
+        inputs: &'a [&'b [u8]],
+        outputs: &'a mut [&'b mut [u8]],
+    ) -> Result<u16> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
         }
@@ -127,7 +132,7 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         let head = self.free_head;
         let mut last = self.free_head;
 
-        for (buffer, direction) in input_output_iter(inputs, outputs) {
+        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
             // Write to desc_shadow then copy.
             let desc = &mut self.desc_shadow[usize::from(self.free_head)];
             desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
@@ -172,14 +177,14 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// them, then pops them.
     ///
     /// This assumes that the device isn't processing any other buffers at the same time.
-    pub fn add_notify_wait_pop(
+    pub fn add_notify_wait_pop<'a>(
         &mut self,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
         transport: &mut impl Transport,
     ) -> Result<u32> {
-        // Safe because we don't return until the same token has been popped, so they remain valid
-        // until then.
+        // Safe because we don't return until the same token has been popped, so the buffers remain
+        // valid and are not otherwise accessed until then.
         let token = unsafe { self.add(inputs, outputs) }?;
 
         // Notify the queue.
@@ -192,7 +197,9 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             spin_loop();
         }
 
-        self.pop_used(token, inputs, outputs)
+        // Safe because these are the same buffers as we passed to `add` above and they are still
+        // valid.
+        unsafe { self.pop_used(token, inputs, outputs) }
     }
 
     /// Returns whether the driver should notify the device after adding a new buffer to the
@@ -252,12 +259,22 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// passed in too.
     ///
     /// This will push all linked descriptors at the front of the free list.
-    fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) {
+    ///
+    /// # Safety
+    ///
+    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+    /// queue by `add`.
+    unsafe fn recycle_descriptors<'a>(
+        &mut self,
+        head: u16,
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
+    ) {
         let original_free_head = self.free_head;
         self.free_head = head;
         let mut next = Some(head);
 
-        for (buffer, direction) in input_output_iter(inputs, outputs) {
+        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
             let desc_index = next.expect("Descriptor chain was shorter than expected.");
             let desc = &mut self.desc_shadow[usize::from(desc_index)];
 
@@ -271,8 +288,12 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
 
             self.write_desc(desc_index);
 
-            // Unshare the buffer (and perhaps copy its contents back to the original buffer).
-            H::unshare(paddr as usize, buffer, direction);
+            // Safe because the caller ensures that the buffer is valid and matches the descriptor
+            // from which we got `paddr`.
+            unsafe {
+                // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+                H::unshare(paddr as usize, buffer, direction);
+            }
         }
 
         if next.is_some() {
@@ -284,11 +305,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// length which was used (written) by the device.
     ///
     /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
-    pub fn pop_used(
+    ///
+    /// # Safety
+    ///
+    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
+    /// queue by `add` when it returned the token being passed in here.
+    pub unsafe fn pop_used<'a>(
         &mut self,
         token: u16,
-        inputs: &[*const [u8]],
-        outputs: &[*mut [u8]],
+        inputs: &'a [&'a [u8]],
+        outputs: &'a mut [&'a mut [u8]],
     ) -> Result<u32> {
         if !self.can_pop() {
             return Err(Error::NotReady);
@@ -311,7 +337,10 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             return Err(Error::WrongToken);
         }
 
-        self.recycle_descriptors(index, inputs, outputs);
+        // Safe because the caller ensures the buffers are valid and match the descriptor.
+        unsafe {
+            self.recycle_descriptors(index, inputs, outputs);
+        }
         self.last_used_idx = self.last_used_idx.wrapping_add(1);
 
         Ok(len)
@@ -558,6 +587,46 @@ struct UsedElem {
     len: u32,
 }
 
+struct InputOutputIter<'a, 'b> {
+    inputs: &'a [&'b [u8]],
+    outputs: &'a mut [&'b mut [u8]],
+}
+
+impl<'a, 'b> InputOutputIter<'a, 'b> {
+    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
+        Self { inputs, outputs }
+    }
+}
+
+impl<'a, 'b> Iterator for InputOutputIter<'a, 'b> {
+    type Item = (NonNull<[u8]>, BufferDirection);
+
+    fn next(&mut self) -> Option<Self::Item> {
+        if let Some(input) = take_first(&mut self.inputs) {
+            Some(((*input).into(), BufferDirection::DriverToDevice))
+        } else {
+            let output = take_first_mut(&mut self.outputs)?;
+            Some(((*output).into(), BufferDirection::DeviceToDriver))
+        }
+    }
+}
+
+// TODO: Use `slice::take_first` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
+    let (first, rem) = slice.split_first()?;
+    *slice = rem;
+    Some(first)
+}
+
+// TODO: Use `slice::take_first_mut` once it is stable
+// (https://github.com/rust-lang/rust/issues/62280).
+fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
+    let (first, rem) = take(slice).split_first_mut()?;
+    *slice = rem;
+    Some(first)
+}
+
 /// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
 ///
 /// The fake device always uses descriptors in order.
@@ -680,7 +749,7 @@ mod tests {
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(
-            unsafe { queue.add(&[], &[]) }.unwrap_err(),
+            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -692,7 +761,7 @@ mod tests {
         let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
-            unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
+            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
             Error::QueueFull
         );
     }
@@ -706,7 +775,7 @@ mod tests {
 
         // Add a buffer chain consisting of two device-readable parts followed by two
         // device-writable parts.
-        let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
+        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
 
         assert_eq!(queue.available_desc(), 0);
         assert!(!queue.can_pop());
@@ -757,27 +826,3 @@ mod tests {
         }
     }
 }
-
-/// Returns an iterator over the buffers of first `inputs` and then `outputs`, paired with the
-/// corresponding `BufferDirection`.
-///
-/// Panics if any of the buffer pointers is null.
-fn input_output_iter<'a>(
-    inputs: &'a [*const [u8]],
-    outputs: &'a [*mut [u8]],
-) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a {
-    inputs
-        .iter()
-        .map(|input| {
-            (
-                NonNull::new(*input as *mut [u8]).unwrap(),
-                BufferDirection::DriverToDevice,
-            )
-        })
-        .chain(outputs.iter().map(|output| {
-            (
-                NonNull::new(*output).unwrap(),
-                BufferDirection::DeviceToDriver,
-            )
-        }))
-}

+ 3 - 1
src/transport/pci.rs

@@ -395,7 +395,9 @@ fn get_bar_region<H: Hal, T>(
         return Err(VirtioPciError::BarOffsetOutOfRange);
     }
     let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
-    let vaddr = H::mmio_phys_to_virt(paddr, struct_info.length as usize);
+    // Safe because the paddr and size describe a valid MMIO region, at least according to the PCI
+    // bus.
+    let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) };
     if vaddr.as_ptr() as usize % align_of::<T>() != 0 {
         return Err(VirtioPciError::Misaligned {
             vaddr,