浏览代码

Specify direction when allocating DMA regions.

This will allow systems with an IOMMU or equivalent to prevent the
device from writing to regions which it shouldn't, such as the
descriptor table and available ring.
Andrew Walbran 2 年之前
父节点
当前提交
608cd9f05d
共有 7 个文件被更改,包括 32 次插入23 次删除
  1. 1 1
      examples/aarch64/src/hal.rs
  2. 1 1
      examples/riscv/src/virtio_impl.rs
  3. 2 2
      src/device/console.rs
  4. 14 10
      src/device/gpu.rs
  5. 7 5
      src/hal.rs
  6. 1 1
      src/hal/fake.rs
  7. 6 3
      src/queue.rs

+ 1 - 1
examples/aarch64/src/hal.rs

@@ -18,7 +18,7 @@ lazy_static! {
 pub struct HalImpl;
 
 impl Hal for HalImpl {
-    fn dma_alloc(pages: usize) -> PhysAddr {
+    fn dma_alloc(pages: usize, _direction: BufferDirection) -> PhysAddr {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
         paddr

+ 1 - 1
examples/riscv/src/virtio_impl.rs

@@ -17,7 +17,7 @@ lazy_static! {
 pub struct HalImpl;
 
 impl Hal for HalImpl {
-    fn dma_alloc(pages: usize) -> PhysAddr {
+    fn dma_alloc(pages: usize, _direction: BufferDirection) -> PhysAddr {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
         paddr

+ 2 - 2
src/device/console.rs

@@ -1,6 +1,6 @@
 //! Driver for VirtIO console devices.
 
-use crate::hal::{Dma, Hal};
+use crate::hal::{BufferDirection, Dma, Hal};
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, WriteOnly};
@@ -74,7 +74,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
         let config_space = transport.config_space::<Config>()?;
         let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, QUEUE_SIZE)?;
         let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, QUEUE_SIZE)?;
-        let queue_buf_dma = Dma::new(1)?;
+        let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?;
         let queue_buf_rx = unsafe { &mut queue_buf_dma.as_buf()[0..] };
         transport.finish_init();
         let mut console = VirtIOConsole {

+ 14 - 10
src/device/gpu.rs

@@ -1,10 +1,10 @@
 //! Driver for VirtIO GPU devices.
 
-use crate::hal::{Dma, Hal};
+use crate::hal::{BufferDirection, Dma, Hal};
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
-use crate::{pages, Error, Result, PAGE_SIZE};
+use crate::{pages, Error, Result};
 use bitflags::bitflags;
 use log::info;
 
@@ -26,8 +26,10 @@ pub struct VirtIOGpu<'a, H: Hal, T: Transport> {
     control_queue: VirtQueue<H>,
     /// Queue for sending cursor commands.
     cursor_queue: VirtQueue<H>,
-    /// Queue buffer DMA
-    queue_buf_dma: Dma<H>,
+    /// DMA region for sending data to the device.
+    dma_send: Dma<H>,
+    /// DMA region for receiving data from the device.
+    dma_recv: Dma<H>,
     /// Send buffer for queue.
     queue_buf_send: &'a mut [u8],
     /// Recv buffer for queue.
@@ -58,9 +60,10 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, 2)?;
         let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, 2)?;
 
-        let queue_buf_dma = Dma::new(2)?;
-        let queue_buf_send = unsafe { &mut queue_buf_dma.as_buf()[..PAGE_SIZE] };
-        let queue_buf_recv = unsafe { &mut queue_buf_dma.as_buf()[PAGE_SIZE..] };
+        let dma_send = Dma::new(1, BufferDirection::DriverToDevice)?;
+        let dma_recv = Dma::new(1, BufferDirection::DeviceToDriver)?;
+        let queue_buf_send = unsafe { dma_send.as_buf() };
+        let queue_buf_recv = unsafe { dma_recv.as_buf() };
 
         transport.finish_init();
 
@@ -71,7 +74,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
             rect: None,
             control_queue,
             cursor_queue,
-            queue_buf_dma,
+            dma_send,
+            dma_recv,
             queue_buf_send,
             queue_buf_recv,
         })
@@ -104,7 +108,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
 
         // alloc continuous pages for the frame buffer
         let size = display_info.rect.width * display_info.rect.height * 4;
-        let frame_buffer_dma = Dma::new(pages(size as usize))?;
+        let frame_buffer_dma = Dma::new(pages(size as usize), BufferDirection::DriverToDevice)?;
 
         // resource_attach_backing
         self.resource_attach_backing(RESOURCE_ID_FB, frame_buffer_dma.paddr() as u64, size)?;
@@ -140,7 +144,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         if cursor_image.len() != size as usize {
             return Err(Error::InvalidParam);
         }
-        let cursor_buffer_dma = Dma::new(pages(size as usize))?;
+        let cursor_buffer_dma = Dma::new(pages(size as usize), BufferDirection::DriverToDevice)?;
         let buf = unsafe { cursor_buffer_dma.as_buf() };
         buf.copy_from_slice(cursor_image);
 

+ 7 - 5
src/hal.rs

@@ -19,8 +19,8 @@ pub struct Dma<H: Hal> {
 }
 
 impl<H: Hal> Dma<H> {
-    pub fn new(pages: usize) -> Result<Self> {
-        let paddr = H::dma_alloc(pages);
+    pub fn new(pages: usize, direction: BufferDirection) -> Result<Self> {
+        let paddr = H::dma_alloc(pages, direction);
         if paddr == 0 {
             return Err(Error::DmaError);
         }
@@ -55,7 +55,7 @@ impl<H: Hal> Drop for Dma<H> {
 /// The interface which a particular hardware implementation must implement.
 pub trait Hal {
     /// Allocates the given number of contiguous physical pages of DMA memory for virtio use.
-    fn dma_alloc(pages: usize) -> PhysAddr;
+    fn dma_alloc(pages: usize, direction: BufferDirection) -> PhysAddr;
     /// Deallocates the given contiguous physical DMA memory pages.
     fn dma_dealloc(paddr: PhysAddr, pages: usize) -> i32;
     /// Converts a physical address used for virtio to a virtual address which the program can
@@ -78,8 +78,10 @@ pub trait Hal {
 /// The direction in which a buffer is passed.
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum BufferDirection {
-    /// The buffer is written by the driver and read by the device.
+    /// The buffer may be read or written by the driver, but only read by the device.
     DriverToDevice,
-    /// The buffer is written by the device and read by the driver.
+    /// The buffer may be read or written by the device, but only read by the driver.
     DeviceToDriver,
+    /// The buffer may be read or written by both the device and the driver.
+    Both,
 }

+ 1 - 1
src/hal/fake.rs

@@ -9,7 +9,7 @@ pub struct FakeHal;
 
 /// Fake HAL implementation for use in unit tests.
 impl Hal for FakeHal {
-    fn dma_alloc(pages: usize) -> PhysAddr {
+    fn dma_alloc(pages: usize, _direction: BufferDirection) -> PhysAddr {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
         // Safe because the size and alignment of the layout are non-zero.

+ 6 - 3
src/queue.rs

@@ -317,7 +317,7 @@ impl<H: Hal> VirtQueueLayout<H> {
         let (desc, avail, used) = queue_part_sizes(queue_size);
         let size = align_up(desc + avail) + align_up(used);
         // Allocate contiguous pages.
-        let dma = Dma::new(size / PAGE_SIZE)?;
+        let dma = Dma::new(size / PAGE_SIZE, BufferDirection::Both)?;
         Ok(Self::Legacy {
             dma,
             avail_offset: desc,
@@ -332,8 +332,8 @@ impl<H: Hal> VirtQueueLayout<H> {
     /// and allows the HAL to know which DMA regions are used in which direction.
     fn allocate_flexible(queue_size: u16) -> Result<Self> {
         let (desc, avail, used) = queue_part_sizes(queue_size);
-        let driver_to_device_dma = Dma::new(pages(desc + avail))?;
-        let device_to_driver_dma = Dma::new(pages(used))?;
+        let driver_to_device_dma = Dma::new(pages(desc + avail), BufferDirection::DriverToDevice)?;
+        let device_to_driver_dma = Dma::new(pages(used), BufferDirection::DeviceToDriver)?;
         Ok(Self::Modern {
             driver_to_device_dma,
             device_to_driver_dma,
@@ -461,6 +461,9 @@ impl Descriptor {
             | match direction {
                 BufferDirection::DeviceToDriver => DescFlags::WRITE,
                 BufferDirection::DriverToDevice => DescFlags::empty(),
+                BufferDirection::Both => {
+                    panic!("Buffer passed to device should never use BufferDirection::Both.")
+                }
             };
     }