ソースを参照

Merge pull request #65 from aliciawyy/wipBasicInfra

Add protocols to support virtio socket device
Andrew Walbran 2 年 前
コミット
f1d1cbb007

+ 1 - 1
.github/workflows/main.yml

@@ -78,7 +78,7 @@ jobs:
     steps:
       - uses: actions/checkout@v2
       - name: Install QEMU
-        run: sudo apt update && sudo apt install ${{ matrix.packages }}
+        run: sudo apt update && sudo apt install ${{ matrix.packages }} && sudo chmod 666 /dev/vhost-vsock
       - uses: actions-rs/toolchain@v1
         with:
           profile: minimal

+ 16 - 2
examples/aarch64/Makefile

@@ -4,6 +4,8 @@ kernel := target/$(target)/$(mode)/aarch64
 kernel_qemu_bin := target/$(target)/$(mode)/aarch64_qemu.bin
 kernel_crosvm_bin := target/$(target)/$(mode)/aarch64_crosvm.bin
 img := target/$(target)/$(mode)/img
+vsock_server_path := ../vsock_server
+vsock_server_bin := $(vsock_server_path)/target/$(mode)/vsock_server
 
 sysroot := $(shell rustc --print sysroot)
 objdump := $(shell find $(sysroot) -name llvm-objdump) --arch-name=aarch64
@@ -14,6 +16,11 @@ ifeq ($(mode), release)
 	BUILD_ARGS += --release
 endif
 
+VSOCK_BUILD_ARGS =
+ifeq ($(mode), release)
+	VSOCK_BUILD_ARGS += --release
+endif
+
 .PHONY: kernel clean qemu run env
 
 env:
@@ -34,6 +41,9 @@ $(kernel_qemu_bin): kernel_qemu
 $(kernel_crosvm_bin): kernel_crosvm
 	aarch64-linux-gnu-objcopy -O binary $(kernel) $(kernel_crosvm_bin)
 
+$(vsock_server_bin):
+	(cd $(vsock_server_path) && cargo build $(VSOCK_BUILD_ARGS))
+
 asm: kernel
 	$(objdump) -d $(kernel) | less
 
@@ -46,7 +56,8 @@ header: kernel
 clean:
 	cargo clean
 
-qemu: $(kernel_qemu_bin) $(img)
+qemu: $(kernel_qemu_bin) $(img) $(vsock_server_bin)
+	$(vsock_server_bin) &
 	qemu-system-aarch64 \
 	  $(QEMU_ARGS) \
 		-machine virt \
@@ -56,6 +67,7 @@ qemu: $(kernel_qemu_bin) $(img)
 		-global virtio-mmio.force-legacy=false \
 		-nic none \
 		-drive file=$(img),if=none,format=raw,id=x0 \
+		-device vhost-vsock-device,id=virtiosocket0,guest-cid=102 \
 		-device virtio-blk-device,drive=x0 \
 		-device virtio-gpu-device \
 		-device virtio-serial,id=virtio-serial0 \
@@ -63,14 +75,16 @@ qemu: $(kernel_qemu_bin) $(img)
 		-device virtconsole,chardev=char0
 
 qemu-pci: $(kernel_qemu_bin) $(img)
+	$(vsock_server_bin) &
 	qemu-system-aarch64 \
-	  $(QEMU_ARGS) \
+		$(QEMU_ARGS) \
 		-machine virt \
 		-cpu max \
 		-serial chardev:char0 \
 		-kernel $(kernel_qemu_bin) \
 		-nic none \
 		-drive file=$(img),if=none,format=raw,id=x0 \
+		-device vhost-vsock-pci,id=virtiosocket0,guest-cid=103 \
 		-device virtio-blk-pci,drive=x0 \
 		-device virtio-gpu-pci \
 		-device virtio-serial,id=virtio-serial0 \

+ 16 - 1
examples/aarch64/src/main.rs

