浏览代码

Merge pull request #41 from rcore-os/share

Add methods to Hal to share and unshare buffers with device
chyyuu 2 年之前
父节点
当前提交
108a82501b
共有 9 个文件被更改,包括 252 次插入70 次删除
  1. 18 4
      examples/aarch64/src/hal.rs
  2. 18 4
      examples/riscv/src/virtio_impl.rs
  3. 51 10
      src/device/blk.rs
  4. 8 6
      src/device/console.rs
  5. 4 1
      src/device/input.rs
  6. 19 4
      src/hal.rs
  7. 15 4
      src/hal/fake.rs
  8. 3 1
      src/lib.rs
  9. 116 36
      src/queue.rs

+ 18 - 4
examples/aarch64/src/hal.rs

@@ -1,7 +1,10 @@
-use core::sync::atomic::*;
+use core::{
+    ptr::NonNull,
+    sync::atomic::{AtomicUsize, Ordering},
+};
 use lazy_static::lazy_static;
 use log::trace;
-use virtio_drivers::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use virtio_drivers::{BufferDirection, Hal, PhysAddr, VirtAddr, PAGE_SIZE};
 
 extern "C" {
     static dma_region: u8;
@@ -30,7 +33,18 @@ impl Hal for HalImpl {
         paddr
     }
 
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
-        vaddr
+    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+        let vaddr = buffer.as_ptr() as *mut u8 as usize;
+        // Nothing to do, as the host already has access to all memory.
+        virt_to_phys(vaddr)
     }
+
+    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+        // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
+        // anywhere else.
+    }
+}
+
+fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+    vaddr
 }

+ 18 - 4
examples/riscv/src/virtio_impl.rs

@@ -1,7 +1,10 @@
-use core::sync::atomic::*;
+use core::{
+    ptr::NonNull,
+    sync::atomic::{AtomicUsize, Ordering},
+};
 use lazy_static::lazy_static;
 use log::trace;
-use virtio_drivers::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use virtio_drivers::{BufferDirection, Hal, PhysAddr, VirtAddr, PAGE_SIZE};
 
 extern "C" {
     fn end();
@@ -29,7 +32,18 @@ impl Hal for HalImpl {
         paddr
     }
 
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
-        vaddr
+    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+        let vaddr = buffer.as_ptr() as *mut u8 as usize;
+        // Nothing to do, as the host already has access to all memory.
+        virt_to_phys(vaddr)
     }
+
+    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+        // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
+        // anywhere else.
+    }
+}
+
+fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+    vaddr
 }

+ 51 - 10
src/device/blk.rs

@@ -125,7 +125,8 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// * `block_id` - The identifier of the block to read.
     /// * `req` - A buffer which the driver can use for the request to send to the device. The
     ///   contents don't matter as `read_block_nb` will initialise it, but like the other buffers it
-    ///   needs to be valid (and not otherwise used) until the corresponding `pop_used` call.
+    ///   needs to be valid (and not otherwise used) until the corresponding `complete_read_block`
+    ///   call.
     /// * `buf` - The buffer in memory into which the block should be read.
     /// * `resp` - A mutable reference to a variable provided by the caller
     ///   to contain the status of the request. The caller can safely
@@ -137,8 +138,9 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// the position of the first Descriptor in the chain. If there are not enough
     /// Descriptors to allocate, then it returns [`Error::QueueFull`].
     ///
-    /// The caller can then call `pop_used` to check whether the device has finished handling the
-    /// request. Once it has, the caller can then read the response and dispose of the buffers.
+    /// The caller can then call `peek_used` with the returned token to check whether the device has
+    /// finished handling the request. Once it has, the caller must call `complete_read_block` with
+    /// the same buffers before reading the response.
     ///
     /// ```
     /// # use virtio_drivers::{Error, Hal};
@@ -153,8 +155,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// let token = unsafe { blk.read_block_nb(42, &mut request, &mut buffer, &mut response) }?;
     ///
     /// // Wait for an interrupt to tell us that the request completed...
