瀏覽代碼

Merge pull request #31 from rcore-os/queuesoundness

Fix soundness issues with VirtQueue
Yuekai Jia 2 年之前
父節點
當前提交
87dc2492ae
共有 6 個文件被更改,包括 79 次插入65 次删除
  1. 12 15
      src/blk.rs
  2. 5 8
      src/console.rs
  3. 8 14
      src/gpu.rs
  4. 4 2
      src/input.rs
  5. 5 14
      src/net.rs
  6. 45 12
      src/queue.rs

+ 12 - 15
src/blk.rs

@@ -3,7 +3,6 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::{volread, Volatile};
 use crate::volatile::{volread, Volatile};
 use bitflags::*;
 use bitflags::*;
-use core::hint::spin_loop;
 use log::*;
 use log::*;
 
 
 const QUEUE: u16 = 0;
 const QUEUE: u16 = 0;
@@ -77,12 +76,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
             sector: block_id as u64,
         };
         };
         let mut resp = BlkResp::default();
         let mut resp = BlkResp::default();
-        self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
-        self.transport.notify(0);
-        while !self.queue.can_pop() {
-            spin_loop();
-        }
-        self.queue.pop_used()?;
+        self.queue.add_notify_wait_pop(
+            &[req.as_buf()],
+            &[buf, resp.as_buf_mut()],
+            &mut self.transport,
+        )?;
         match resp.status {
         match resp.status {
             RespStatus::Ok => Ok(()),
             RespStatus::Ok => Ok(()),
             _ => Err(Error::IoError),
             _ => Err(Error::IoError),
@@ -130,7 +128,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
             sector: block_id as u64,
         };
         };
         let token = self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
         let token = self.queue.add(&[req.as_buf()], &[buf, resp.as_buf_mut()])?;
-        self.transport.notify(0);
+        self.transport.notify(QUEUE);
         Ok(token)
         Ok(token)
     }
     }
 
 
@@ -143,12 +141,11 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
             sector: block_id as u64,
         };
         };
         let mut resp = BlkResp::default();
         let mut resp = BlkResp::default();
-        self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
-        self.transport.notify(0);
-        while !self.queue.can_pop() {
-            spin_loop();
-        }
-        self.queue.pop_used()?;
+        self.queue.add_notify_wait_pop(
+            &[req.as_buf(), buf],
+            &[resp.as_buf_mut()],
+            &mut self.transport,
+        )?;
         match resp.status {
         match resp.status {
             RespStatus::Ok => Ok(()),
             RespStatus::Ok => Ok(()),
             _ => Err(Error::IoError),
             _ => Err(Error::IoError),
@@ -185,7 +182,7 @@ impl<H: Hal, T: Transport> VirtIOBlk<H, T> {
             sector: block_id as u64,
             sector: block_id as u64,
         };
         };
         let token = self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
         let token = self.queue.add(&[req.as_buf(), buf], &[resp.as_buf_mut()])?;
-        self.transport.notify(0);
+        self.transport.notify(QUEUE);
         Ok(token)
         Ok(token)
     }
     }
 
 

+ 5 - 8
src/console.rs

@@ -3,7 +3,6 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, WriteOnly};
 use crate::volatile::{volread, ReadOnly, WriteOnly};
 use bitflags::*;
 use bitflags::*;
-use core::hint::spin_loop;
 use log::*;
 use log::*;
 
 
 const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
 const QUEUE_RECEIVEQ_PORT_0: u16 = 0;
@@ -60,7 +59,8 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     }
     }
 
 
     fn poll_retrieve(&mut self) -> Result<()> {
     fn poll_retrieve(&mut self) -> Result<()> {
-        self.receiveq.add(&[], &[self.queue_buf_rx])?;
+        // Safe because the buffer lasts at least as long as the queue.
+        unsafe { self.receiveq.add(&[], &[self.queue_buf_rx])? };
         Ok(())
         Ok(())
     }
     }
 
 
