浏览代码

Add VsockConnectionManager that can handle multiple connections.

Andrew Walbran 1 年之前
父节点
当前提交
159961172c

+ 2 - 0
src/device/socket/mod.rs

@@ -1,6 +1,8 @@
 //! This module implements the virtio vsock device.
 
 mod error;
+#[cfg(feature = "alloc")]
+mod multiconnectionmanager;
 mod protocol;
 #[cfg(feature = "alloc")]
 mod singleconnectionmanager;

+ 254 - 0
src/device/socket/multiconnectionmanager.rs

@@ -0,0 +1,254 @@
+use super::{
+    protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
+    VsockEventType,
+};
+use crate::{transport::Transport, Hal, Result};
+use alloc::{boxed::Box, vec::Vec};
+use core::cmp::min;
+use core::convert::TryInto;
+use core::hint::spin_loop;
+use log::debug;
+use zerocopy::FromBytes;
+
+const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
+
+/// A higher level interface for vsock devices.
+///
+/// This keeps track of a single vsock connection.
+pub struct VsockConnectionManager<H: Hal, T: Transport> {
+    driver: VirtIOSocket<H, T>,
+    connections: Vec<Connection>,
+}
+
+#[derive(Debug)]
+struct Connection {
+    info: ConnectionInfo,
+    buffer: RingBuffer,
+}
+
+impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
+    /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
+    pub fn new(driver: VirtIOSocket<H, T>) -> Self {
+        Self {
+            driver,
+            connections: Vec::new(),
+        }
+    }
+
+    /// Returns the CID which has been assigned to this guest.
+    pub fn guest_cid(&self) -> u64 {
+        self.driver.guest_cid()
+    }
+
+    /// Sends a request to connect to the given destination.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll` returns a
+    /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
+    /// before sending data.
+    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        if self.connections.iter().any(|connection| {
+            connection.info.dst == destination && connection.info.src_port == src_port
+        }) {
+            return Err(SocketError::ConnectionExists.into());
+        }
+
+        let mut new_connection_info = ConnectionInfo::new(destination, src_port);
+        new_connection_info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
+
+        self.driver.connect(destination, src_port)?;
+        debug!("Connection requested: {:?}", new_connection_info);
+        self.connections.push(Connection {
+            info: new_connection_info,
+            buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
+        });
+        Ok(())
+    }
+
+    /// Sends the buffer to the destination.
+    pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
+        let connection = self
+            .connections
+            .iter_mut()
+            .find(|connection| {
+                connection.info.dst == destination && connection.info.src_port == src_port
+            })
+            .ok_or(SocketError::NotConnected)?;
+
+        self.driver.send(buffer, &mut connection.info)
+    }
+
+    /// Polls the vsock device to receive data or other updates.
+    pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
+        let guest_cid = self.driver.guest_cid();
+        let connections = &mut self.connections;
+
+        self.driver.poll_recv(|event, body| {
+            let connection = connections
+                .iter_mut()
+                .find(|connection| event.matches_connection(&connection.info, guest_cid));
+
+            let Some(connection) = connection else {
+                // Skip events which don't match any connection we know about.
+                return Ok(None);
+            };
+
+            // Update stored connection info.
+            connection.info.update_for_event(&event);
+
+            match event.event_type {
+                VsockEventType::Connected => {}
+                VsockEventType::Disconnected { .. } => {
+                    // TODO: Wait until client reads all data before removing connection.
+                    //self.connection_info = None;
+                }
+                VsockEventType::Received { length } => {
+                    // Copy to buffer
+                    if !connection.buffer.write(body) {
+                        return Err(SocketError::OutputBufferTooShort(length).into());
+                    }
+                }
+                VsockEventType::CreditRequest => {
+                    // TODO: Send a credit update.
+                }
+                VsockEventType::CreditUpdate => {}
+            }
+
+            Ok(Some(event))
+        })
+    }
+
+    /// Reads data received from the given connection.
+    pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
+        let connection = self
+            .connections
+            .iter_mut()
+            .find(|connection| connection.info.dst == peer && connection.info.src_port == src_port)
+            .ok_or(SocketError::NotConnected)?;
+
+        // Copy from ring buffer
+        let bytes_read = connection.buffer.read(buffer);
+
+        connection.info.done_forwarding(bytes_read);
+
+        Ok(bytes_read)
+    }
+
+    /// Blocks until we get some event from the vsock device.
+    pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
+        loop {
+            if let Some(event) = self.poll()? {
+                return Ok(event);
+            } else {
+                spin_loop();
+            }
+        }
+    }
+
+    /// Requests to shut down the connection cleanly.
+    ///
+    /// This returns as soon as the request is sent; you should wait until `poll` returns a
+    /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
+    /// shutdown.
+    pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        let connection = self
+            .connections
+            .iter()
+            .find(|connection| {
+                connection.info.dst == destination && connection.info.src_port == src_port
+            })
+            .ok_or(SocketError::NotConnected)?;
+
+        self.driver.shutdown(&connection.info)
+    }
+
+    /// Forcibly closes the connection without waiting for the peer.
+    pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
+        let (index, connection) = self
+            .connections
+            .iter()
+            .enumerate()
+            .find(|(_, connection)| {
+                connection.info.dst == destination && connection.info.src_port == src_port
+            })
+            .ok_or(SocketError::NotConnected)?;
+
+        self.driver.force_close(&connection.info)?;
+
+        self.connections.swap_remove(index);
+        Ok(())
+    }
+}
+
+#[derive(Debug)]
+struct RingBuffer {
+    buffer: Box<[u8]>,
+    /// The number of bytes currently in the buffer.
+    used: usize,
+    /// The index of the first used byte in the buffer.
+    start: usize,
+}
+
+impl RingBuffer {
+    pub fn new(capacity: usize) -> Self {
+        Self {
+            buffer: FromBytes::new_box_slice_zeroed(capacity),
+            used: 0,
+            start: 0,
+        }
+    }
+
+    /// Returns the number of bytes currently used in the buffer.
+    pub fn used(&self) -> usize {
+        self.used
+    }
+
+    /// Returns the number of bytes currently free in the buffer.
+    pub fn available(&self) -> usize {
+        self.buffer.len() - self.used
+    }
+
+    /// Adds the given bytes to the buffer if there is enough capacity for them all.
+    ///
+    /// Returns true if they were added, or false if they were not.
+    pub fn write(&mut self, bytes: &[u8]) -> bool {
+        if bytes.len() > self.available() {
+            return false;
+        }
+
+        let end = (self.start + self.used) % self.buffer.len();
+        let write_before_wraparound = min(bytes.len(), self.buffer.len() - end);
+        let write_after_wraparound = bytes
+            .len()
+            .checked_sub(write_before_wraparound)
+            .unwrap_or_default();
+        self.buffer[end..end + write_before_wraparound]
+            .copy_from_slice(&bytes[0..write_before_wraparound]);
+        self.buffer[0..write_after_wraparound].copy_from_slice(&bytes[write_before_wraparound..]);
+        self.used += bytes.len();
+
+        true
+    }
+
+    /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
+    /// buffer.
+    pub fn read(&mut self, out: &mut [u8]) -> usize {
+        let bytes_read = min(self.used, out.len());
+
+        // The number of bytes to copy out between `start` and the end of the buffer.
+        let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
+        // The number of bytes to copy out from the beginning of the buffer after wrapping around.
+        let read_after_wraparound = bytes_read
+            .checked_sub(read_before_wraparound)
+            .unwrap_or_default();
+
+        out[0..read_before_wraparound]
+            .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
+        out[read_before_wraparound..bytes_read]
+            .copy_from_slice(&self.buffer[0..read_after_wraparound]);
+
+        self.used -= bytes_read;
+        self.start = (self.start + bytes_read) % self.buffer.len();
+
+        bytes_read
+    }
+}

