Browse Source

Mark Queue::add as unsafe.

Caller must ensure that buffers live as long as the device is using
them.
Andrew Walbran 2 years ago
parent
commit
6183bed94a
6 changed files with 51 additions and 30 deletions
  1. 6 4
      src/blk.rs
  2. 6 3
      src/console.rs
  3. 9 5
      src/gpu.rs
  4. 4 2
      src/input.rs
  5. 8 4
      src/net.rs
  6. 18 12
      src/queue.rs

+ 6 - 4
src/blk.rs

@@ -77,12 +77,13 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
         };
         let mut resp = BlkResp::default();
-        self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
+        let token = unsafe { self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])? };
         self.transport.notify(0);
         while !self.queue.can_pop() {
             spin_loop();
         }
-        self.queue.pop_used()?;
+        let (popped_token, _) = self.queue.pop_used()?;
+        assert_eq!(popped_token, token);
         match resp.status {
             RespStatus::Ok => Ok(()),
             _ => Err(Error::IoError),
@@ -143,12 +144,13 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
         };
         let mut resp = BlkResp::default();
-        self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
+        let token = unsafe { self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])? };
         self.transport.notify(0);
         while !self.queue.can_pop() {
             spin_loop();
         }
-        self.queue.pop_used()?;
+        let (popped_token, _) = self.queue.pop_used()?;
+        assert_eq!(popped_token, token);
         match resp.status {
             RespStatus::Ok => Ok(()),
             _ => Err(Error::IoError),

+ 6 - 3
src/console.rs

@@ -60,7 +60,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     }
 
     fn poll_retrieve(&mut self) -> Result<()> {
-        self.receiveq.add(&[], &[self.queue_buf_rx])?;
+        // Safe because the buffer lasts at least as long as the queue.
+        unsafe { self.receiveq.add(&[], &[self.queue_buf_rx])? };
         Ok(())
     }
 
@@ -99,12 +100,14 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     /// Put a char onto the device.
     pub fn send(&mut self, chr: u8) -> Result<()> {
         let buf: [u8; 1] = [chr];
-        self.transmitq.add(&[&buf], &[])?;
+        // Safe because the buffer is valid until we pop_used below.
+        let token = unsafe { self.transmitq.add(&[&buf], &[]) }?;
         self.transport.notify(QUEUE_TRANSMITQ_PORT_0);
         while !self.transmitq.can_pop() {
             spin_loop();
         }
-        self.transmitq.pop_used()?;
+        let (popped_token, _) = self.transmitq.pop_used()?;
+        assert_eq!(popped_token, token);
         Ok(())
     }
 }

+ 9 - 5
src/gpu.rs

@@ -169,13 +169,16 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         unsafe {
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
         }
-        self.control_queue
-            .add(&[self.queue_buf_send], &[self.queue_buf_recv])?;
+        let token = unsafe {
+            self.control_queue
+                .add(&[self.queue_buf_send], &[self.queue_buf_recv])?
+        };
         self.transport.notify(QUEUE_TRANSMIT);
         while !self.control_queue.can_pop() {
             spin_loop();
         }
-        self.control_queue.pop_used()?;
+        let (popped_token, _) = self.control_queue.pop_used()?;
+        assert_eq!(popped_token, token);
         Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
     }
 
@@ -184,12 +187,13 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         unsafe {
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
         }
-        self.cursor_queue.add(&[self.queue_buf_send], &[])?;
+        let token = unsafe { self.cursor_queue.add(&[self.queue_buf_send], &[])? };
         self.transport.notify(QUEUE_CURSOR);
         while !self.cursor_queue.can_pop() {
             spin_loop();
         }
-        self.cursor_queue.pop_used()?;
+        let (popped_token, _) = self.cursor_queue.pop_used()?;
+        assert_eq!(popped_token, token);
         Ok(())
     }
 

+ 4 - 2
src/input.rs

@@ -36,7 +36,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, QUEUE_SIZE as u16)?;
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
-            let token = event_queue.add(&[], &[event.as_buf_mut()])?;
+            // Safe because the buffer lasts as long as the queue.
+            let token = unsafe { event_queue.add(&[], &[event.as_buf_mut()])? };
             assert_eq!(token, i as u16);
         }
 
