Преглед изворни кода

Use zerocopy rather than our AsBuf trait.

The derive macros in zerocopy check that the structs are actually valid
to convert to and from bytes. This required converting some enums to
structs, as there's no guarantee that the value we get is actually a
valid enum variant.
Andrew Walbran пре 2 година
родитељ
комит
8048887954
5 измењених фајлова са 47 додато и 51 уклоњено
  1. 1 0
      Cargo.toml
  2. 26 22
      src/blk.rs
  3. 4 5
      src/input.rs
  4. 0 11
      src/lib.rs
  5. 16 13
      src/net.rs

+ 1 - 0
Cargo.toml

@@ -17,6 +17,7 @@ categories = ["hardware-support", "no-std"]
 [dependencies]
 [dependencies]
 log = "0.4"
 log = "0.4"
 bitflags = "1.3"
 bitflags = "1.3"
+zerocopy = "0.6.1"
 
 
 [features]
 [features]
 default = ["alloc"]
 default = ["alloc"]

+ 26 - 22
src/blk.rs

@@ -4,6 +4,7 @@ use crate::transport::Transport;
 use crate::volatile::{volread, Volatile};
 use crate::volatile::{volread, Volatile};
 use bitflags::*;
 use bitflags::*;
 use log::*;
 use log::*;
+use zerocopy::{AsBytes, FromBytes};
 
 
 const QUEUE: u16 = 0;
 const QUEUE: u16 = 0;
 
 
@@ -77,12 +78,12 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         };
         let mut resp = BlkResp::default();
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
         self.queue.add_notify_wait_pop(
-            &[req.as_buf()],
-            &[buf, resp.as_buf_mut()],
+            &[req.as_bytes()],
+            &[buf, resp.as_bytes_mut()],
             &mut self.transport,
             &mut self.transport,
         )?;
         )?;
         match resp.status {
         match resp.status {
-            RespStatus::Ok => Ok(()),
+            RespStatus::OK => Ok(()),
             _ => Err(Error::IoError),
             _ => Err(Error::IoError),
         }
         }
     }
     }
@@ -127,7 +128,9 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             reserved: 0,
             reserved: 0,
             sector: block_id as u64,
             sector: block_id as u64,
         };
         };
-        let token = self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
+        let token = self
+            .queue
+            .add(&[req.as_bytes()], &[buf, resp.as_bytes_mut()])?;
         self.transport.notify(QUEUE);
         self.transport.notify(QUEUE);
         Ok(token)
         Ok(token)
     }
     }
@@ -142,12 +145,12 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
         };
         };
         let mut resp = BlkResp::default();
         let mut resp = BlkResp::default();
         self.queue.add_notify_wait_pop(
         self.queue.add_notify_wait_pop(
-            &[req.as_buf(), buf],
-            &[resp.as_buf_mut()],
+            &[req.as_bytes(), buf],
+            &[resp.as_bytes_mut()],
             &mut self.transport,
             &mut self.transport,
         )?;
         )?;
         match resp.status {
         match resp.status {
-            RespStatus::Ok => Ok(()),
+            RespStatus::OK => Ok(()),
             _ => Err(Error::IoError),
             _ => Err(Error::IoError),
         }
         }
     }
     }
@@ -181,7 +184,9 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             reserved: 0,
             reserved: 0,
             sector: block_id as u64,
             sector: block_id as u64,
         };
         };
-        let token = self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
+        let token = self
+            .queue
+            .add(&[req.as_bytes(), buf], &[resp.as_bytes_mut()])?;
         self.transport.notify(QUEUE);
         self.transport.notify(QUEUE);
         Ok(token)
         Ok(token)
     }
     }
