浏览代码

Specify direction when allocating DMA regions.

This will allow systems with an IOMMU or equivalent to prevent the
device from writing to regions which it shouldn't, such as the
descriptor table and available ring.
Andrew Walbran 2 年之前
父节点
当前提交
608cd9f05d
共有 7 个文件被更改,包括 32 次插入23 次删除
  1. 1 1
      examples/aarch64/src/hal.rs
  2. 1 1
      examples/riscv/src/virtio_impl.rs
  3. 2 2
      src/device/console.rs
  4. 14 10
      src/device/gpu.rs
  5. 7 5
      src/hal.rs
  6. 1 1
      src/hal/fake.rs
  7. 6 3
      src/queue.rs

+ 1 - 1
examples/aarch64/src/hal.rs

@@ -18,7 +18,7 @@ lazy_static! {
 pub struct HalImpl;
 pub struct HalImpl;
 
 
 impl Hal for HalImpl {
 impl Hal for HalImpl {
-    fn dma_alloc(pages: usize) -> PhysAddr {
+    fn dma_alloc(pages: usize, _direction: BufferDirection) -> PhysAddr {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
         paddr
         paddr

+ 1 - 1
examples/riscv/src/virtio_impl.rs

@@ -17,7 +17,7 @@ lazy_static! {
 pub struct HalImpl;
 pub struct HalImpl;
 
 
 impl Hal for HalImpl {
 impl Hal for HalImpl {
-    fn dma_alloc(pages: usize) -> PhysAddr {
+    fn dma_alloc(pages: usize, _direction: BufferDirection) -> PhysAddr {
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         let paddr = DMA_PADDR.fetch_add(PAGE_SIZE * pages, Ordering::SeqCst);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
         trace!("alloc DMA: paddr={:#x}, pages={}", paddr, pages);
         paddr
         paddr

+ 2 - 2
src/device/console.rs

@@ -1,6 +1,6 @@
 //! Driver for VirtIO console devices.
 //! Driver for VirtIO console devices.
 
 
-use crate::hal::{Dma, Hal};
+use crate::hal::{BufferDirection, Dma, Hal};
 use crate::queue::VirtQueue;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, WriteOnly};
 use crate::volatile::{volread, ReadOnly, WriteOnly};
@@ -74,7 +74,7 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
         let config_space = transport.config_space::<Config>()?;
         let config_space = transport.config_space::<Config>()?;
         let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, QUEUE_SIZE)?;
         let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, QUEUE_SIZE)?;
         let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, QUEUE_SIZE)?;
         let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, QUEUE_SIZE)?;
-        let queue_buf_dma = Dma::new(1)?;
+        let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?;
         let queue_buf_rx = unsafe { &mut queue_buf_dma.as_buf()[0..] };
         let queue_buf_rx = unsafe { &mut queue_buf_dma.as_buf()[0..] };
         transport.finish_init();
         transport.finish_init();
         let mut console = VirtIOConsole {
         let mut console = VirtIOConsole {

+ 14 - 10
src/device/gpu.rs

@@ -1,10 +1,10 @@
 //! Driver for VirtIO GPU devices.
 //! Driver for VirtIO GPU devices.
 
 
-use crate::hal::{Dma, Hal};
+use crate::hal::{BufferDirection, Dma, Hal};
 use crate::queue::VirtQueue;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
 use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
-use crate::{pages, Error, Result, PAGE_SIZE};
+use crate::{pages, Error, Result};
 use bitflags::bitflags;
 use bitflags::bitflags;
 use log::info;
 use log::info;
 
 
@@ -26,8 +26,10 @@ pub struct VirtIOGpu<'a, H: Hal, T: Transport> {
     control_queue: VirtQueue<H>,
     control_queue: VirtQueue<H>,
     /// Queue for sending cursor commands.
     /// Queue for sending cursor commands.
     cursor_queue: VirtQueue<H>,
     cursor_queue: VirtQueue<H>,
-    /// Queue buffer DMA
-    queue_buf_dma: Dma<H>,
+    /// DMA region for sending data to the device.
+    dma_send: Dma<H>,
+    /// DMA region for receiving data from the device.
+    dma_recv: Dma<H>,
     /// Send buffer for queue.
     /// Send buffer for queue.
     queue_buf_send: &'a mut [u8],
     queue_buf_send: &'a mut [u8],
     /// Recv buffer for queue.
     /// Recv buffer for queue.
@@ -58,9 +60,10 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, 2)?;
         let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, 2)?;
         let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, 2)?;
         let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, 2)?;
 
 
-        let queue_buf_dma = Dma::new(2)?;
-        let queue_buf_send = unsafe { &mut queue_buf_dma.as_buf()[..PAGE_SIZE] };
-        let queue_buf_recv = unsafe { &mut queue_buf_dma.as_buf()[PAGE_SIZE..] };
+        let dma_send = Dma::new(1, BufferDirection::DriverToDevice)?;
+        let dma_recv = Dma::new(1, BufferDirection::DeviceToDriver)?;
+        let queue_buf_send = unsafe { dma_send.as_buf() };
+        let queue_buf_recv = unsafe { dma_recv.as_buf() };
 
 
         transport.finish_init();
         transport.finish_init();
 
 
@@ -71,7 +74,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
             rect: None,
             rect: None,
             control_queue,
             control_queue,
             cursor_queue,
             cursor_queue,
-            queue_buf_dma,
+            dma_send,
+            dma_recv,
             queue_buf_send,
             queue_buf_send,
             queue_buf_recv,
             queue_buf_recv,
         })
         })