@@ -23,7 +23,7 @@ use hal::HalImpl;
 use log::{debug, error, info, trace, warn, LevelFilter};
 use psci::system_off;
 use virtio_drivers::{
-    device::{blk::VirtIOBlk, console::VirtIOConsole, gpu::VirtIOGpu},
+    device::{blk::VirtIOBlk, console::VirtIOConsole, gpu::VirtIOGpu, socket::VirtIOSocket},
     transport::{
         mmio::{MmioTransport, VirtIOHeader},
         pci::{
@@ -116,6 +116,7 @@ fn virtio_device(transport: impl Transport) {
         DeviceType::GPU => virtio_gpu(transport),
         // DeviceType::Network => virtio_net(transport), // currently is unsupported without alloc
         DeviceType::Console => virtio_console(transport),
+        DeviceType::Socket => virtio_socket(transport),
         t => warn!("Unrecognized virtio device: {:?}", t),
     }
 }
@@ -178,6 +179,20 @@ fn virtio_console<T: Transport>(transport: T) {
     info!("virtio-console test finished");
 }
 
+fn virtio_socket<T: Transport>(transport: T) {
+    let mut socket =
+        VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver");
+    let host_cid = 2;
+    let port = 1221;
+    info!("Connecting to host on port {port}...");
+    if let Err(e) = socket.connect(host_cid, port, port) {
+        error!("Failed to connect to host: {:?}", e);
+    } else {
+        info!("Connected to host on port {port} successfully.")
+    }
+    info!("VirtIO socket test finished");
+}
+
 #[derive(Copy, Clone, Debug, Eq, PartialEq)]
 enum PciRangeType {
     ConfigurationSpace,

+ 8 - 0
examples/vsock_server/Cargo.toml

@@ -0,0 +1,8 @@
+[package]
+name = "vsock_server"
+version = "0.1.0"
+authors = ["Alice Wang <aliceywang@google.com>"]
+edition = "2021"
+
+[dependencies]
+vsock = "0.3.0"

+ 13 - 0
examples/vsock_server/src/main.rs

@@ -0,0 +1,13 @@
+// Sets a listening socket on host.
+use vsock::{VsockAddr, VsockListener, VMADDR_CID_HOST};
+
+const PORT: u32 = 1221;
+
+fn main() {
+    println!("Setting up listening socket on port {PORT}");
+    let listener = VsockListener::bind(&VsockAddr::new(VMADDR_CID_HOST, PORT))
+        .expect("Failed to set up listening port");
+    for incoming in listener.incoming() {
+        println!("Accept connection: {incoming:?}");
+    }
+}

+ 23 - 0
src/device/common.rs

@@ -0,0 +1,23 @@
+//! Common part shared across all the devices.
+
+use bitflags::bitflags;
+
+bitflags! {
+    pub(crate) struct Feature: u64 {
+        // device independent
+        const NOTIFY_ON_EMPTY       = 1 << 24; // legacy
+        const ANY_LAYOUT            = 1 << 27; // legacy
+        const RING_INDIRECT_DESC    = 1 << 28;
+        const RING_EVENT_IDX        = 1 << 29;
+        const UNUSED                = 1 << 30; // legacy
+        const VERSION_1             = 1 << 32; // detect legacy
+
+        // since virtio v1.1
+        const ACCESS_PLATFORM       = 1 << 33;
+        const RING_PACKED           = 1 << 34;
+        const IN_ORDER              = 1 << 35;
+        const ORDER_PLATFORM        = 1 << 36;
+        const SR_IOV                = 1 << 37;
+        const NOTIFICATION_DATA     = 1 << 38;
+    }
+}

+ 1 - 21
src/device/input.rs

@@ -1,12 +1,12 @@
 //! Driver for VirtIO input devices.
 
+use super::common::Feature;
 use crate::hal::Hal;
 use crate::queue::VirtQueue;
 use crate::transport::Transport;
 use crate::volatile::{volread, volwrite, ReadOnly, WriteOnly};
 use crate::Result;
 use alloc::boxed::Box;
-use bitflags::bitflags;
 use core::ptr::NonNull;
 use log::info;
 use zerocopy::{AsBytes, FromBytes};
@@ -191,26 +191,6 @@ pub struct InputEvent {
     pub value: u32,
 }
 
-bitflags! {
-    struct Feature: u64 {
-        // device independent
-        const NOTIFY_ON_EMPTY       = 1 << 24; // legacy
-        const ANY_LAYOUT            = 1 << 27; // legacy
-        const RING_INDIRECT_DESC    = 1 << 28;
-        const RING_EVENT_IDX        = 1 << 29;
-        const UNUSED                = 1 << 30; // legacy
-        const VERSION_1             = 1 << 32; // detect legacy
-
-        // since virtio v1.1
-        const ACCESS_PLATFORM       = 1 << 33;
-        const RING_PACKED           = 1 << 34;
-        const IN_ORDER              = 1 << 35;
-        const ORDER_PLATFORM        = 1 << 36;
-        const SR_IOV                = 1 << 37;
-        const NOTIFICATION_DATA     = 1 << 38;
-    }
-}
-
 const QUEUE_EVENT: u16 = 0;
 const QUEUE_STATUS: u16 = 1;
 

+ 3 - 0
src/device/mod.rs

@@ -7,3 +7,6 @@ pub mod gpu;
 pub mod input;
 #[cfg(feature = "alloc")]
 pub mod net;
+pub mod socket;
+
+pub(crate) mod common;

+ 34 - 0
src/device/socket/error.rs

@@ -0,0 +1,34 @@
+//! This module contain the error from the VirtIO socket driver.
+
+use core::{fmt, result};
+
+/// The error type of VirtIO socket driver.
+#[derive(Copy, Clone, Debug, Eq, PartialEq)]
+pub enum SocketError {
+    /// Failed to establish the connection.
+    ConnectionFailed,
+    /// No response received.
+    NoResponseReceived,
+    /// The given buffer is shorter than expected.
+    BufferTooShort,
+    /// Unknown operation.
+    UnknownOperation(u16),
+    /// Invalid operation,
+    InvalidOperation,
+}
+
+impl fmt::Display for SocketError {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::ConnectionFailed => write!(f, "Failed to establish the connection"),
+            Self::NoResponseReceived => write!(f, "No response received"),
+            Self::BufferTooShort => write!(f, "The given buffer is shorter than expected"),
+            Self::UnknownOperation(op) => {
+                write!(f, "The operation code '{op}' is unknown")
+            }
+            Self::InvalidOperation => write!(f, "Invalid operation"),
+        }
+    }
+}
+
+pub type Result<T> = result::Result<T, SocketError>;

+ 8 - 0
src/device/socket/mod.rs

@@ -0,0 +1,8 @@
+//! This module implements the virtio vsock device.
+
+mod error;
+mod protocol;
+mod vsock;
+
+pub use error::SocketError;
+pub use vsock::VirtIOSocket;

+ 162 - 0
src/device/socket/protocol.rs

@@ -0,0 +1,162 @@
+//! This module defines the socket device protocol according to the virtio spec v1.1 5.10 Socket Device
+
+use super::error::{self, SocketError};
+use crate::volatile::ReadOnly;
+use core::{
+    convert::{TryFrom, TryInto},
+    fmt,
+    mem::size_of,
+};
+use zerocopy::{
+    byteorder::{LittleEndian, U16, U32, U64},
+    AsBytes, FromBytes,
+};
+
+/// Currently only stream sockets are supported. type is 1 for stream socket types.
+#[derive(Copy, Clone, Debug)]
+#[repr(u16)]
+pub enum SocketType {
+    /// Stream sockets provide in-order, guaranteed, connection-oriented delivery without message boundaries.
+    Stream = 1,
+}
+
+impl From<SocketType> for U16<LittleEndian> {
+    fn from(socket_type: SocketType) -> Self {
+        (socket_type as u16).into()
+    }
+}
+
+/// VirtioVsockConfig is the vsock device configuration space.
+#[repr(C)]
+pub struct VirtioVsockConfig {
+    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
+    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
+    ///
+    /// According to virtio spec v1.1 2.4.1 Driver Requirements: Device Configuration Space,
+    /// drivers MUST NOT assume reads from fields greater than 32 bits wide are atomic.
+    /// So we need to split the u64 guest_cid into two parts.
+    pub guest_cid_low: ReadOnly<u32>,
+    pub guest_cid_high: ReadOnly<u32>,
+}
+
+/// The message header for data packets sent on the tx/rx queues
+#[repr(packed)]
+#[derive(AsBytes, Clone, Copy, Debug, FromBytes)]
+pub struct VirtioVsockHdr {
+    pub src_cid: U64<LittleEndian>,
+    pub dst_cid: U64<LittleEndian>,
+    pub src_port: U32<LittleEndian>,
+    pub dst_port: U32<LittleEndian>,
+    pub len: U32<LittleEndian>,
+    pub socket_type: U16<LittleEndian>,
+    pub op: U16<LittleEndian>,
+    pub flags: U32<LittleEndian>,
+    pub buf_alloc: U32<LittleEndian>,
+    pub fwd_cnt: U32<LittleEndian>,
+}
+
+impl Default for VirtioVsockHdr {
+    fn default() -> Self {
+        Self {
+            src_cid: 0.into(),
+            dst_cid: 0.into(),
+            src_port: 0.into(),
+            dst_port: 0.into(),
+            len: 0.into(),
+            socket_type: SocketType::Stream.into(),
+            op: 0.into(),
+            flags: 0.into(),
+            buf_alloc: 0.into(),
+            fwd_cnt: 0.into(),
+        }
+    }
+}
+
+#[derive(Clone, Debug)]
+pub struct VirtioVsockPacket<'a> {
+    pub hdr: VirtioVsockHdr,
+    pub data: &'a [u8],
+}
+
+impl<'a> VirtioVsockPacket<'a> {
+    pub fn read_from(buffer: &'a [u8]) -> error::Result<Self> {
+        let hdr = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
+        let data_end = size_of::<VirtioVsockHdr>() + (hdr.len.get() as usize);
+        let data = buffer
+            .get(size_of::<VirtioVsockHdr>()..data_end)
+            .ok_or(SocketError::BufferTooShort)?;
+        Ok(Self { hdr, data })
+    }
+
+    pub fn op(&self) -> error::Result<VirtioVsockOp> {
+        self.hdr.op.try_into()
+    }
+}
+
+/// An event sent to the event queue
+#[derive(Copy, Clone, Debug, Default, AsBytes, FromBytes)]
+#[repr(C)]
+pub struct VirtioVsockEvent {
+    // ID from the virtio_vsock_event_id struct in the virtio spec
+    pub id: U32<LittleEndian>,
+}
+
+#[derive(Copy, Clone, Eq, PartialEq)]
+#[repr(u16)]
+pub enum VirtioVsockOp {
+    Invalid = 0,
+
+    /* Connect operations */
+    Request = 1,
+    Response = 2,
+    Rst = 3,
+    Shutdown = 4,
+
+    /* To send payload */
+    Rw = 5,
+
+    /* Tell the peer our credit info */
+    CreditUpdate = 6,
+    /* Request the peer to send the credit info to us */
+    CreditRequest = 7,
+}
+
+impl From<VirtioVsockOp> for U16<LittleEndian> {
+    fn from(op: VirtioVsockOp) -> Self {
+        (op as u16).into()
+    }
+}
+
+impl TryFrom<U16<LittleEndian>> for VirtioVsockOp {
+    type Error = SocketError;
+
+    fn try_from(v: U16<LittleEndian>) -> Result<Self, Self::Error> {
+        let op = match u16::from(v) {
+            0 => Self::Invalid,
+            1 => Self::Request,
+            2 => Self::Response,
+            3 => Self::Rst,
+            4 => Self::Shutdown,
+            5 => Self::Rw,
+            6 => Self::CreditUpdate,
+            7 => Self::CreditRequest,
+            _ => return Err(SocketError::UnknownOperation(v.into())),
+        };
+        Ok(op)
+    }
+}
+
+impl fmt::Debug for VirtioVsockOp {
+    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+        match self {
+            Self::Invalid => write!(f, "VIRTIO_VSOCK_OP_INVALID"),
+            Self::Request => write!(f, "VIRTIO_VSOCK_OP_REQUEST"),
+            Self::Response => write!(f, "VIRTIO_VSOCK_OP_RESPONSE"),
+            Self::Rst => write!(f, "VIRTIO_VSOCK_OP_RST"),
+            Self::Shutdown => write!(f, "VIRTIO_VSOCK_OP_SHUTDOWN"),
+            Self::Rw => write!(f, "VIRTIO_VSOCK_OP_RW"),
+            Self::CreditUpdate => write!(f, "VIRTIO_VSOCK_OP_CREDIT_UPDATE"),
+            Self::CreditRequest => write!(f, "VIRTIO_VSOCK_OP_CREDIT_REQUEST"),
+        }
+    }
+}