+    /// assert_eq!(blk.peek_used(), Some(token));
     ///
-    /// assert_eq!(blk.pop_used()?, token);
+    /// unsafe {
+    ///   blk.complete_read_block(token, &request, &mut buffer, &mut response)?;
+    /// }
     /// if response.status() == RespStatus::OK {
     ///   println!("Successfully read block.");
     /// } else {
@@ -189,6 +194,24 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         Ok(token)
     }
 
+    /// Completes a read operation which was started by `read_block_nb`.
+    ///
+    /// # Safety
+    ///
+    /// The same buffers must be passed in again as were passed to `read_block_nb` when it returned
+    /// the token.
+    pub unsafe fn complete_read_block(
+        &mut self,
+        token: u16,
+        req: &BlkReq,
+        buf: &mut [u8],
+        resp: &mut BlkResp,
+    ) -> Result<()> {
+        self.queue
+            .pop_used(token, &[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
+        Ok(())
+    }
+
     /// Writes the contents of the given buffer to a block.
     ///
     /// Blocks until the write is complete or there is an error.
@@ -219,7 +242,8 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// * `block_id` - The identifier of the block to write.
     /// * `req` - A buffer which the driver can use for the request to send to the device. The
     ///   contents don't matter as `read_block_nb` will initialise it, but like the other buffers it
-    ///   needs to be valid (and not otherwise used) until the corresponding `pop_used` call.
+    ///   needs to be valid (and not otherwise used) until the corresponding `complete_read_block`
+    ///   call.
     /// * `buf` - The buffer in memory containing the data to write to the block.
     /// * `resp` - A mutable reference to a variable provided by the caller
     ///   to contain the status of the request. The caller can safely
@@ -252,11 +276,28 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         Ok(token)
     }
 
-    /// During an interrupt, it fetches a token of a completed request from the used
-    /// ring and return it. If all completed requests have already been fetched, return
-    /// Err(Error::NotReady).
-    pub fn pop_used(&mut self) -> Result<u16> {
-        self.queue.pop_used().map(|p| p.0)
+    /// Completes a write operation which was started by `write_block_nb`.
+    ///
+    /// # Safety
+    ///
+    /// The same buffers must be passed in again as were passed to `write_block_nb` when it returned
+    /// the token.
+    pub unsafe fn complete_write_block(
+        &mut self,
+        token: u16,
+        req: &BlkReq,
+        buf: &[u8],
+        resp: &mut BlkResp,
+    ) -> Result<()> {
+        self.queue
+            .pop_used(token, &[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
+        Ok(())
+    }
+
+    /// Fetches the token of the next completed request from the used ring and returns it, without
+    /// removing it from the used ring. If there are no pending completed requests returns `None`.
+    pub fn peek_used(&mut self) -> Option<u16> {
+        self.queue.peek_used()
     }
 
     /// Returns the size of the device's VirtQueue.

+ 8 - 6
src/device/console.rs

@@ -127,31 +127,33 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
             return Ok(false);
         }
 
-        Ok(self.finish_receive())
+        self.finish_receive()
     }
 
     /// If there is an outstanding receive request and it has finished, completes it.
     ///
     /// Returns true if new data has been received.
-    fn finish_receive(&mut self) -> bool {
+    fn finish_receive(&mut self) -> Result<bool> {
         let mut flag = false;
         if let Some(receive_token) = self.receive_token {
-            if let Ok((token, len)) = self.receiveq.pop_used() {
-                assert_eq!(token, receive_token);
+            if self.receive_token == self.receiveq.peek_used() {
+                let len = self
+                    .receiveq
+                    .pop_used(receive_token, &[], &[self.queue_buf_rx])?;
                 flag = true;
                 assert_ne!(len, 0);
                 self.cursor = 0;
                 self.pending_len = len as usize;
             }
         }
-        flag
+        Ok(flag)
     }
 
     /// Returns the next available character from the console, if any.
     ///
     /// If no data has been received this will not block but immediately return `Ok<None>`.
     pub fn recv(&mut self, pop: bool) -> Result<Option<u8>> {
-        self.finish_receive();
+        self.finish_receive()?;
         if self.cursor == self.pending_len {
             return Ok(None);
         }

+ 4 - 1
src/device/input.rs

@@ -64,8 +64,11 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
 
     /// Pop the pending event.
     pub fn pop_pending_event(&mut self) -> Option<InputEvent> {
-        if let Ok((token, _)) = self.event_queue.pop_used() {
+        if let Some(token) = self.event_queue.peek_used() {
             let event = &mut self.event_buf[token as usize];
+            self.event_queue
+                .pop_used(token, &[], &[event.as_bytes_mut()])
+                .ok()?;
             // requeue
             // Safe because buffer lasts as long as the queue.
             if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {

+ 19 - 4
src/hal.rs

@@ -2,7 +2,7 @@
 pub mod fake;
 
 use crate::{Error, Result, PAGE_SIZE};
-use core::marker::PhantomData;
+use core::{marker::PhantomData, ptr::NonNull};
 
 /// A virtual memory address in the address space of the program.
 pub type VirtAddr = usize;
@@ -61,7 +61,22 @@ pub trait Hal {
     /// Converts a physical address used for virtio to a virtual address which the program can
     /// access.
     fn phys_to_virt(paddr: PhysAddr) -> VirtAddr;
-    /// Converts a virtual address which the program can access to the corresponding physical
-    /// address to use for virtio.
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr;
+    /// Shares the given memory range with the device, and returns the physical address that the
+    /// device can use to access it.
+    ///
+    /// This may involve mapping the buffer into an IOMMU, giving the host permission to access the
+    /// memory, or copying it to a special region where it can be accessed.
+    fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+    /// Unshares the given memory range from the device and (if necessary) copies it back to the
+    /// original buffer.
+    fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
+}
+
+/// The direction in which a buffer is passed.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum BufferDirection {
+    /// The buffer is written by the driver and read by the device.
+    DriverToDevice,
+    /// The buffer is written by the device and read by the driver.
+    DeviceToDriver,
 }

+ 15 - 4
src/hal/fake.rs

@@ -1,8 +1,8 @@
 //! Fake HAL implementation for tests.
 
-use crate::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use crate::{BufferDirection, Hal, PhysAddr, VirtAddr, PAGE_SIZE};
 use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error};
-use core::alloc::Layout;
+use core::{alloc::Layout, ptr::NonNull};
 
 #[derive(Debug)]
 pub struct FakeHal;
@@ -35,7 +35,18 @@ impl Hal for FakeHal {
         paddr
     }
 
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
-        vaddr
+    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+        let vaddr = buffer.as_ptr() as *mut u8 as usize;
+        // Nothing to do, as the host already has access to all memory.
+        virt_to_phys(vaddr)
     }
+
+    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+        // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
+        // anywhere else.
+    }
+}
+
+fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+    vaddr
 }

+ 3 - 1
src/lib.rs

@@ -53,7 +53,7 @@ mod queue;
 pub mod transport;
 mod volatile;
 
-pub use self::hal::{Hal, PhysAddr, VirtAddr};
+pub use self::hal::{BufferDirection, Hal, PhysAddr, VirtAddr};
 
 /// The page size in bytes supported by the library (4 KiB).
 pub const PAGE_SIZE: usize = 0x1000;
@@ -68,6 +68,8 @@ pub enum Error {
     QueueFull,
     /// The device is not ready.
     NotReady,
+    /// The device used a different descriptor chain to the one we were expecting.
+    WrongToken,
     /// The queue is already in use.
     AlreadyUsed,
     /// Invalid parameter.

+ 116 - 36
src/queue.rs

@@ -1,6 +1,6 @@
 #[cfg(test)]
 use crate::hal::VirtAddr;
-use crate::hal::{Dma, Hal};
+use crate::hal::{BufferDirection, Dma, Hal};
 use crate::transport::Transport;
 use crate::{align_up, Error, Result, PAGE_SIZE};
 use bitflags::bitflags;
@@ -114,17 +114,9 @@ impl<H: Hal> VirtQueue<H> {
         // Safe because self.desc is properly aligned, dereferenceable and initialised, and nothing
         // else reads or writes the free descriptors during this block.
         unsafe {
-            for input in inputs.iter() {
-                let mut desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(NonNull::new(*input as *mut [u8]).unwrap());
-                (*desc).flags = DescFlags::NEXT;
-                last = self.free_head;
-                self.free_head = (*desc).next;
-            }
-            for output in outputs.iter() {
+            for (buffer, direction) in input_output_iter(inputs, outputs) {
                 let desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(NonNull::new(*output).unwrap());
-                (*desc).flags = DescFlags::NEXT | DescFlags::WRITE;
+                (*desc).set_buf::<H>(buffer, direction, DescFlags::NEXT);
                 last = self.free_head;
                 self.free_head = (*desc).next;
             }
@@ -174,13 +166,12 @@ impl<H: Hal> VirtQueue<H> {
         // Notify the queue.
         transport.notify(self.queue_idx);
 
+        // Wait until there is at least one element in the used ring.
         while !self.can_pop() {
             spin_loop();
         }
-        let (popped_token, length) = self.pop_used()?;
-        assert_eq!(popped_token, token);
 
-        Ok(length)
+        self.pop_used(token, inputs, outputs)
     }
 
     /// Returns a non-null pointer to the descriptor at the given index.
@@ -199,44 +190,75 @@ impl<H: Hal> VirtQueue<H> {
         self.last_used_idx != unsafe { (*self.used.as_ptr()).idx }
     }
 
+    /// Returns the descriptor index (a.k.a. token) of the next used element without popping it, or
+    /// `None` if the used ring is empty.
+    pub fn peek_used(&self) -> Option<u16> {
+        if self.can_pop() {
+            let last_used_slot = self.last_used_idx & (self.queue_size - 1);
+            // Safe because self.used points to a valid, aligned, initialised, dereferenceable,
+            // readable instance of UsedRing.
+            Some(unsafe { (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16 })
+        } else {
+            None
+        }
+    }
+
     /// Returns the number of free descriptors.
     pub fn available_desc(&self) -> usize {
         (self.queue_size - self.num_used) as usize
     }
 
-    /// Recycle descriptors in the list specified by head.
+    /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
+    /// list. Unsharing may involve copying data back to the original buffers, so they must be
+    /// passed in too.
     ///
     /// This will push all linked descriptors at the front of the free list.
-    fn recycle_descriptors(&mut self, mut head: u16) {
+    fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) {
         let original_free_head = self.free_head;
         self.free_head = head;
-        loop {
-            let desc = self.desc_ptr(head);
+        let mut next = Some(head);
+
+        for (buffer, direction) in input_output_iter(inputs, outputs) {
+            let desc = self.desc_ptr(next.expect("Descriptor chain was shorter than expected."));
+
             // Safe because self.desc is properly aligned, dereferenceable and initialised, and
             // nothing else reads or writes the descriptor during this block.
-            unsafe {
-                let flags = (*desc).flags;
+            let paddr = unsafe {
+                let paddr = (*desc).addr;
+                (*desc).unset_buf();
                 self.num_used -= 1;
-                if flags.contains(DescFlags::NEXT) {
-                    head = (*desc).next;
-                } else {
+                next = (*desc).next();
+                if next.is_none() {
                     (*desc).next = original_free_head;
-                    return;
                 }
-            }
+                paddr
+            };
+
+            // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+            H::unshare(paddr as usize, buffer, direction);
+        }
+
+        if next.is_some() {
+            panic!("Descriptor chain was longer than expected.");
         }
     }
 
-    /// Get a token from device used buffers, return (token, len).
+    /// If the given token is next on the device used queue, pops it and returns the total buffer
+    /// length which was used (written) by the device.
     ///
     /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
-    pub fn pop_used(&mut self) -> Result<(u16, u32)> {
+    pub fn pop_used(
+        &mut self,
+        token: u16,
+        inputs: &[*const [u8]],
+        outputs: &[*mut [u8]],
+    ) -> Result<u32> {
         if !self.can_pop() {
             return Err(Error::NotReady);
         }
-        // read barrier
-        fence(Ordering::SeqCst);
+        // Read barrier not necessary, as can_pop already has one.
 
+        // Get the index of the start of the descriptor chain for the next element in the used ring.
         let last_used_slot = self.last_used_idx & (self.queue_size - 1);
         let index;
         let len;
@@ -247,10 +269,15 @@ impl<H: Hal> VirtQueue<H> {
             len = (*self.used.as_ptr()).ring[last_used_slot as usize].len;
         }
 
-        self.recycle_descriptors(index);
+        if index != token {
+            // The device used a different descriptor chain to the one we were expecting.
+            return Err(Error::WrongToken);
+        }
+
+        self.recycle_descriptors(index, inputs, outputs);
         self.last_used_idx = self.last_used_idx.wrapping_add(1);
 
-        Ok((index, len))
+        Ok(len)
     }
 
     /// Return size of the queue.
@@ -296,12 +323,42 @@ pub(crate) struct Descriptor {
 }
 
 impl Descriptor {
+    /// Sets the buffer address, length and flags, and shares it with the device.
+    ///
     /// # Safety
     ///
     /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
-    unsafe fn set_buf<H: Hal>(&mut self, buf: NonNull<[u8]>) {
-        self.addr = H::virt_to_phys(buf.as_ptr() as *mut u8 as usize) as u64;
+    unsafe fn set_buf<H: Hal>(
+        &mut self,
+        buf: NonNull<[u8]>,
+        direction: BufferDirection,
+        extra_flags: DescFlags,
+    ) {
+        self.addr = H::share(buf, direction) as u64;
         self.len = buf.len() as u32;
+        self.flags = extra_flags
+            | match direction {
+                BufferDirection::DeviceToDriver => DescFlags::WRITE,
+                BufferDirection::DriverToDevice => DescFlags::empty(),
+            };
+    }
+
+    /// Sets the buffer address and length to 0.
+    ///
+    /// This must only be called once the device has finished using the descriptor.
+    fn unset_buf(&mut self) {
+        self.addr = 0;
+        self.len = 0;
+    }
+
+    /// Returns the index of the next descriptor in the chain if the `NEXT` flag is set, or `None`
+    /// if it is not (and thus this descriptor is the end of the chain).
+    fn next(&self) -> Option<u16> {
+        if self.flags.contains(DescFlags::NEXT) {
+            Some(self.next)
+        } else {
+            None
+        }
     }
 }
 
@@ -385,9 +442,8 @@ pub(crate) fn fake_write_to_queue(
             );
             remaining_data = &remaining_data[length_to_write..];
 
-            if flags.contains(DescFlags::NEXT) {
-                let next = descriptor.next as usize;
-                descriptor = &(*descriptors)[next];
+            if let Some(next) = descriptor.next() {
+                descriptor = &(*descriptors)[next as usize];
             } else {
                 assert_eq!(remaining_data.len(), 0);
                 break;
@@ -526,3 +582,27 @@ mod tests {
         }
     }
 }
+
+/// Returns an iterator over the buffers of first `inputs` and then `outputs`, paired with the
+/// corresponding `BufferDirection`.
+///
+/// Panics if any of the buffer pointers is null.
+fn input_output_iter<'a>(
+    inputs: &'a [*const [u8]],
+    outputs: &'a [*mut [u8]],
+) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a {
+    inputs
+        .iter()
+        .map(|input| {
+            (
+                NonNull::new(*input as *mut [u8]).unwrap(),
+                BufferDirection::DriverToDevice,
+            )
+        })
+        .chain(outputs.iter().map(|output| {
+            (
+                NonNull::new(*output).unwrap(),
+                BufferDirection::DeviceToDriver,
+            )
+        }))
+}