@@ -104,7 +108,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
 
 
         // alloc continuous pages for the frame buffer
         // alloc continuous pages for the frame buffer
         let size = display_info.rect.width * display_info.rect.height * 4;
         let size = display_info.rect.width * display_info.rect.height * 4;
-        let frame_buffer_dma = Dma::new(pages(size as usize))?;
+        let frame_buffer_dma = Dma::new(pages(size as usize), BufferDirection::DriverToDevice)?;
 
 
         // resource_attach_backing
         // resource_attach_backing
         self.resource_attach_backing(RESOURCE_ID_FB, frame_buffer_dma.paddr() as u64, size)?;
         self.resource_attach_backing(RESOURCE_ID_FB, frame_buffer_dma.paddr() as u64, size)?;
@@ -140,7 +144,7 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         if cursor_image.len() != size as usize {
         if cursor_image.len() != size as usize {
             return Err(Error::InvalidParam);
             return Err(Error::InvalidParam);
         }
         }
-        let cursor_buffer_dma = Dma::new(pages(size as usize))?;
+        let cursor_buffer_dma = Dma::new(pages(size as usize), BufferDirection::DriverToDevice)?;
         let buf = unsafe { cursor_buffer_dma.as_buf() };
         let buf = unsafe { cursor_buffer_dma.as_buf() };
         buf.copy_from_slice(cursor_image);
         buf.copy_from_slice(cursor_image);
 
 

+ 7 - 5
src/hal.rs

