|
@@ -376,19 +376,27 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
self.send_packet_to_tx_queue(&header, &[])
|
|
}
|
|
}
|
|
|
|
|
|
- /// Polls the vsock device to receive data or other updates.
|
|
|
|
- ///
|
|
|
|
- /// A buffer must be provided to put the data in if there is some to
|
|
|
|
- /// receive.
|
|
|
|
- pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
|
|
|
|
- // Handle entries from the RX virtqueue until we find one that generates an event.
|
|
|
|
- let event = self.poll_rx_queue(buffer)?;
|
|
|
|
|
|
+ /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
|
|
|
|
+ /// it.
|
|
|
|
+ pub fn poll_recv(
|
|
|
|
+ &mut self,
|
|
|
|
+ handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
|
|
|
|
+ ) -> Result<Option<VsockEvent>> {
|
|
|
|
+ let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
|
|
|
|
+ return Ok(None);
|
|
|
|
+ };
|
|
|
|
|
|
- if self.rx.should_notify() {
|
|
|
|
- self.transport.notify(RX_QUEUE_IDX);
|
|
|
|
|
|
+ let result = match VsockEvent::from_header(&header) {
|
|
|
|
+ Ok(Some(event)) => handler(event, body),
|
|
|
|
+ other => other,
|
|
|
|
+ };
|
|
|
|
+
|
|
|
|
+ unsafe {
|
|
|
|
+ // TODO: What about if both handler and this give errors?
|
|
|
|
+ self.add_buffer_to_rx_queue(token)?;
|
|
}
|
|
}
|
|
|
|
|
|
- Ok(event)
|
|
|
|
|
|
+ result
|
|
}
|
|
}
|
|
|
|
|
|
/// Requests to shut down the connection cleanly.
|
|
/// Requests to shut down the connection cleanly.
|
|
@@ -423,29 +431,35 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
Ok(())
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
- /// Polls the RX virtqueue for the next event.
|
|
|
|
|
|
+ /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
|
|
///
|
|
///
|
|
- /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
|
|
|
|
- /// don't result in any events to return.
|
|
|
|
- fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
|
|
|
|
- let Some(header) = self.pop_packet_from_rx_queue(body)? else {
|
|
|
|
- return Ok(None);
|
|
|
|
- };
|
|
|
|
|
|
+ /// # Safety
|
|
|
|
+ ///
|
|
|
|
+ /// The buffer must not currently be in the RX queue, and no other references to it must exist
|
|
|
|
+ /// between when this method is called and when it is popped from the queue.
|
|
|
|
+ unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
|
|
|
|
+ // Safe because the buffer lives as long as the queue, and the caller guarantees that it's
|
|
|
|
+ // not currently in the queue or referred to anywhere else until it is popped.
|
|
|
|
+ unsafe {
|
|
|
|
+ let buffer = self.rx_queue_buffers[usize::from(index)].as_mut();
|
|
|
|
+ let new_token = self.rx.add(&[], &mut [buffer])?;
|
|
|
|
+ // If the RX buffer somehow gets assigned a different token, then our safety assumptions
|
|
|
|
+ // are broken and we can't safely continue to do anything with the device.
|
|
|
|
+ assert_eq!(new_token, index);
|
|
|
|
+ }
|
|
|
|
|
|
- // TODO: Add buffer back immediately on error or None
|
|
|
|
- if let Some(event) = VsockEvent::from_header(&header)? {
|
|
|
|
- Ok(Some(event))
|
|
|
|
- } else {
|
|
|
|
- Ok(None)
|
|
|
|
|
|
+ if self.rx.should_notify() {
|
|
|
|
+ self.transport.notify(RX_QUEUE_IDX);
|
|
}
|
|
}
|
|
|
|
+
|
|
|
|
+ Ok(())
|
|
}
|
|
}
|
|
|
|
|
|
- /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
|
|
|
|
- /// the body into the given buffer.
|
|
|
|
|
|
+ /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
|
|
|
|
+ /// reference to the buffer containing the body.
|
|
///
|
|
///
|
|
- /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
|
|
|
|
- /// buffer supplied.
|
|
|
|
- fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
|
|
|
|
|
|
+ /// Returns `None` if there is no pending packet.
|
|
|
|
+ fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
|
|
let Some(token) = self.rx.peek_used() else {
|
|
let Some(token) = self.rx.peek_used() else {
|
|
return Ok(None);
|
|
return Ok(None);
|
|
};
|
|
};
|
|
@@ -453,29 +467,28 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
// Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
|
|
// Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
|
|
// buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
|
|
// buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
|
|
// buffer back to the RX queue then we don't access it again until next time it is popped.
|
|
// buffer back to the RX queue then we don't access it again until next time it is popped.
|
|
- let header = unsafe {
|
|
|
|
|
|
+ let (header, body) = unsafe {
|
|
let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
|
|
let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
|
|
let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
|
|
let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
|
|
|
|
|
|
// Read the header and body from the buffer. Don't check the result yet, because we need
|
|
// Read the header and body from the buffer. Don't check the result yet, because we need
|
|
// to add the buffer back to the queue either way.
|
|
// to add the buffer back to the queue either way.
|
|
- let header_result = read_header_and_body(buffer, body);
|
|
|
|
-
|
|
|
|
- // Add the buffer back to the RX queue.
|
|
|
|
- let new_token = self.rx.add(&[], &mut [buffer])?;
|
|
|
|
- // If the RX buffer somehow gets assigned a different token, then our safety assumptions
|
|
|
|
- // are broken and we can't safely continue to do anything with the device.
|
|
|
|
- assert_eq!(new_token, token);
|
|
|
|
|
|
+ let header_result = read_header_and_body(buffer);
|
|
|
|
+ if let Err(_) = header_result {
|
|
|
|
+ // If there was an error, add the buffer back immediately. Ignore any errors, as we
|
|
|
|
+ // need to return the first error.
|
|
|
|
+ let _ = self.add_buffer_to_rx_queue(token);
|
|
|
|
+ }
|
|
|
|
|
|
header_result
|
|
header_result
|
|
}?;
|
|
}?;
|
|
|
|
|
|
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
|
debug!("Received packet {:?}. Op {:?}", header, header.op());
|
|
- Ok(Some(header))
|
|
|
|
|
|
+ Ok(Some((header, body, token)))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
-fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
|
|
|
|
|
|
+fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
|
|
let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
|
|
let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
|
|
let body_length = header.len() as usize;
|
|
let body_length = header.len() as usize;
|
|
let data_end = size_of::<VirtioVsockHdr>()
|
|
let data_end = size_of::<VirtioVsockHdr>()
|
|
@@ -484,10 +497,7 @@ fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr
|
|
let data = buffer
|
|
let data = buffer
|
|
.get(size_of::<VirtioVsockHdr>()..data_end)
|
|
.get(size_of::<VirtioVsockHdr>()..data_end)
|
|
.ok_or(SocketError::BufferTooShort)?;
|
|
.ok_or(SocketError::BufferTooShort)?;
|
|
- body.get_mut(0..body_length)
|
|
|
|
- .ok_or(SocketError::OutputBufferTooShort(body_length))?
|
|
|
|
- .copy_from_slice(data);
|
|
|
|
- Ok(header)
|
|
|
|
|
|
+ Ok((header, data))
|
|
}
|
|
}
|
|
|
|
|
|
#[cfg(test)]
|
|
#[cfg(test)]
|