فهرست منبع

Have transport wrap non-null pointer.

Andrew Walbran 2 سال پیش
والد
کامیت
c8c26e3905
2فایلهای تغییر یافته به همراه74 افزوده شده و 58 حذف شده
  1. 21 14
      src/queue.rs
  2. 53 44
      src/transport/mmio.rs

+ 21 - 14
src/queue.rs

@@ -273,48 +273,54 @@ struct UsedElem {
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::hal::fake::FakeHal;
+    use crate::{hal::fake::FakeHal, transport::mmio::MODERN_VERSION};
+    use core::ptr::NonNull;
 
     #[test]
     fn invalid_queue_size() {
-        let mut header = MmioTransport::make_fake_header(0, 0, 0, 4);
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 0, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) };
         // Size not a power of 2.
         assert_eq!(
-            VirtQueue::<FakeHal>::new(&mut header, 0, 3).unwrap_err(),
+            VirtQueue::<FakeHal>::new(&mut transport, 0, 3).unwrap_err(),
             Error::InvalidParam
         );
     }
 
     #[test]
     fn queue_too_big() {
-        let mut header = MmioTransport::make_fake_header(0, 0, 0, 4);
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 0, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) };
         assert_eq!(
-            VirtQueue::<FakeHal>::new(&mut header, 0, 5).unwrap_err(),
+            VirtQueue::<FakeHal>::new(&mut transport, 0, 5).unwrap_err(),
             Error::InvalidParam
         );
     }
 
     #[test]
     fn queue_already_used() {
-        let mut header = MmioTransport::make_fake_header(0, 0, 0, 4);
-        VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 0, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) };
+        VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         assert_eq!(
-            VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap_err(),
+            VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap_err(),
             Error::AlreadyUsed
         );
     }
 
     #[test]
     fn add_empty() {
-        let mut header = MmioTransport::make_fake_header(0, 0, 0, 4);
-        let mut queue = VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 0, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) };
+        let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         assert_eq!(queue.add(&[], &[]).unwrap_err(), Error::InvalidParam);
     }
 
     #[test]
     fn add_too_big() {
-        let mut header = MmioTransport::make_fake_header(0, 0, 0, 4);
-        let mut queue = VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 0, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) };
+        let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
             queue
@@ -326,8 +332,9 @@ mod tests {
 
     #[test]
     fn add_buffers() {
-        let mut header = MmioTransport::make_fake_header(0, 0, 0, 4);
-        let mut queue = VirtQueue::<FakeHal>::new(&mut header, 0, 4).unwrap();
+        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 0, 0, 0, 4);
+        let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) };
+        let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         assert_eq!(queue.size(), 4);
         assert_eq!(queue.available_desc(), 4);
 

+ 53 - 44
src/transport/mmio.rs

@@ -1,11 +1,11 @@
 use super::{DeviceStatus, DeviceType, Transport};
 use crate::{align_up, queue::Descriptor, PhysAddr, PAGE_SIZE};
-use core::mem::size_of;
+use core::{mem::size_of, ptr::NonNull};
 use volatile::{ReadOnly, Volatile, WriteOnly};
 
 const MAGIC_VALUE: u32 = 0x7472_6976;
-const LEGACY_VERSION: u32 = 1;
-const MODERN_VERSION: u32 = 2;
+pub(crate) const LEGACY_VERSION: u32 = 1;
+pub(crate) const MODERN_VERSION: u32 = 2;
 const CONFIG_SPACE_OFFSET: usize = 0x100;
 
 /// MMIO Device Register Interface, both legacy and modern.
