浏览代码

Fix soundness issues with MMIO and shared memory (#18)

* First cut at sound volatile MMIO.

* Use new volatile module for blk too.

* Backwards compatibility for old rustc in example.

* Document volatile module.

* Use volatile module for remaining devices.

* UsedRing is only read by the driver.

* No need to volatile for queue.

The queue is just shared memory, not MMIO memory, so there is no need to
use volatile reads and writes.

* Add fence before reading used ring and after writing available ring.
Andrew Walbran 2 年之前
父节点
当前提交
e28a1e05c8
共有 10 个文件被更改,包括 394 次插入213 次删除
  1. 0 1
      Cargo.toml
  2. 9 11
      src/blk.rs
  3. 4 4
      src/console.rs
  4. 3 3
      src/gpu.rs
  5. 15 11
      src/input.rs
  6. 1 0
      src/lib.rs
  7. 12 9
      src/net.rs
  8. 165 112
      src/queue.rs
  9. 77 62
      src/transport/mmio.rs
  10. 108 0
      src/volatile.rs

+ 0 - 1
Cargo.toml

@@ -8,6 +8,5 @@ description = "VirtIO guest drivers."
 # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
 
 [dependencies]
-volatile = "0.3"
 log = "0.4"
 bitflags = "1.3"

+ 9 - 11
src/blk.rs

@@ -1,22 +1,22 @@
 use super::*;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
+use crate::volatile::{volread, Volatile};
 use bitflags::*;
 use core::hint::spin_loop;
 use log::*;
-use volatile::Volatile;
 
 /// The virtio block device is a simple virtual block device (ie. disk).
 ///
 /// Read and write requests (and other exotic requests) are placed in the queue,
 /// and serviced (probably out of order) by the device except where noted.
-pub struct VirtIOBlk<'a, H: Hal, T: Transport> {
+pub struct VirtIOBlk<H: Hal, T: Transport> {
     transport: T,
-    queue: VirtQueue<'a, H>,
+    queue: VirtQueue<H>,
     capacity: usize,
 }
 
-impl<H: Hal, T: Transport> VirtIOBlk<'_, H, T> {
+impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// Create a new VirtIO-Blk driver.
     pub fn new(mut transport: T) -> Result<Self> {
         transport.begin_init(|features| {
@@ -28,13 +28,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<'_, H, T> {
         });
 
         // read configuration space
-        let config_space = transport.config_space().cast::<BlkConfig>();
-        let config = unsafe { config_space.as_ref() };
+        let config = transport.config_space().cast::<BlkConfig>();
         info!("config: {:?}", config);
-        info!(
-            "found a block device of size {}KB",
-            config.capacity.read() / 2
-        );
+        // Safe because config is a valid pointer to the device configuration space.
+        let capacity = unsafe { volread!(config, capacity) };
+        info!("found a block device of size {}KB", capacity / 2);
 
         let queue = VirtQueue::new(&mut transport, 0, 16)?;
         transport.finish_init();
@@ -42,7 +40,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<'_, H, T> {
         Ok(VirtIOBlk {
             transport,
             queue,
-            capacity: config.capacity.read() as usize,
+            capacity: capacity as usize,
         })
     }
 

+ 4 - 4
src/console.rs

@@ -1,10 +1,10 @@
 use super::*;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
+use crate::volatile::{ReadOnly, WriteOnly};
 use bitflags::*;
 use core::{fmt, hint::spin_loop};
 use log::*;
-use volatile::{ReadOnly, WriteOnly};
 
 const QUEUE_RECEIVEQ_PORT_0: usize = 0;
 const QUEUE_TRANSMITQ_PORT_0: usize = 1;
@@ -14,8 +14,8 @@ const QUEUE_SIZE: u16 = 2;
 /// Emergency and cols/rows unimplemented.
 pub struct VirtIOConsole<'a, H: Hal, T: Transport> {
     transport: T,
-    receiveq: VirtQueue<'a, H>,
-    transmitq: VirtQueue<'a, H>,
+    receiveq: VirtQueue<H>,
+    transmitq: VirtQueue<H>,
     queue_buf_dma: DMA<H>,
     queue_buf_rx: &'a mut [u8],
     cursor: usize,
@@ -161,7 +161,7 @@ mod tests {
             cols: ReadOnly::new(0),
             rows: ReadOnly::new(0),
             max_nr_ports: ReadOnly::new(0),
-            emerg_wr: WriteOnly::new(0),
+            emerg_wr: WriteOnly::default(),
         };
         let state = Arc::new(Mutex::new(State {
             status: DeviceStatus::empty(),

+ 3 - 3
src/gpu.rs

@@ -1,10 +1,10 @@
 use super::*;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
+use crate::volatile::{ReadOnly, Volatile, WriteOnly};
 use bitflags::*;
 use core::{fmt, hint::spin_loop};
 use log::*;
-use volatile::{ReadOnly, Volatile, WriteOnly};
 
 /// A virtio based graphics adapter.
 ///
@@ -21,9 +21,9 @@ pub struct VirtIOGpu<'a, H: Hal, T: Transport> {
     /// DMA area of cursor image buffer.
     cursor_buffer_dma: Option<DMA<H>>,
     /// Queue for sending control commands.
-    control_queue: VirtQueue<'a, H>,
+    control_queue: VirtQueue<H>,
     /// Queue for sending cursor commands.
-    cursor_queue: VirtQueue<'a, H>,
+    cursor_queue: VirtQueue<H>,
     /// Queue buffer DMA
     queue_buf_dma: DMA<H>,
     /// Send buffer for queue.

+ 15 - 11
src/input.rs

@@ -1,23 +1,23 @@
 use super::*;
 use crate::transport::Transport;
+use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly};
 use alloc::boxed::Box;
 use bitflags::*;
 use log::*;
-use volatile::{ReadOnly, WriteOnly};
 
 /// Virtual human interface devices such as keyboards, mice and tablets.
 ///
 /// An instance of the virtio device represents one such input device.
 /// Device behavior mirrors that of the evdev layer in Linux,
 /// making pass-through implementations on top of evdev easy.
-pub struct VirtIOInput<'a, H: Hal, T: Transport> {
+pub struct VirtIOInput<H: Hal, T: Transport> {
     transport: T,
-    event_queue: VirtQueue<'a, H>,
-    status_queue: VirtQueue<'a, H>,
+    event_queue: VirtQueue<H>,
+    status_queue: VirtQueue<H>,
     event_buf: Box<[InputEvent; 32]>,
 }
 
-impl<H: Hal, T: Transport> VirtIOInput<'_, H, T> {
+impl<H: Hal, T: Transport> VirtIOInput<H, T> {
     /// Create a new VirtIO-Input driver.
     pub fn new(mut transport: T) -> Result<Self> {
         let mut event_buf = Box::new([InputEvent::default(); QUEUE_SIZE]);
@@ -75,12 +75,16 @@ impl<H: Hal, T: Transport> VirtIOInput<'_, H, T> {
         subsel: u8,
         out: &mut [u8],
     ) -> u8 {
-        let mut config_space = self.transport.config_space().cast::<Config>();
-        let config = unsafe { config_space.as_mut() };
-        config.select.write(select as u8);
-        config.subsel.write(subsel);
-        let size = config.size.read();
-        let data = config.data.read();
+        let config = self.transport.config_space().cast::<Config>();
+        let size;
+        let data;
+        // Safe because config points to a valid MMIO region for the config space.
+        unsafe {
+            volwrite!(config, select, select as u8);
+            volwrite!(config, subsel, subsel);
+            size = volread!(config, size);
+            data = volread!(config, data);
+        }
         out[..size as usize].copy_from_slice(&data[..size as usize]);
         size
     }

+ 1 - 0
src/lib.rs

@@ -15,6 +15,7 @@ mod input;
 mod net;
 mod queue;
 mod transport;
+mod volatile;
 
 pub use self::blk::{BlkResp, RespStatus, VirtIOBlk};
 pub use self::console::VirtIOConsole;

+ 12 - 9
src/net.rs

@@ -2,10 +2,10 @@ use core::mem::{size_of, MaybeUninit};
 
 use super::*;
 use crate::transport::Transport;
+use crate::volatile::{volread, ReadOnly, Volatile};
 use bitflags::*;
 use core::hint::spin_loop;
 use log::*;
-use volatile::{ReadOnly, Volatile};
 
 /// The virtio network device is a virtual ethernet card.
 ///
@@ -14,14 +14,14 @@ use volatile::{ReadOnly, Volatile};
 /// Empty buffers are placed in one virtqueue for receiving packets, and
 /// outgoing packets are enqueued into another for transmission in that order.
 /// A third command queue is used to control advanced filtering features.
-pub struct VirtIONet<'a, H: Hal, T: Transport> {
+pub struct VirtIONet<H: Hal, T: Transport> {
     transport: T,
     mac: EthernetAddress,
-    recv_queue: VirtQueue<'a, H>,
-    send_queue: VirtQueue<'a, H>,
+    recv_queue: VirtQueue<H>,
+    send_queue: VirtQueue<H>,
 }
 
-impl<H: Hal, T: Transport> VirtIONet<'_, H, T> {
+impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Create a new VirtIO-Net driver.
     pub fn new(mut transport: T) -> Result<Self> {
         transport.begin_init(|features| {
@@ -31,10 +31,13 @@ impl<H: Hal, T: Transport> VirtIONet<'_, H, T> {
             (features & supported_features).bits()
         });
         // read configuration space
-        let config_space = transport.config_space().cast::<Config>();
-        let config = unsafe { config_space.as_ref() };
-        let mac = config.mac.read();
-        debug!("Got MAC={:?}, status={:?}", mac, config.status.read());
+        let config = transport.config_space().cast::<Config>();
+        let mac;
+        // Safe because config points to a valid MMIO region for the config space.
+        unsafe {
+            mac = volread!(config, mac);
+            debug!("Got MAC={:?}, status={:?}", mac, volread!(config, status));
+        }
 
         let queue_num = 2; // for simplicity
         let recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE, queue_num)?;

+ 165 - 112
src/queue.rs

@@ -1,28 +1,26 @@
+#[cfg(test)]
+use core::cmp::min;
 use core::mem::size_of;
-use core::slice;
+use core::ptr::{self, addr_of_mut, NonNull};
 use core::sync::atomic::{fence, Ordering};
-#[cfg(test)]
-use core::{cmp::min, ptr};
 
 use super::*;
 use crate::transport::Transport;
 use bitflags::*;
 
-use volatile::Volatile;
-
 /// The mechanism for bulk data transport on virtio devices.
 ///
 /// Each device can have zero or more virtqueues.
 #[derive(Debug)]
-pub struct VirtQueue<'a, H: Hal> {
+pub struct VirtQueue<H: Hal> {
     /// DMA guard
     dma: DMA<H>,
     /// Descriptor table
-    desc: &'a mut [Descriptor],
+    desc: NonNull<[Descriptor]>,
     /// Available ring
-    avail: &'a mut AvailRing,
+    avail: NonNull<AvailRing>,
     /// Used ring
-    used: &'a mut UsedRing,
+    used: NonNull<UsedRing>,
 
     /// The index of queue
     queue_idx: u32,
@@ -39,7 +37,7 @@ pub struct VirtQueue<'a, H: Hal> {
     last_used_idx: u16,
 }
 
-impl<H: Hal> VirtQueue<'_, H> {
+impl<H: Hal> VirtQueue<H> {
     /// Create a new VirtQueue.
     pub fn new<T: Transport>(transport: &mut T, idx: usize, size: u16) -> Result<Self> {
         if transport.queue_used(idx as u32) {
@@ -60,14 +58,21 @@ impl<H: Hal> VirtQueue<'_, H> {
             dma.paddr() + layout.used_offset,
         );
 
-        let desc =
-            unsafe { slice::from_raw_parts_mut(dma.vaddr() as *mut Descriptor, size as usize) };
-        let avail = unsafe { &mut *((dma.vaddr() + layout.avail_offset) as *mut AvailRing) };
-        let used = unsafe { &mut *((dma.vaddr() + layout.used_offset) as *mut UsedRing) };
+        let desc = NonNull::new(ptr::slice_from_raw_parts_mut(
+            dma.vaddr() as *mut Descriptor,
+            size as usize,
+        ))
+        .unwrap();
+        let avail = NonNull::new((dma.vaddr() + layout.avail_offset) as *mut AvailRing).unwrap();
+        let used = NonNull::new((dma.vaddr() + layout.used_offset) as *mut UsedRing).unwrap();
 
         // Link descriptors together.
         for i in 0..(size - 1) {
-            desc[i as usize].next.write(i + 1);
+            // Safe because `desc` is properly aligned, dereferenceable, initialised, and the device
+            // won't access the descriptors for the duration of this unsafe block.
+            unsafe {
+                (*desc.as_ptr())[i as usize].next = i + 1;
+            }
         }
 
         Ok(VirtQueue {
@@ -98,47 +103,70 @@ impl<H: Hal> VirtQueue<'_, H> {
         // allocate descriptors from free list
         let head = self.free_head;
         let mut last = self.free_head;
-        for input in inputs.iter() {
-            let desc = &mut self.desc[self.free_head as usize];
-            desc.set_buf::<H>(input);
-            desc.flags.write(DescFlags::NEXT);
-            last = self.free_head;
-            self.free_head = desc.next.read();
-        }
-        for output in outputs.iter() {
-            let desc = &mut self.desc[self.free_head as usize];
-            desc.set_buf::<H>(output);
-            desc.flags.write(DescFlags::NEXT | DescFlags::WRITE);
-            last = self.free_head;
-            self.free_head = desc.next.read();
-        }
-        // set last_elem.next = NULL
-        {
-            let desc = &mut self.desc[last as usize];
-            let mut flags = desc.flags.read();
-            flags.remove(DescFlags::NEXT);
-            desc.flags.write(flags);
+
+        // Safe because self.desc is properly aligned, dereferenceable and initialised, and nothing
+        // else reads or writes the free descriptors during this block.
+        unsafe {
+            for input in inputs.iter() {
+                let mut desc = self.desc_ptr(self.free_head);
+                (*desc).set_buf::<H>(input);
+                (*desc).flags = DescFlags::NEXT;
+                last = self.free_head;
+                self.free_head = (*desc).next;
+            }
+            for output in outputs.iter() {
+                let desc = self.desc_ptr(self.free_head);
+                (*desc).set_buf::<H>(output);
+                (*desc).flags = DescFlags::NEXT | DescFlags::WRITE;
+                last = self.free_head;
+                self.free_head = (*desc).next;
+            }
+
+            // set last_elem.next = NULL
+            (*self.desc_ptr(last)).flags.remove(DescFlags::NEXT);
         }
         self.num_used += (inputs.len() + outputs.len()) as u16;
 
         let avail_slot = self.avail_idx & (self.queue_size - 1);
-        self.avail.ring[avail_slot as usize].write(head);
+        // Safe because self.avail is properly aligned, dereferenceable and initialised.
+        unsafe {
+            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
+        }
 
-        // write barrier
+        // Write barrier so that device sees changes to descriptor table and available ring before
+        // change to available index.
         fence(Ordering::SeqCst);
 
         // increase head of avail ring
         self.avail_idx = self.avail_idx.wrapping_add(1);
-        self.avail.idx.write(self.avail_idx);
+        // Safe because self.avail is properly aligned, dereferenceable and initialised.
+        unsafe {
+            (*self.avail.as_ptr()).idx = self.avail_idx;
+        }
+
+        // Write barrier so that device can see change to available index after this method returns.
+        fence(Ordering::SeqCst);
+
         Ok(head)
     }
 
-    /// Whether there is a used element that can pop.
+    /// Returns a non-null pointer to the descriptor at the given index.
+    fn desc_ptr(&mut self, index: u16) -> *mut Descriptor {
+        // Safe because self.desc is properly aligned and dereferenceable.
+        unsafe { addr_of_mut!((*self.desc.as_ptr())[index as usize]) }
+    }
+
+    /// Returns whether there is a used element that can be popped.
     pub fn can_pop(&self) -> bool {
-        self.last_used_idx != self.used.idx.read()
+        // Read barrier, so we read a fresh value from the device.
+        fence(Ordering::SeqCst);
+
+        // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
+        // instance of UsedRing.
+        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx }
     }
 
-    /// The number of free descriptors.
+    /// Returns the number of free descriptors.
     pub fn available_desc(&self) -> usize {
         (self.queue_size - self.num_used) as usize
     }
@@ -147,17 +175,21 @@ impl<H: Hal> VirtQueue<'_, H> {
     ///
     /// This will push all linked descriptors at the front of the free list.
     fn recycle_descriptors(&mut self, mut head: u16) {
-        let origin_free_head = self.free_head;
+        let original_free_head = self.free_head;
         self.free_head = head;
         loop {
-            let desc = &mut self.desc[head as usize];
-            let flags = desc.flags.read();
-            self.num_used -= 1;
-            if flags.contains(DescFlags::NEXT) {
-                head = desc.next.read();
-            } else {
-                desc.next.write(origin_free_head);
-                return;
+            let desc = self.desc_ptr(head);
+            // Safe because self.desc is properly aligned, dereferenceable and initialised, and
+            // nothing else reads or writes the descriptor during this block.
+            unsafe {
+                let flags = (*desc).flags;
+                self.num_used -= 1;
+                if flags.contains(DescFlags::NEXT) {
+                    head = (*desc).next;
+                } else {
+                    (*desc).next = original_free_head;
+                    return;
+                }
             }
         }
     }
@@ -173,8 +205,14 @@ impl<H: Hal> VirtQueue<'_, H> {
         fence(Ordering::SeqCst);
 
         let last_used_slot = self.last_used_idx & (self.queue_size - 1);
-        let index = self.used.ring[last_used_slot as usize].id.read() as u16;
-        let len = self.used.ring[last_used_slot as usize].len.read();
+        let index;
+        let len;
+        // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
+        // instance of UsedRing.
+        unsafe {
+            index = (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16;
+            len = (*self.used.as_ptr()).ring[last_used_slot as usize].len;
+        }
 
         self.recycle_descriptors(index);
         self.last_used_idx = self.last_used_idx.wrapping_add(1);
@@ -218,17 +256,16 @@ impl VirtQueueLayout {
 #[repr(C, align(16))]
 #[derive(Debug)]
 pub(crate) struct Descriptor {
-    addr: Volatile<u64>,
-    len: Volatile<u32>,
-    flags: Volatile<DescFlags>,
-    next: Volatile<u16>,
+    addr: u64,
+    len: u32,
+    flags: DescFlags,
+    next: u16,
 }
 
 impl Descriptor {
     fn set_buf<H: Hal>(&mut self, buf: &[u8]) {
-        self.addr
-            .write(H::virt_to_phys(buf.as_ptr() as usize) as u64);
-        self.len.write(buf.len() as u32);
+        self.addr = H::virt_to_phys(buf.as_ptr() as usize) as u64;
+        self.len = buf.len() as u32;
     }
 }
 
@@ -247,11 +284,11 @@ bitflags! {
 #[repr(C)]
 #[derive(Debug)]
 struct AvailRing {
-    flags: Volatile<u16>,
+    flags: u16,
     /// A driver MUST NOT decrement the idx.
-    idx: Volatile<u16>,
-    ring: [Volatile<u16>; 32], // actual size: queue_size
-    used_event: Volatile<u16>, // unused
+    idx: u16,
+    ring: [u16; 32], // actual size: queue_size
+    used_event: u16, // unused
 }
 
 /// The used ring is where the device returns buffers once it is done with them:
@@ -259,17 +296,17 @@ struct AvailRing {
 #[repr(C)]
 #[derive(Debug)]
 struct UsedRing {
-    flags: Volatile<u16>,
-    idx: Volatile<u16>,
-    ring: [UsedElem; 32],       // actual size: queue_size
-    avail_event: Volatile<u16>, // unused
+    flags: u16,
+    idx: u16,
+    ring: [UsedElem; 32], // actual size: queue_size
+    avail_event: u16,     // unused
 }
 
 #[repr(C)]
 #[derive(Debug)]
 struct UsedElem {
-    id: Volatile<u32>,
-    len: Volatile<u32>,
+    id: u32,
+    len: u32,
 }
 
 /// Simulates the device writing to a VirtIO queue, for use in tests.
@@ -283,37 +320,38 @@ pub(crate) fn fake_write_to_queue(
     receive_queue_device_area: VirtAddr,
     data: &[u8],
 ) {
-    let descriptors =
-        unsafe { slice::from_raw_parts(receive_queue_descriptors, queue_size as usize) };
+    let descriptors = ptr::slice_from_raw_parts(receive_queue_descriptors, queue_size as usize);
     let available_ring = receive_queue_driver_area as *const AvailRing;
     let used_ring = receive_queue_device_area as *mut UsedRing;
+    // Safe because the various pointers are properly aligned, dereferenceable, initialised, and
+    // nothing else accesses them during this block.
     unsafe {
         // Make sure there is actually at least one descriptor available to write to.
-        assert_ne!((*available_ring).idx.read(), (*used_ring).idx.read());
+        assert_ne!((*available_ring).idx, (*used_ring).idx);
         // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
         // `used_ring.idx` marks the next descriptor we should take from the available ring.
-        let next_slot = (*used_ring).idx.read() & (queue_size - 1);
-        let head_descriptor_index = (*available_ring).ring[next_slot as usize].read();
-        let mut descriptor = &descriptors[head_descriptor_index as usize];
+        let next_slot = (*used_ring).idx & (queue_size - 1);
+        let head_descriptor_index = (*available_ring).ring[next_slot as usize];
+        let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
 
         // Loop through all descriptors in the chain, writing data to them.
         let mut remaining_data = data;
         loop {
             // Check the buffer and write to it.
-            let flags = descriptor.flags.read();
+            let flags = descriptor.flags;
             assert!(flags.contains(DescFlags::WRITE));
-            let buffer_length = descriptor.len.read() as usize;
+            let buffer_length = descriptor.len as usize;
             let length_to_write = min(remaining_data.len(), buffer_length);
             ptr::copy(
                 remaining_data.as_ptr(),
-                descriptor.addr.read() as *mut u8,
+                descriptor.addr as *mut u8,
                 length_to_write,
             );
             remaining_data = &remaining_data[length_to_write..];
 
             if flags.contains(DescFlags::NEXT) {
-                let next = descriptor.next.read() as usize;
-                descriptor = &descriptors[next];
+                let next = descriptor.next as usize;
+                descriptor = &(*descriptors)[next];
             } else {
                 assert_eq!(remaining_data.len(), 0);
                 break;
@@ -321,13 +359,9 @@ pub(crate) fn fake_write_to_queue(
         }
 
         // Mark the buffer as used.
-        (*used_ring).ring[next_slot as usize]
-            .id
-            .write(head_descriptor_index as u32);
-        (*used_ring).ring[next_slot as usize]
-            .len
-            .write(data.len() as u32);
-        (*used_ring).idx.update(|idx| *idx += 1);
+        (*used_ring).ring[next_slot as usize].id = head_descriptor_index as u32;
+        (*used_ring).ring[next_slot as usize].len = data.len() as u32;
+        (*used_ring).idx += 1;
     }
 }
 
@@ -408,30 +442,49 @@ mod tests {
         assert_eq!(queue.available_desc(), 0);
         assert!(!queue.can_pop());
 
-        let first_descriptor_index = queue.avail.ring[0].read();
-        assert_eq!(first_descriptor_index, token);
-        assert_eq!(queue.desc[first_descriptor_index as usize].len.read(), 2);
-        assert_eq!(
-            queue.desc[first_descriptor_index as usize].flags.read(),
-            DescFlags::NEXT
-        );
-        let second_descriptor_index = queue.desc[first_descriptor_index as usize].next.read();
-        assert_eq!(queue.desc[second_descriptor_index as usize].len.read(), 1);
-        assert_eq!(
-            queue.desc[second_descriptor_index as usize].flags.read(),
-            DescFlags::NEXT
-        );
-        let third_descriptor_index = queue.desc[second_descriptor_index as usize].next.read();
-        assert_eq!(queue.desc[third_descriptor_index as usize].len.read(), 2);
-        assert_eq!(
-            queue.desc[third_descriptor_index as usize].flags.read(),
-            DescFlags::NEXT | DescFlags::WRITE
-        );
-        let fourth_descriptor_index = queue.desc[third_descriptor_index as usize].next.read();
-        assert_eq!(queue.desc[fourth_descriptor_index as usize].len.read(), 1);
-        assert_eq!(
-            queue.desc[fourth_descriptor_index as usize].flags.read(),
-            DescFlags::WRITE
-        );
+        // Safe because the various parts of the queue are properly aligned, dereferenceable and
+        // initialised, and nothing else is accessing them at the same time.
+        unsafe {
+            let first_descriptor_index = (*queue.avail.as_ptr()).ring[0];
+            assert_eq!(first_descriptor_index, token);
+            assert_eq!(
+                (*queue.desc.as_ptr())[first_descriptor_index as usize].len,
+                2
+            );
+            assert_eq!(
+                (*queue.desc.as_ptr())[first_descriptor_index as usize].flags,
+                DescFlags::NEXT
+            );
+            let second_descriptor_index =
+                (*queue.desc.as_ptr())[first_descriptor_index as usize].next;
+            assert_eq!(
+                (*queue.desc.as_ptr())[second_descriptor_index as usize].len,
+                1
+            );
+            assert_eq!(
+                (*queue.desc.as_ptr())[second_descriptor_index as usize].flags,
+                DescFlags::NEXT
+            );
+            let third_descriptor_index =
+                (*queue.desc.as_ptr())[second_descriptor_index as usize].next;
+            assert_eq!(
+                (*queue.desc.as_ptr())[third_descriptor_index as usize].len,
+                2
+            );
+            assert_eq!(
+                (*queue.desc.as_ptr())[third_descriptor_index as usize].flags,
+                DescFlags::NEXT | DescFlags::WRITE
+            );
+            let fourth_descriptor_index =
+                (*queue.desc.as_ptr())[third_descriptor_index as usize].next;
+            assert_eq!(
+                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].len,
+                1
+            );
+            assert_eq!(
+                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].flags,
+                DescFlags::WRITE
+            );
+        }
     }
 }

+ 77 - 62
src/transport/mmio.rs

@@ -1,12 +1,16 @@
 use super::{DeviceStatus, DeviceType, Transport};
-use crate::{align_up, queue::Descriptor, PhysAddr, PAGE_SIZE};
+use crate::{
+    align_up,
+    queue::Descriptor,
+    volatile::{volread, volwrite, ReadOnly, Volatile, WriteOnly},
+    PhysAddr, PAGE_SIZE,
+};
 use core::{
     convert::{TryFrom, TryInto},
     fmt::{self, Display, Formatter},
     mem::size_of,
     ptr::NonNull,
 };
-use volatile::{ReadOnly, Volatile, WriteOnly};
 
 const MAGIC_VALUE: u32 = 0x7472_6976;
 pub(crate) const LEGACY_VERSION: u32 = 1;
@@ -281,14 +285,14 @@ impl MmioTransport {
     /// `header` must point to a properly aligned valid VirtIO MMIO region, which must remain valid
     /// for the lifetime of the transport that is returned.
     pub unsafe fn new(header: NonNull<VirtIOHeader>) -> Result<Self, MmioError> {
-        let magic = header.as_ref().magic.read();
+        let magic = volread!(header, magic);
         if magic != MAGIC_VALUE {
             return Err(MmioError::BadMagic(magic));
         }
-        if header.as_ref().device_id.read() == 0 {
+        if volread!(header, device_id) == 0 {
             return Err(MmioError::ZeroDeviceId);
         }
-        let version = header.as_ref().version.read().try_into()?;
+        let version = volread!(header, version).try_into()?;
         Ok(Self { header, version })
     }
 
@@ -299,61 +303,67 @@ impl MmioTransport {
 
     /// Gets the vendor ID.
     pub fn vendor_id(&self) -> u32 {
-        self.header().vendor_id.read()
-    }
-
-    fn header(&self) -> &VirtIOHeader {
-        unsafe { self.header.as_ref() }
-    }
-
-    fn header_mut(&mut self) -> &mut VirtIOHeader {
-        unsafe { self.header.as_mut() }
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe { volread!(self.header, vendor_id) }
     }
 }
 
 impl Transport for MmioTransport {
     fn device_type(&self) -> DeviceType {
-        match self.header().device_id.read() {
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        match unsafe { volread!(self.header, device_id) } {
             x @ 1..=13 | x @ 16..=24 => unsafe { core::mem::transmute(x as u8) },
             _ => DeviceType::Invalid,
         }
     }
 
     fn read_device_features(&mut self) -> u64 {
-        let header = self.header_mut();
-        header.device_features_sel.write(0); // device features [0, 32)
-        let mut device_features_bits = header.device_features.read().into();
-        header.device_features_sel.write(1); // device features [32, 64)
-        device_features_bits += (header.device_features.read() as u64) << 32;
-        device_features_bits
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe {
+            volwrite!(self.header, device_features_sel, 0); // device features [0, 32)
+            let mut device_features_bits = volread!(self.header, device_features).into();
+            volwrite!(self.header, device_features_sel, 1); // device features [32, 64)
+            device_features_bits += (volread!(self.header, device_features) as u64) << 32;
+            device_features_bits
+        }
     }
 
     fn write_driver_features(&mut self, driver_features: u64) {
-        let header = self.header_mut();
-        header.driver_features_sel.write(0); // driver features [0, 32)
-        header.driver_features.write(driver_features as u32);
-        header.driver_features_sel.write(1); // driver features [32, 64)
-        header.driver_features.write((driver_features >> 32) as u32);
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe {
+            volwrite!(self.header, driver_features_sel, 0); // driver features [0, 32)
+            volwrite!(self.header, driver_features, driver_features as u32);
+            volwrite!(self.header, driver_features_sel, 1); // driver features [32, 64)
+            volwrite!(self.header, driver_features, (driver_features >> 32) as u32);
+        }
     }
 
     fn max_queue_size(&self) -> u32 {
-        self.header().queue_num_max.read()
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe { volread!(self.header, queue_num_max) }
     }
 
     fn notify(&mut self, queue: u32) {
-        self.header_mut().queue_notify.write(queue);
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe {
+            volwrite!(self.header, queue_notify, queue);
+        }
     }
 
     fn set_status(&mut self, status: DeviceStatus) {
-        self.header_mut().status.write(status);
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe {
+            volwrite!(self.header, status, status);
+        }
     }
 
     fn set_guest_page_size(&mut self, guest_page_size: u32) {
         match self.version {
             MmioVersion::Legacy => {
-                self.header_mut()
-                    .legacy_guest_page_size
-                    .write(guest_page_size);
+                // Safe because self.header points to a valid VirtIO MMIO region.
+                unsafe {
+                    volwrite!(self.header, legacy_guest_page_size, guest_page_size);
+                }
             }
             MmioVersion::Modern => {
                 // No-op, modern devices don't care.
@@ -385,47 +395,52 @@ impl Transport for MmioTransport {
                 let align = PAGE_SIZE as u32;
                 let pfn = (descriptors / PAGE_SIZE) as u32;
                 assert_eq!(pfn as usize * PAGE_SIZE, descriptors);
-                self.header_mut().queue_sel.write(queue);
-                self.header_mut().queue_num.write(size);
-                self.header_mut().legacy_queue_align.write(align);
-                self.header_mut().legacy_queue_pfn.write(pfn);
+                // Safe because self.header points to a valid VirtIO MMIO region.
+                unsafe {
+                    volwrite!(self.header, queue_sel, queue);
+                    volwrite!(self.header, queue_num, size);
+                    volwrite!(self.header, legacy_queue_align, align);
+                    volwrite!(self.header, legacy_queue_pfn, pfn);
+                }
             }
             MmioVersion::Modern => {
-                self.header_mut().queue_sel.write(queue);
-                self.header_mut().queue_num.write(size);
-                self.header_mut().queue_desc_low.write(descriptors as u32);
-                self.header_mut()
-                    .queue_desc_high
-                    .write((descriptors >> 32) as u32);
-                self.header_mut().queue_driver_low.write(driver_area as u32);
-                self.header_mut()
-                    .queue_driver_high
-                    .write((driver_area >> 32) as u32);
-                self.header_mut().queue_device_low.write(device_area as u32);
-                self.header_mut()
-                    .queue_device_high
-                    .write((device_area >> 32) as u32);
-                self.header_mut().queue_ready.write(1);
+                // Safe because self.header points to a valid VirtIO MMIO region.
+                unsafe {
+                    volwrite!(self.header, queue_sel, queue);
+                    volwrite!(self.header, queue_num, size);
+                    volwrite!(self.header, queue_desc_low, descriptors as u32);
+                    volwrite!(self.header, queue_desc_high, (descriptors >> 32) as u32);
+                    volwrite!(self.header, queue_driver_low, driver_area as u32);
+                    volwrite!(self.header, queue_driver_high, (driver_area >> 32) as u32);
+                    volwrite!(self.header, queue_device_low, device_area as u32);
+                    volwrite!(self.header, queue_device_high, (device_area >> 32) as u32);
+                    volwrite!(self.header, queue_ready, 1);
+                }
             }
         }
     }
 
     fn queue_used(&mut self, queue: u32) -> bool {
-        self.header_mut().queue_sel.write(queue);
-        match self.version {
-            MmioVersion::Legacy => self.header().legacy_queue_pfn.read() != 0,
-            MmioVersion::Modern => self.header().queue_ready.read() != 0,
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe {
+            volwrite!(self.header, queue_sel, queue);
+            match self.version {
+                MmioVersion::Legacy => volread!(self.header, legacy_queue_pfn) != 0,
+                MmioVersion::Modern => volread!(self.header, queue_ready) != 0,
+            }
         }
     }
 
     fn ack_interrupt(&mut self) -> bool {
-        let header = self.header_mut();
-        let interrupt = header.interrupt_status.read();
-        if interrupt != 0 {
-            header.interrupt_ack.write(interrupt);
-            true
-        } else {
-            false
+        // Safe because self.header points to a valid VirtIO MMIO region.
+        unsafe {
+            let interrupt = volread!(self.header, interrupt_status);
+            if interrupt != 0 {
+                volwrite!(self.header, interrupt_ack, interrupt);
+                true
+            } else {
+                false
+            }
         }
     }
 

+ 108 - 0
src/volatile.rs

@@ -0,0 +1,108 @@
+/// An MMIO register which can only be read from.
+#[derive(Debug, Default)]
+#[repr(transparent)]
+pub struct ReadOnly<T: Copy>(T);
+
+impl<T: Copy> ReadOnly<T> {
+    /// Construct a new instance for testing.
+    pub fn new(value: T) -> Self {
+        Self(value)
+    }
+}
+
+/// An MMIO register which can only be written to.
+#[derive(Debug, Default)]
+#[repr(transparent)]
+pub struct WriteOnly<T: Copy>(T);
+
+/// An MMIO register which may be both read and written.
+#[derive(Debug, Default)]
+#[repr(transparent)]
+pub struct Volatile<T: Copy>(T);
+
+impl<T: Copy> Volatile<T> {
+    /// Construct a new instance for testing.
+    pub fn new(value: T) -> Self {
+        Self(value)
+    }
+}
+
+/// A trait implemented by MMIO registers which may be read from.
+pub trait VolatileReadable<T> {
+    /// Performs a volatile read from the MMIO register.
+    unsafe fn vread(self) -> T;
+}
+
+impl<T: Copy> VolatileReadable<T> for *const ReadOnly<T> {
+    unsafe fn vread(self) -> T {
+        self.read_volatile().0
+    }
+}
+
+impl<T: Copy> VolatileReadable<T> for *const Volatile<T> {
+    unsafe fn vread(self) -> T {
+        self.read_volatile().0
+    }
+}
+
+/// A trait implemented by MMIO registers which may be written to.
+pub trait VolatileWritable<T> {
+    /// Performs a volatile write to the MMIO register.
+    unsafe fn vwrite(self, value: T);
+}
+
+impl<T: Copy> VolatileWritable<T> for *mut WriteOnly<T> {
+    unsafe fn vwrite(self, value: T) {
+        (self as *mut T).write_volatile(value)
+    }
+}
+
+impl<T: Copy> VolatileWritable<T> for *mut Volatile<T> {
+    unsafe fn vwrite(self, value: T) {
+        (self as *mut T).write_volatile(value)
+    }
+}
+
+/// Performs a volatile read from the given field of pointer to a struct representing an MMIO region.
+///
+/// # Usage
+/// ```compile_fail
+/// # use core::ptr::NonNull;
+/// # use virtio_drivers::volatile::{ReadOnly, volread};
+/// struct MmioDevice {
+///   field: ReadOnly<u32>,
+/// }
+///
+/// let device: NonNull<MmioDevice> = NonNull::new(0x1234 as *mut MmioDevice).unwrap();
+/// let value = unsafe { volread!(device, field) };
+/// ```
+macro_rules! volread {
+    ($nonnull:expr, $field:ident) => {
+        $crate::volatile::VolatileReadable::vread(core::ptr::addr_of!((*$nonnull.as_ptr()).$field))
+    };
+}
+
+/// Performs a volatile write to the given field of pointer to a struct representing an MMIO region.
+///
+/// # Usage
+/// ```compile_fail
+/// # use core::ptr::NonNull;
+/// # use virtio_drivers::volatile::{WriteOnly, volread};
+/// struct MmioDevice {
+///   field: WriteOnly<u32>,
+/// }
+///
+/// let device: NonNull<MmioDevice> = NonNull::new(0x1234 as *mut MmioDevice).unwrap();
+/// unsafe { volwrite!(device, field, 42); }
+/// ```
+macro_rules! volwrite {
+    ($nonnull:expr, $field:ident, $value:expr) => {
+        $crate::volatile::VolatileWritable::vwrite(
+            core::ptr::addr_of_mut!((*$nonnull.as_ptr()).$field),
+            $value,
+        )
+    };
+}
+
+pub(crate) use volread;
+pub(crate) use volwrite;