Browse Source

Allow incoming connections.

Andrew Walbran 1 year ago
parent
commit
79f08dd8e6
2 changed files with 237 additions and 40 deletions
  1. 228 38
      src/device/socket/multiconnectionmanager.rs
  2. 9 2
      src/device/socket/vsock.rs

+ 228 - 38
src/device/socket/multiconnectionmanager.rs

@@ -43,6 +43,7 @@ const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
 pub struct VsockConnectionManager<H: Hal, T: Transport> {
     driver: VirtIOSocket<H, T>,
     connections: Vec<Connection>,
+    listening_ports: Vec<u32>,
 }
 
 #[derive(Debug)]
@@ -68,6 +69,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
         Self {
             driver,
             connections: Vec::new(),
+            listening_ports: Vec::new(),
         }
     }
 
@@ -76,6 +78,18 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
         self.driver.guest_cid()
     }
 
+    /// Allows incoming connections on the given port number.
+    pub fn listen(&mut self, port: u32) {
+        if !self.listening_ports.contains(&port) {
+            self.listening_ports.push(port);
+        }
+    }
+
+    /// Stops allowing incoming connections on the given port number.
+    pub fn unlisten(&mut self, port: u32) {
+        self.listening_ports.retain(|p| *p != port);
+    }
+
     /// Sends a request to connect to the given destination.
     ///
     /// This returns as soon as the request is sent; you should wait until `poll` returns a
@@ -119,57 +133,77 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
                 .iter_mut()
                 .find(|connection| event.matches_connection(&connection.info, guest_cid));
 
-            let Some(connection) = connection else {
-                // Skip events which don't match any connection we know about.
+            // Skip events which don't match any connection we know about, unless they are a
+            // connection request.
+            let connection = if let Some(connection) = connection {
+                connection
+            } else if let VsockEventType::ConnectionRequest = event.event_type {
+                // If the requested connection already exists or the CID isn't ours, ignore it.
+                if connection.is_some() || event.destination.cid != guest_cid {
+                    return Ok(None);
+                }
+                // Add the new connection to our list, at least for now. It will be removed again
+                // below if we weren't listening on the port.
+                connections.push(Connection::new(event.source, event.destination.port));
+                connections.last_mut().unwrap()
+            } else {
                 return Ok(None);
             };
 
             // Update stored connection info.
             connection.info.update_for_event(&event);
 
-            match event.event_type {
-                VsockEventType::ConnectionRequest => {
-                    // TODO: Send Rst or handle incoming connections.
-                }
-                VsockEventType::Connected => {}
-                VsockEventType::Disconnected { .. } => {
-                    // TODO: Wait until client reads all data before removing connection.
-                    //self.connection_info = None;
+            if let VsockEventType::Received { length } = event.event_type {
+                // Copy to buffer
+                if !connection.buffer.write(body) {
+                    return Err(SocketError::OutputBufferTooShort(length).into());
                 }
-                VsockEventType::Received { length } => {
-                    // Copy to buffer
-                    if !connection.buffer.write(body) {
-                        return Err(SocketError::OutputBufferTooShort(length).into());
-                    }
-                }
-                VsockEventType::CreditRequest => {}
-                VsockEventType::CreditUpdate => {}
             }
 
             Ok(Some(event))
         })?;
 
-        // If the peer requested credit, send an update.
-        if let Some(VsockEvent {
-            source,
-            destination,
-            event_type: VsockEventType::CreditRequest,
-            ..
-        }) = result
-        {
-            let connection = self
-                .connections
-                .iter()
-                .find(|connection| {
-                    connection.info.dst == source && connection.info.src_port == destination.port
-                })
-                .unwrap();
-            self.driver.credit_update(&connection.info)?;
-            // No need to pass the request on to the client, we've already handled it.
-            Ok(None)
-        } else {
-            Ok(result)
+        let Some(event) = result else {
+            return Ok(None);
+        };
+
+        // The connection must exist because we found it above in the callback.
+        let (connection_index, connection) = connections
+            .iter_mut()
+            .enumerate()
+            .find(|(_, connection)| event.matches_connection(&connection.info, guest_cid))
+            .unwrap();
+
+        match event.event_type {
+            VsockEventType::ConnectionRequest => {
+                if self.listening_ports.contains(&event.destination.port) {
+                    self.driver.accept(&connection.info)?;
+                } else {
+                    // Reject the connection request and remove it from our list.
+                    self.driver.force_close(&connection.info)?;
+                    self.connections.swap_remove(connection_index);
+
+                    // No need to pass the request on to the client, as we've already rejected it.
+                    return Ok(None);
+                }
+            }
+            VsockEventType::Connected => {}
+            VsockEventType::Disconnected { .. } => {
+                // TODO: Wait until client reads all data before removing connection.
+            }
+            VsockEventType::Received { .. } => {
+                // Already copied the buffer in the callback above.
+            }
+            VsockEventType::CreditRequest => {
+                // If the peer requested credit, send an update.
+                self.driver.credit_update(&connection.info)?;
+                // No need to pass the request on to the client, we've already handled it.
+                return Ok(None);
+            }
+            VsockEventType::CreditUpdate => {}
         }
+
+        Ok(Some(event))
     }
 
     /// Reads data received from the given connection.
@@ -543,4 +577,160 @@ mod tests {
 
         handle.join().unwrap();
     }
+
+    #[test]
+    fn incoming_connection() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+        let wrong_guest_port = 4444;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket = VsockConnectionManager::new(
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
+        );
+
+        socket.listen(guest_port);
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Send a connection request for a port the guest isn't listening on.
+            println!("Host sending connection request to wrong port");
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: wrong_guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a rejection.
+            println!("Host waiting for rejection");
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rst.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: wrong_guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Send a connection request for a port the guest is listening on.
+            println!("Host sending connection request to right port");
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a response.
+            println!("Host waiting for response");
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 1024.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            println!("Host finished");
+        });
+
+        // Expect an incoming connection.
+        println!("Guest expecting incoming connection.");
+        assert_eq!(
+            socket.wait_for_event().unwrap(),
+            VsockEvent {
+                source: host_address,
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::ConnectionRequest,
+                buffer_status: VsockBufferStatus {
+                    buffer_allocation: 50,
+                    forward_count: 0,
+                },
+            }
+        );
+
+        handle.join().unwrap();
+    }
 }

+ 9 - 2
src/device/socket/vsock.rs

@@ -305,9 +305,16 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         };
         // Sends a header only packet to the TX queue to connect the device to the listening socket
         // at the given destination.
-        self.send_packet_to_tx_queue(&header, &[])?;
+        self.send_packet_to_tx_queue(&header, &[])
+    }
 
-        Ok(())
+    /// Accepts the given connection from a peer.
+    pub fn accept(&mut self, connection_info: &ConnectionInfo) -> Result {
+        let header = VirtioVsockHdr {
+            op: VirtioVsockOp::Response.into(),
+            ..connection_info.new_header(self.guest_cid)
+        };
+        self.send_packet_to_tx_queue(&header, &[])
     }
 
     /// Requests the peer to send us a credit update for the given connection.