|
@@ -116,13 +116,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
|
|
|
|
/// Sends the buffer to the destination.
|
|
/// Sends the buffer to the destination.
|
|
pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
|
|
pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
|
|
- let connection = self
|
|
|
|
- .connections
|
|
|
|
- .iter_mut()
|
|
|
|
- .find(|connection| {
|
|
|
|
- connection.info.dst == destination && connection.info.src_port == src_port
|
|
|
|
- })
|
|
|
|
- .ok_or(SocketError::NotConnected)?;
|
|
|
|
|
|
+ let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
|
|
|
|
|
self.driver.send(buffer, &mut connection.info)
|
|
self.driver.send(buffer, &mut connection.info)
|
|
}
|
|
}
|
|
@@ -133,13 +127,11 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
let connections = &mut self.connections;
|
|
let connections = &mut self.connections;
|
|
|
|
|
|
let result = self.driver.poll(|event, body| {
|
|
let result = self.driver.poll(|event, body| {
|
|
- let connection = connections
|
|
|
|
- .iter_mut()
|
|
|
|
- .find(|connection| event.matches_connection(&connection.info, guest_cid));
|
|
|
|
|
|
+ let connection = get_connection_for_event(connections, &event, guest_cid);
|
|
|
|
|
|
// Skip events which don't match any connection we know about, unless they are a
|
|
// Skip events which don't match any connection we know about, unless they are a
|
|
// connection request.
|
|
// connection request.
|
|
- let connection = if let Some(connection) = connection {
|
|
|
|
|
|
+ let connection = if let Some((_, connection)) = connection {
|
|
connection
|
|
connection
|
|
} else if let VsockEventType::ConnectionRequest = event.event_type {
|
|
} else if let VsockEventType::ConnectionRequest = event.event_type {
|
|
// If the requested connection already exists or the CID isn't ours, ignore it.
|
|
// If the requested connection already exists or the CID isn't ours, ignore it.
|
|
@@ -172,11 +164,8 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
};
|
|
};
|
|
|
|
|
|
// The connection must exist because we found it above in the callback.
|
|
// The connection must exist because we found it above in the callback.
|
|
- let (connection_index, connection) = connections
|
|
|
|
- .iter_mut()
|
|
|
|
- .enumerate()
|
|
|
|
- .find(|(_, connection)| event.matches_connection(&connection.info, guest_cid))
|
|
|
|
- .unwrap();
|
|
|
|
|
|
+ let (connection_index, connection) =
|
|
|
|
+ get_connection_for_event(connections, &event, guest_cid).unwrap();
|
|
|
|
|
|
match event.event_type {
|
|
match event.event_type {
|
|
VsockEventType::ConnectionRequest => {
|
|
VsockEventType::ConnectionRequest => {
|
|
@@ -220,14 +209,7 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
|
|
|
|
/// Reads data received from the given connection.
|
|
/// Reads data received from the given connection.
|
|
pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
|
|
pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
|
|
- let (connection_index, connection) = self
|
|
|
|
- .connections
|
|
|
|
- .iter_mut()
|
|
|
|
- .enumerate()
|
|
|
|
- .find(|(_, connection)| {
|
|
|
|
- connection.info.dst == peer && connection.info.src_port == src_port
|
|
|
|
- })
|
|
|
|
- .ok_or(SocketError::NotConnected)?;
|
|
|
|
|
|
+ let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
|
|
|
|
|
|
// Copy from ring buffer
|
|
// Copy from ring buffer
|
|
let bytes_read = connection.buffer.read(buffer);
|
|
let bytes_read = connection.buffer.read(buffer);
|
|
@@ -261,27 +243,14 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
|
|
/// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
|
|
/// shutdown.
|
|
/// shutdown.
|
|
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
|
|
pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
|
|
- let connection = self
|
|
|
|
- .connections
|
|
|
|
- .iter()
|
|
|
|
- .find(|connection| {
|
|
|
|
- connection.info.dst == destination && connection.info.src_port == src_port
|
|
|
|
- })
|
|
|
|
- .ok_or(SocketError::NotConnected)?;
|
|
|
|
|
|
+ let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
|
|
|
|
|
self.driver.shutdown(&connection.info)
|
|
self.driver.shutdown(&connection.info)
|
|
}
|
|
}
|
|
|
|
|
|
/// Forcibly closes the connection without waiting for the peer.
|
|
/// Forcibly closes the connection without waiting for the peer.
|
|
pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
|
|
pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
|
|
- let (index, connection) = self
|
|
|
|
- .connections
|
|
|
|
- .iter()
|
|
|
|
- .enumerate()
|
|
|
|
- .find(|(_, connection)| {
|
|
|
|
- connection.info.dst == destination && connection.info.src_port == src_port
|
|
|
|
- })
|
|
|
|
- .ok_or(SocketError::NotConnected)?;
|
|
|
|
|
|
+ let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
|
|
|
|
|
|
self.driver.force_close(&connection.info)?;
|
|
self.driver.force_close(&connection.info)?;
|
|
|
|
|
|
@@ -290,6 +259,36 @@ impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
|
|
|
|
+/// Returns the connection from the given list matching the given peer address and local port, and
|
|
|
|
+/// its index.
|
|
|
|
+///
|
|
|
|
+/// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
|
|
|
|
+fn get_connection(
|
|
|
|
+ connections: &mut [Connection],
|
|
|
|
+ peer: VsockAddr,
|
|
|
|
+ local_port: u32,
|
|
|
|
+) -> core::result::Result<(usize, &mut Connection), SocketError> {
|
|
|
|
+ connections
|
|
|
|
+ .iter_mut()
|
|
|
|
+ .enumerate()
|
|
|
|
+ .find(|(_, connection)| {
|
|
|
|
+ connection.info.dst == peer && connection.info.src_port == local_port
|
|
|
|
+ })
|
|
|
|
+ .ok_or(SocketError::NotConnected)
|
|
|
|
+}
|
|
|
|
+
|
|
|
|
+/// Returns the connection from the given list matching the event, if any, and its index.
|
|
|
|
+fn get_connection_for_event<'a>(
|
|
|
|
+ connections: &'a mut [Connection],
|
|
|
|
+ event: &VsockEvent,
|
|
|
|
+ local_cid: u64,
|
|
|
|
+) -> Option<(usize, &'a mut Connection)> {
|
|
|
|
+ connections
|
|
|
|
+ .iter_mut()
|
|
|
|
+ .enumerate()
|
|
|
|
+ .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
|
|
|
|
+}
|
|
|
|
+
|
|
#[derive(Debug)]
|
|
#[derive(Debug)]
|
|
struct RingBuffer {
|
|
struct RingBuffer {
|
|
buffer: Box<[u8]>,
|
|
buffer: Box<[u8]>,
|