Alice Wang преди 2 години
родител
ревизия
c86d88c0b8

+ 1 - 1
.github/workflows/main.yml

@@ -78,7 +78,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Install QEMU
-        run: sudo apt update && sudo apt install ${{ matrix.packages }}
+        run: sudo apt update && sudo apt install ${{ matrix.packages }} && sudo chmod 666 /dev/vhost-vsock
       - uses: actions-rs/toolchain@v1
         with:
           profile: minimal

+ 7 - 5
examples/aarch64/Makefile

@@ -47,11 +47,12 @@ clean:
 	cargo clean
 
 qemu: $(kernel_qemu_bin) $(img)
+	(nc localhost -l 1235 -v) &
 	qemu-system-aarch64 \
 	  $(QEMU_ARGS) \
 		-machine virt \
 		-cpu max \
-		-nographic \
+		-serial chardev:char0 \
 		-kernel $(kernel_qemu_bin) \
 		-global virtio-mmio.force-legacy=false \
 		-nic none \
@@ -60,15 +61,16 @@ qemu: $(kernel_qemu_bin) $(img)
 		-device virtio-blk-device,drive=x0 \
 		-device virtio-gpu-device \
 		-device virtio-serial,id=virtio-serial0 \
-		-chardev pty,id=char0 \
+		-chardev stdio,id=char0,mux=on \
 		-device virtconsole,chardev=char0
 
 qemu-pci: $(kernel_qemu_bin) $(img)
+	(nc localhost -l 1235 -v) &
 	qemu-system-aarch64 \
-	  $(QEMU_ARGS) \
+		$(QEMU_ARGS) \
 		-machine virt \
 		-cpu max \
-		-nographic \
+		-serial chardev:char0 \
 		-kernel $(kernel_qemu_bin) \
 		-nic none \
 		-drive file=$(img),if=none,format=raw,id=x0 \
@@ -76,7 +78,7 @@ qemu-pci: $(kernel_qemu_bin) $(img)
 		-device virtio-blk-pci,drive=x0 \
 		-device virtio-gpu-pci \
 		-device virtio-serial,id=virtio-serial0 \
-		-chardev pty,id=char0 \
+		-chardev stdio,id=char0,mux=on \
 		-device virtconsole,chardev=char0
 
 crosvm: $(kernel_crosvm_bin) $(img)

+ 0 - 23
examples/aarch64/socket_test_host.py

@@ -1,23 +0,0 @@
-import socket
-
-HOST = 'localhost'
-PORT = 1234
-
-def main():
-    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
-        s.bind((HOST, PORT))
-        s.listen()
-        print(f"Server listening on {HOST}:{PORT}")
-        conn, addr = s.accept()
-        with conn:
-            print(f"Connected by {addr}")
-            while True:
-                conn.sendall(b"start vsock aaaa")
-                data = conn.recv(1024)
-                if not data:
-                    break
-                print(f"Received: {data.decode('utf-8')}")
-                conn.sendall(data)
-
-if __name__ == '__main__':
-    main()

+ 2 - 1
examples/aarch64/src/main.rs

