Browse Source

Add helper functions to find connections.

Andrew Walbran 1 năm trước cách đây
mục cha
commit
277c474c7b
1 tập tin đã thay đổi với 38 bổ sung39 xóa
  1. 38 39
      src/device/socket/multiconnectionmanager.rs

+ 38 - 39
src/device/socket/multiconnectionmanager.rs

@@ -116,13 +116,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
 
     /// Sends the buffer to the destination.
     pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
-        let connection = self
-            .connections
-            .iter_mut()
-            .find(|connection| {
-                connection.info.dst == destination && connection.info.src_port == src_port
-            })
-            .ok_or(SocketError::NotConnected)?;
+        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
 
         self.driver.send(buffer, &mut connection.info)
     }
@@ -133,13 +127,11 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
         let connections = &mut self.connections;
 
         let result = self.driver.poll(|event, body| {
-            let connection = connections
-                .iter_mut()
-                .find(|connection| event.matches_connection(&connection.info, guest_cid));
+            let connection = get_connection_for_event(connections, &event, guest_cid);
 
             // Skip events which don't match any connection we know about, unless they are a
             // connection request.
-            let connection = if let Some(connection) = connection {
+            let connection = if let Some((_, connection)) = connection {
                 connection
             } else if let VsockEventType::ConnectionRequest = event.event_type {
                 // If the requested connection already exists or the CID isn't ours, ignore it.
@@ -172,11 +164,8 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
         };
 
         // The connection must exist because we found it above in the callback.
-        let (connection_index, connection) = connections
-            .iter_mut()
-            .enumerate()
-            .find(|(_, connection)| event.matches_connection(&connection.info, guest_cid))
-            .unwrap();
+        let (connection_index, connection) =
+            get_connection_for_event(connections, &event, guest_cid).unwrap();
 
         match event.event_type {
             VsockEventType::ConnectionRequest => {
@@ -220,14 +209,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
 
     /// Reads data received from the given connection.
     pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
-        let (connection_index, connection) = self
-            .connections
-            .iter_mut()
-            .enumerate()
-            .find(|(_, connection)| {
-                connection.info.dst == peer && connection.info.src_port == src_port
-            })
-            .ok_or(SocketError::NotConnected)?;
+        let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
 
         // Copy from ring buffer
         let bytes_read = connection.buffer.read(buffer);
@@ -261,27 +243,14 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
     /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
     /// shutdown.
     pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
-        let connection = self
-            .connections
-            .iter()
-            .find(|connection| {
-                connection.info.dst == destination && connection.info.src_port == src_port
-            })
-            .ok_or(SocketError::NotConnected)?;
+        let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
 
         self.driver.shutdown(&connection.info)
     }
 
     /// Forcibly closes the connection without waiting for the peer.
     pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
-        let (index, connection) = self
-            .connections
-            .iter()
-            .enumerate()
-            .find(|(_, connection)| {
-                connection.info.dst == destination && connection.info.src_port == src_port
-            })
-            .ok_or(SocketError::NotConnected)?;
+        let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
 
         self.driver.force_close(&connection.info)?;
 
@@ -290,6 +259,36 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
     }
 }
 
+/// Returns the connection from the given list matching the given peer address and local port, and
+/// its index.
+///
+/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
+fn get_connection(
+    connections: &mut [Connection],
+    peer: VsockAddr,
+    local_port: u32,
+) -> core::result::Result<(usize, &mut Connection), SocketError> {
+    connections
+        .iter_mut()
+        .enumerate()
+        .find(|(_, connection)| {
+            connection.info.dst == peer && connection.info.src_port == local_port
+        })
+        .ok_or(SocketError::NotConnected)
+}
+
+/// Returns the connection from the given list matching the event, if any, and its index.
+fn get_connection_for_event<'a>(
+    connections: &'a mut [Connection],
+    event: &VsockEvent,
+    local_cid: u64,
+) -> Option<(usize, &'a mut Connection)> {
+    connections
+        .iter_mut()
+        .enumerate()
+        .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
+}
+
 #[derive(Debug)]
 struct RingBuffer {
     buffer: Box<[u8]>,