Forráskód Böngészése

Pass closure to handle packet.

This avoids the need to know which buffer to copy the body into before
seeing the packet header, which will be necessary for handling multiple
connections.

It changes the behaviour of poll_rx_queue (and hence poll_recv) to
immediately return Ok(None) if the first packet in the RX queue is not
relevant, rather than continuing to look for others. This may mean that
poll_recv needs to be called more times than before.
Andrew Walbran 1 éve
szülő
commit
34a1a7520a
2 módosított fájl, 64 hozzáadás és 52 törlés
  1. 13 11
      src/device/socket/singleconnectionmanager.rs
  2. 51 41
      src/device/socket/vsock.rs

+ 13 - 11
src/device/socket/singleconnectionmanager.rs

@@ -87,22 +87,21 @@ impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
     }
 
     fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
-        loop {
-            let Some(event) = self.driver.poll_recv(body)? else {
-                return Ok(None)
-            };
+        let guest_cid = self.driver.guest_cid();
+        let self_connection_info = &mut self.connection_info;
 
-            let Some(connection_info) = &mut self.connection_info else {
-                continue;
+        self.driver.poll_recv(|event, borrowed_body| {
+            let Some(connection_info) = self_connection_info else {
+                return Ok(None);
             };
 
             // Skip packets which don't match our current connection.
-            if !event.matches_connection(connection_info, self.driver.guest_cid()) {
+            if !event.matches_connection(connection_info, guest_cid) {
                 debug!(
                     "Skipping {:?} as connection is {:?}",
                     event, connection_info
                 );
-                continue;
+                return Ok(None);
             }
 
             // Update stored connection info.
@@ -111,9 +110,12 @@ impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
             match event.event_type {
                 VsockEventType::Connected => {}
                 VsockEventType::Disconnected { .. } => {
-                    self.connection_info = None;
+                    *self_connection_info = None;
                 }
                 VsockEventType::Received { length } => {
+                    body.get_mut(0..length)
+                        .ok_or_else(|| SocketError::OutputBufferTooShort(length))?
+                        .copy_from_slice(borrowed_body);
                     connection_info.done_forwarding(length);
                 }
                 VsockEventType::CreditRequest => {
@@ -122,8 +124,8 @@ impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
                 VsockEventType::CreditUpdate => {}
             }
 
-            return Ok(Some(event));
-        }
+            Ok(Some(event))
+        })
     }
 
     /// Requests to shut down the connection cleanly.

+ 51 - 41
src/device/socket/vsock.rs

@@ -376,19 +376,27 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, &[])
     }
 
-    /// Polls the vsock device to receive data or other updates.
-    ///
-    /// A buffer must be provided to put the data in if there is some to
-    /// receive.
-    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
-        // Handle entries from the RX virtqueue until we find one that generates an event.
-        let event = self.poll_rx_queue(buffer)?;
+    /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
+    /// it.
+    pub fn poll_recv(
+        &mut self,
+        handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
+    ) -> Result<Option<VsockEvent>> {
+        let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
+            return Ok(None);
+        };
 
-        if self.rx.should_notify() {
-            self.transport.notify(RX_QUEUE_IDX);
+        let result = match VsockEvent::from_header(&header) {
+            Ok(Some(event)) => handler(event, body),
+            other => other,
+        };
+
+        unsafe {
+            // TODO: What about if both handler and this give errors?
+            self.add_buffer_to_rx_queue(token)?;
         }
 
-        Ok(event)
+        result
     }
 
     /// Requests to shut down the connection cleanly.
@@ -423,29 +431,35 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    /// Polls the RX virtqueue for the next event.
+    /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
     ///
-    /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
-    /// don't result in any events to return.
-    fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
-        let Some(header) = self.pop_packet_from_rx_queue(body)? else {
-            return Ok(None);
-        };
+    /// # Safety
+    ///
+    /// The buffer must not currently be in the RX queue, and no other references to it must exist
+    /// between when this method is called and when it is popped from the queue.
+    unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
+        // Safe because the buffer lives as long as the queue, and the caller guarantees that it's
+        // not currently in the queue or referred to anywhere else until it is popped.
+        unsafe {
+            let buffer = self.rx_queue_buffers[usize::from(index)].as_mut();
+            let new_token = self.rx.add(&[], &mut [buffer])?;
+            // If the RX buffer somehow gets assigned a different token, then our safety assumptions
+            // are broken and we can't safely continue to do anything with the device.
+            assert_eq!(new_token, index);
+        }
 
-        // TODO: Add buffer back immediately on error or None
-        if let Some(event) = VsockEvent::from_header(&header)? {
-            Ok(Some(event))
-        } else {
-            Ok(None)
+        if self.rx.should_notify() {
+            self.transport.notify(RX_QUEUE_IDX);
         }
+
+        Ok(())
     }
 
-    /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
-    /// the body into the given buffer.
+    /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
+    /// reference to the buffer containing the body.
     ///
-    /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
-    /// buffer supplied.
-    fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
+    /// Returns `None` if there is no pending packet.
+    fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
         let Some(token) = self.rx.peek_used() else {
             return Ok(None);
         };
@@ -453,29 +467,28 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
         // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
         // buffer back to the RX queue then we don't access it again until next time it is popped.
-        let header = unsafe {
+        let (header, body) = unsafe {
             let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
             let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
 
             // Read the header and body from the buffer. Don't check the result yet, because we need
             // to add the buffer back to the queue either way.
-            let header_result = read_header_and_body(buffer, body);
-
-            // Add the buffer back to the RX queue.
-            let new_token = self.rx.add(&[], &mut [buffer])?;
-            // If the RX buffer somehow gets assigned a different token, then our safety assumptions
-            // are broken and we can't safely continue to do anything with the device.
-            assert_eq!(new_token, token);
+            let header_result = read_header_and_body(buffer);
+            if let Err(_) = header_result {
+                // If there was an error, add the buffer back immediately. Ignore any errors, as we
+                // need to return the first error.
+                let _ = self.add_buffer_to_rx_queue(token);
+            }
 
             header_result
         }?;
 
         debug!("Received packet {:?}. Op {:?}", header, header.op());
-        Ok(Some(header))
+        Ok(Some((header, body, token)))
     }
 }
 
-fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
+fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
     let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
     let body_length = header.len() as usize;
     let data_end = size_of::<VirtioVsockHdr>()
@@ -484,10 +497,7 @@ fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr
     let data = buffer
         .get(size_of::<VirtioVsockHdr>()..data_end)
         .ok_or(SocketError::BufferTooShort)?;
-    body.get_mut(0..body_length)
-        .ok_or(SocketError::OutputBufferTooShort(body_length))?
-        .copy_from_slice(data);
-    Ok(header)
+    Ok((header, data))
 }
 
 #[cfg(test)]