فهرست منبع

Add methods to Hal to share and unshare buffers with device.

These are necessary on platforms where devices don't have access to all
memory of the VM, such as if the device is behind an IOMMU or the VM
memory isn't exposed to the host.
Andrew Walbran 2 سال پیش
والد
کامیت
d0d47bcc0d
6فایلهای تغییر یافته به همراه101 افزوده شده و 35 حذف شده
  1. 18 4
      examples/aarch64/src/hal.rs
  2. 18 4
      examples/riscv/src/virtio_impl.rs
  3. 19 4
      src/hal.rs
  4. 15 4
      src/hal/fake.rs
  5. 1 1
      src/lib.rs
  6. 30 18
      src/queue.rs

+ 18 - 4
examples/aarch64/src/hal.rs

@@ -1,7 +1,10 @@
-use core::sync::atomic::*;
+use core::{
+    ptr::NonNull,
+    sync::atomic::{AtomicUsize, Ordering},
+};
 use lazy_static::lazy_static;
 use log::trace;
-use virtio_drivers::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use virtio_drivers::{BufferDirection, Hal, PhysAddr, VirtAddr, PAGE_SIZE};
 
 extern "C" {
     static dma_region: u8;
@@ -30,7 +33,18 @@ impl Hal for HalImpl {
         paddr
     }
 
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
-        vaddr
+    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+        let vaddr = buffer.as_ptr() as *mut u8 as usize;
+        // Nothing to do, as the host already has access to all memory.
+        virt_to_phys(vaddr)
     }
+
+    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+        // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
+        // anywhere else.
+    }
+}
+
+fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+    vaddr
 }

+ 18 - 4
examples/riscv/src/virtio_impl.rs

@@ -1,7 +1,10 @@
-use core::sync::atomic::*;
+use core::{
+    ptr::NonNull,
+    sync::atomic::{AtomicUsize, Ordering},
+};
 use lazy_static::lazy_static;
 use log::trace;
-use virtio_drivers::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use virtio_drivers::{BufferDirection, Hal, PhysAddr, VirtAddr, PAGE_SIZE};
 
 extern "C" {
     fn end();
@@ -29,7 +32,18 @@ impl Hal for HalImpl {
         paddr
     }
 
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
-        vaddr
+    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+        let vaddr = buffer.as_ptr() as *mut u8 as usize;
+        // Nothing to do, as the host already has access to all memory.
+        virt_to_phys(vaddr)
     }
+
+    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+        // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
+        // anywhere else.
+    }
+}
+
+fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+    vaddr
 }

+ 19 - 4
src/hal.rs

@@ -2,7 +2,7 @@
 pub mod fake;
 
 use crate::{Error, Result, PAGE_SIZE};
-use core::marker::PhantomData;
+use core::{marker::PhantomData, ptr::NonNull};
 
 /// A virtual memory address in the address space of the program.
 pub type VirtAddr = usize;
@@ -61,7 +61,22 @@ pub trait Hal {
     /// Converts a physical address used for virtio to a virtual address which the program can
     /// access.
     fn phys_to_virt(paddr: PhysAddr) -> VirtAddr;
-    /// Converts a virtual address which the program can access to the corresponding physical
-    /// address to use for virtio.
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr;
+    /// Shares the given memory range with the device, and returns the physical address that the
+    /// device can use to access it.
+    ///
+    /// This may involve mapping the buffer into an IOMMU, giving the host permission to access the
+    /// memory, or copying it to a special region where it can be accessed.
+    fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+    /// Unshares the given memory range from the device and (if necessary) copies it back to the
+    /// original buffer.
+    fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
+}
+
+/// The direction in which a buffer is passed.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum BufferDirection {
+    /// The buffer is written by the driver and read by the device.
+    DriverToDevice,
+    /// The buffer is written by the device and read by the driver.
+    DeviceToDriver,
 }

+ 15 - 4
src/hal/fake.rs

@@ -1,8 +1,8 @@
 //! Fake HAL implementation for tests.
 
-use crate::{Hal, PhysAddr, VirtAddr, PAGE_SIZE};
+use crate::{BufferDirection, Hal, PhysAddr, VirtAddr, PAGE_SIZE};
 use alloc::alloc::{alloc_zeroed, dealloc, handle_alloc_error};
