浏览代码

Document safety requirements for Hal trait and mark it as unsafe.

Some methods are also unsafe.
Andrew Walbran 2 年之前
父节点
当前提交
7ea4491389
共有 6 个文件被更改,包括 92 次插入26 次删除
  1. 5 5
      examples/aarch64/src/hal.rs
  2. 5 5
      examples/riscv/src/virtio_impl.rs
  3. 52 6
      src/hal.rs
  4. 5 5
      src/hal/fake.rs
  5. 22 4
      src/queue.rs
  6. 3 1
      src/transport/pci.rs

+ 5 - 5
examples/aarch64/src/hal.rs

@@ -17,7 +17,7 @@ lazy_static! {
 
 pub struct HalImpl;
 
-impl Hal for HalImpl {
+unsafe impl Hal for HalImpl {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
@@ -25,22 +25,22 @@ impl Hal for HalImpl {
         (paddr, vaddr)
     }
 
-    fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
         trace!("dealloc DMA: paddr={:#x}, pages={}", paddr, pages);
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }

+ 5 - 5
examples/riscv/src/virtio_impl.rs

@@ -16,7 +16,7 @@ lazy_static! {
 
 pub struct HalImpl;
 
-impl Hal for HalImpl {
+unsafe impl Hal for HalImpl {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
@@ -24,22 +24,22 @@ impl Hal for HalImpl {
         (paddr, vaddr)
     }
 
-    fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(paddr: PhysAddr, _vaddr: NonNull<u8>, pages: usize) -> i32 {
         trace!("dealloc DMA: paddr={:#x}, pages={}", paddr, pages);
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }

+ 52 - 6
src/hal.rs

@@ -53,35 +53,81 @@ impl<H: Hal> Dma<H> {
 
 impl<H: Hal> Drop for Dma<H> {
     fn drop(&mut self) {
-        let err = H::dma_dealloc(self.paddr, self.vaddr, self.pages);
+        // Safe because the memory was previously allocated by `dma_alloc` in `Dma::new`, not yet
+        // deallocated, and we are passing the values from then.
+        let err = unsafe { H::dma_dealloc(self.paddr, self.vaddr, self.pages) };
         assert_eq!(err, 0, "failed to deallocate DMA");
     }
 }
 
 /// The interface which a particular hardware implementation must implement.
-pub trait Hal {
+///
+/// # Safety
+///
+/// Implementations of this trait must follow the "implementation safety" requirements documented
+/// for each method. Callers must follow the safety requirements documented for the unsafe methods.
+pub unsafe trait Hal {
     /// Allocates the given number of contiguous physical pages of DMA memory for VirtIO use.
     ///
     /// Returns both the physical address which the device can use to access the memory, and a
     /// pointer to the start of it which the driver can use to access it.
+    ///
+    /// # Implementation safety
+    ///
+    /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+    /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, aligned to
+    /// [`PAGE_SIZE`], and won't alias any other allocations or references in the program until it
+    /// is deallocated by `dma_dealloc`.
     fn dma_alloc(pages: usize, direction: BufferDirection) -> (PhysAddr, NonNull<u8>);
+
     /// Deallocates the given contiguous physical DMA memory pages.
-    fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+    ///
+    /// # Safety
+    ///
+    /// The memory must have been allocated by `dma_alloc` on the same `Hal` implementation, and not
+    /// yet deallocated. `pages` must be the same number passed to `dma_alloc` originally, and both
+    /// `paddr` and `vaddr` must be the values returned by `dma_alloc`.
+    unsafe fn dma_dealloc(paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32;
+
     /// Converts a physical address used for MMIO to a virtual address which the driver can access.
     ///
     /// This is only used for MMIO addresses within BARs read from the device, for the PCI
     /// transport. It may check that the address range up to the given size is within the region
     /// expected for MMIO.
-    fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+    ///
+    /// # Implementation safety
+    ///
+    /// Implementations of this method must ensure that the `NonNull<u8>` returned is a
+    /// [_valid_](https://doc.rust-lang.org/std/ptr/index.html#safety) pointer, and won't alias any
+    /// other allocations or references in the program.
+    ///
+    /// # Safety
+    ///
+    /// The `paddr` and `size` must describe a valid MMIO region. The implementation may validate it
+    /// in some way (and panic if it is invalid) but is not guaranteed to.
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, size: usize) -> NonNull<u8>;
+
     /// Shares the given memory range with the device, and returns the physical address that the
     /// device can use to access it.
     ///
     /// This may involve mapping the buffer into an IOMMU, giving the host permission to access the
     /// memory, or copying it to a special region where it can be accessed.
-    fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+    ///
+    /// # Safety
+    ///
+    /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+    /// for the duration of this method call.
+    unsafe fn share(buffer: NonNull<[u8]>, direction: BufferDirection) -> PhysAddr;
+
     /// Unshares the given memory range from the device and (if necessary) copies it back to the
     /// original buffer.
-    fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
+    ///
+    /// # Safety
+    ///
+    /// The buffer must be a valid pointer to memory which will not be accessed by any other thread
+    /// for the duration of this method call. The `paddr` must be the value previously returned by
+    /// the corresponding `share` call.
+    unsafe fn unshare(paddr: PhysAddr, buffer: NonNull<[u8]>, direction: BufferDirection);
 }
 
 /// The direction in which a buffer is passed.

+ 5 - 5
src/hal/fake.rs

@@ -8,7 +8,7 @@ use core::{alloc::Layout, ptr::NonNull};
 pub struct FakeHal;
 
 /// Fake HAL implementation for use in unit tests.
-impl Hal for FakeHal {
+unsafe impl Hal for FakeHal {
     fn dma_alloc(pages: usize, _direction: BufferDirection) -> (PhysAddr, NonNull<u8>) {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
@@ -21,7 +21,7 @@ impl Hal for FakeHal {
         }
     }
 
-    fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
+    unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
         // Safe because the layout is the same as was used when the memory was allocated by
@@ -32,17 +32,17 @@ impl Hal for FakeHal {
         0
     }
 
-    fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
+    unsafe fn mmio_phys_to_virt(paddr: PhysAddr, _size: usize) -> NonNull<u8> {
         NonNull::new(paddr as _).unwrap()
     }
 
-    fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
+    unsafe fn share(buffer: NonNull<[u8]>, _direction: BufferDirection) -> PhysAddr {
         let vaddr = buffer.as_ptr() as *mut u8 as usize;
         // Nothing to do, as the host already has access to all memory.
         virt_to_phys(vaddr)
     }
 
-    fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
+    unsafe fn unshare(_paddr: PhysAddr, _buffer: NonNull<[u8]>, _direction: BufferDirection) {
         // Nothing to do, as the host already has access to all memory and we didn't copy the buffer
         // anywhere else.
     }

+ 22 - 4
src/queue.rs

@@ -252,7 +252,18 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
     /// passed in too.
     ///
     /// This will push all linked descriptors at the front of the free list.
-    fn recycle_descriptors(&mut self, head: u16, inputs: &[*const [u8]], outputs: &[*mut [u8]]) {
+    ///
+    /// # Safety
+    ///
+    /// The buffers in `inputs` and `outputs` must be valid pointers to memory which is not accessed
+    /// by any other thread for the duration of this method call, and must match the set of buffers
+    /// originally added to the queue by `add`.
+    unsafe fn recycle_descriptors(
+        &mut self,
+        head: u16,
+        inputs: &[*const [u8]],
+        outputs: &[*mut [u8]],
+    ) {
         let original_free_head = self.free_head;
         self.free_head = head;
         let mut next = Some(head);
@@ -271,8 +282,12 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
 
             self.write_desc(desc_index);
 
-            // Unshare the buffer (and perhaps copy its contents back to the original buffer).
-            H::unshare(paddr as usize, buffer, direction);
+            // Safe because the caller ensures that the buffer is valid and matches the descriptor
+            // from which we got `paddr`.
+            unsafe {
+                // Unshare the buffer (and perhaps copy its contents back to the original buffer).
+                H::unshare(paddr as usize, buffer, direction);
+            }
         }
 
         if next.is_some() {
@@ -311,7 +326,10 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             return Err(Error::WrongToken);
         }
 
-        self.recycle_descriptors(index, inputs, outputs);
+        // Safe because the caller ensures the buffers are valid and match the descriptor.
+        unsafe {
+            self.recycle_descriptors(index, inputs, outputs);
+        }
         self.last_used_idx = self.last_used_idx.wrapping_add(1);
 
         Ok(len)

+ 3 - 1
src/transport/pci.rs

@@ -395,7 +395,9 @@ fn get_bar_region<H: Hal, T>(
         return Err(VirtioPciError::BarOffsetOutOfRange);
     }
     let paddr = bar_address as PhysAddr + struct_info.offset as PhysAddr;
-    let vaddr = H::mmio_phys_to_virt(paddr, struct_info.length as usize);
+    // Safe because the paddr and size describe a valid MMIO region, at least according to the PCI
+    // bus.
+    let vaddr = unsafe { H::mmio_phys_to_virt(paddr, struct_info.length as usize) };
     if vaddr.as_ptr() as usize % align_of::<T>() != 0 {
         return Err(VirtioPciError::Misaligned {
             vaddr,