Browse Source

Merge pull request #106 from rcore-os/event_idx

Add support for VIRTIO_F_EVENT_IDX buffer notification suppression.
Andrew Walbran 1 year ago
parent
commit
cfb8a80cb6
9 changed files with 220 additions and 77 deletions
  1. 1 1
      README.md
  2. 4 10
      src/device/blk.rs
  3. 14 9
      src/device/console.rs
  4. 14 8
      src/device/gpu.rs
  5. 15 10
      src/device/input.rs
  6. 17 9
      src/device/net.rs
  7. 20 10
      src/device/socket/vsock.rs
  8. 119 15
      src/queue.rs
  9. 16 5
      src/transport/mod.rs

+ 1 - 1
README.md

@@ -33,7 +33,7 @@ VirtIO guest drivers in Rust. For **no_std** environment.
 | Feature flag                 | Supported |                                         |
 | ---------------------------- | --------- | --------------------------------------- |
 | `VIRTIO_F_INDIRECT_DESC`     | ✅        | Indirect descriptors                    |
-| `VIRTIO_F_EVENT_IDX`         |         | `avail_event` and `used_event` fields   |
+| `VIRTIO_F_EVENT_IDX`         |         | `avail_event` and `used_event` fields   |
 | `VIRTIO_F_VERSION_1`         | TODO      | VirtIO version 1 compliance             |
 | `VIRTIO_F_ACCESS_PLATFORM`   | ❌        | Limited device access to memory         |
 | `VIRTIO_F_RING_PACKED`       | ❌        | Packed virtqueue layout                 |

+ 4 - 10
src/device/blk.rs

@@ -13,7 +13,8 @@ const QUEUE: u16 = 0;
 const QUEUE_SIZE: u16 = 16;
 const SUPPORTED_FEATURES: BlkFeature = BlkFeature::RO
     .union(BlkFeature::FLUSH)
-    .union(BlkFeature::RING_INDIRECT_DESC);
+    .union(BlkFeature::RING_INDIRECT_DESC)
+    .union(BlkFeature::RING_EVENT_IDX);
 
 /// Driver for a VirtIO block device.
 ///
