|
@@ -43,6 +43,7 @@ const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
|
|
|
pub struct VsockConnectionManager<H: Hal, T: Transport> {
|
|
|
driver: VirtIOSocket<H, T>,
|
|
|
connections: Vec<Connection>,
|
|
|
+ listening_ports: Vec<u32>,
|
|
|
}
|
|
|
|
|
|
#[derive(Debug)]
|
|
@@ -68,6 +69,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
|
Self {
|
|
|
driver,
|
|
|
connections: Vec::new(),
|
|
|
+ listening_ports: Vec::new(),
|
|
|
}
|
|
|
}
|
|
|
|
|
@@ -76,6 +78,18 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
|
self.driver.guest_cid()
|
|
|
}
|
|
|
|
|
|
+ /// Allows incoming connections on the given port number.
|
|
|
+ pub fn listen(&mut self, port: u32) {
|
|
|
+ if !self.listening_ports.contains(&port) {
|
|
|
+ self.listening_ports.push(port);
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
+ /// Stops allowing incoming connections on the given port number.
|
|
|
+ pub fn unlisten(&mut self, port: u32) {
|
|
|
+ self.listening_ports.retain(|p| *p != port);
|
|
|
+ }
|
|
|
+
|
|
|
/// Sends a request to connect to the given destination.
|
|
|
///
|
|
|
/// This returns as soon as the request is sent; you should wait until `poll` returns a
|
|
@@ -119,57 +133,77 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
|
.iter_mut()
|
|
|
.find(|connection| event.matches_connection(&connection.info, guest_cid));
|
|
|
|
|
|
- let Some(connection) = connection else {
|
|
|
- // Skip events which don't match any connection we know about.
|
|
|
+ // Skip events which don't match any connection we know about, unless they are a
|
|
|
+ // connection request.
|
|
|
+ let connection = if let Some(connection) = connection {
|
|
|
+ connection
|
|
|
+ } else if let VsockEventType::ConnectionRequest = event.event_type {
|
|
|
+ // If the requested connection already exists or the CID isn't ours, ignore it.
|
|
|
+ if connection.is_some() || event.destination.cid != guest_cid {
|
|
|
+ return Ok(None);
|
|
|
+ }
|
|
|
+ // Add the new connection to our list, at least for now. It will be removed again
|
|
|
+ // below if we weren't listening on the port.
|
|
|
+ connections.push(Connection::new(event.source, event.destination.port));
|
|
|
+ connections.last_mut().unwrap()
|
|
|
+ } else {
|
|
|
return Ok(None);
|
|
|
};
|
|
|
|
|
|
// Update stored connection info.
|
|
|
connection.info.update_for_event(&event);
|
|
|
|
|
|
- match event.event_type {
|
|
|
- VsockEventType::ConnectionRequest => {
|
|
|
- // TODO: Send Rst or handle incoming connections.
|
|
|
- }
|
|
|
- VsockEventType::Connected => {}
|
|
|
- VsockEventType::Disconnected { .. } => {
|
|
|
- // TODO: Wait until client reads all data before removing connection.
|
|
|
- //self.connection_info = None;
|
|
|
+ if let VsockEventType::Received { length } = event.event_type {
|
|
|
+ // Copy to buffer
|
|
|
+ if !connection.buffer.write(body) {
|
|
|
+ return Err(SocketError::OutputBufferTooShort(length).into());
|
|
|
}
|
|
|
- VsockEventType::Received { length } => {
|
|
|
- // Copy to buffer
|
|
|
- if !connection.buffer.write(body) {
|
|
|
- return Err(SocketError::OutputBufferTooShort(length).into());
|
|
|
- }
|
|
|
- }
|
|
|
- VsockEventType::CreditRequest => {}
|
|
|
- VsockEventType::CreditUpdate => {}
|
|
|
}
|
|
|
|
|
|
Ok(Some(event))
|
|
|
})?;
|
|
|
|
|
|
- // If the peer requested credit, send an update.
|
|
|
- if let Some(VsockEvent {
|
|
|
- source,
|
|
|
- destination,
|
|
|
- event_type: VsockEventType::CreditRequest,
|
|
|
- ..
|
|
|
- }) = result
|
|
|
- {
|
|
|
- let connection = self
|
|
|
- .connections
|
|
|
- .iter()
|
|
|
- .find(|connection| {
|
|
|
- connection.info.dst == source && connection.info.src_port == destination.port
|
|
|
- })
|
|
|
- .unwrap();
|
|
|
- self.driver.credit_update(&connection.info)?;
|
|
|
- // No need to pass the request on to the client, we've already handled it.
|
|
|
- Ok(None)
|
|
|
- } else {
|
|
|
- Ok(result)
|
|
|
+ let Some(event) = result else {
|
|
|
+ return Ok(None);
|
|
|
+ };
|
|
|
+
|
|
|
+ // The connection must exist because we found it above in the callback.
|
|
|
+ let (connection_index, connection) = connections
|
|
|
+ .iter_mut()
|
|
|
+ .enumerate()
|
|
|
+ .find(|(_, connection)| event.matches_connection(&connection.info, guest_cid))
|
|
|
+ .unwrap();
|
|
|
+
|
|
|
+ match event.event_type {
|
|
|
+ VsockEventType::ConnectionRequest => {
|
|
|
+ if self.listening_ports.contains(&event.destination.port) {
|
|
|
+ self.driver.accept(&connection.info)?;
|
|
|
+ } else {
|
|
|
+ // Reject the connection request and remove it from our list.
|
|
|
+ self.driver.force_close(&connection.info)?;
|
|
|
+ self.connections.swap_remove(connection_index);
|
|
|
+
|
|
|
+ // No need to pass the request on to the client, as we've already rejected it.
|
|
|
+ return Ok(None);
|
|
|
+ }
|
|
|
+ }
|
|
|
+ VsockEventType::Connected => {}
|
|
|
+ VsockEventType::Disconnected { .. } => {
|
|
|
+ // TODO: Wait until client reads all data before removing connection.
|
|
|
+ }
|
|
|
+ VsockEventType::Received { .. } => {
|
|
|
+ // Already copied the buffer in the callback above.
|
|
|
+ }
|
|
|
+ VsockEventType::CreditRequest => {
|
|
|
+ // If the peer requested credit, send an update.
|
|
|
+ self.driver.credit_update(&connection.info)?;
|
|
|
+ // No need to pass the request on to the client, we've already handled it.
|
|
|
+ return Ok(None);
|
|
|
+ }
|
|
|
+ VsockEventType::CreditUpdate => {}
|
|
|
}
|
|
|
+
|
|
|
+ Ok(Some(event))
|
|
|
}
|
|
|
|
|
|
/// Reads data received from the given connection.
|
|
@@ -543,4 +577,160 @@ mod tests {
|
|
|
|
|
|
handle.join().unwrap();
|
|
|
}
|
|
|
+
|
|
|
+ #[test]
|
|
|
+ fn incoming_connection() {
|
|
|
+ let host_cid = 2;
|
|
|
+ let guest_cid = 66;
|
|
|
+ let host_port = 1234;
|
|
|
+ let guest_port = 4321;
|
|
|
+ let wrong_guest_port = 4444;
|
|
|
+ let host_address = VsockAddr {
|
|
|
+ cid: host_cid,
|
|
|
+ port: host_port,
|
|
|
+ };
|
|
|
+
|
|
|
+ let mut config_space = VirtioVsockConfig {
|
|
|
+ guest_cid_low: ReadOnly::new(66),
|
|
|
+ guest_cid_high: ReadOnly::new(0),
|
|
|
+ };
|
|
|
+ let state = Arc::new(Mutex::new(State {
|
|
|
+ status: DeviceStatus::empty(),
|
|
|
+ driver_features: 0,
|
|
|
+ guest_page_size: 0,
|
|
|
+ interrupt_pending: false,
|
|
|
+ queues: vec![
|
|
|
+ QueueStatus::default(),
|
|
|
+ QueueStatus::default(),
|
|
|
+ QueueStatus::default(),
|
|
|
+ ],
|
|
|
+ }));
|
|
|
+ let transport = FakeTransport {
|
|
|
+ device_type: DeviceType::Socket,
|
|
|
+ max_queue_size: 32,
|
|
|
+ device_features: 0,
|
|
|
+ config_space: NonNull::from(&mut config_space),
|
|
|
+ state: state.clone(),
|
|
|
+ };
|
|
|
+ let mut socket = VsockConnectionManager::new(
|
|
|
+ VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
|
|
|
+ );
|
|
|
+
|
|
|
+ socket.listen(guest_port);
|
|
|
+
|
|
|
+ // Start a thread to simulate the device.
|
|
|
+ let handle = thread::spawn(move || {
|
|
|
+ // Send a connection request for a port the guest isn't listening on.
|
|
|
+ println!("Host sending connection request to wrong port");
|
|
|
+ state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
|
|
|
+ RX_QUEUE_IDX,
|
|
|
+ VirtioVsockHdr {
|
|
|
+ op: VirtioVsockOp::Request.into(),
|
|
|
+ src_cid: host_cid.into(),
|
|
|
+ dst_cid: guest_cid.into(),
|
|
|
+ src_port: host_port.into(),
|
|
|
+ dst_port: wrong_guest_port.into(),
|
|
|
+ len: 0.into(),
|
|
|
+ socket_type: SocketType::Stream.into(),
|
|
|
+ flags: 0.into(),
|
|
|
+ buf_alloc: 50.into(),
|
|
|
+ fwd_cnt: 0.into(),
|
|
|
+ }
|
|
|
+ .as_bytes(),
|
|
|
+ );
|
|
|
+
|
|
|
+ // Expect a rejection.
|
|
|
+ println!("Host waiting for rejection");
|
|
|
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
|
|
|
+ assert_eq!(
|
|
|
+ VirtioVsockHdr::read_from(
|
|
|
+ state
|
|
|
+ .lock()
|
|
|
+ .unwrap()
|
|
|
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
|
|
|
+ .as_slice()
|
|
|
+ )
|
|
|
+ .unwrap(),
|
|
|
+ VirtioVsockHdr {
|
|
|
+ op: VirtioVsockOp::Rst.into(),
|
|
|
+ src_cid: guest_cid.into(),
|
|
|
+ dst_cid: host_cid.into(),
|
|
|
+ src_port: wrong_guest_port.into(),
|
|
|
+ dst_port: host_port.into(),
|
|
|
+ len: 0.into(),
|
|
|
+ socket_type: SocketType::Stream.into(),
|
|
|
+ flags: 0.into(),
|
|
|
+ buf_alloc: 1024.into(),
|
|
|
+ fwd_cnt: 0.into(),
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ // Send a connection request for a port the guest is listening on.
|
|
|
+ println!("Host sending connection request to right port");
|
|
|
+ state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
|
|
|
+ RX_QUEUE_IDX,
|
|
|
+ VirtioVsockHdr {
|
|
|
+ op: VirtioVsockOp::Request.into(),
|
|
|
+ src_cid: host_cid.into(),
|
|
|
+ dst_cid: guest_cid.into(),
|
|
|
+ src_port: host_port.into(),
|
|
|
+ dst_port: guest_port.into(),
|
|
|
+ len: 0.into(),
|
|
|
+ socket_type: SocketType::Stream.into(),
|
|
|
+ flags: 0.into(),
|
|
|
+ buf_alloc: 50.into(),
|
|
|
+ fwd_cnt: 0.into(),
|
|
|
+ }
|
|
|
+ .as_bytes(),
|
|
|
+ );
|
|
|
+
|
|
|
+ // Expect a response.
|
|
|
+ println!("Host waiting for response");
|
|
|
+ State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
|
|
|
+ assert_eq!(
|
|
|
+ VirtioVsockHdr::read_from(
|
|
|
+ state
|
|
|
+ .lock()
|
|
|
+ .unwrap()
|
|
|
+ .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
|
|
|
+ .as_slice()
|
|
|
+ )
|
|
|
+ .unwrap(),
|
|
|
+ VirtioVsockHdr {
|
|
|
+ op: VirtioVsockOp::Response.into(),
|
|
|
+ src_cid: guest_cid.into(),
|
|
|
+ dst_cid: host_cid.into(),
|
|
|
+ src_port: guest_port.into(),
|
|
|
+ dst_port: host_port.into(),
|
|
|
+ len: 0.into(),
|
|
|
+ socket_type: SocketType::Stream.into(),
|
|
|
+ flags: 0.into(),
|
|
|
+ buf_alloc: 1024.into(),
|
|
|
+ fwd_cnt: 0.into(),
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ println!("Host finished");
|
|
|
+ });
|
|
|
+
|
|
|
+ // Expect an incoming connection.
|
|
|
+ println!("Guest expecting incoming connection.");
|
|
|
+ assert_eq!(
|
|
|
+ socket.wait_for_event().unwrap(),
|
|
|
+ VsockEvent {
|
|
|
+ source: host_address,
|
|
|
+ destination: VsockAddr {
|
|
|
+ cid: guest_cid,
|
|
|
+ port: guest_port,
|
|
|
+ },
|
|
|
+ event_type: VsockEventType::ConnectionRequest,
|
|
|
+ buffer_status: VsockBufferStatus {
|
|
|
+ buffer_allocation: 50,
|
|
|
+ forward_count: 0,
|
|
|
+ },
|
|
|
+ }
|
|
|
+ );
|
|
|
+
|
|
|
+ handle.join().unwrap();
|
|
|
+ }
|
|
|
}
|