Browse Source

Add VsockEvent struct.

Andrew Walbran 2 years ago
parent
commit
273a496b0f
3 changed files with 71 additions and 13 deletions
  1. 12 4
      examples/aarch64/src/main.rs
  2. 1 1
      src/device/socket/mod.rs
  3. 58 8
      src/device/socket/vsock.rs

+ 12 - 4
examples/aarch64/src/main.rs

@@ -23,7 +23,12 @@ use hal::HalImpl;
 use log::{debug, error, info, trace, warn, LevelFilter};
 use psci::system_off;
 use virtio_drivers::{
-    device::{blk::VirtIOBlk, console::VirtIOConsole, gpu::VirtIOGpu, socket::VirtIOSocket},
+    device::{
+        blk::VirtIOBlk,
+        console::VirtIOConsole,
+        gpu::VirtIOGpu,
+        socket::{VirtIOSocket, VsockEventType},
+    },
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
         pci::{
@@ -195,12 +200,15 @@ fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
     let messages = ["0-Ack. Hello from guest.", "1-Ack. Received again."];
     for k in 0..EXCHANGE_NUM {
         let mut buffer = [0u8; 24];
-        let len = socket.recv(&mut buffer)?;
+        let socket_event = socket.poll_recv(&mut buffer)?;
+        let VsockEventType::Received {length, ..} = socket_event.event_type else {
+            panic!("Received unexpected socket event {:?}", socket_event);
+        };
         info!(
             "Received message: {:?}({:?}), len: {:?}",
             buffer,
-            core::str::from_utf8(&buffer[..len]),
-            len
+            core::str::from_utf8(&buffer[..length]),
+            length
         );
 
         let message = messages[k % messages.len()];

+ 1 - 1
src/device/socket/mod.rs

@@ -5,4 +5,4 @@ mod protocol;
 mod vsock;
 
 pub use error::SocketError;
-pub use vsock::VirtIOSocket;
+pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};

+ 58 - 8
src/device/socket/vsock.rs

@@ -53,6 +53,44 @@ impl ConnectionInfo {
     }
 }
 
+/// An event received from a VirtIO socket device.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub struct VsockEvent {
+    /// The source of the event, i.e. the peer who sent it.
+    pub source: VsockAddr,
+    /// The destination of the event, i.e. the CID and port on our side.
+    pub destination: VsockAddr,
+    /// The type of event.
+    pub event_type: VsockEventType,
+}
+
+/// The reason why a vsock connection was closed.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum DisconnectReason {
+    /// The peer has either closed the connection in response to our shutdown request, or forcibly
+    /// closed it of its own accord.
+    Reset,
+    /// The peer asked to shut down the connection.
+    Shutdown,
+}
+
+/// Details of the type of an event received from a VirtIO socket.
+#[derive(Clone, Debug, Eq, PartialEq)]
+pub enum VsockEventType {
+    /// The connection was successfully established.
+    Connected,
+    /// The connection was closed.
+    Disconnected {
+        /// The reason for the disconnection.
+        reason: DisconnectReason,
+    },
+    /// Data was received on the connection.
+    Received {
+        /// The length of the data in bytes.
+        length: usize,
+    },
+}
+
 /// Driver for a VirtIO socket device.
 pub struct VirtIOSocket<H: Hal, T: Transport> {
     transport: T,
@@ -156,7 +194,7 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         Ok(())
     }
 
-    /// Connects to the destination.
+    /// Sends a request to connect to the given destination.
     pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
         if self.connection_info.is_some() {
             return Err(SocketError::ConnectionExists.into());
@@ -242,9 +280,11 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         }
     }
 
-    /// Receives the buffer from the destination.
-    /// Returns the actual size of the message.
-    pub fn recv(&mut self, buffer: &mut [u8]) -> Result<usize> {
+    /// Polls the vsock device to receive data or other updates.
+    ///
+    /// A buffer must be provided to put the data in if there is some to
+    /// receive.
+    pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
         let connection_info = self.connection_info()?;
 
         // Tell the peer that we have space to receive some data.
@@ -256,13 +296,23 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         self.send_packet_to_tx_queue(&header, &[])?;
 
         // Wait to receive some data.
-        let mut len: u32 = 0;
+        let mut length: u32 = 0;
         self.poll_and_filter_packet_from_rx_queue(&[VirtioVsockOp::Rw], buffer, |header| {
-            len = header.len();
+            length = header.len();
             Ok(())
         })?;
-        self.connection_info.as_mut().unwrap().fwd_cnt += len;
-        Ok(len as usize)
+        self.connection_info.as_mut().unwrap().fwd_cnt += length;
+
+        Ok(VsockEvent {
+            source: connection_info.dst,
+            destination: VsockAddr {
+                cid: self.guest_cid,
+                port: connection_info.src_port,
+            },
+            event_type: VsockEventType::Received {
+                length: length as usize,
+            },
+        })
     }
 
     /// Shuts down the connection.