Sfoglia il codice sorgente

Take VsockAddr to connect method rather than separate CID and port.

Andrew Walbran 1 anno fa
parent
commit
f9191727c4
3 ha cambiato i file con 16 aggiunte e 8 eliminazioni
  1. 8 2
      examples/aarch64/src/main.rs
  2. 1 0
      src/device/socket/mod.rs
  3. 7 6
      src/device/socket/vsock.rs

+ 8 - 2
examples/aarch64/src/main.rs

@@ -30,7 +30,7 @@ use virtio_drivers::{
         blk::VirtIOBlk,
         console::VirtIOConsole,
         gpu::VirtIOGpu,
-        socket::{VirtIOSocket, VsockEventType},
+        socket::{VirtIOSocket, VsockAddr, VsockEventType},
     },
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
@@ -209,7 +209,13 @@ fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
     let host_cid = 2;
     let port = 1221;
     info!("Connecting to host on port {port}...");
-    socket.connect(host_cid, port, port)?;
+    socket.connect(
+        VsockAddr {
+            cid: host_cid,
+            port,
+        },
+        port,
+    )?;
     socket.wait_for_connect()?;
     info!("Connected to the host");
 

+ 1 - 0
src/device/socket/mod.rs

@@ -6,5 +6,6 @@ mod protocol;
 mod vsock;
 
 pub use error::SocketError;
+pub use protocol::VsockAddr;
 #[cfg(feature = "alloc")]
 pub use vsock::{DisconnectReason, VirtIOSocket, VsockEvent, VsockEventType};

+ 7 - 6
src/device/socket/vsock.rs

@@ -207,15 +207,12 @@ impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
     /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
     /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
     /// before sending data.
-    pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
+    pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
         if self.connection_info.is_some() {
             return Err(SocketError::ConnectionExists.into());
         }
         let new_connection_info = ConnectionInfo {
-            dst: VsockAddr {
-                cid: dst_cid,
-                port: dst_port,
-            },
+            dst: destination,
             src_port,
             ..Default::default()
         };
@@ -607,6 +604,10 @@ mod tests {
         let guest_cid = 66;
         let host_port = 1234;
         let guest_port = 4321;
+        let host_address = VsockAddr {
+            cid: host_cid,
+            port: host_port,
+        };
         let hello_from_guest = "Hello from guest";
         let hello_from_host = "Hello from host";
 
@@ -807,7 +808,7 @@ mod tests {
             );
         });
 
-        socket.connect(host_cid, guest_port, host_port).unwrap();
+        socket.connect(host_address, guest_port).unwrap();
         socket.wait_for_connect().unwrap();
         socket.send(hello_from_guest.as_bytes()).unwrap();
         let mut buffer = [0u8; 64];