-use core::alloc::Layout;
+use core::{alloc::Layout, ptr::NonNull};
 
 #[derive(Debug)]
 pub struct FakeHal;
@@ -35,7 +35,18 @@ impl Hal for FakeHal {
         paddr
     }
 
-    fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
-        vaddr
+    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+        let vaddr = buffer.as_ptr() as *mut u8 as usize;
+        // Nothing to do, as the host already has access to all memory.
+        virt_to_phys(vaddr)
     }
+
+    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+        // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
+        // anywhere else.
+    }
+}
+
+fn virt_to_phys(vaddr: VirtAddr) -> PhysAddr {
+    vaddr
 }

+ 1 - 1
src/lib.rs

@@ -53,7 +53,7 @@ mod queue;
 pub mod transport;
 mod volatile;
 
-pub use self::hal::{Hal, PhysAddr, VirtAddr};
+pub use self::hal::{BufferDirection, Hal, PhysAddr, VirtAddr};
 
 /// The page size in bytes supported by the library (4 KiB).
 pub const PAGE_SIZE: usize = 0x1000;

+ 30 - 18
src/queue.rs

@@ -1,6 +1,6 @@
 #[cfg(test)]
 use crate::hal::VirtAddr;
-use crate::hal::{Dma, Hal};
+use crate::hal::{BufferDirection, Dma, Hal};
 use crate::transport::Transport;
 use crate::{align_up, Error, Result, PAGE_SIZE};
 use bitflags::bitflags;
@@ -114,15 +114,9 @@ impl<H: Hal> VirtQueue<H> {
         // Safe because self.desc is properly aligned, dereferenceable and initialised, and nothing
         // else reads or writes the free descriptors during this block.
         unsafe {
-            for (buffer, is_output) in input_output_iter(inputs, outputs) {
+            for (buffer, direction) in input_output_iter(inputs, outputs) {
                 let desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(buffer);
-                (*desc).flags = DescFlags::NEXT
-                    | if is_output {
-                        DescFlags::WRITE
-                    } else {
-                        DescFlags::empty()
-                    };
+                (*desc).set_buf::<H>(buffer, direction, DescFlags::NEXT);
                 last = self.free_head;
                 self.free_head = (*desc).next;
             }
@@ -294,12 +288,24 @@ pub(crate) struct Descriptor {
 }
 
 impl Descriptor {
+    /// Sets the buffer address, length and flags, and shares it with the device.
+    ///
     /// # Safety
     ///
     /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
-    unsafe fn set_buf<H: Hal>(&mut self, buf: NonNull<[u8]>) {
-        self.addr = H::virt_to_phys(buf.as_ptr() as *mut u8 as usize) as u64;
+    unsafe fn set_buf<H: Hal>(
+        &mut self,
+        buf: NonNull<[u8]>,
+        direction: BufferDirection,
+        extra_flags: DescFlags,
+    ) {
+        self.addr = H::share(buf, direction) as u64;
         self.len = buf.len() as u32;
+        self.flags = extra_flags
+            | match direction {
+                BufferDirection::DeviceToDriver => DescFlags::WRITE,
+                BufferDirection::DriverToDevice => DescFlags::empty(),
+            };
     }
 }
 
@@ -532,13 +538,19 @@ mod tests {
 fn input_output_iter<'a>(
     inputs: &'a [*const [u8]],
     outputs: &'a [*mut [u8]],
-) -> impl Iterator<Item = (NonNull<[u8]>, bool)> + 'a {
+) -> impl Iterator<Item = (NonNull<[u8]>, BufferDirection)> + 'a {
     inputs
         .iter()
-        .map(|input| (NonNull::new(*input as *mut [u8]).unwrap(), false))
-        .chain(
-            outputs
-                .iter()
-                .map(|output| (NonNull::new(*output).unwrap(), true)),
-        )
+        .map(|input| {
+            (
+                NonNull::new(*input as *mut [u8]).unwrap(),
+                BufferDirection::DriverToDevice,
+            )
+        })
+        .chain(outputs.iter().map(|output| {
+            (
+                NonNull::new(*output).unwrap(),
+                BufferDirection::DeviceToDriver,
+            )
+        }))
 }