@@ -99,12 +99,9 @@ impl<H: Hal, T: Transport> VirtIOConsole<'_, H, T> {
     /// Put a char onto the device.
     /// Put a char onto the device.
     pub fn send(&mut self, chr: u8) -> Result<()> {
     pub fn send(&mut self, chr: u8) -> Result<()> {
         let buf: [u8; 1] = [chr];
         let buf: [u8; 1] = [chr];
-        self.transmitq.add(&[&buf], &[])?;
-        self.transport.notify(QUEUE_TRANSMITQ_PORT_0);
-        while !self.transmitq.can_pop() {
-            spin_loop();
-        }
-        self.transmitq.pop_used()?;
+        // Safe because the buffer is valid until we pop_used below.
+        self.transmitq
+            .add_notify_wait_pop(&[&buf], &[], &mut self.transport)?;
         Ok(())
         Ok(())
     }
     }
 }
 }

+ 8 - 14
src/gpu.rs

@@ -3,7 +3,7 @@ use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
 use crate::volatile::{volread, ReadOnly, Volatile, WriteOnly};
 use bitflags::*;
 use bitflags::*;
-use core::{fmt, hint::spin_loop};
+use core::fmt;
 use log::*;
 use log::*;
 
 
 /// A virtio based graphics adapter.
 /// A virtio based graphics adapter.
@@ -169,13 +169,11 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         unsafe {
         unsafe {
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
         }
         }
-        self.control_queue
-            .add(&[self.queue_buf_send], &[self.queue_buf_recv])?;
-        self.transport.notify(QUEUE_TRANSMIT);
-        while !self.control_queue.can_pop() {
-            spin_loop();
-        }
-        self.control_queue.pop_used()?;
+        self.control_queue.add_notify_wait_pop(
+            &[self.queue_buf_send],
+            &[self.queue_buf_recv],
+            &mut self.transport,
+        )?;
         Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
         Ok(unsafe { (self.queue_buf_recv.as_ptr() as *const Rsp).read() })
     }
     }
 
 
@@ -184,12 +182,8 @@ impl<H: Hal, T: Transport> VirtIOGpu<'_, H, T> {
         unsafe {
         unsafe {
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
             (self.queue_buf_send.as_mut_ptr() as *mut Req).write(req);
         }
         }
-        self.cursor_queue.add(&[self.queue_buf_send], &[])?;
-        self.transport.notify(QUEUE_CURSOR);
-        while !self.cursor_queue.can_pop() {
-            spin_loop();
-        }
-        self.cursor_queue.pop_used()?;
+        self.cursor_queue
+            .add_notify_wait_pop(&[self.queue_buf_send], &[], &mut self.transport)?;
         Ok(())
         Ok(())
     }
     }
 
 

+ 4 - 2
src/input.rs

@@ -36,7 +36,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, QUEUE_SIZE as u16)?;
         let mut event_queue = VirtQueue::new(&mut transport, QUEUE_EVENT, QUEUE_SIZE as u16)?;
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
         let status_queue = VirtQueue::new(&mut transport, QUEUE_STATUS, QUEUE_SIZE as u16)?;
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
         for (i, event) in event_buf.as_mut().iter_mut().enumerate() {
-            let token = event_queue.add(&[], &[event.as_buf_mut()])?;
+            // Safe because the buffer lasts as long as the queue.
+            let token = unsafe { event_queue.add(&[], &[event.as_buf_mut()])? };
             assert_eq!(token, i as u16);
             assert_eq!(token, i as u16);
         }
         }
 
 
