Browse Source

Merge pull request #87 from rcore-os/vsocktest

Add test for vsock driver
Andrew Walbran 1 year ago
parent
commit
1d79ac893e
5 changed files with 277 additions and 25 deletions
  1. 6 10
      src/device/blk.rs
  2. 4 6
      src/device/console.rs
  3. 1 1
      src/device/socket/protocol.rs
  4. 244 3
      src/device/socket/vsock.rs
  5. 22 5
      src/transport/fake.rs

+ 6 - 10
src/device/blk.rs

@@ -478,7 +478,7 @@ mod tests {
     };
     use alloc::{sync::Arc, vec};
     use core::{mem::size_of, ptr::NonNull};
-    use std::{sync::Mutex, thread, time::Duration};
+    use std::{sync::Mutex, thread};
 
     #[test]
     fn config() {
@@ -501,7 +501,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -537,7 +537,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -551,9 +551,7 @@ mod tests {
         // Start a thread to simulate the device waiting for a read request.
         let handle = thread::spawn(move || {
             println!("Device waiting for a request.");
-            while !state.lock().unwrap().queues[usize::from(QUEUE)].notified {
-                thread::sleep(Duration::from_millis(10));
-            }
+            State::wait_until_queue_notified(&state, QUEUE);
             println!("Transmit queue was notified.");
 
             state
@@ -612,7 +610,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -626,9 +624,7 @@ mod tests {
         // Start a thread to simulate the device waiting for a write request.
         let handle = thread::spawn(move || {
             println!("Device waiting for a request.");
-            while !state.lock().unwrap().queues[usize::from(QUEUE)].notified {
-                thread::sleep(Duration::from_millis(10));
-            }
+            State::wait_until_queue_notified(&state, QUEUE);
             println!("Transmit queue was notified.");
 
             state

+ 4 - 6
src/device/console.rs

@@ -246,7 +246,7 @@ mod tests {
     };
     use alloc::{sync::Arc, vec};
     use core::ptr::NonNull;
-    use std::{sync::Mutex, thread, time::Duration};
+    use std::{sync::Mutex, thread};
 
     #[test]
     fn receive() {
@@ -261,7 +261,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -309,7 +309,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -323,9 +323,7 @@ mod tests {
         // Start a thread to simulate the device waiting for characters.
         let handle = thread::spawn(move || {
             println!("Device waiting for a character.");
-            while !state.lock().unwrap().queues[usize::from(QUEUE_TRANSMITQ_PORT_0)].notified {
-                thread::sleep(Duration::from_millis(10));
-            }
+            State::wait_until_queue_notified(&state, QUEUE_TRANSMITQ_PORT_0);
             println!("Transmit queue was notified.");
 
             let data = state

+ 1 - 1
src/device/socket/protocol.rs

@@ -40,7 +40,7 @@ pub struct VirtioVsockConfig {
 
 /// The message header for data packets sent on the tx/rx queues
 #[repr(packed)]
-#[derive(AsBytes, Clone, Copy, Debug, FromBytes)]
+#[derive(AsBytes, Clone, Copy, Debug, Eq, FromBytes, PartialEq)]
 pub struct VirtioVsockHdr {
     pub src_cid: U64<LittleEndian>,
     pub dst_cid: U64<LittleEndian>,

+ 244 - 3
src/device/socket/vsock.rs

@@ -559,17 +559,18 @@ fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr
 #[cfg(test)]
 mod tests {
     use super::*;
-    use crate::volatile::ReadOnly;
     use crate::{
+        device::socket::protocol::SocketType,
         hal::fake::FakeHal,
         transport::{
             fake::{FakeTransport, QueueStatus, State},
             DeviceStatus, DeviceType,
         },
+        volatile::ReadOnly,
     };
     use alloc::{sync::Arc, vec};
     use core::ptr::NonNull;
-    use std::sync::Mutex;
+    use std::{sync::Mutex, thread};
 
     #[test]
     fn config() {
@@ -582,7 +583,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,
@@ -595,4 +600,240 @@ mod tests {
             VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
         assert_eq!(socket.guest_cid, 0x00_0000_0042);
     }
+
+    #[test]
+    fn send_recv() {
+        let host_cid = 2;
+        let guest_cid = 66;
+        let host_port = 1234;
+        let guest_port = 4321;
+        let hello_from_guest = "Hello from guest";
+        let hello_from_host = "Hello from host";
+
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let mut socket =
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
+
+        // Start a thread to simulate the device.
+        let handle = thread::spawn(move || {
+            // Wait for connection request.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Request.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Accept connection and give the peer enough credit to send the message.
+            state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
+                RX_QUEUE_IDX,
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Response.into(),
+                    src_cid: host_cid.into(),
+                    dst_cid: guest_cid.into(),
+                    src_port: host_port.into(),
+                    dst_port: guest_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 50.into(),
+                    fwd_cnt: 0.into(),
+                }
+                .as_bytes(),
+            );
+
+            // Expect a credit update.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect the guest to send some data.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            let request = state
+                .lock()
+                .unwrap()
+                .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
+            assert_eq!(
+                request.len(),
+                size_of::<VirtioVsockHdr>() + hello_from_guest.len()
+            );
+            assert_eq!(
+                VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Rw.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: (hello_from_guest.len() as u32).into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+            assert_eq!(
+                &request[size_of::<VirtioVsockHdr>()..],
+                hello_from_guest.as_bytes()
+            );
+
+            // Send a response.
+            let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
+            VirtioVsockHdr {
+                op: VirtioVsockOp::Rw.into(),
+                src_cid: host_cid.into(),
+                dst_cid: guest_cid.into(),
+                src_port: host_port.into(),
+                dst_port: guest_port.into(),
+                len: (hello_from_host.len() as u32).into(),
+                socket_type: SocketType::Stream.into(),
+                flags: 0.into(),
+                buf_alloc: 50.into(),
+                fwd_cnt: (hello_from_guest.len() as u32).into(),
+            }
+            .write_to_prefix(response.as_mut_slice());
+            response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
+            state
+                .lock()
+                .unwrap()
+                .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
+
+            // Expect a credit update.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::CreditUpdate.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 64.into(),
+                    fwd_cnt: 0.into(),
+                }
+            );
+
+            // Expect a shutdown.
+            State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
+            assert_eq!(
+                VirtioVsockHdr::read_from(
+                    state
+                        .lock()
+                        .unwrap()
+                        .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
+                        .as_slice()
+                )
+                .unwrap(),
+                VirtioVsockHdr {
+                    op: VirtioVsockOp::Shutdown.into(),
+                    src_cid: guest_cid.into(),
+                    dst_cid: host_cid.into(),
+                    src_port: guest_port.into(),
+                    dst_port: host_port.into(),
+                    len: 0.into(),
+                    socket_type: SocketType::Stream.into(),
+                    flags: 0.into(),
+                    buf_alloc: 0.into(),
+                    fwd_cnt: (hello_from_host.len() as u32).into(),
+                }
+            );
+        });
+
+        socket.connect(host_cid, guest_port, host_port).unwrap();
+        socket.wait_for_connect().unwrap();
+        socket.send(hello_from_guest.as_bytes()).unwrap();
+        let mut buffer = [0u8; 64];
+        let event = socket.wait_for_recv(&mut buffer).unwrap();
+        assert_eq!(
+            event,
+            VsockEvent {
+                source: VsockAddr {
+                    cid: host_cid,
+                    port: host_port,
+                },
+                destination: VsockAddr {
+                    cid: guest_cid,
+                    port: guest_port,
+                },
+                event_type: VsockEventType::Received {
+                    length: hello_from_host.len()
+                }
+            }
+        );
+        assert_eq!(
+            &buffer[0..hello_from_host.len()],
+            hello_from_host.as_bytes()
+        );
+        socket.shutdown().unwrap();
+
+        handle.join().unwrap();
+    }
 }

+ 22 - 5
src/transport/fake.rs

@@ -4,8 +4,13 @@ use crate::{
     PhysAddr, Result,
 };
 use alloc::{sync::Arc, vec::Vec};
-use core::{any::TypeId, ptr::NonNull};
-use std::sync::Mutex;
+use core::{
+    any::TypeId,
+    ptr::NonNull,
+    sync::atomic::{AtomicBool, Ordering},
+    time::Duration,
+};
+use std::{sync::Mutex, thread};
 
 /// A fake implementation of [`Transport`] for unit tests.
 #[derive(Debug)]
@@ -35,7 +40,9 @@ impl<C> Transport for FakeTransport<C> {
     }
 
     fn notify(&mut self, queue: u16) {
-        self.state.lock().unwrap().queues[queue as usize].notified = true;
+        self.state.lock().unwrap().queues[queue as usize]
+            .notified
+            .store(true, Ordering::SeqCst);
     }
 
     fn get_status(&self) -> DeviceStatus {
@@ -168,13 +175,23 @@ impl State {
             handler,
         )
     }
+
+    /// Waits until the given queue is notified.
+    pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) {
+        while !state.lock().unwrap().queues[usize::from(queue_index)]
+            .notified
+            .swap(false, Ordering::SeqCst)
+        {
+            thread::sleep(Duration::from_millis(10));
+        }
+    }
 }
 
-#[derive(Clone, Debug, Default, Eq, PartialEq)]
+#[derive(Debug, Default)]
 pub struct QueueStatus {
     pub size: u32,
     pub descriptors: PhysAddr,
     pub driver_area: PhysAddr,
     pub device_area: PhysAddr,
-    pub notified: bool,
+    pub notified: AtomicBool,
 }