Эх сурвалжийг харах

Add test for vsock sending.

Andrew Walbran 1 жил өмнө
parent
commit
2b0bdca803

+ 1 - 1
src/device/socket/protocol.rs

@@ -40,7 +40,7 @@ pub struct VirtioVsockConfig {
 
 /// The message header for data packets sent on the tx/rx queues
 #[repr(packed)]
-#[derive(AsBytes, Clone, Copy, Debug, FromBytes)]
+#[derive(AsBytes, Clone, Copy, Debug, Eq, FromBytes, PartialEq)]
 pub struct VirtioVsockHdr {
     pub src_cid: U64<LittleEndian>,
     pub dst_cid: U64<LittleEndian>,

+ 178 - 3
src/device/socket/vsock.rs

@@ -559,17 +559,18 @@ fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::volatile::ReadOnly;
     use crate::{
+        device::socket::protocol::SocketType,
         hal::fake::FakeHal,
         transport::{
             fake::{FakeTransport, QueueStatus, State},
             DeviceStatus, DeviceType,
         },
+        volatile::ReadOnly,
     };
     use alloc::{sync::Arc, vec};
-    use core::ptr::NonNull;
-    use std::sync::Mutex;
+    use core::{ptr::NonNull, time::Duration};
+    use std::{sync::Mutex, thread};
 
     #[test]
     fn config() {
@@ -595,4 +596,178 @@ mod tests {
             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
         assert_eq!(socket.guest_cid, 0x00_0000_0042);
     }
+
+    #[test]
+    fn send() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![QueueStatus::default(); 3],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket =
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Wait for connection request.
+            while !state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified {
+                thread::sleep(Duration::from_millis(10));
+            }
+            state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified = false;
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Accept connection and give the peer enough credit to send the message.
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a credit update.
+            while !state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified {
+                thread::sleep(Duration::from_millis(10));
+            }
+            state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified = false;
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect the guest to send some data.
+            while !state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified {
+                thread::sleep(Duration::from_millis(10));
+            }
+            state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified = false;
+            let request = state
+                .lock()
+                .unwrap()
+                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+            assert_eq!(
+                request.len(),
+                size_of::<VirtioVsockHdr>() + "Hello from guest".len()
+            );
+            assert_eq!(
+                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rw.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 16.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+            assert_eq!(
+                &request[size_of::<VirtioVsockHdr>()..],
+                "Hello from guest".as_bytes()
+            );
+
+            // Expect a shutdown.
+            while !state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified {
+                thread::sleep(Duration::from_millis(10));
+            }
+            state.lock().unwrap().queues[usize::from(TX_QUEUE_IDX)].notified = false;
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Shutdown.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+        });
+
+        socket.connect(host_cid, guest_port, host_port).unwrap();
+        socket.wait_for_connect().unwrap();
+        socket.send("Hello from guest".as_bytes()).unwrap();
+        socket.shutdown().unwrap();
+
+        handle.join().unwrap();
+    }
 }