瀏覽代碼

Resolve conversation of PR #61

Yuekai Jia 2 年之前
父節點
當前提交
28ce9d6e17
共有 4 個文件被更改,包括 41 次插入35 次删除
  1. 1 1
      examples/riscv/Cargo.toml
  2. 2 0
      examples/riscv/Makefile
  3. 3 3
      examples/riscv/src/main.rs
  4. 35 31
      src/device/net.rs

+ 1 - 1
examples/riscv/Cargo.toml

@@ -8,7 +8,7 @@ edition = "2018"
 
 [features]
 tcp = ["smoltcp"]
-default = []
+default = ["tcp"]
 
 [dependencies]
 log = "0.4"

+ 2 - 0
examples/riscv/Makefile

@@ -17,6 +17,8 @@ endif
 
 ifeq ($(tcp), on)
 	BUILD_ARGS += --features tcp
+else
+	BUILD_ARGS += --no-default-features
 endif
 
 .PHONY: kernel build clean qemu run env

+ 3 - 3
examples/riscv/src/main.rs

@@ -26,6 +26,9 @@ mod virtio_impl;
 #[cfg(feature = "tcp")]
 mod tcp;
 
+const NET_BUFFER_LEN: usize = 2048;
+const NET_QUEUE_SIZE: usize = 16;
+
 #[no_mangle]
 extern "C" fn main(_hartid: usize, device_tree_paddr: usize) {
     log::set_max_level(LevelFilter::Info);
@@ -143,9 +146,6 @@ fn virtio_input<T: Transport>(transport: T) {
 }
 
 fn virtio_net<T: Transport>(transport: T) {
-    const NET_BUFFER_LEN: usize = 2048;
-    const NET_QUEUE_SIZE: usize = 16;
-
     let net = VirtIONet::<HalImpl, T, NET_QUEUE_SIZE>::new(transport, NET_BUFFER_LEN)
         .expect("failed to create net driver");
     info!("MAC address: {:02x?}", net.mac_address());

+ 35 - 31
src/device/net.rs

@@ -20,7 +20,7 @@ pub struct TxBuffer(Vec<u8>);
 
 /// A buffer used for receiving.
 pub struct RxBuffer {
-    buf: Vec<u8>,
+    buf: Vec<usize>, // for alignment
     packet_len: usize,
     idx: usize,
 }
@@ -51,7 +51,7 @@ impl RxBuffer {
     /// Allocates a new buffer with length `buf_len`.
     fn new(idx: usize, buf_len: usize) -> Self {
         Self {
-            buf: vec![0; buf_len],
+            buf: vec![0; buf_len / size_of::<usize>()],
             packet_len: 0,
             idx,
         }
@@ -69,13 +69,13 @@ impl RxBuffer {
 
     /// Returns all data in the buffer, including both the header and the packet.
     pub fn as_bytes(&self) -> &[u8] {
-        self.buf.as_slice()
+        self.buf.as_bytes()
     }
 
     /// Returns all data in the buffer with the mutable reference,
     /// including both the header and the packet.
     pub fn as_bytes_mut(&mut self) -> &mut [u8] {
-        self.buf.as_mut_slice()
+        self.buf.as_bytes_mut()
     }
 
     /// Returns the reference of the header.
@@ -85,12 +85,12 @@ impl RxBuffer {
 
     /// Returns the network packet as a slice.
     pub fn packet(&self) -> &[u8] {
-        &self.buf[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+        &self.buf.as_bytes()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
     }
 
     /// Returns the network packet as a mutable slice.
     pub fn packet_mut(&mut self) -> &mut [u8] {
-        &mut self.buf[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
+        &mut self.buf.as_bytes_mut()[NET_HDR_SIZE..NET_HDR_SIZE + self.packet_len]
     }
 }
 
@@ -106,7 +106,7 @@ pub struct VirtIONet<H: Hal, T: Transport, const QUEUE_SIZE: usize> {
     mac: EthernetAddress,
     recv_queue: VirtQueue<H, QUEUE_SIZE>,
     send_queue: VirtQueue<H, QUEUE_SIZE>,
-    rx_buffers: Vec<Option<RxBuffer>>,
+    rx_buffers: [Option<RxBuffer>; QUEUE_SIZE],
 }
 
 impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE> {
@@ -142,12 +142,14 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE>
         let send_queue = VirtQueue::new(&mut transport, QUEUE_TRANSMIT)?;
         let mut recv_queue = VirtQueue::new(&mut transport, QUEUE_RECEIVE)?;
 
-        let mut rx_buffers = Vec::with_capacity(QUEUE_SIZE);
-        for i in 0..QUEUE_SIZE {
+        const NONE_BUF: Option<RxBuffer> = None;
+        let mut rx_buffers = [NONE_BUF; QUEUE_SIZE];
+        for (i, rx_buf_place) in rx_buffers.iter_mut().enumerate() {
             let mut rx_buf = RxBuffer::new(i, buf_len);
+            // Safe because the buffer lives as long as the queue.
             let token = unsafe { recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()])? };
             assert_eq!(token, i as u16);
-            rx_buffers.push(Some(rx_buf));
+            *rx_buf_place = Some(rx_buf);
         }
 
         if recv_queue.should_notify() {
@@ -192,14 +194,25 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE>
     /// NIC queue.
     pub fn receive(&mut self) -> Result<RxBuffer> {
         if let Some(token) = self.recv_queue.peek_used() {
-            let mut rx_buf = self.rx_buffers[token as usize].take().unwrap();
-            assert_eq!(token as usize, rx_buf.idx);
+            let mut rx_buf = self.rx_buffers[token as usize]
+                .take()
+                .ok_or(Error::WrongToken)?;
+            if token as usize != rx_buf.idx {
+                return Err(Error::WrongToken);
+            }
+
+            // Safe because `token` == `rx_buf.idx`, we are passing the same
+            // buffer as we passed to `VirtQueue::add` and it is still valid.
             let len = unsafe {
                 self.recv_queue
                     .pop_used(token, &[], &mut [rx_buf.as_bytes_mut()])?
             };
-            rx_buf.set_packet_len(len as usize - NET_HDR_SIZE);
-            Ok(rx_buf)
+            if (len as usize) < NET_HDR_SIZE {
+                Err(Error::IoError)
+            } else {
+                rx_buf.set_packet_len(len as usize - NET_HDR_SIZE);
+                Ok(rx_buf)
+            }
         } else {
             Err(Error::NotReady)
         }
@@ -210,8 +223,12 @@ impl<H: Hal, T: Transport, const QUEUE_SIZE: usize> VirtIONet<H, T, QUEUE_SIZE>
     /// It will add the buffer back to the NIC queue.
     pub fn recycle_rx_buffer(&mut self, mut rx_buf: RxBuffer) -> Result {
         let old_token = rx_buf.idx;
+        // Safe because we take the ownership of `rx_buf` back to `rx_buffers`,
+        // it lives as long as the queue.
         let new_token = unsafe { self.recv_queue.add(&[], &mut [rx_buf.as_bytes_mut()]) }?;
-        assert_eq!(new_token as usize, old_token);
+        if new_token as usize != old_token {
+            return Err(Error::WrongToken);
+        }
         self.rx_buffers[old_token] = Some(rx_buf);
         if self.recv_queue.should_notify() {
             self.transport.notify(QUEUE_RECEIVE);
@@ -337,7 +354,7 @@ type EthernetAddress = [u8; 6];
 /// and buffers for incoming packets are placed in the receiveq1. . .receiveqN.
 /// In each case, the packet itself is preceded by a header.
 #[repr(C)]
-#[derive(AsBytes, Debug, FromBytes)]
+#[derive(AsBytes, Debug, Default, FromBytes)]
 pub struct VirtioNetHdr {
     flags: Flags,
     gso_type: GsoType,
@@ -349,22 +366,9 @@ pub struct VirtioNetHdr {
     // payload starts from here
 }
 
-impl Default for VirtioNetHdr {
-    fn default() -> Self {
-        Self {
-            flags: Flags::empty(),
-            gso_type: GsoType::NONE,
-            hdr_len: 0,
-            gso_size: 0,
-            csum_start: 0,
-            csum_offset: 0,
-        }
-    }
-}
-
 bitflags! {
     #[repr(transparent)]
-    #[derive(AsBytes, FromBytes)]
+    #[derive(AsBytes, Default, FromBytes)]
     struct Flags: u8 {
         const NEEDS_CSUM = 1;
         const DATA_VALID = 2;
@@ -373,7 +377,7 @@ bitflags! {
 }
 
 #[repr(transparent)]
-#[derive(AsBytes, Debug, Copy, Clone, Eq, FromBytes, PartialEq)]
+#[derive(AsBytes, Debug, Copy, Clone, Default, Eq, FromBytes, PartialEq)]
 struct GsoType(u8);
 
 impl GsoType {