+ 163 - 0
src/device/socket/vsock.rs

@@ -0,0 +1,163 @@
+//! Driver for VirtIO socket devices.
+
+use super::error::SocketError;
+use super::protocol::{VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VirtioVsockPacket};
+use crate::device::common::Feature;
+use crate::hal::{BufferDirection, Dma, Hal};
+use crate::queue::VirtQueue;
+use crate::transport::Transport;
+use crate::volatile::volread;
+use crate::Result;
+use log::{info, trace};
+use zerocopy::AsBytes;
+
+const RX_QUEUE_IDX: u16 = 0;
+const TX_QUEUE_IDX: u16 = 1;
+const EVENT_QUEUE_IDX: u16 = 2;
+
+const QUEUE_SIZE: usize = 2;
+
+/// Driver for a VirtIO socket device.
+pub struct VirtIOSocket<'a, H: Hal, T: Transport> {
+    transport: T,
+    /// Virtqueue to receive packets.
+    rx: VirtQueue<H, { QUEUE_SIZE }>,
+    tx: VirtQueue<H, { QUEUE_SIZE }>,
+    /// Virtqueue to receive events from the device.
+    event: VirtQueue<H, { QUEUE_SIZE }>,
+    /// The guest_cid field contains the guest’s context ID, which uniquely identifies
+    /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
+    guest_cid: u64,
+    queue_buf_dma: Dma<H>,
+    queue_buf_rx: &'a mut [u8],
+}
+
+impl<'a, H: Hal, T: Transport> Drop for VirtIOSocket<'a, H, T> {
+    fn drop(&mut self) {
+        // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
+        // after they have been freed.
+        self.transport.queue_unset(RX_QUEUE_IDX);
+        self.transport.queue_unset(TX_QUEUE_IDX);
+        self.transport.queue_unset(EVENT_QUEUE_IDX);
+    }
+}
+
+impl<'a, H: Hal, T: Transport> VirtIOSocket<'a, H, T> {
+    /// Create a new VirtIO Vsock driver.
+    pub fn new(mut transport: T) -> Result<Self> {
+        transport.begin_init(|features| {
+            let features = Feature::from_bits_truncate(features);
+            info!("Device features: {:?}", features);
+            // negotiate these flags only
+            let supported_features = Feature::empty();
+            (features & supported_features).bits()
+        });
+
+        let config = transport.config_space::<VirtioVsockConfig>()?;
+        info!("config: {:?}", config);
+        // Safe because config is a valid pointer to the device configuration space.
+        let guest_cid = unsafe {
+            volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
+        };
+        info!("guest cid: {guest_cid:?}");
+
+        let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
+        let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
+        let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
+
+        let queue_buf_dma = Dma::new(1, BufferDirection::DeviceToDriver)?;
+
+        // Safe because no alignment or initialisation is required for [u8], the DMA buffer is
+        // dereferenceable, and the lifetime of the reference matches the lifetime of the DMA buffer
+        // (which we don't otherwise access).
+        let queue_buf_rx = unsafe { queue_buf_dma.raw_slice().as_mut() };
+
+        // Safe because the buffer lives as long as the queue.
+        let _token = unsafe { rx.add(&[], &mut [queue_buf_rx])? };
+
+        if rx.should_notify() {
+            transport.notify(RX_QUEUE_IDX);
+        }
+        transport.finish_init();
+
+        Ok(Self {
+            transport,
+            rx,
+            tx,
+            event,
+            guest_cid,
+            queue_buf_dma,
+            queue_buf_rx,
+        })
+    }
+
+    /// Connect to the destination.
+    pub fn connect(&mut self, dst_cid: u64, src_port: u32, dst_port: u32) -> Result {
+        let header = VirtioVsockHdr {
+            src_cid: self.guest_cid.into(),
+            dst_cid: dst_cid.into(),
+            src_port: src_port.into(),
+            dst_port: dst_port.into(),
+            op: VirtioVsockOp::Request.into(),
+            ..Default::default()
+        };
+        self.tx
+            .add_notify_wait_pop(&[header.as_bytes(), &[]], &mut [], &mut self.transport)?;
+        let token = if let Some(token) = self.rx.peek_used() {
+            token // TODO: Use let else after updating Rust
+        } else {
+            return Err(SocketError::NoResponseReceived.into());
+        };
+        // Safe because we are passing the same buffer as we passed to `VirtQueue::add`.
+        let _len = unsafe { self.rx.pop_used(token, &[], &mut [self.queue_buf_rx])? };
+        let packet_rx = VirtioVsockPacket::read_from(self.queue_buf_rx)?;
+        trace!("Received packet {:?}. Op {:?}", packet_rx, packet_rx.op());
+        match packet_rx.op()? {
+            VirtioVsockOp::Response => Ok(()),
+            VirtioVsockOp::Rst => Err(SocketError::ConnectionFailed.into()),
+            VirtioVsockOp::Invalid => Err(SocketError::InvalidOperation.into()),
+            _ => todo!(),
+        }
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::volatile::ReadOnly;
+    use crate::{
+        hal::fake::FakeHal,
+        transport::{
+            fake::{FakeTransport, QueueStatus, State},
+            DeviceStatus, DeviceType,
+        },
+    };
+    use alloc::{sync::Arc, vec};
+    use core::ptr::NonNull;
+    use std::sync::Mutex;
+
+    #[test]
+    fn config() {
+        let mut config_space = VirtioVsockConfig {
+            guest_cid_low: ReadOnly::new(66),
+            guest_cid_high: ReadOnly::new(0),
+        };
+        let state = Arc::new(Mutex::new(State {
+            status: DeviceStatus::empty(),
+            driver_features: 0,
+            guest_page_size: 0,
+            interrupt_pending: false,
+            queues: vec![QueueStatus::default(); 3],
+        }));
+        let transport = FakeTransport {
+            device_type: DeviceType::Socket,
+            max_queue_size: 32,
+            device_features: 0,
+            config_space: NonNull::from(&mut config_space),
+            state: state.clone(),
+        };
+        let socket =
+            VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
+        assert_eq!(socket.guest_cid, 0x00_0000_0042);
+    }
+}

+ 9 - 0
src/lib.rs

@@ -89,6 +89,8 @@ pub enum Error {
     ConfigSpaceTooSmall,
     /// The device doesn't have any config space, but the driver expects some.
     ConfigSpaceMissing,
+    /// Error from the socket device.
+    SocketDeviceError(device::socket::SocketError),
 }
 
 impl Display for Error {
@@ -115,10 +117,17 @@ impl Display for Error {
                     "The device doesn't have any config space, but the driver expects some"
                 )
             }
+            Self::SocketDeviceError(e) => write!(f, "Error from the socket device: {e:?}"),
         }
     }
 }
 
+impl From<device::socket::SocketError> for Error {
+    fn from(e: device::socket::SocketError) -> Self {
+        Self::SocketDeviceError(e)
+    }
+}
+
 /// Align `size` up to a page.
 fn align_up(size: usize) -> usize {
     (size + PAGE_SIZE) & !(PAGE_SIZE - 1)