Browse Source

Add safe helper method to add buffers and wait until they are used.

Andrew Walbran 2 years ago
parent
commit
25925a31c7
5 changed files with 54 additions and 61 deletions
  1. 12 17
      src/blk.rs
  2. 2 8
      src/console.rs
  3. 8 18
      src/gpu.rs
  4. 5 18
      src/net.rs
  5. 27 0
      src/queue.rs

+ 12 - 17
src/blk.rs

@@ -3,7 +3,6 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, Volatile};
 use bitflags::*;
-use core::hint::spin_loop;
 use log::*;
 
 const QUEUE: u16 = 0;
@@ -77,13 +76,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
         };
         let mut resp = BlkResp::default();
-        let token = unsafe { self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])? };
-        self.transport.notify(0);
-        while !self.queue.can_pop() {
-            spin_loop();
-        }
-        let (popped_token, _) = self.queue.pop_used()?;
-        assert_eq!(popped_token, token);
+        self.queue.add_notify_wait_pop(
+            &[req.as_buf()],
+            &[buf, resp.as_buf_mut()],
+            &mut self.transport,
+        )?;
         match resp.status {
             RespStatus::Ok => Ok(()),
             _ => Err(Error::IoError),
@@ -131,7 +128,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
         };
         let token = self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
-        self.transport.notify(0);
+        self.transport.notify(QUEUE);
         Ok(token)
     }
 
@@ -144,13 +141,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
         };
         let mut resp = BlkResp::default();
-        let token = unsafe { self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])? };
-        self.transport.notify(0);
-        while !self.queue.can_pop() {
-            spin_loop();
-        }
-        let (popped_token, _) = self.queue.pop_used()?;
-        assert_eq!(popped_token, token);
+        self.queue.add_notify_wait_pop(
+            &[req.as_buf(), buf],
+            &[resp.as_buf_mut()],
+            &mut self.transport,
+        )?;
         match resp.status {
             RespStatus::Ok => Ok(()),
             _ => Err(Error::IoError),
@@ -187,7 +182,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
         };
         let token = self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
-        self.transport.notify(0);
+        self.transport.notify(QUEUE);
         Ok(token)
     }
 

+ 2 - 8
src/console.rs

@@ -3,7 +3,6 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, WriteOnly};
 use bitflags::*;
-use core::hint::spin_loop;
 use log::*;
 
 const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
@@ -101,13 +100,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     pub fn send(&mut self, chr: u8) -> Result<()> {
         let buf: [u8; 1] = [chr];
         // Safe because the buffer is valid until we pop_used below.
-        let token = unsafe { self.transmitq.add(&[&buf], &[]) }?;
-        self.transport.notify(QUEUE_TRANSMITQ_PORT_0);
-        while !self.transmitq.can_pop() {
-            spin_loop();
-        }
-        let (popped_token, _) = self.transmitq.pop_used()?;
-        assert_eq!(popped_token, token);
+        self.transmitq
+            .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
         Ok(())
     }
 }

+ 8 - 18
src/gpu.rs

@@ -3,7 +3,7 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
 use bitflags::*;
-use core::{fmt, hint::spin_loop};
+use core::fmt;
 use log::*;
 
 /// A virtio based graphics adapter.
@@ -169,16 +169,11 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         unsafe {
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
         }
-        let token = unsafe {
-            self.control_queue
-                .add(&[self.queue_buf_send], &[self.queue_buf_recv])?
-        };
-        self.transport.notify(QUEUE_TRANSMIT);
-        while !self.control_queue.can_pop() {
-            spin_loop();
-        }
-        let (popped_token, _) = self.control_queue.pop_used()?;
-        assert_eq!(popped_token, token);
+        self.control_queue.add_notify_wait_pop(
+            &[self.queue_buf_send],
+            &[self.queue_buf_recv],
+            &mut self.transport,
+        )?;
         Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
     }
 
@@ -187,13 +182,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         unsafe {
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
         }
-        let token = unsafe { self.cursor_queue.add(&[self.queue_buf_send], &[])? };
-        self.transport.notify(QUEUE_CURSOR);
-        while !self.cursor_queue.can_pop() {
-            spin_loop();
-        }
-        let (popped_token, _) = self.cursor_queue.pop_used()?;
-        assert_eq!(popped_token, token);
+        self.cursor_queue
+            .add_notify_wait_pop(&[self.queue_buf_send], &[], &mut self.transport)?;
         Ok(())
     }
 

+ 5 - 18
src/net.rs

@@ -4,7 +4,6 @@ use super::*;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, Volatile};
 use bitflags::*;
-use core::hint::spin_loop;
 use log::*;
 
 /// The virtio network device is a virtual ethernet card.
@@ -77,15 +76,9 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
         let mut header = MaybeUninit::<Header>::uninit();
         let header_buf = unsafe { (*header.as_mut_ptr()).as_buf_mut() };
-        // Safe because the buffers are valid at least until we pop_used below.
-        let token = unsafe { self.recv_queue.add(&[], &[header_buf, buf])? };
-        self.transport.notify(QUEUE_RECEIVE);
-        while !self.recv_queue.can_pop() {
-            spin_loop();
-        }
-
-        let (popped_token, len) = self.recv_queue.pop_used()?;
-        assert_eq!(popped_token, token);
+        let len =
+            self.recv_queue
+                .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
         // let header = unsafe { header.assume_init() };
         Ok(len as usize - size_of::<Header>())
     }
@@ -93,14 +86,8 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Send a packet.
     pub fn send(&mut self, buf: &[u8]) -> Result {
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
-        // Safe because the buffers are valid at least until we pop_used below.
-        let token = unsafe { self.send_queue.add(&[header.as_buf(), buf], &[])? };
-        self.transport.notify(QUEUE_TRANSMIT);
-        while !self.send_queue.can_pop() {
-            spin_loop();
-        }
-        let (popped_token, _) = self.send_queue.pop_used()?;
-        assert_eq!(popped_token, token);
+        self.send_queue
+            .add_notify_wait_pop(&[header.as_buf(), buf], &[], &mut self.transport)?;
         Ok(())
     }
 }

+ 27 - 0
src/queue.rs

@@ -1,5 +1,6 @@
 #[cfg(test)]
 use core::cmp::min;
+use core::hint::spin_loop;
 use core::mem::size_of;
 use core::ptr::{self, addr_of_mut, NonNull};
 use core::sync::atomic::{fence, Ordering};
@@ -154,6 +155,32 @@ impl<H: Hal> VirtQueue<H> {
         Ok(head)
     }
 
+    /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
+    /// them, then pops them.
+    ///
+    /// This assumes that the device isn't processing any other buffers at the same time.
+    pub fn add_notify_wait_pop(
+        &mut self,
+        inputs: &[*const [u8]],
+        outputs: &[*mut [u8]],
+        transport: &mut impl Transport,
+    ) -> Result<u32> {
+        // Safe because we don't return until the same token has been popped, so they remain valid
+        // until then.
+        let token = unsafe { self.add(inputs, outputs) }?;
+
+        // Notify the queue.
+        transport.notify(self.queue_idx);
+
+        while !self.can_pop() {
+            spin_loop();
+        }
+        let (popped_token, length) = self.pop_used()?;
+        assert_eq!(popped_token, token);
+
+        Ok(length)
+    }
+
     /// Returns a non-null pointer to the descriptor at the given index.
     fn desc_ptr(&mut self, index: u16) -> *mut Descriptor {
         // Safe because self.desc is properly aligned and dereferenceable.