@@ -61,7 +62,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         if let Ok((token, _)) = self.event_queue.pop_used() {
             let event = &mut self.event_buf[token as usize];
             // requeue
-            if let Ok(new_token) = self.event_queue.add(&[], &[event.as_buf_mut()]) {
+            // Safe because buffer lasts as long as the queue.
+            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_buf_mut()]) } {
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // was just freed by `pop_used`.

+ 8 - 4
src/net.rs

@@ -77,13 +77,15 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
         let mut header = MaybeUninit::<Header>::uninit();
         let header_buf = unsafe { (*header.as_mut_ptr()).as_buf_mut() };
-        self.recv_queue.add(&[], &[header_buf, buf])?;
+        // Safe because the buffers are valid at least until we pop_used below.
+        let token = unsafe { self.recv_queue.add(&[], &[header_buf, buf])? };
         self.transport.notify(QUEUE_RECEIVE);
         while !self.recv_queue.can_pop() {
             spin_loop();
         }
 
-        let (_, len) = self.recv_queue.pop_used()?;
+        let (popped_token, len) = self.recv_queue.pop_used()?;
+        assert_eq!(popped_token, token);
         // let header = unsafe { header.assume_init() };
         Ok(len as usize - size_of::<Header>())
     }
@@ -91,12 +93,14 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Send a packet.
     pub fn send(&mut self, buf: &[u8]) -> Result {
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
-        self.send_queue.add(&[header.as_buf(), buf], &[])?;
+        // Safe because the buffers are valid at least until we pop_used below.
+        let token = unsafe { self.send_queue.add(&[header.as_buf(), buf], &[])? };
         self.transport.notify(QUEUE_TRANSMIT);
         while !self.send_queue.can_pop() {
             spin_loop();
         }
-        self.send_queue.pop_used()?;
+        let (popped_token, _) = self.send_queue.pop_used()?;
+        assert_eq!(popped_token, token);
         Ok(())
     }
 }

+ 18 - 12
src/queue.rs

@@ -92,7 +92,11 @@ impl<H: Hal> VirtQueue<H> {
     /// Add buffers to the virtqueue, return a token.
     ///
     /// Ref: linux virtio_ring.c virtqueue_add
-    pub fn add(&mut self, inputs: &[&[u8]], outputs: &[&mut [u8]]) -> Result<u16> {
+    ///
+    /// # Safety
+    ///
+    /// The input and output buffers must remain valid until the token is returned by `pop_used`.
+    pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
         }
@@ -109,14 +113,14 @@ impl<H: Hal> VirtQueue<H> {
         unsafe {
             for input in inputs.iter() {
                 let mut desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(input);
+                (*desc).set_buf::<H>(NonNull::new(*input as *mut [u8]).unwrap());
                 (*desc).flags = DescFlags::NEXT;
                 last = self.free_head;
                 self.free_head = (*desc).next;
             }
             for output in outputs.iter() {
                 let desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(output);
+                (*desc).set_buf::<H>(NonNull::new(*output).unwrap());
                 (*desc).flags = DescFlags::NEXT | DescFlags::WRITE;
                 last = self.free_head;
                 self.free_head = (*desc).next;
@@ -263,8 +267,11 @@ pub(crate) struct Descriptor {
 }
 
 impl Descriptor {
-    fn set_buf<H: Hal>(&mut self, buf: &[u8]) {
-        self.addr = H::virt_to_phys(buf.as_ptr() as usize) as u64;
+    /// # Safety
+    ///
+    /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
+    unsafe fn set_buf<H: Hal>(&mut self, buf: NonNull<[u8]>) {
+        self.addr = H::virt_to_phys(buf.as_ptr() as *mut u8 as usize) as u64;
         self.len = buf.len() as u32;
     }
 }
@@ -408,7 +415,10 @@ mod tests {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
-        assert_eq!(queue.add(&[], &[]).unwrap_err(), Error::InvalidParam);
+        assert_eq!(
+            unsafe { queue.add(&[], &[]) }.unwrap_err(),
+            Error::InvalidParam
+        );
     }
 
     #[test]
@@ -418,9 +428,7 @@ mod tests {
         let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
-            queue
-                .add(&[&[], &[], &[]], &[&mut [], &mut []])
-                .unwrap_err(),
+            unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
             Error::BufferTooSmall
         );
     }
@@ -435,9 +443,7 @@ mod tests {
 
         // Add a buffer chain consisting of two device-readable parts followed by two
         // device-writable parts.
-        let token = queue
-            .add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]])
-            .unwrap();
+        let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
 
         assert_eq!(queue.available_desc(), 0);
         assert!(!queue.can_pop());