@@ -183,7 +183,8 @@ fn virtio_socket<T: Transport>(transport: T) {
     let mut socket =
         VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver");
     let host_cid = 2;
-    let port = 1234;
+    let port = 1235;
+    info!("Connecting to host on port {port}...");
     if let Err(e) = socket.connect(host_cid, port, port) {
         error!("Failed to connect to host: {:?}", e);
     }

+ 1 - 6
src/device/socket/error.rs

@@ -11,11 +11,9 @@ pub enum SocketError {
     NoResponseReceived,
     /// The given buffer is shorter than expected.
     BufferTooShort,
-    /// Failed to parse the VirtioVsockPacket from buffer.
-    PacketParsingFailed,
     /// Unknown operation.
     UnknownOperation(u16),
-    /// Invalid opration,
+    /// Invalid operation,
     InvalidOperation,
 }
 
@@ -25,9 +23,6 @@ impl fmt::Display for SocketError {
             Self::ConnectionFailed => write!(f, "Failed to establish the connection"),
             Self::NoResponseReceived => write!(f, "No response received"),
             Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"),
-            Self::PacketParsingFailed => {
-                write!(f, "Failed to parse the VirtioVsockPacket from buffer")
-            }
             Self::UnknownOperation(op) => {
                 write!(f, "The operation code '{op}' is unknown")
             }

+ 46 - 32
src/device/socket/protocol.rs

@@ -2,8 +2,11 @@
 
 use super::error::{self, SocketError};
 use crate::volatile::ReadOnly;
-use core::convert::TryInto;
-use core::{convert::TryFrom, mem::size_of};
+use core::{
+    convert::{TryFrom, TryInto},
+    fmt,
+    mem::size_of,
+};
 use zerocopy::{
     byteorder::{LittleEndian, U16, U32, U64},
     AsBytes, FromBytes,
@@ -19,7 +22,7 @@ pub struct VirtioVsockConfig {
     ///
     /// We need to split the guest_cid into two parts because VirtIO only guarantees 4 bytes alignment.
     pub guest_cid_low: ReadOnly<u32>,
-    pub _guest_cid_high: ReadOnly<u32>,
+    pub guest_cid_high: ReadOnly<u32>,
 }
 
 /// The message header for data packets sent on the tx/rx queues
@@ -31,7 +34,7 @@ pub struct VirtioVsockHdr {
     pub src_port: U32<LittleEndian>,
     pub dst_port: U32<LittleEndian>,
     pub len: U32<LittleEndian>,
-    pub r#type: U16<LittleEndian>,
+    pub socket_type: U16<LittleEndian>,
     pub op: U16<LittleEndian>,
     pub flags: U32<LittleEndian>,
     pub buf_alloc: U32<LittleEndian>,
@@ -46,7 +49,7 @@ impl Default for VirtioVsockHdr {
             src_port: 0.into(),
             dst_port: 0.into(),
             len: 0.into(),
-            r#type: TYPE_STREAM_SOCKET.into(),
+            socket_type: TYPE_STREAM_SOCKET.into(),
             op: 0.into(),
             flags: 0.into(),
             buf_alloc: 0.into(),
@@ -62,11 +65,8 @@ pub struct VirtioVsockPacket<'a> {
 }
 
 impl<'a> VirtioVsockPacket<'a> {
-    pub fn read_from(buffer: &'a [u8]) -> Result<Self, SocketError> {
-        let hdr = buffer
-            .get(0..size_of::<VirtioVsockHdr>())
-            .ok_or(SocketError::BufferTooShort)?;
-        let hdr = VirtioVsockHdr::read_from(hdr).ok_or(SocketError::PacketParsingFailed)?;
+    pub fn read_from(buffer: &'a [u8]) -> error::Result<Self> {
+        let hdr = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
         let data_end = size_of::<VirtioVsockHdr>() + (hdr.len.get() as usize);
         let data = buffer
             .get(size_of::<VirtioVsockHdr>()..data_end)
@@ -74,7 +74,7 @@ impl<'a> VirtioVsockPacket<'a> {
         Ok(Self { hdr, data })
     }
 
-    pub fn op(&self) -> error::Result<Op> {
+    pub fn op(&self) -> error::Result<VirtioVsockOp> {
         self.hdr.op.try_into()
     }
 }
@@ -87,48 +87,62 @@ pub struct VirtioVsockEvent {
     pub id: U32<LittleEndian>,
 }
 
-#[allow(non_camel_case_types)]
-#[derive(Copy, Clone, Eq, PartialEq, Debug)]
+#[derive(Copy, Clone, Eq, PartialEq)]
 #[repr(u16)]
-pub enum Op {
-    VIRTIO_VSOCK_OP_INVALID = 0,
+pub enum VirtioVsockOp {
+    Invalid = 0,
 
     /* Connect operations */
-    VIRTIO_VSOCK_OP_REQUEST = 1,
-    VIRTIO_VSOCK_OP_RESPONSE = 2,
-    VIRTIO_VSOCK_OP_RST = 3,
-    VIRTIO_VSOCK_OP_SHUTDOWN = 4,
+    Request = 1,
+    Response = 2,
+    Rst = 3,
+    Shutdown = 4,
 
     /* To send payload */
-    VIRTIO_VSOCK_OP_RW = 5,
+    Rw = 5,
 
     /* Tell the peer our credit info */
-    VIRTIO_VSOCK_OP_CREDIT_UPDATE = 6,
+    CreditUpdate = 6,
     /* Request the peer to send the credit info to us */
-    VIRTIO_VSOCK_OP_CREDIT_REQUEST = 7,
+    CreditRequest = 7,
 }
 
-impl Into<U16<LittleEndian>> for Op {
+impl Into<U16<LittleEndian>> for VirtioVsockOp {
     fn into(self) -> U16<LittleEndian> {
         (self as u16).into()
     }
 }
 
-impl TryFrom<U16<LittleEndian>> for Op {
+impl TryFrom<U16<LittleEndian>> for VirtioVsockOp {
     type Error = SocketError;
 
     fn try_from(v: U16<LittleEndian>) -> Result<Self, Self::Error> {
         let op = match u16::from(v) {
-            0 => Self::VIRTIO_VSOCK_OP_INVALID,
-            1 => Self::VIRTIO_VSOCK_OP_REQUEST,
-            2 => Self::VIRTIO_VSOCK_OP_RESPONSE,
-            3 => Self::VIRTIO_VSOCK_OP_RST,
-            4 => Self::VIRTIO_VSOCK_OP_SHUTDOWN,
-            5 => Self::VIRTIO_VSOCK_OP_RW,
-            6 => Self::VIRTIO_VSOCK_OP_CREDIT_UPDATE,
-            7 => Self::VIRTIO_VSOCK_OP_CREDIT_REQUEST,
+            0 => Self::Invalid,
+            1 => Self::Request,
+            2 => Self::Response,
+            3 => Self::Rst,
+            4 => Self::Shutdown,
+            5 => Self::Rw,
+            6 => Self::CreditUpdate,
+            7 => Self::CreditRequest,
             _ => return Err(SocketError::UnknownOperation(v.into())),
         };
         Ok(op)
     }
 }
+
+impl fmt::Debug for VirtioVsockOp {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::Invalid => write!(f, "VIRTIO_VSOCK_OP_INVALID"),
+            Self::Request => write!(f, "VIRTIO_VSOCK_OP_REQUEST"),
+            Self::Response => write!(f, "VIRTIO_VSOCK_OP_RESPONSE"),
+            Self::Rst => write!(f, "VIRTIO_VSOCK_OP_RST"),
+            Self::Shutdown => write!(f, "VIRTIO_VSOCK_OP_SHUTDOWN"),
+            Self::Rw => write!(f, "VIRTIO_VSOCK_OP_RW"),
+            Self::CreditUpdate => write!(f, "VIRTIO_VSOCK_OP_CREDIT_UPDATE"),
+            Self::CreditRequest => write!(f, "VIRTIO_VSOCK_OP_CREDIT_REQUEST"),
+        }
+    }
+}

+ 13 - 17
src/device/socket/vsock.rs

@@ -2,13 +2,13 @@
 
 use super::common::Feature;
 use super::error::SocketError;
-use super::protocol::{Op, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockPacket};
+use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VirtioVsockPacket};
 use crate::hal::{BufferDirection, Dma, Hal};
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::volread;
 use crate::Result;
-use log::{error, info};
+use log::{trace, info};
 use zerocopy::AsBytes;
 
 const RX_QUEUE_IDX: u16 = 0;
@@ -56,7 +56,9 @@ impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
         let config = transport.config_space::<VirtioVsockConfig>()?;
         info!("config: {:?}", config);
         // Safe because config is a valid pointer to the device configuration space.
-        let guest_cid = unsafe { volread!(config, guest_cid_low) as u64 };
+        let guest_cid = unsafe {
+            volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
+        };
         info!("guest cid: {guest_cid:?}");
 
         let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
@@ -96,7 +98,7 @@ impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
             dst_cid: dst_cid.into(),
             src_port: src_port.into(),
             dst_port: dst_port.into(),
-            op: Op::VIRTIO_VSOCK_OP_REQUEST.into(),
+            op: VirtioVsockOp::Request.into(),
             ..Default::default()
         };
         self.tx
@@ -106,25 +108,19 @@ impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
         } else {
             return Err(SocketError::NoResponseReceived.into());
         };
+        // Safe because we are passing the same buffer as we passed to `VirtQueue::add`.
         let _len = unsafe {
             self.rx
                 .pop_used(token, &[], &mut [&mut self.queue_buf_rx])?
         };
         let packet_rx = VirtioVsockPacket::read_from(&self.queue_buf_rx)?;
-        let result = match packet_rx.op()? {
-            Op::VIRTIO_VSOCK_OP_RESPONSE => Ok(()),
-            Op::VIRTIO_VSOCK_OP_RST => Err(SocketError::ConnectionFailed.into()),
-            Op::VIRTIO_VSOCK_OP_INVALID => Err(SocketError::InvalidOperation.into()),
+        trace!("Received packet {:?}. Op {:?}", packet_rx, packet_rx.op());
+        match packet_rx.op()? {
+            VirtioVsockOp::Response => Ok(()),
+            VirtioVsockOp::Rst => Err(SocketError::ConnectionFailed.into()),
+            VirtioVsockOp::Invalid => Err(SocketError::InvalidOperation.into()),
             _ => todo!(),
-        };
-        if result.is_err() {
-            error!(
-                "Connection failed. Packet received: {:?}, op={:?}",
-                packet_rx,
-                packet_rx.op()
-            );
         }
-        result
     }
 }
 
@@ -147,7 +143,7 @@ mod tests {
     fn config() {
         let mut config_space = VirtioVsockConfig {
             guest_cid_low: ReadOnly::new(66),
-            _guest_cid_high: ReadOnly::new(0),
+            guest_cid_high: ReadOnly::new(0),
         };
         let state = Arc::new(Mutex::new(State {
             status: DeviceStatus::empty(),

+ 0 - 1
src/hal/fake.rs

@@ -21,7 +21,6 @@ unsafe impl Hal for FakeHal {
         }
     }
 
-    #[allow(unused_unsafe)]
     unsafe fn dma_dealloc(_paddr: PhysAddr, vaddr: NonNull<u8>, pages: usize) -> i32 {
         assert_ne!(pages, 0);
         let layout = Layout::from_size_align(pages * PAGE_SIZE, PAGE_SIZE).unwrap();

+ 0 - 2
src/queue.rs

@@ -1,5 +1,3 @@
-#![allow(unused_unsafe)]
-
 use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
 use crate::transport::Transport;
 use crate::{align_up, nonnull_slice_from_raw_parts, pages, Error, Result, PAGE_SIZE};