@@ -213,7 +213,7 @@ impl VirtIOHeader {
 
     /// Constructs a fake virtio header for use in unit tests.
     #[cfg(test)]
-    fn make_fake_header(
+    pub fn make_fake_header(
         version: u32,
         device_id: u32,
         vendor_id: u32,
@@ -369,20 +369,31 @@ impl Transport for LegacyMmioTransport {
 /// MMIO Device Register Interface.
 ///
 /// Ref: 4.2.2 MMIO Device Register Layout
-#[repr(transparent)]
-pub struct MmioTransport(VirtIOHeader);
+#[derive(Debug)]
+pub struct MmioTransport {
+    header: NonNull<VirtIOHeader>,
+}
 
 impl MmioTransport {
+    /// Constructs a new modern VirtIO MMIO transport.
+    ///
+    /// # Safety
+    /// `header` must point to a properly aligned valid modern VirtIO MMIO region, which must remain
+    /// valid for the lifetime of the transport that is returned.
+    pub unsafe fn new(header: NonNull<VirtIOHeader>) -> Self {
+        Self { header }
+    }
+
     /// Verify a valid header.
     pub fn verify(&self) -> bool {
-        self.0.magic.read() == MAGIC_VALUE
-            && self.0.version.read() == MODERN_VERSION
-            && self.0.device_id.read() != 0
+        self.header().magic.read() == MAGIC_VALUE
+            && self.header().version.read() == MODERN_VERSION
+            && self.header().device_id.read() != 0
     }
 
     /// Get the device type.
     pub fn device_type(&self) -> DeviceType {
-        match self.0.device_id.read() {
+        match self.header().device_id.read() {
             x @ 1..=13 | x @ 16..=24 => unsafe { core::mem::transmute(x as u8) },
             _ => DeviceType::Invalid,
         }
@@ -390,49 +401,41 @@ impl MmioTransport {
 
     /// Get the vendor ID.
     pub fn vendor_id(&self) -> u32 {
-        self.0.vendor_id.read()
+        self.header().vendor_id.read()
     }
 
-    #[cfg(test)]
-    pub fn make_fake_header(
-        device_id: u32,
-        vendor_id: u32,
-        device_features: u32,
-        queue_num_max: u32,
-    ) -> Self {
-        Self(VirtIOHeader::make_fake_header(
-            MODERN_VERSION,
-            device_id,
-            vendor_id,
-            device_features,
-            queue_num_max,
-        ))
+    fn header(&self) -> &VirtIOHeader {
+        unsafe { self.header.as_ref() }
+    }
+
+    fn header_mut(&mut self) -> &mut VirtIOHeader {
+        unsafe { self.header.as_mut() }
     }
 }
 
 impl Transport for MmioTransport {
     fn device_type(&self) -> DeviceType {
-        self.0.device_type()
+        self.header().device_type()
     }
 
     fn read_device_features(&mut self) -> u64 {
-        self.0.read_device_features()
+        self.header_mut().read_device_features()
     }
 
     fn write_driver_features(&mut self, driver_features: u64) {
-        self.0.write_driver_features(driver_features)
+        self.header_mut().write_driver_features(driver_features)
     }
 
     fn max_queue_size(&self) -> u32 {
-        self.0.max_queue_size()
+        self.header().max_queue_size()
     }
 
     fn notify(&mut self, queue: u32) {
-        self.0.notify(queue)
+        self.header_mut().notify(queue)
     }
 
     fn set_status(&mut self, status: DeviceStatus) {
-        self.0.set_status(status)
+        self.header_mut().set_status(status)
     }
 
     fn set_guest_page_size(&mut self, _guest_page_size: u32) {
@@ -447,27 +450,33 @@ impl Transport for MmioTransport {
         driver_area: PhysAddr,
         device_area: PhysAddr,
     ) {
-        self.0.queue_sel.write(queue);
-        self.0.queue_num.write(size);
-        self.0.queue_desc_low.write(descriptors as u32);
-        self.0.queue_desc_high.write((descriptors >> 32) as u32);
-        self.0.queue_driver_low.write(driver_area as u32);
-        self.0.queue_driver_high.write((driver_area >> 32) as u32);
-        self.0.queue_device_low.write(device_area as u32);
-        self.0.queue_device_high.write((device_area >> 32) as u32);
-        self.0.queue_ready.write(1);
+        self.header_mut().queue_sel.write(queue);
+        self.header_mut().queue_num.write(size);
+        self.header_mut().queue_desc_low.write(descriptors as u32);
+        self.header_mut()
+            .queue_desc_high
+            .write((descriptors >> 32) as u32);
+        self.header_mut().queue_driver_low.write(driver_area as u32);
+        self.header_mut()
+            .queue_driver_high
+            .write((driver_area >> 32) as u32);
+        self.header_mut().queue_device_low.write(device_area as u32);
+        self.header_mut()
+            .queue_device_high
+            .write((device_area >> 32) as u32);
+        self.header_mut().queue_ready.write(1);
     }
 
     fn queue_used(&mut self, queue: u32) -> bool {
-        self.0.queue_sel.write(queue);
-        self.0.queue_ready.read() != 0
+        self.header_mut().queue_sel.write(queue);
+        self.header().queue_ready.read() != 0
     }
 
     fn ack_interrupt(&mut self) -> bool {
-        self.0.ack_interrupt()
+        self.header_mut().ack_interrupt()
     }
 
     fn config_space(&self) -> *mut u64 {
-        self.0.config_space()
+        self.header().config_space()
     }
 }