@@ -51,15 +52,7 @@ pub struct VirtIOBlk<H: Hal, T: Transport> {
 impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
     /// Create a new VirtIO-Blk driver.
     pub fn new(mut transport: T) -> Result<Self> {
-        let mut negotiated_features = BlkFeature::empty();
-
-        transport.begin_init(|features| {
-            let features = BlkFeature::from_bits_truncate(features);
-            info!("device features: {:?}", features);
-            negotiated_features = features & SUPPORTED_FEATURES;
-            // Negotiate these features only.
-            negotiated_features.bits()
-        });
+        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
 
         // Read configuration space.
         let config = transport.config_space::<BlkConfig>()?;
@@ -74,6 +67,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             &mut transport,
             QUEUE,
             negotiated_features.contains(BlkFeature::RING_INDIRECT_DESC),
+            negotiated_features.contains(BlkFeature::RING_EVENT_IDX),
         )?;
         transport.finish_init();
 

+ 14 - 9
src/device/console.rs

@@ -8,11 +8,11 @@ use crate::{Result, PAGE_SIZE};
 use alloc::boxed::Box;
 use bitflags::bitflags;
 use core::ptr::NonNull;
-use log::info;
 
 const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
 const QUEUE_TRANSMITQ_PORT_0: u16 = 1;
 const QUEUE_SIZE: usize = 2;
+const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX;
 
 /// Driver for a VirtIO console device.
 ///
@@ -65,15 +65,20 @@ pub struct ConsoleInfo {
 impl<H: Hal, T: Transport> VirtIOConsole<H, T> {
     /// Creates a new VirtIO console driver.
     pub fn new(mut transport: T) -> Result<Self> {
-        transport.begin_init(|features| {
-            let features = Features::from_bits_truncate(features);
-            info!("Device features {:?}", features);
-            let supported_features = Features::empty();
-            (features & supported_features).bits()
-        });
+        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
         let config_space = transport.config_space::<Config>()?;
-        let receiveq = VirtQueue::new(&mut transport, QUEUE_RECEIVEQ_PORT_0, false)?;
-        let transmitq = VirtQueue::new(&mut transport, QUEUE_TRANSMITQ_PORT_0, false)?;
+        let receiveq = VirtQueue::new(
+            &mut transport,
+            QUEUE_RECEIVEQ_PORT_0,
+            false,
+            negotiated_features.contains(Features::RING_EVENT_IDX),
+        )?;
+        let transmitq = VirtQueue::new(
+            &mut transport,
+            QUEUE_TRANSMITQ_PORT_0,
+            false,
+            negotiated_features.contains(Features::RING_EVENT_IDX),
+        )?;
 
         // Safe because no alignment or initialisation is required for [u8], the DMA buffer is
         // dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer

+ 14 - 8
src/device/gpu.rs

@@ -11,6 +11,7 @@ use log::info;
 use zerocopy::{AsBytes, FromBytes};
 
 const QUEUE_SIZE: u16 = 2;
+const SUPPORTED_FEATURES: Features = Features::RING_EVENT_IDX;
 
 /// A virtio based graphics adapter.
 ///
@@ -39,12 +40,7 @@ pub struct VirtIOGpu<H: Hal, T: Transport> {
 impl<H: Hal, T: Transport> VirtIOGpu<H, T> {
     /// Create a new VirtIO-Gpu driver.
     pub fn new(mut transport: T) -> Result<Self> {
-        transport.begin_init(|features| {
-            let features = Features::from_bits_truncate(features);
-            info!("Device features {:?}", features);
-            let supported_features = Features::empty();
-            (features & supported_features).bits()
-        });
+        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
 
         // read configuration space
         let config_space = transport.config_space::<Config>()?;
@@ -57,8 +53,18 @@ impl<H: Hal, T: Transport> VirtIOGpu<H, T> {
             );
         }
 
-        let control_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, false)?;
-        let cursor_queue = VirtQueue::new(&mut transport, QUEUE_CURSOR, false)?;
+        let control_queue = VirtQueue::new(
+            &mut transport,
+            QUEUE_TRANSMIT,
+            false,
+            negotiated_features.contains(Features::RING_EVENT_IDX),
+        )?;
+        let cursor_queue = VirtQueue::new(
+            &mut transport,
+            QUEUE_CURSOR,
+            false,
+            negotiated_features.contains(Features::RING_EVENT_IDX),
+        )?;
 
         let queue_buf_send = FromBytes::new_box_slice_zeroed(PAGE_SIZE);
         let queue_buf_recv = FromBytes::new_box_slice_zeroed(PAGE_SIZE);

+ 15 - 10
src/device/input.rs

@@ -8,7 +8,6 @@ use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly};
 use crate::Result;
 use alloc::boxed::Box;
 use core::ptr::NonNull;
-use log::info;
 use zerocopy::{AsBytes, FromBytes};
 
 /// Virtual human interface devices such as keyboards, mice and tablets.
@@ -28,18 +27,23 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
     /// Create a new VirtIO-Input driver.
     pub fn new(mut transport: T) -> Result<Self> {
         let mut event_buf = Box::new([InputEvent::default(); QUEUE_SIZE]);
-        transport.begin_init(|features| {
-            let features = Feature::from_bits_truncate(features);
-            info!("Device features: {:?}", features);
-            // negotiate these flags only
-            let supported_features = Feature::empty();
-            (features & supported_features).bits()
-        });
+
+        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
 
         let config = transport.config_space::<Config>()?;
 
-        let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, false)?;
-        let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, false)?;
+        let mut event_queue = VirtQueue::new(
+            &mut transport,
+            QUEUE_EVENT,
+            false,
+            negotiated_features.contains(Feature::RING_EVENT_IDX),
+        )?;
+        let status_queue = VirtQueue::new(
+            &mut transport,
+            QUEUE_STATUS,
+            false,
+            negotiated_features.contains(Feature::RING_EVENT_IDX),
+        )?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
             let token = unsafe { event_queue.add(&[], &mut [event.as_bytes_mut()])? };
@@ -193,6 +197,7 @@ pub struct InputEvent {
 
 const QUEUE_EVENT: u16 = 0;
 const QUEUE_STATUS: u16 = 1;
+const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX;
 
 // a parameter that can change
 const QUEUE_SIZE: usize = 32;

+ 17 - 9
src/device/net.rs

@@ -8,7 +8,7 @@ use crate::{Error, Result};
 use alloc::{vec, vec::Vec};
 use bitflags::bitflags;
 use core::{convert::TryInto, mem::size_of};
-use log::{debug, info, warn};
+use log::{debug, warn};
 use zerocopy::{AsBytes, FromBytes};
 
 const MAX_BUFFER_LEN: usize = 65535;
@@ -112,12 +112,7 @@ pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
 impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> {
     /// Create a new VirtIO-Net driver.
     pub fn new(mut transport: T, buf_len: usize) -> Result<Self> {
-        transport.begin_init(|features| {
-            let features = Features::from_bits_truncate(features);
-            info!("Device features {:?}", features);
-            let supported_features = Features::MAC | Features::STATUS;
-            (features & supported_features).bits()
-        });
+        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
         // read configuration space
         let config = transport.config_space::<Config>()?;
         let mac;
@@ -139,8 +134,18 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE>
             return Err(Error::InvalidParam);
         }
 
-        let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT, false)?;
-        let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE, false)?;
+        let send_queue = VirtQueue::new(
+            &mut transport,
+            QUEUE_TRANSMIT,
+            false,
+            negotiated_features.contains(Features::RING_EVENT_IDX),
+        )?;
+        let mut recv_queue = VirtQueue::new(
+            &mut transport,
+            QUEUE_RECEIVE,
+            false,
+            negotiated_features.contains(Features::RING_EVENT_IDX),
+        )?;
 
         const NONE_BUF: Option<RxBuffer> = None;
         let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];
@@ -403,3 +408,6 @@ impl GsoType {
 
 const QUEUE_RECEIVE: u16 = 0;
 const QUEUE_TRANSMIT: u16 = 1;
+const SUPPORTED_FEATURES: Features = Features::MAC
+    .union(Features::STATUS)
+    .union(Features::RING_EVENT_IDX);

+ 20 - 10
src/device/socket/vsock.rs

@@ -19,6 +19,7 @@ pub(crate) const TX_QUEUE_IDX: u16 = 1;
 const EVENT_QUEUE_IDX: u16 = 2;
 
 pub(crate) const QUEUE_SIZE: usize = 8;
+const SUPPORTED_FEATURES: Feature = Feature::RING_EVENT_IDX;
 
 /// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
 const RX_BUFFER_SIZE: usize = 512;
@@ -241,13 +242,7 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
 impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// Create a new VirtIO Vsock driver.
     pub fn new(mut transport: T) -> Result<Self> {
-        transport.begin_init(|features| {
-            let features = Feature::from_bits_truncate(features);
-            debug!("Device features: {:?}", features);
-            // negotiate these flags only
-            let supported_features = Feature::empty();
-            (features & supported_features).bits()
-        });
+        let negotiated_features = transport.begin_init(SUPPORTED_FEATURES);
 
         let config = transport.config_space::<VirtioVsockConfig>()?;
         debug!("config: {:?}", config);
@@ -257,9 +252,24 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         };
         debug!("guest cid: {guest_cid:?}");
 
-        let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX, false)?;
-        let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX, false)?;
-        let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX, false)?;
+        let mut rx = VirtQueue::new(
+            &mut transport,
+            RX_QUEUE_IDX,
+            false,
+            negotiated_features.contains(Feature::RING_EVENT_IDX),
+        )?;
+        let tx = VirtQueue::new(
+            &mut transport,
+            TX_QUEUE_IDX,
+            false,
+            negotiated_features.contains(Feature::RING_EVENT_IDX),
+        )?;
+        let event = VirtQueue::new(
+            &mut transport,
+            EVENT_QUEUE_IDX,
+            false,
+            negotiated_features.contains(Feature::RING_EVENT_IDX),
+        )?;
 
         // Allocate and add buffers for the RX queue.
         let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];