@@ -61,7 +62,8 @@ impl<H: Hal, T: Transport> VirtIOInput<H, T> {
         if let Ok((token, _)) = self.event_queue.pop_used() {
         if let Ok((token, _)) = self.event_queue.pop_used() {
             let event = &mut self.event_buf[token as usize];
             let event = &mut self.event_buf[token as usize];
             // requeue
             // requeue
-            if let Ok(new_token) = self.event_queue.add(&[], &[event.as_buf_mut()]) {
+            // Safe because buffer lasts as long as the queue.
+            if let Ok(new_token) = unsafe { self.event_queue.add(&[], &[event.as_buf_mut()]) } {
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // This only works because nothing happen between `pop_used` and `add` that affects
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // the list of free descriptors in the queue, so `add` reuses the descriptor which
                 // was just freed by `pop_used`.
                 // was just freed by `pop_used`.

+ 5 - 14
src/net.rs

@@ -4,7 +4,6 @@ use super::*;
 use crate::transport::Transport;
 use crate::transport::Transport;
 use crate::volatile::{volread, ReadOnly, Volatile};
 use crate::volatile::{volread, ReadOnly, Volatile};
 use bitflags::*;
 use bitflags::*;
-use core::hint::spin_loop;
 use log::*;
 use log::*;
 
 
 /// The virtio network device is a virtual ethernet card.
 /// The virtio network device is a virtual ethernet card.
@@ -77,13 +76,9 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
     pub fn recv(&mut self, buf: &mut [u8]) -> Result<usize> {
         let mut header = MaybeUninit::<Header>::uninit();
         let mut header = MaybeUninit::<Header>::uninit();
         let header_buf = unsafe { (*header.as_mut_ptr()).as_buf_mut() };
         let header_buf = unsafe { (*header.as_mut_ptr()).as_buf_mut() };
-        self.recv_queue.add(&[], &[header_buf, buf])?;
-        self.transport.notify(QUEUE_RECEIVE);
-        while !self.recv_queue.can_pop() {
-            spin_loop();
-        }
-
-        let (_, len) = self.recv_queue.pop_used()?;
+        let len =
+            self.recv_queue
+                .add_notify_wait_pop(&[], &[header_buf, buf], &mut self.transport)?;
         // let header = unsafe { header.assume_init() };
         // let header = unsafe { header.assume_init() };
         Ok(len as usize - size_of::<Header>())
         Ok(len as usize - size_of::<Header>())
     }
     }
@@ -91,12 +86,8 @@ impl<H: Hal, T: Transport> VirtIONet<H, T> {
     /// Send a packet.
     /// Send a packet.
     pub fn send(&mut self, buf: &[u8]) -> Result {
     pub fn send(&mut self, buf: &[u8]) -> Result {
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
         let header = unsafe { MaybeUninit::<Header>::zeroed().assume_init() };
-        self.send_queue.add(&[header.as_buf(), buf], &[])?;
-        self.transport.notify(QUEUE_TRANSMIT);
-        while !self.send_queue.can_pop() {
-            spin_loop();
-        }
-        self.send_queue.pop_used()?;
+        self.send_queue
+            .add_notify_wait_pop(&[header.as_buf(), buf], &[], &mut self.transport)?;
         Ok(())
         Ok(())
     }
     }
 }
 }

+ 45 - 12
src/queue.rs

@@ -1,5 +1,6 @@
 #[cfg(test)]
 #[cfg(test)]
 use core::cmp::min;
 use core::cmp::min;
+use core::hint::spin_loop;
 use core::mem::size_of;
 use core::mem::size_of;
 use core::ptr::{self, addr_of_mut, NonNull};
 use core::ptr::{self, addr_of_mut, NonNull};
 use core::sync::atomic::{fence, Ordering};
 use core::sync::atomic::{fence, Ordering};
@@ -92,7 +93,11 @@ impl<H: Hal> VirtQueue<H> {
     /// Add buffers to the virtqueue, return a token.
     /// Add buffers to the virtqueue, return a token.
     ///
     ///
     /// Ref: linux virtio_ring.c virtqueue_add
     /// Ref: linux virtio_ring.c virtqueue_add
-    pub fn add(&mut self, inputs: &[&[u8]], outputs: &[&mut [u8]]) -> Result<u16> {
+    ///
+    /// # Safety
+    ///
+    /// The input and output buffers must remain valid until the token is returned by `pop_used`.
+    pub unsafe fn add(&mut self, inputs: &[*const [u8]], outputs: &[*mut [u8]]) -> Result<u16> {
         if inputs.is_empty() && outputs.is_empty() {
         if inputs.is_empty() && outputs.is_empty() {
             return Err(Error::InvalidParam);
             return Err(Error::InvalidParam);
         }
         }
@@ -109,14 +114,14 @@ impl<H: Hal> VirtQueue<H> {
         unsafe {
         unsafe {
             for input in inputs.iter() {
             for input in inputs.iter() {
                 let mut desc = self.desc_ptr(self.free_head);
                 let mut desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(input);
+                (*desc).set_buf::<H>(NonNull::new(*input as *mut [u8]).unwrap());
                 (*desc).flags = DescFlags::NEXT;
                 (*desc).flags = DescFlags::NEXT;
                 last = self.free_head;
                 last = self.free_head;
                 self.free_head = (*desc).next;
                 self.free_head = (*desc).next;
             }
             }
             for output in outputs.iter() {
             for output in outputs.iter() {
                 let desc = self.desc_ptr(self.free_head);
                 let desc = self.desc_ptr(self.free_head);
-                (*desc).set_buf::<H>(output);
+                (*desc).set_buf::<H>(NonNull::new(*output).unwrap());
                 (*desc).flags = DescFlags::NEXT | DescFlags::WRITE;
                 (*desc).flags = DescFlags::NEXT | DescFlags::WRITE;
                 last = self.free_head;
                 last = self.free_head;
                 self.free_head = (*desc).next;
                 self.free_head = (*desc).next;
@@ -150,6 +155,32 @@ impl<H: Hal> VirtQueue<H> {
         Ok(head)
         Ok(head)
     }
     }
 
 
+    /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
+    /// them, then pops them.
+    ///
+    /// This assumes that the device isn't processing any other buffers at the same time.
+    pub fn add_notify_wait_pop(
+        &mut self,
+        inputs: &[*const [u8]],
+        outputs: &[*mut [u8]],
+        transport: &mut impl Transport,
+    ) -> Result<u32> {
+        // Safe because we don't return until the same token has been popped, so they remain valid
+        // until then.
+        let token = unsafe { self.add(inputs, outputs) }?;
+
+        // Notify the queue.
+        transport.notify(self.queue_idx);
+
+        while !self.can_pop() {
+            spin_loop();
+        }
+        let (popped_token, length) = self.pop_used()?;
+        assert_eq!(popped_token, token);
+
+        Ok(length)
+    }
+
     /// Returns a non-null pointer to the descriptor at the given index.
     /// Returns a non-null pointer to the descriptor at the given index.
     fn desc_ptr(&mut self, index: u16) -> *mut Descriptor {
     fn desc_ptr(&mut self, index: u16) -> *mut Descriptor {
         // Safe because self.desc is properly aligned and dereferenceable.
         // Safe because self.desc is properly aligned and dereferenceable.
@@ -263,8 +294,11 @@ pub(crate) struct Descriptor {
 }
 }
 
 
 impl Descriptor {
 impl Descriptor {
-    fn set_buf<H: Hal>(&mut self, buf: &[u8]) {
-        self.addr = H::virt_to_phys(buf.as_ptr() as usize) as u64;
+    /// # Safety
+    ///
+    /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
+    unsafe fn set_buf<H: Hal>(&mut self, buf: NonNull<[u8]>) {
+        self.addr = H::virt_to_phys(buf.as_ptr() as *mut u8 as usize) as u64;
         self.len = buf.len() as u32;
         self.len = buf.len() as u32;
     }
     }
 }
 }
@@ -408,7 +442,10 @@ mod tests {
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         let mut transport = unsafe { MmioTransport::new(NonNull::from(&mut header)) }.unwrap();
         let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
-        assert_eq!(queue.add(&[], &[]).unwrap_err(), Error::InvalidParam);
+        assert_eq!(
+            unsafe { queue.add(&[], &[]) }.unwrap_err(),
+            Error::InvalidParam
+        );
     }
     }
 
 
     #[test]
     #[test]
@@ -418,9 +455,7 @@ mod tests {
         let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         let mut queue = VirtQueue::<FakeHal>::new(&mut transport, 0, 4).unwrap();
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(queue.available_desc(), 4);
         assert_eq!(
         assert_eq!(
-            queue
-                .add(&[&[], &[], &[]], &[&mut [], &mut []])
-                .unwrap_err(),
+            unsafe { queue.add(&[&[], &[], &[]], &[&mut [], &mut []]) }.unwrap_err(),
             Error::BufferTooSmall
             Error::BufferTooSmall
         );
         );
     }
     }
@@ -435,9 +470,7 @@ mod tests {
 
 
         // Add a buffer chain consisting of two device-readable parts followed by two
         // Add a buffer chain consisting of two device-readable parts followed by two
         // device-writable parts.
         // device-writable parts.
-        let token = queue
-            .add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]])
-            .unwrap();
+        let token = unsafe { queue.add(&[&[1, 2], &[3]], &[&mut [0, 0], &mut [0]]) }.unwrap();
 
 
         assert_eq!(queue.available_desc(), 0);
         assert_eq!(queue.available_desc(), 0);
         assert!(!queue.can_pop());
         assert!(!queue.can_pop());