@@ -227,7 +232,7 @@ struct BlkConfig {
 }
 }
 
 
 #[repr(C)]
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 struct BlkReq {
 struct BlkReq {
     type_: ReqType,
     type_: ReqType,
     reserved: u32,
     reserved: u32,
@@ -236,7 +241,7 @@ struct BlkReq {
 
 
 /// Response of a VirtIOBlk request.
 /// Response of a VirtIOBlk request.
 #[repr(C)]
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug, FromBytes)]
 pub struct BlkResp {
 pub struct BlkResp {
     status: RespStatus,
     status: RespStatus,
 }
 }
@@ -249,7 +254,7 @@ impl BlkResp {
 }
 }
 
 
 #[repr(u32)]
 #[repr(u32)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug)]
 enum ReqType {
 enum ReqType {
     In = 0,
     In = 0,
     Out = 1,
     Out = 1,
@@ -259,23 +264,25 @@ enum ReqType {
 }
 }
 
 
 /// Status of a VirtIOBlk request.
 /// Status of a VirtIOBlk request.
-#[repr(u8)]
-#[derive(Debug, Eq, PartialEq, Copy, Clone)]
-pub enum RespStatus {
+#[repr(transparent)]
+#[derive(AsBytes, Copy, Clone, Debug, Eq, FromBytes, PartialEq)]
+pub struct RespStatus(u8);
+
+impl RespStatus {
     /// Ok.
     /// Ok.
-    Ok = 0,
+    pub const OK: RespStatus = RespStatus(0);
     /// IoErr.
     /// IoErr.
-    IoErr = 1,
+    pub const IO_ERR: RespStatus = RespStatus(1);
     /// Unsupported yet.
     /// Unsupported yet.
-    Unsupported = 2,
+    pub const UNSUPPORTED: RespStatus = RespStatus(2);
     /// Not ready.
     /// Not ready.
-    _NotReady = 3,
+    pub const NOT_READY: RespStatus = RespStatus(3);
 }
 }
 
 
 impl Default for BlkResp {
 impl Default for BlkResp {
     fn default() -> Self {
     fn default() -> Self {
         BlkResp {
         BlkResp {
-            status: RespStatus::_NotReady,
+            status: RespStatus::NOT_READY,
         }
         }
     }
     }
 }
 }
@@ -332,6 +339,3 @@ bitflags! {
         const NOTIFICATION_DATA     = 1 << 38;
         const NOTIFICATION_DATA     = 1 << 38;
     }
     }
 }
 }
-
-unsafe impl AsBuf for BlkReq {}
-unsafe impl AsBuf for BlkResp {}

+ 4 - 5
src/input.rs

@@ -5,6 +5,7 @@ use alloc::boxed::Box;
 use bitflags::*;
 use bitflags::*;
 use core::ptr::NonNull;
 use core::ptr::NonNull;
 use log::*;
 use log::*;
+use zerocopy::{AsBytes, FromBytes};
 
 
 /// Virtual human interface devices such as keyboards, mice and tablets.
 /// Virtual human interface devices such as keyboards, mice and tablets.
 ///
 ///
@@ -37,7 +38,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
             // Safe because the buffer lasts as long as the queue.
             // Safe because the buffer lasts as long as the queue.
-            let token = unsafe { event_queue.add(&[], &[event.as_buf_mut()])? };
+            let token = unsafe { event_queue.add(&[], &[event.as_bytes_mut()])? };
             assert_eq!(token, i as u16);
             assert_eq!(token, i as u16);
         }
         }
 
 
@@ -63,7 +64,7 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
             let event = &mut self.event_buf[token as usize];
             let event = &mut self.event_buf[token as usize];
             // requeue
             // requeue
             // Safe because buffer lasts as long as the queue.
             // Safe because buffer lasts as long as the queue.