+ 119 - 15
src/queue.rs

@@ -52,6 +52,8 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> {
     /// Our trusted copy of `avail.idx`.
     avail_idx: u16,
     last_used_idx: u16,
+    /// Whether the `VIRTIO_F_EVENT_IDX` feature has been negotiated.
+    event_idx: bool,
     #[cfg(feature = "alloc")]
     indirect: bool,
     #[cfg(feature = "alloc")]
@@ -59,8 +61,19 @@ pub struct VirtQueue<H: Hal, const SIZE: usize> {
 }
 
 impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
-    /// Create a new VirtQueue.
-    pub fn new<T: Transport>(transport: &mut T, idx: u16, indirect: bool) -> Result<Self> {
+    /// Creates a new VirtQueue.
+    ///
+    /// * `indirect`: Whether to use indirect descriptors. This should be set if the
+    ///   `VIRTIO_F_INDIRECT_DESC` feature has been negotiated with the device.
+    /// * `event_idx`: Whether to use the `used_event` and `avail_event` fields for notification
+    ///   suppression. This should be set if the `VIRTIO_F_EVENT_IDX` feature has been negotiated
+    ///   with the device.
+    pub fn new<T: Transport>(
+        transport: &mut T,
+        idx: u16,
+        indirect: bool,
+        event_idx: bool,
+    ) -> Result<Self> {
         if transport.queue_used(idx) {
             return Err(Error::AlreadyUsed);
         }
@@ -115,6 +128,7 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
             desc_shadow,
             avail_idx: 0,
             last_used_idx: 0,
+            event_idx,
             #[cfg(feature = "alloc")]
             indirect,
             #[cfg(feature = "alloc")]
@@ -310,9 +324,16 @@ impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
         // Read barrier, so we read a fresh value from the device.
         fence(Ordering::SeqCst);
 
-        // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
-        // instance of UsedRing.
-        unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 }
+        if self.event_idx {
+            // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
+            // instance of UsedRing.
+            let avail_event = unsafe { (*self.used.as_ptr()).avail_event };
+            self.avail_idx >= avail_event.wrapping_add(1)
+        } else {
+            // Safe because self.used points to a valid, aligned, initialised, dereferenceable, readable
+            // instance of UsedRing.
+            unsafe { (*self.used.as_ptr()).flags & 0x0001 == 0 }
+        }
     }
 
     /// Copies the descriptor at the given index from `desc_shadow` to `desc`, so it can be seen by
@@ -735,7 +756,8 @@ struct UsedRing<const SIZE: usize> {
     flags: u16,
     idx: u16,
     ring: [UsedElem; SIZE],
-    avail_event: u16, // unused
+    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
+    avail_event: u16,
 }
 
 #[repr(C)]
@@ -917,10 +939,16 @@ pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
 mod tests {
     use super::*;
     use crate::{
+        device::common::Feature,
         hal::fake::FakeHal,
-        transport::mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION},
+        transport::{
+            fake::{FakeTransport, QueueStatus, State},
+            mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION},
+            DeviceStatus, DeviceType,
+        },
     };
     use core::ptr::NonNull;