+ 23 - 19
src/device/socket/singleconnectionmanager.rs

@@ -48,12 +48,12 @@ impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
 
     /// Sends the buffer to the destination.
     pub fn send(&mut self, buffer: &[u8]) -> Result {
-        self.driver.send(
-            buffer,
-            self.connection_info
-                .as_mut()
-                .ok_or(SocketError::NotConnected)?,
-        )
+        let connection_info = self
+            .connection_info
+            .as_mut()
+            .ok_or(SocketError::NotConnected)?;
+        connection_info.buf_alloc = 0;
+        self.driver.send(buffer, connection_info)
     }
 
     /// Polls the vsock device to receive data or other updates.
@@ -61,13 +61,13 @@ impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
     /// A buffer must be provided to put the data in if there is some to
     /// receive.
     pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
-        let Some(connection_info) = &self.connection_info else {
+        let Some(connection_info) = &mut self.connection_info else {
             return Err(SocketError::NotConnected.into());
         };
 
         // Tell the peer that we have space to receive some data.
-        self.driver
-            .credit_update(connection_info, buffer.len() as u32)?;
+        connection_info.buf_alloc = buffer.len() as u32;
+        self.driver.credit_update(connection_info)?;
 
         self.poll_rx_queue(buffer)
     }
@@ -134,20 +134,24 @@ impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
     /// shutdown.
     pub fn shutdown(&mut self) -> Result {
-        self.driver.shutdown(
-            self.connection_info
-                .as_ref()
-                .ok_or(SocketError::NotConnected)?,
-        )
+        let connection_info = self
+            .connection_info
+            .as_mut()
+            .ok_or(SocketError::NotConnected)?;
+        connection_info.buf_alloc = 0;
+
+        self.driver.shutdown(connection_info)
     }
 
     /// Forcibly closes the connection without waiting for the peer.
     pub fn force_close(&mut self) -> Result {
-        self.driver.force_close(
-            self.connection_info
-                .as_ref()
-                .ok_or(SocketError::NotConnected)?,
-        )?;
+        let connection_info = self
+            .connection_info
+            .as_mut()
+            .ok_or(SocketError::NotConnected)?;
+        connection_info.buf_alloc = 0;
+
+        self.driver.force_close(connection_info)?;
         self.connection_info = None;
         Ok(())
     }

