|
@@ -7,11 +7,11 @@ use crate::hal::Hal;
|
|
|
use crate::queue::VirtQueue;
|
|
|
use crate::transport::Transport;
|
|
|
use crate::volatile::volread;
|
|
|
-use crate::{Result, PAGE_SIZE};
|
|
|
+use crate::Result;
|
|
|
use alloc::boxed::Box;
|
|
|
use core::hint::spin_loop;
|
|
|
use core::mem::size_of;
|
|
|
-use core::ptr::NonNull;
|
|
|
+use core::ptr::{null_mut, NonNull};
|
|
|
use log::{debug, info};
|
|
|
use zerocopy::{AsBytes, FromBytes};
|
|
|
|
|
@@ -21,6 +21,9 @@ const EVENT_QUEUE_IDX: u16 = 2;
|
|
|
|
|
|
const QUEUE_SIZE: usize = 8;
|
|
|
|
|
|
+/// The size in bytes of each buffer used in the RX virtqueue.
|
|
|
+const RX_BUFFER_SIZE: usize = 512;
|
|
|
+
|
|
|
#[derive(Clone, Debug, Default, PartialEq, Eq)]
|
|
|
struct ConnectionInfo {
|
|
|
dst: VsockAddr,
|
|
@@ -108,7 +111,7 @@ pub struct VirtIOSocket<H: Hal, T: Transport> {
|
|
|
/// The guest_cid field contains the guest’s context ID, which uniquely identifies
|
|
|
/// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
|
|
|
guest_cid: u64,
|
|
|
- rx_buf: NonNull<[u8; PAGE_SIZE]>,
|
|
|
+ rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
|
|
|
|
|
|
/// Currently the device is only allowed to be connected to one destination at a time.
|
|
|
connection_info: Option<ConnectionInfo>,
|
|
@@ -122,9 +125,11 @@ impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
|
|
|
self.transport.queue_unset(TX_QUEUE_IDX);
|
|
|
self.transport.queue_unset(EVENT_QUEUE_IDX);
|
|
|
|
|
|
- // Safe because we obtained the rx_buf pointer from Box::into_raw, and it won't be used
|
|
|
- // anywhere else after the driver is destroyed.
|
|
|
- unsafe { drop(Box::from_raw(self.rx_buf.as_ptr())) };
|
|
|
+ for buffer in self.rx_queue_buffers {
|
|
|
+ // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
|
|
|
+ // used anywhere else after the driver is destroyed.
|
|
|
+ unsafe { drop(Box::from_raw(buffer.as_ptr())) };
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -151,14 +156,22 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
|
let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
|
|
|
let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
|
|
|
|
|
|
- // Allocates 4 KiB memory for the RX buffer.
|
|
|
- let rx_buf: NonNull<[u8; PAGE_SIZE]> =
|
|
|
- NonNull::new(Box::into_raw(FromBytes::new_box_zeroed())).unwrap();
|
|
|
- // Safe because `rx_buf` lives as long as the `rx` queue.
|
|
|
- unsafe {
|
|
|
- Self::fill_rx_queue(&mut rx, rx_buf, &mut transport)?;
|
|
|
+ // Allocate and add buffers for the RX queue.
|
|
|
+ let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
|
|
|
+ for i in 0..QUEUE_SIZE {
|
|
|
+ let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
|
|
|
+ // Safe because the buffer lives as long as the queue, as specified in the function
|
|
|
+ // safety requirement, and we don't access it until it is popped.
|
|
|
+ let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
|
|
|
+ assert_eq!(i, token.into());
|
|
|
+ rx_queue_buffers[i] = Box::into_raw(buffer);
|
|
|
}
|
|
|
+ let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
|
|
|
+
|
|
|
transport.finish_init();
|
|
|
+ if rx.should_notify() {
|
|
|
+ transport.notify(RX_QUEUE_IDX);
|
|
|
+ }
|
|
|
|
|
|
Ok(Self {
|
|
|
transport,
|
|
@@ -166,39 +179,14 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
|
tx,
|
|
|
event,
|
|
|
guest_cid,
|
|
|
- rx_buf,
|
|
|
+ rx_queue_buffers,
|
|
|
connection_info: None,
|
|
|
})
|
|
|
}
|
|
|
|
|
|
- /// Fills the `rx` queue with the buffer `rx_buf`.
|
|
|
- ///
|
|
|
- /// # Safety
|
|
|
- ///
|
|
|
- /// `rx_buf` must live at least as long as the `rx` queue, and the parts of the buffer which are
|
|
|
- /// in the queue must not be used anywhere else at the same time.
|
|
|
- unsafe fn fill_rx_queue(
|
|
|
- rx: &mut VirtQueue<H, { QUEUE_SIZE }>,
|
|
|
- rx_buf: NonNull<[u8]>,
|
|
|
- transport: &mut T,
|
|
|
- ) -> Result {
|
|
|
- if rx_buf.len() < size_of::<VirtioVsockHdr>() * QUEUE_SIZE {
|
|
|
- return Err(SocketError::BufferTooShort.into());
|
|
|
- }
|
|
|
- for i in 0..QUEUE_SIZE {
|
|
|
- // Safe because the buffer lives as long as the queue, as specified in the function
|
|
|
- // safety requirement, and we don't access it until it is popped.
|
|
|
- unsafe {
|
|
|
- let buffer = Self::as_mut_sub_rx_buffer(rx_buf, i)?;
|
|
|
- let token = rx.add(&[], &mut [buffer])?;
|
|
|
- assert_eq!(i, token.into());
|
|
|
- }
|
|
|
- }
|
|
|
-
|
|
|
- if rx.should_notify() {
|
|
|
- transport.notify(RX_QUEUE_IDX);
|
|
|
- }
|
|
|
- Ok(())
|
|
|
+ /// Returns the CID which has been assigned to this guest.
|
|
|
+ pub fn guest_cid(&self) -> u64 {
|
|
|
+ self.guest_cid
|
|
|
}
|
|
|
|
|
|
/// Sends a request to connect to the given destination.
|
|
@@ -206,15 +194,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
|
/// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
|
|
|
/// `VsockEventType::Connected` event indicating that the peer has accepted the connection
|
|
|
/// before sending data.
|
|
|
- pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
|
|
|
+ pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
|
|
|
if self.connection_info.is_some() {
|
|
|
return Err(SocketError::ConnectionExists.into());
|
|
|
}
|
|
|
let new_connection_info = ConnectionInfo {
|
|
|
- dst: VsockAddr {
|
|
|
- cid: dst_cid,
|
|
|
- port: dst_port,
|
|
|
- },
|
|
|
+ dst: destination,
|
|
|
src_port,
|
|
|
..Default::default()
|
|
|
};
|
|
@@ -483,7 +468,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
|
// buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
|
|
|
// buffer back to the RX queue then we don't access it again until next time it is popped.
|
|
|
let header = unsafe {
|
|
|
- let buffer = Self::as_mut_sub_rx_buffer(self.rx_buf, token.into())?;
|
|
|
+ let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
|
|
|
let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
|
|
|
|
|
|
// Read the header and body from the buffer. Don't check the result yet, because we need
|
|
@@ -508,36 +493,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
|
|
|
.clone()
|
|
|
.ok_or(SocketError::NotConnected.into())
|
|
|
}
|
|
|
-
|
|
|
- /// Gets a reference to a subslice of the RX buffer to be used for the given entry in the RX
|
|
|
- /// virtqueue.
|
|
|
- ///
|
|
|
- /// # Safety
|
|
|
- ///
|
|
|
- /// `rx_buf` must be a valid dereferenceable pointer.
|
|
|
- /// The returned reference has an arbitrary lifetime `'a`. This lifetime must not overlap with
|
|
|
- /// any other references to the same subslice of the RX buffer or outlive the buffer.
|
|
|
- unsafe fn as_mut_sub_rx_buffer<'a>(
|
|
|
- mut rx_buf: NonNull<[u8]>,
|
|
|
- i: usize,
|
|
|
- ) -> Result<&'a mut [u8]> {
|
|
|
- let buffer_size = rx_buf.len() / QUEUE_SIZE;
|
|
|
- let start = buffer_size
|
|
|
- .checked_mul(i)
|
|
|
- .ok_or(SocketError::InvalidNumber)?;
|
|
|
- let end = start
|
|
|
- .checked_add(buffer_size)
|
|
|
- .ok_or(SocketError::InvalidNumber)?;
|
|
|
- // Safe because no alignment or initialisation is required for [u8], and our caller assures
|
|
|
- // us that `rx_buf` is dereferenceable and that the lifetime of the slice we are creating
|
|
|
- // won't overlap with any other references to the same slice or outlive it.
|
|
|
- unsafe {
|
|
|
- rx_buf
|
|
|
- .as_mut()
|
|
|
- .get_mut(start..end)
|
|
|
- .ok_or(SocketError::BufferTooShort.into())
|
|
|
- }
|
|
|
- }
|
|
|
}
|
|
|
|
|
|
fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
|
|
@@ -597,7 +552,7 @@ mod tests {
|
|
|
};
|
|
|
let socket =
|
|
|
VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
|
|
|
- assert_eq!(socket.guest_cid, 0x00_0000_0042);
|
|
|
+ assert_eq!(socket.guest_cid(), 0x00_0000_0042);
|
|
|
}
|
|
|
|
|
|
#[test]
|
|
@@ -606,6 +561,10 @@ mod tests {
|
|
|
let guest_cid = 66;
|
|
|
let host_port = 1234;
|
|
|
let guest_port = 4321;
|
|
|
+ let host_address = VsockAddr {
|
|
|
+ cid: host_cid,
|
|
|
+ port: host_port,
|
|
|
+ };
|
|
|
let hello_from_guest = "Hello from guest";
|
|
|
let hello_from_host = "Hello from host";
|
|
|
|
|
@@ -806,7 +765,7 @@ mod tests {
|
|
|
);
|
|
|
});
|
|
|
|
|
|
- socket.connect(host_cid, guest_port, host_port).unwrap();
|
|
|
+ socket.connect(host_address, guest_port).unwrap();
|
|
|
socket.wait_for_connect().unwrap();
|
|
|
socket.send(hello_from_guest.as_bytes()).unwrap();
|
|
|
let mut buffer = [0u8; 64];
|