+    use std::sync::{Arc, Mutex};
 
     #[test]
     fn invalid_queue_size() {
@@ -928,7 +956,7 @@ mod tests {
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         // Size not a power of 2.
         assert_eq!(
-            VirtQueue::<FakeHal, 3>::new(&mut transport, 0, false).unwrap_err(),
+            VirtQueue::<FakeHal, 3>::new(&mut transport, 0, false, false).unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -938,7 +966,7 @@ mod tests {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false).unwrap_err(),
+            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(),
             Error::InvalidParam
         );
     }
@@ -947,9 +975,9 @@ mod tests {
     fn queue_already_used() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
+        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap_err(),
+            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(),
             Error::AlreadyUsed
         );
     }
@@ -958,7 +986,7 @@ mod tests {
     fn add_empty() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
         assert_eq!(
             unsafe { queue.add(&[], &mut []) }.unwrap_err(),
             Error::InvalidParam
@@ -969,7 +997,7 @@ mod tests {
     fn add_too_many() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
             unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
@@ -981,7 +1009,7 @@ mod tests {
     fn add_buffers() {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
         assert_eq!(queue.available_desc(), 4);
 
         // Add a buffer chain consisting of two device-readable parts followed by two
@@ -1044,7 +1072,7 @@ mod tests {
 
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
-        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true).unwrap();
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap();
         assert_eq!(queue.available_desc(), 4);
 
         // Add a buffer chain consisting of two device-readable parts followed by two
@@ -1089,4 +1117,80 @@ mod tests {
             assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE);
         }
     }
+
+    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
+    /// notifications.
+    #[test]
+    fn add_notify() {
+        let mut config_space = ();
+        let state = Arc::new(Mutex::new(State {
+            queues: vec![QueueStatus::default()],
+            ..Default::default()
+        }));
+        let mut transport = FakeTransport {
+            device_type: DeviceType::Block,
+            max_queue_size: 4,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
+
+        // Add a buffer chain with a single device-readable part.
+        unsafe { queue.add(&[&[42]], &mut []) }.unwrap();
+
+        // Check that the transport would be notified.
+        assert_eq!(queue.should_notify(), true);
+
+        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
+        // initialised, and nothing else is accessing them at the same time.
+        unsafe {
+            // Suppress notifications.
+            (*queue.used.as_ptr()).flags = 0x01;
+        }
+
+        // Check that the transport would not be notified.
+        assert_eq!(queue.should_notify(), false);
+    }
+
+    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
+    /// notifications with the `avail_event` index.
+    #[test]
+    fn add_notify_event_idx() {
+        let mut config_space = ();
+        let state = Arc::new(Mutex::new(State {
+            queues: vec![QueueStatus::default()],
+            ..Default::default()
+        }));
+        let mut transport = FakeTransport {
+            device_type: DeviceType::Block,
+            max_queue_size: 4,
+            device_features: Feature::RING_EVENT_IDX.bits(),
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, true).unwrap();
+
+        // Add a buffer chain with a single device-readable part.
+        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 0);
+
+        // Check that the transport would be notified.
+        assert_eq!(queue.should_notify(), true);
+
+        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
+        // initialised, and nothing else is accessing them at the same time.
+        unsafe {
+            // Suppress notifications.
+            (*queue.used.as_ptr()).avail_event = 1;
+        }
+
+        // Check that the transport would not be notified.
+        assert_eq!(queue.should_notify(), false);
+
+        // Add another buffer chain.
+        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 1);
+
+        // Check that the transport should be notified again now.
+        assert_eq!(queue.should_notify(), true);
+    }
 }

+ 16 - 5
src/transport/mod.rs

@@ -6,8 +6,9 @@ pub mod mmio;
 pub mod pci;
 
 use crate::{PhysAddr, Result, PAGE_SIZE};
-use bitflags::bitflags;
-use core::ptr::NonNull;
+use bitflags::{bitflags, Flags};
+use core::{fmt::Debug, ops::BitAnd, ptr::NonNull};
+use log::debug;
 
 /// A VirtIO transport layer.
 pub trait Transport {
@@ -64,17 +65,27 @@ pub trait Transport {
     /// Begins initializing the device.
     ///
     /// Ref: virtio 3.1.1 Device Initialization
-    fn begin_init(&mut self, negotiate_features: impl FnOnce(u64) -> u64) {
+    ///
+    /// Returns the negotiated set of features.
+    fn begin_init<F: Flags<Bits = u64> + BitAnd<Output = F> + Debug>(
+        &mut self,
+        supported_features: F,
+    ) -> F {
         self.set_status(DeviceStatus::empty());
         self.set_status(DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER);
 
-        let features = self.read_device_features();
-        self.write_driver_features(negotiate_features(features));
+        let device_features = F::from_bits_truncate(self.read_device_features());
+        debug!("Device features: {:?}", device_features);
+        let negotiated_features = device_features & supported_features;
+        self.write_driver_features(negotiated_features.bits());
+
         self.set_status(
             DeviceStatus::ACKNOWLEDGE | DeviceStatus::DRIVER | DeviceStatus::FEATURES_OK,
         );
 
         self.set_guest_page_size(PAGE_SIZE as u32);
+
+        negotiated_features
     }
 
     /// Finishes initializing the device.