use super::{ protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent, VsockEventType, }; use crate::{transport::Transport, Hal, Result}; use core::hint::spin_loop; use log::debug; /// A higher level interface for VirtIO socket (vsock) devices. /// /// This can only keep track of a single vsock connection. If you want to support multiple /// simultaneous connections, try [`VsockConnectionManager`](super::VsockConnectionManager). pub struct SingleConnectionManager { driver: VirtIOSocket, connection_info: Option, } impl SingleConnectionManager { /// Construct a new connection manager wrapping the given low-level VirtIO socket driver. pub fn new(driver: VirtIOSocket) -> Self { Self { driver, connection_info: None, } } /// Returns the CID which has been assigned to this guest. pub fn guest_cid(&self) -> u64 { self.driver.guest_cid() } /// Sends a request to connect to the given destination. /// /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a /// `VsockEventType::Connected` event indicating that the peer has accepted the connection /// before sending data. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result { if self.connection_info.is_some() { return Err(SocketError::ConnectionExists.into()); } let new_connection_info = ConnectionInfo::new(destination, src_port); self.driver.connect(&new_connection_info)?; debug!("Connection requested: {:?}", new_connection_info); self.connection_info = Some(new_connection_info); Ok(()) } /// Sends the buffer to the destination. pub fn send(&mut self, buffer: &[u8]) -> Result { let connection_info = self .connection_info .as_mut() .ok_or(SocketError::NotConnected)?; connection_info.buf_alloc = 0; self.driver.send(buffer, connection_info) } /// Polls the vsock device to receive data or other updates. /// /// A buffer must be provided to put the data in if there is some to /// receive. pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result> { let Some(connection_info) = &mut self.connection_info else { return Err(SocketError::NotConnected.into()); }; // Tell the peer that we have space to receive some data. connection_info.buf_alloc = buffer.len() as u32; self.driver.credit_update(connection_info)?; self.poll_rx_queue(buffer) } /// Blocks until we get some event from the vsock device. /// /// A buffer must be provided to put the data in if there is some to /// receive. pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result { loop { if let Some(event) = self.poll_recv(buffer)? { return Ok(event); } else { spin_loop(); } } } fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result> { let guest_cid = self.driver.guest_cid(); let self_connection_info = &mut self.connection_info; self.driver.poll(|event, borrowed_body| { let Some(connection_info) = self_connection_info else { return Ok(None); }; // Skip packets which don't match our current connection. if !event.matches_connection(connection_info, guest_cid) { debug!( "Skipping {:?} as connection is {:?}", event, connection_info ); return Ok(None); } // Update stored connection info. connection_info.update_for_event(&event); match event.event_type { VsockEventType::ConnectionRequest => { // TODO: Send Rst or handle incoming connections. } VsockEventType::Connected => {} VsockEventType::Disconnected { .. } => { *self_connection_info = None; } VsockEventType::Received { length } => { body.get_mut(0..length) .ok_or_else(|| SocketError::OutputBufferTooShort(length))? .copy_from_slice(borrowed_body); connection_info.done_forwarding(length); } VsockEventType::CreditRequest => { // No point sending a credit update until `poll_recv` is called with a buffer, // as otherwise buf_alloc would just be 0 anyway. } VsockEventType::CreditUpdate => {} } Ok(Some(event)) }) } /// Requests to shut down the connection cleanly. /// /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the /// shutdown. pub fn shutdown(&mut self) -> Result { let connection_info = self .connection_info .as_mut() .ok_or(SocketError::NotConnected)?; connection_info.buf_alloc = 0; self.driver.shutdown(connection_info) } /// Forcibly closes the connection without waiting for the peer. pub fn force_close(&mut self) -> Result { let connection_info = self .connection_info .as_mut() .ok_or(SocketError::NotConnected)?; connection_info.buf_alloc = 0; self.driver.force_close(connection_info)?; self.connection_info = None; Ok(()) } /// Blocks until the peer either accepts our connection request (with a /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a /// `VIRTIO_VSOCK_OP_RST`). pub fn wait_for_connect(&mut self) -> Result { loop { match self.wait_for_recv(&mut [])?.event_type { VsockEventType::Connected => return Ok(()), VsockEventType::Disconnected { .. } => { return Err(SocketError::ConnectionFailed.into()) } VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()), VsockEventType::ConnectionRequest | VsockEventType::CreditRequest | VsockEventType::CreditUpdate => {} } } } } #[cfg(test)] mod tests { use super::*; use crate::{ device::socket::{ protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp}, vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX}, }, hal::fake::FakeHal, transport::{ fake::{FakeTransport, QueueStatus, State}, DeviceStatus, DeviceType, }, volatile::ReadOnly, }; use alloc::{sync::Arc, vec}; use core::{mem::size_of, ptr::NonNull}; use std::{sync::Mutex, thread}; use zerocopy::{AsBytes, FromBytes}; #[test] fn send_recv() { let host_cid = 2; let guest_cid = 66; let host_port = 1234; let guest_port = 4321; let host_address = VsockAddr { cid: host_cid, port: host_port, }; let hello_from_guest = "Hello from guest"; let hello_from_host = "Hello from host"; let mut config_space = VirtioVsockConfig { guest_cid_low: ReadOnly::new(66), guest_cid_high: ReadOnly::new(0), }; let state = Arc::new(Mutex::new(State { status: DeviceStatus::empty(), driver_features: 0, guest_page_size: 0, interrupt_pending: false, queues: vec![ QueueStatus::default(), QueueStatus::default(), QueueStatus::default(), ], })); let transport = FakeTransport { device_type: DeviceType::Socket, max_queue_size: 32, device_features: 0, config_space: NonNull::from(&mut config_space), state: state.clone(), }; let mut socket = SingleConnectionManager::new( VirtIOSocket::>::new(transport).unwrap(), ); // Start a thread to simulate the device. let handle = thread::spawn(move || { // Wait for connection request. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Request.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 0.into(), fwd_cnt: 0.into(), } ); // Accept connection and give the peer enough credit to send the message. state.lock().unwrap().write_to_queue::( RX_QUEUE_IDX, VirtioVsockHdr { op: VirtioVsockOp::Response.into(), src_cid: host_cid.into(), dst_cid: guest_cid.into(), src_port: host_port.into(), dst_port: guest_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 50.into(), fwd_cnt: 0.into(), } .as_bytes(), ); // Expect a credit update. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::CreditUpdate.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 0.into(), fwd_cnt: 0.into(), } ); // Expect the guest to send some data. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); let request = state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX); assert_eq!( request.len(), size_of::() + hello_from_guest.len() ); assert_eq!( VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Rw.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: (hello_from_guest.len() as u32).into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 0.into(), fwd_cnt: 0.into(), } ); assert_eq!( &request[size_of::()..], hello_from_guest.as_bytes() ); // Send a response. let mut response = vec![0; size_of::() + hello_from_host.len()]; VirtioVsockHdr { op: VirtioVsockOp::Rw.into(), src_cid: host_cid.into(), dst_cid: guest_cid.into(), src_port: host_port.into(), dst_port: guest_port.into(), len: (hello_from_host.len() as u32).into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 50.into(), fwd_cnt: (hello_from_guest.len() as u32).into(), } .write_to_prefix(response.as_mut_slice()); response[size_of::()..].copy_from_slice(hello_from_host.as_bytes()); state .lock() .unwrap() .write_to_queue::(RX_QUEUE_IDX, &response); // Expect a credit update. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::CreditUpdate.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 64.into(), fwd_cnt: 0.into(), } ); // Expect a shutdown. State::wait_until_queue_notified(&state, TX_QUEUE_IDX); assert_eq!( VirtioVsockHdr::read_from( state .lock() .unwrap() .read_from_queue::(TX_QUEUE_IDX) .as_slice() ) .unwrap(), VirtioVsockHdr { op: VirtioVsockOp::Shutdown.into(), src_cid: guest_cid.into(), dst_cid: host_cid.into(), src_port: guest_port.into(), dst_port: host_port.into(), len: 0.into(), socket_type: SocketType::Stream.into(), flags: 0.into(), buf_alloc: 0.into(), fwd_cnt: (hello_from_host.len() as u32).into(), } ); }); socket.connect(host_address, guest_port).unwrap(); socket.wait_for_connect().unwrap(); socket.send(hello_from_guest.as_bytes()).unwrap(); let mut buffer = [0u8; 64]; let event = socket.wait_for_recv(&mut buffer).unwrap(); assert_eq!( event, VsockEvent { source: VsockAddr { cid: host_cid, port: host_port, }, destination: VsockAddr { cid: guest_cid, port: guest_port, }, event_type: VsockEventType::Received { length: hello_from_host.len() }, buffer_status: VsockBufferStatus { buffer_allocation: 50, forward_count: hello_from_guest.len() as u32, }, } ); assert_eq!( &buffer[0..hello_from_host.len()], hello_from_host.as_bytes() ); socket.shutdown().unwrap(); handle.join().unwrap(); } }