+ 5 - 3
src/device/socket/vsock.rs

@@ -35,6 +35,9 @@ pub struct ConnectionInfo {
     peer_fwd_cnt: u32,
     /// The number of bytes of packet bodies which we have sent to the peer.
     tx_cnt: u32,
+    /// The number of bytes of buffer space we have allocated to receive packet bodies from the
+    /// peer.
+    pub buf_alloc: u32,
     /// The number of bytes of packet bodies which we have received from the peer and handled.
     fwd_cnt: u32,
     /// Whether we have recently requested credit from the peer.
@@ -82,6 +85,7 @@ impl ConnectionInfo {
             dst_cid: self.dst.cid.into(),
             src_port: self.src_port.into(),
             dst_port: self.dst.port.into(),
+            buf_alloc: self.buf_alloc.into(),
             fwd_cnt: self.fwd_cnt.into(),
             ..Default::default()
         }
@@ -341,7 +345,6 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::Rw.into(),
             len: len.into(),
-            buf_alloc: 0.into(),
             ..connection_info.new_header(self.guest_cid)
         };
         connection_info.tx_cnt += len;
@@ -367,10 +370,9 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     }
 
     /// Tells the peer how much buffer space we have to receive data.
-    pub fn credit_update(&mut self, connection_info: &ConnectionInfo, buffer_size: u32) -> Result {
+    pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
         let header = VirtioVsockHdr {
             op: VirtioVsockOp::CreditUpdate.into(),
-            buf_alloc: buffer_size.into(),
             ..connection_info.new_header(self.guest_cid)
         };
         self.send_packet_to_tx_queue(&header, &[])