-            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_buf_mut()]) } {
+            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_bytes_mut()]) } {
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // was just freed by `pop_used`.
                 // was just freed by `pop_used`.
@@ -161,7 +162,7 @@ struct DevIDs {
 /// Both queues use the same `virtio_input_event` struct. `type`, `code` and `value`
 /// Both queues use the same `virtio_input_event` struct. `type`, `code` and `value`
 /// are filled according to the Linux input layer (evdev) interface.
 /// are filled according to the Linux input layer (evdev) interface.
 #[repr(C)]
 #[repr(C)]
-#[derive(Clone, Copy, Debug, Default)]
+#[derive(AsBytes, Clone, Copy, Debug, Default, FromBytes)]
 pub struct InputEvent {
 pub struct InputEvent {
     /// Event type.
     /// Event type.
     pub event_type: u16,
     pub event_type: u16,
@@ -171,8 +172,6 @@ pub struct InputEvent {
     pub value: u32,
     pub value: u32,
 }
 }
 
 
-unsafe impl AsBuf for InputEvent {}
-
 bitflags! {
 bitflags! {
     struct Feature: u64 {
     struct Feature: u64 {
         // device independent
         // device independent

+ 0 - 11
src/lib.rs

@@ -30,7 +30,6 @@ use self::queue::VirtQueue;
 pub use self::transport::mmio::{MmioError, MmioTransport, MmioVersion, VirtIOHeader};
 pub use self::transport::mmio::{MmioError, MmioTransport, MmioVersion, VirtIOHeader};
 pub use self::transport::pci;
 pub use self::transport::pci;
 pub use self::transport::{DeviceStatus, DeviceType, Transport};
 pub use self::transport::{DeviceStatus, DeviceType, Transport};
-use core::mem::size_of;
 use hal::*;
 use hal::*;
 
 
 /// The page size in bytes supported by the library (4 KiB).
 /// The page size in bytes supported by the library (4 KiB).
@@ -69,13 +68,3 @@ fn align_up(size: usize) -> usize {
 fn pages(size: usize) -> usize {
 fn pages(size: usize) -> usize {
     (size + PAGE_SIZE - 1) / PAGE_SIZE
     (size + PAGE_SIZE - 1) / PAGE_SIZE
 }
 }
-
-/// Convert a struct into a byte buffer.
-unsafe trait AsBuf: Sized {
-    fn as_buf(&self) -> &[u8] {
-        unsafe { core::slice::from_raw_parts(self as *const _ as _, size_of::<Self>()) }
-    }
-    fn as_buf_mut(&mut self) -> &mut [u8] {
-        unsafe { core::slice::from_raw_parts_mut(self as *mut _ as _, size_of::<Self>()) }
-    }
-}

+ 16 - 13
src/net.rs

@@ -5,6 +5,7 @@ use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly};
 use crate::volatile::{volread, ReadOnly};
 use bitflags::*;
 use bitflags::*;
 use log::*;
 use log::*;
+use zerocopy::{AsBytes, FromBytes};
 
 
 /// The virtio network device is a virtual ethernet card.
 /// The virtio network device is a virtual ethernet card.
 ///
 ///
@@ -75,7 +76,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Receive a packet.
     /// Receive a packet.
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
         let mut header = MaybeUninit::<Header>::uninit();
         let mut header = MaybeUninit::<Header>::uninit();
-        let header_buf = unsafe { (*header.as_mut_ptr()).as_buf_mut() };
+        let header_buf = unsafe { (*header.as_mut_ptr()).as_bytes_mut() };
         let len =
         let len =
             self.recv_queue
             self.recv_queue
                 .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
                 .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
@@ -87,7 +88,7 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     pub fn send(&mut self, buf: &[u8]) -> Result {
     pub fn send(&mut self, buf: &[u8]) -> Result {
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
         self.send_queue
         self.send_queue
-            .add_notify_wait_pop(&[header.as_buf(), buf], &[], &mut self.transport)?;
+            .add_notify_wait_pop(&[header.as_bytes(), buf], &[], &mut self.transport)?;
         Ok(())
         Ok(())
     }
     }
 }
 }
@@ -186,7 +187,7 @@ type EthernetAddress = [u8; 6];
 
 
 // virtio 5.1.6 Device Operation
 // virtio 5.1.6 Device Operation
 #[repr(C)]
 #[repr(C)]
-#[derive(Debug)]
+#[derive(AsBytes, Debug, FromBytes)]
 struct Header {
 struct Header {
     flags: Flags,
     flags: Flags,
     gso_type: GsoType,
     gso_type: GsoType,
@@ -197,9 +198,9 @@ struct Header {
     // payload starts from here
     // payload starts from here
 }
 }
 
 
-unsafe impl AsBuf for Header {}
-
 bitflags! {
 bitflags! {
+    #[repr(transparent)]
+    #[derive(AsBytes, FromBytes)]
     struct Flags: u8 {
     struct Flags: u8 {
         const NEEDS_CSUM = 1;
         const NEEDS_CSUM = 1;
         const DATA_VALID = 2;
         const DATA_VALID = 2;
@@ -207,14 +208,16 @@ bitflags! {
     }
     }
 }
 }
 
 
-#[repr(u8)]
-#[derive(Debug, Copy, Clone, Eq, PartialEq)]
-enum GsoType {
-    None = 0,
-    TcpV4 = 1,
-    Udp = 3,
-    TcpV6 = 4,
-    Ecn = 0x80,
+#[repr(transparent)]
+#[derive(AsBytes, Debug, Copy, Clone, Eq, FromBytes, PartialEq)]
+struct GsoType(u8);
+
+impl GsoType {
+    const NONE: GsoType = GsoType(0);
+    const TCPV4: GsoType = GsoType(1);
+    const UDP: GsoType = GsoType(3);
+    const TCPV6: GsoType = GsoType(4);
+    const ECN: GsoType = GsoType(0x80);
 }
 }
 
 
 const QUEUE_RECEIVE: u16 = 0;
 const QUEUE_RECEIVE: u16 = 0;