@@ -19,8 +19,8 @@ pub struct Dma<H: Hal> {
 }
 }
 
 
 impl<H: Hal> Dma<H> {
 impl<H: Hal> Dma<H> {
-    pub fn new(pages: usize) -> Result<Self> {
-        let paddr = H::dma_alloc(pages);
+    pub fn new(pages: usize, direction: BufferDirection) -> Result<Self> {
+        let paddr = H::dma_alloc(pages, direction);
         if paddr == 0 {
         if paddr == 0 {
             return Err(Error::DmaError);
             return Err(Error::DmaError);
         }
         }
@@ -55,7 +55,7 @@ impl<H: Hal> Drop for Dma<H> {
 /// The interface which a particular hardware implementation must implement.
 /// The interface which a particular hardware implementation must implement.
 pub trait Hal {
 pub trait Hal {
     /// Allocates the given number of contiguous physical pages of DMA memory for virtio use.
     /// Allocates the given number of contiguous physical pages of DMA memory for virtio use.
-    fn dma_alloc(pages: usize) -> PhysAddr;
+    fn dma_alloc(pages: usize, direction: BufferDirection) -> PhysAddr;
     /// Deallocates the given contiguous physical DMA memory pages.
     /// Deallocates the given contiguous physical DMA memory pages.
     fn dma_dealloc(paddr: PhysAddr, pages: usize) -> i32;
     fn dma_dealloc(paddr: PhysAddr, pages: usize) -> i32;
     /// Converts a physical address used for virtio to a virtual address which the program can
     /// Converts a physical address used for virtio to a virtual address which the program can
@@ -78,8 +78,10 @@ pub trait Hal {
 /// The direction in which a buffer is passed.
 /// The direction in which a buffer is passed.
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 pub enum BufferDirection {
 pub enum BufferDirection {
-    /// The buffer is written by the driver and read by the device.
+    /// The buffer may be read or written by the driver, but only read by the device.
     DriverToDevice,
     DriverToDevice,
-    /// The buffer is written by the device and read by the driver.
+    /// The buffer may be read or written by the device, but only read by the driver.
     DeviceToDriver,
     DeviceToDriver,
+    /// The buffer may be read or written by both the device and the driver.
+    Both,
 }
 }

+ 1 - 1
src/hal/fake.rs

@@ -9,7 +9,7 @@ pub struct FakeHal;
 
 
 /// Fake HAL implementation for use in unit tests.
 /// Fake HAL implementation for use in unit tests.
 impl Hal for FakeHal {
 impl Hal for FakeHal {
-    fn dma_alloc(pages: usize) -> PhysAddr {
+    fn dma_alloc(pages: usize, _direction: BufferDirection) -> PhysAddr {
         assert_ne!(pages, 0);
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();
         // Safe because the size and alignment of the layout are non-zero.
         // Safe because the size and alignment of the layout are non-zero.

+ 6 - 3
src/queue.rs

@@ -317,7 +317,7 @@ impl<H: Hal> VirtQueueLayout<H> {
         let (desc, avail, used) = queue_part_sizes(queue_size);
         let (desc, avail, used) = queue_part_sizes(queue_size);
         let size = align_up(desc + avail) + align_up(used);
         let size = align_up(desc + avail) + align_up(used);
         // Allocate contiguous pages.
         // Allocate contiguous pages.
-        let dma = Dma::new(size / PAGE_SIZE)?;
+        let dma = Dma::new(size / PAGE_SIZE, BufferDirection::Both)?;
         Ok(Self::Legacy {
         Ok(Self::Legacy {
             dma,
             dma,
             avail_offset: desc,
             avail_offset: desc,
@@ -332,8 +332,8 @@ impl<H: Hal> VirtQueueLayout<H> {
     /// and allows the HAL to know which DMA regions are used in which direction.
     /// and allows the HAL to know which DMA regions are used in which direction.
     fn allocate_flexible(queue_size: u16) -> Result<Self> {
     fn allocate_flexible(queue_size: u16) -> Result<Self> {
         let (desc, avail, used) = queue_part_sizes(queue_size);
         let (desc, avail, used) = queue_part_sizes(queue_size);
-        let driver_to_device_dma = Dma::new(pages(desc + avail))?;
-        let device_to_driver_dma = Dma::new(pages(used))?;
+        let driver_to_device_dma = Dma::new(pages(desc + avail), BufferDirection::DriverToDevice)?;
+        let device_to_driver_dma = Dma::new(pages(used), BufferDirection::DeviceToDriver)?;
         Ok(Self::Modern {
         Ok(Self::Modern {
             driver_to_device_dma,
             driver_to_device_dma,
             device_to_driver_dma,
             device_to_driver_dma,
@@ -461,6 +461,9 @@ impl Descriptor {
             | match direction {
             | match direction {
                 BufferDirection::DeviceToDriver => DescFlags::WRITE,
                 BufferDirection::DeviceToDriver => DescFlags::WRITE,
                 BufferDirection::DriverToDevice => DescFlags::empty(),
                 BufferDirection::DriverToDevice => DescFlags::empty(),
+                BufferDirection::Both => {
+                    panic!("Buffer passed to device should never use BufferDirection::Both.")
+                }
             };
             };
     }
     }