123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447 |
- use super::{
- protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
- VsockEventType,
- };
- use crate::{transport::Transport, Hal, Result};
- use core::hint::spin_loop;
- use log::debug;
- /// A higher level interface for VirtIO socket (vsock) devices.
- ///
- /// This can only keep track of a single vsock connection. If you want to support multiple
- /// simultaneous connections, try [`VsockConnectionManager`](super::VsockConnectionManager).
- pub struct SingleConnectionManager<H: Hal, T: Transport> {
- driver: VirtIOSocket<H, T>,
- connection_info: Option<ConnectionInfo>,
- }
- impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
- /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
- pub fn new(driver: VirtIOSocket<H, T>) -> Self {
- Self {
- driver,
- connection_info: None,
- }
- }
- /// Returns the CID which has been assigned to this guest.
- pub fn guest_cid(&self) -> u64 {
- self.driver.guest_cid()
- }
- /// Sends a request to connect to the given destination.
- ///
- /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
- /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
- /// before sending data.
- pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
- if self.connection_info.is_some() {
- return Err(SocketError::ConnectionExists.into());
- }
- let new_connection_info = ConnectionInfo::new(destination, src_port);
- self.driver.connect(&new_connection_info)?;
- debug!("Connection requested: {:?}", new_connection_info);
- self.connection_info = Some(new_connection_info);
- Ok(())
- }
- /// Sends the buffer to the destination.
- pub fn send(&mut self, buffer: &[u8]) -> Result {
- let connection_info = self
- .connection_info
- .as_mut()
- .ok_or(SocketError::NotConnected)?;
- connection_info.buf_alloc = 0;
- self.driver.send(buffer, connection_info)
- }
- /// Polls the vsock device to receive data or other updates.
- ///
- /// A buffer must be provided to put the data in if there is some to
- /// receive.
- pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
- let Some(connection_info) = &mut self.connection_info else {
- return Err(SocketError::NotConnected.into());
- };
- // Tell the peer that we have space to receive some data.
- connection_info.buf_alloc = buffer.len() as u32;
- self.driver.credit_update(connection_info)?;
- self.poll_rx_queue(buffer)
- }
- /// Blocks until we get some event from the vsock device.
- ///
- /// A buffer must be provided to put the data in if there is some to
- /// receive.
- pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
- loop {
- if let Some(event) = self.poll_recv(buffer)? {
- return Ok(event);
- } else {
- spin_loop();
- }
- }
- }
- fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
- let guest_cid = self.driver.guest_cid();
- let self_connection_info = &mut self.connection_info;
- self.driver.poll(|event, borrowed_body| {
- let Some(connection_info) = self_connection_info else {
- return Ok(None);
- };
- // Skip packets which don't match our current connection.
- if !event.matches_connection(connection_info, guest_cid) {
- debug!(
- "Skipping {:?} as connection is {:?}",
- event, connection_info
- );
- return Ok(None);
- }
- // Update stored connection info.
- connection_info.update_for_event(&event);
- match event.event_type {
- VsockEventType::ConnectionRequest => {
- // TODO: Send Rst or handle incoming connections.
- }
- VsockEventType::Connected => {}
- VsockEventType::Disconnected { .. } => {
- *self_connection_info = None;
- }
- VsockEventType::Received { length } => {
- body.get_mut(0..length)
- .ok_or_else(|| SocketError::OutputBufferTooShort(length))?
- .copy_from_slice(borrowed_body);
- connection_info.done_forwarding(length);
- }
- VsockEventType::CreditRequest => {
- // No point sending a credit update until `poll_recv` is called with a buffer,
- // as otherwise buf_alloc would just be 0 anyway.
- }
- VsockEventType::CreditUpdate => {}
- }
- Ok(Some(event))
- })
- }
- /// Requests to shut down the connection cleanly.
- ///
- /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
- /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
- /// shutdown.
- pub fn shutdown(&mut self) -> Result {
- let connection_info = self
- .connection_info
- .as_mut()
- .ok_or(SocketError::NotConnected)?;
- connection_info.buf_alloc = 0;
- self.driver.shutdown(connection_info)
- }
- /// Forcibly closes the connection without waiting for the peer.
- pub fn force_close(&mut self) -> Result {
- let connection_info = self
- .connection_info
- .as_mut()
- .ok_or(SocketError::NotConnected)?;
- connection_info.buf_alloc = 0;
- self.driver.force_close(connection_info)?;
- self.connection_info = None;
- Ok(())
- }
- /// Blocks until the peer either accepts our connection request (with a
- /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
- /// `VIRTIO_VSOCK_OP_RST`).
- pub fn wait_for_connect(&mut self) -> Result {
- loop {
- match self.wait_for_recv(&mut [])?.event_type {
- VsockEventType::Connected => return Ok(()),
- VsockEventType::Disconnected { .. } => {
- return Err(SocketError::ConnectionFailed.into())
- }
- VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()),
- VsockEventType::ConnectionRequest
- | VsockEventType::CreditRequest
- | VsockEventType::CreditUpdate => {}
- }
- }
- }
- }
- #[cfg(test)]
- mod tests {
- use super::*;
- use crate::{
- device::socket::{
- protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
- vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
- },
- hal::fake::FakeHal,
- transport::{
- fake::{FakeTransport, QueueStatus, State},
- DeviceStatus, DeviceType,
- },
- volatile::ReadOnly,
- };
- use alloc::{sync::Arc, vec};
- use core::{mem::size_of, ptr::NonNull};
- use std::{sync::Mutex, thread};
- use zerocopy::{AsBytes, FromBytes};
- #[test]
- fn send_recv() {
- let host_cid = 2;
- let guest_cid = 66;
- let host_port = 1234;
- let guest_port = 4321;
- let host_address = VsockAddr {
- cid: host_cid,
- port: host_port,
- };
- let hello_from_guest = "Hello from guest";
- let hello_from_host = "Hello from host";
- let mut config_space = VirtioVsockConfig {
- guest_cid_low: ReadOnly::new(66),
- guest_cid_high: ReadOnly::new(0),
- };
- let state = Arc::new(Mutex::new(State {
- status: DeviceStatus::empty(),
- driver_features: 0,
- guest_page_size: 0,
- interrupt_pending: false,
- queues: vec![
- QueueStatus::default(),
- QueueStatus::default(),
- QueueStatus::default(),
- ],
- }));
- let transport = FakeTransport {
- device_type: DeviceType::Socket,
- max_queue_size: 32,
- device_features: 0,
- config_space: NonNull::from(&mut config_space),
- state: state.clone(),
- };
- let mut socket = SingleConnectionManager::new(
- VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
- );
- // Start a thread to simulate the device.
- let handle = thread::spawn(move || {
- // Wait for connection request.
- State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
- assert_eq!(
- VirtioVsockHdr::read_from(
- state
- .lock()
- .unwrap()
- .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
- .as_slice()
- )
- .unwrap(),
- VirtioVsockHdr {
- op: VirtioVsockOp::Request.into(),
- src_cid: guest_cid.into(),
- dst_cid: host_cid.into(),
- src_port: guest_port.into(),
- dst_port: host_port.into(),
- len: 0.into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 0.into(),
- fwd_cnt: 0.into(),
- }
- );
- // Accept connection and give the peer enough credit to send the message.
- state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
- RX_QUEUE_IDX,
- VirtioVsockHdr {
- op: VirtioVsockOp::Response.into(),
- src_cid: host_cid.into(),
- dst_cid: guest_cid.into(),
- src_port: host_port.into(),
- dst_port: guest_port.into(),
- len: 0.into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 50.into(),
- fwd_cnt: 0.into(),
- }
- .as_bytes(),
- );
- // Expect a credit update.
- State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
- assert_eq!(
- VirtioVsockHdr::read_from(
- state
- .lock()
- .unwrap()
- .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
- .as_slice()
- )
- .unwrap(),
- VirtioVsockHdr {
- op: VirtioVsockOp::CreditUpdate.into(),
- src_cid: guest_cid.into(),
- dst_cid: host_cid.into(),
- src_port: guest_port.into(),
- dst_port: host_port.into(),
- len: 0.into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 0.into(),
- fwd_cnt: 0.into(),
- }
- );
- // Expect the guest to send some data.
- State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
- let request = state
- .lock()
- .unwrap()
- .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
- assert_eq!(
- request.len(),
- size_of::<VirtioVsockHdr>() + hello_from_guest.len()
- );
- assert_eq!(
- VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
- VirtioVsockHdr {
- op: VirtioVsockOp::Rw.into(),
- src_cid: guest_cid.into(),
- dst_cid: host_cid.into(),
- src_port: guest_port.into(),
- dst_port: host_port.into(),
- len: (hello_from_guest.len() as u32).into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 0.into(),
- fwd_cnt: 0.into(),
- }
- );
- assert_eq!(
- &request[size_of::<VirtioVsockHdr>()..],
- hello_from_guest.as_bytes()
- );
- // Send a response.
- let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
- VirtioVsockHdr {
- op: VirtioVsockOp::Rw.into(),
- src_cid: host_cid.into(),
- dst_cid: guest_cid.into(),
- src_port: host_port.into(),
- dst_port: guest_port.into(),
- len: (hello_from_host.len() as u32).into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 50.into(),
- fwd_cnt: (hello_from_guest.len() as u32).into(),
- }
- .write_to_prefix(response.as_mut_slice());
- response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
- state
- .lock()
- .unwrap()
- .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
- // Expect a credit update.
- State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
- assert_eq!(
- VirtioVsockHdr::read_from(
- state
- .lock()
- .unwrap()
- .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
- .as_slice()
- )
- .unwrap(),
- VirtioVsockHdr {
- op: VirtioVsockOp::CreditUpdate.into(),
- src_cid: guest_cid.into(),
- dst_cid: host_cid.into(),
- src_port: guest_port.into(),
- dst_port: host_port.into(),
- len: 0.into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 64.into(),
- fwd_cnt: 0.into(),
- }
- );
- // Expect a shutdown.
- State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
- assert_eq!(
- VirtioVsockHdr::read_from(
- state
- .lock()
- .unwrap()
- .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
- .as_slice()
- )
- .unwrap(),
- VirtioVsockHdr {
- op: VirtioVsockOp::Shutdown.into(),
- src_cid: guest_cid.into(),
- dst_cid: host_cid.into(),
- src_port: guest_port.into(),
- dst_port: host_port.into(),
- len: 0.into(),
- socket_type: SocketType::Stream.into(),
- flags: 0.into(),
- buf_alloc: 0.into(),
- fwd_cnt: (hello_from_host.len() as u32).into(),
- }
- );
- });
- socket.connect(host_address, guest_port).unwrap();
- socket.wait_for_connect().unwrap();
- socket.send(hello_from_guest.as_bytes()).unwrap();
- let mut buffer = [0u8; 64];
- let event = socket.wait_for_recv(&mut buffer).unwrap();
- assert_eq!(
- event,
- VsockEvent {
- source: VsockAddr {
- cid: host_cid,
- port: host_port,
- },
- destination: VsockAddr {
- cid: guest_cid,
- port: guest_port,
- },
- event_type: VsockEventType::Received {
- length: hello_from_host.len()
- },
- buffer_status: VsockBufferStatus {
- buffer_allocation: 50,
- forward_count: hello_from_guest.len() as u32,
- },
- }
- );
- assert_eq!(
- &buffer[0..hello_from_host.len()],
- hello_from_host.as_bytes()
- );
- socket.shutdown().unwrap();
- handle.join().unwrap();
- }
- }
|