vsock.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526
  1. //! Driver for VirtIO socket devices.
  2. #![deny(unsafe_op_in_unsafe_fn)]
  3. use super::error::SocketError;
  4. use super::protocol::{Feature, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
  5. use crate::hal::Hal;
  6. use crate::queue::VirtQueue;
  7. use crate::transport::Transport;
  8. use crate::volatile::volread;
  9. use crate::Result;
  10. use alloc::boxed::Box;
  11. use core::mem::size_of;
  12. use core::ptr::{null_mut, NonNull};
  13. use log::{debug, info};
  14. use zerocopy::{AsBytes, FromBytes};
  15. pub(crate) const RX_QUEUE_IDX: u16 = 0;
  16. pub(crate) const TX_QUEUE_IDX: u16 = 1;
  17. const EVENT_QUEUE_IDX: u16 = 2;
  18. pub(crate) const QUEUE_SIZE: usize = 8;
  19. /// The size in bytes of each buffer used in the RX virtqueue. This must be bigger than size_of::<VirtioVsockHdr>().
  20. const RX_BUFFER_SIZE: usize = 512;
  21. #[derive(Clone, Debug, Default, PartialEq, Eq)]
  22. pub struct ConnectionInfo {
  23. pub dst: VsockAddr,
  24. pub src_port: u32,
  25. /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
  26. /// bytes it has allocated for packet bodies.
  27. peer_buf_alloc: u32,
  28. /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
  29. /// has finished processing.
  30. peer_fwd_cnt: u32,
  31. /// The number of bytes of packet bodies which we have sent to the peer.
  32. tx_cnt: u32,
  33. /// The number of bytes of buffer space we have allocated to receive packet bodies from the
  34. /// peer.
  35. pub buf_alloc: u32,
  36. /// The number of bytes of packet bodies which we have received from the peer and handled.
  37. fwd_cnt: u32,
  38. /// Whether we have recently requested credit from the peer.
  39. ///
  40. /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
  41. /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
  42. has_pending_credit_request: bool,
  43. }
  44. impl ConnectionInfo {
  45. pub fn new(destination: VsockAddr, src_port: u32) -> Self {
  46. Self {
  47. dst: destination,
  48. src_port,
  49. ..Default::default()
  50. }
  51. }
  52. /// Updates this connection info with the peer buffer allocation and forwarded count from the
  53. /// given event.
  54. pub fn update_for_event(&mut self, event: &VsockEvent) {
  55. self.peer_buf_alloc = event.buffer_status.buffer_allocation;
  56. self.peer_fwd_cnt = event.buffer_status.forward_count;
  57. if let VsockEventType::CreditUpdate = event.event_type {
  58. self.has_pending_credit_request = false;
  59. }
  60. }
  61. /// Increases the forwarded count recorded for this connection by the given number of bytes.
  62. ///
  63. /// This should be called once received data has been passed to the client, so there is buffer
  64. /// space available for more.
  65. pub fn done_forwarding(&mut self, length: usize) {
  66. self.fwd_cnt += length as u32;
  67. }
  68. fn peer_free(&self) -> u32 {
  69. self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
  70. }
  71. fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
  72. VirtioVsockHdr {
  73. src_cid: src_cid.into(),
  74. dst_cid: self.dst.cid.into(),
  75. src_port: self.src_port.into(),
  76. dst_port: self.dst.port.into(),
  77. buf_alloc: self.buf_alloc.into(),
  78. fwd_cnt: self.fwd_cnt.into(),
  79. ..Default::default()
  80. }
  81. }
  82. }
  83. /// An event received from a VirtIO socket device.
  84. #[derive(Clone, Debug, Eq, PartialEq)]
  85. pub struct VsockEvent {
  86. /// The source of the event, i.e. the peer who sent it.
  87. pub source: VsockAddr,
  88. /// The destination of the event, i.e. the CID and port on our side.
  89. pub destination: VsockAddr,
  90. /// The peer's buffer status for the connection.
  91. pub buffer_status: VsockBufferStatus,
  92. /// The type of event.
  93. pub event_type: VsockEventType,
  94. }
  95. impl VsockEvent {
  96. /// Returns whether the event matches the given connection.
  97. pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
  98. self.source == connection_info.dst
  99. && self.destination.cid == guest_cid
  100. && self.destination.port == connection_info.src_port
  101. }
  102. fn from_header(header: &VirtioVsockHdr) -> Result<Self> {
  103. let op = header.op()?;
  104. let buffer_status = VsockBufferStatus {
  105. buffer_allocation: header.buf_alloc.into(),
  106. forward_count: header.fwd_cnt.into(),
  107. };
  108. let source = header.source();
  109. let destination = header.destination();
  110. let event_type = match op {
  111. VirtioVsockOp::Request => {
  112. header.check_data_is_empty()?;
  113. VsockEventType::ConnectionRequest
  114. }
  115. VirtioVsockOp::Response => {
  116. header.check_data_is_empty()?;
  117. VsockEventType::Connected
  118. }
  119. VirtioVsockOp::CreditUpdate => {
  120. header.check_data_is_empty()?;
  121. VsockEventType::CreditUpdate
  122. }
  123. VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
  124. header.check_data_is_empty()?;
  125. info!("Disconnected from the peer");
  126. let reason = if op == VirtioVsockOp::Rst {
  127. DisconnectReason::Reset
  128. } else {
  129. DisconnectReason::Shutdown
  130. };
  131. VsockEventType::Disconnected { reason }
  132. }
  133. VirtioVsockOp::Rw => VsockEventType::Received {
  134. length: header.len() as usize,
  135. },
  136. VirtioVsockOp::CreditRequest => {
  137. header.check_data_is_empty()?;
  138. VsockEventType::CreditRequest
  139. }
  140. VirtioVsockOp::Invalid => return Err(SocketError::InvalidOperation.into()),
  141. };
  142. Ok(VsockEvent {
  143. source,
  144. destination,
  145. buffer_status,
  146. event_type,
  147. })
  148. }
  149. }
  150. #[derive(Clone, Debug, Eq, PartialEq)]
  151. pub struct VsockBufferStatus {
  152. pub buffer_allocation: u32,
  153. pub forward_count: u32,
  154. }
  155. /// The reason why a vsock connection was closed.
  156. #[derive(Copy, Clone, Debug, Eq, PartialEq)]
  157. pub enum DisconnectReason {
  158. /// The peer has either closed the connection in response to our shutdown request, or forcibly
  159. /// closed it of its own accord.
  160. Reset,
  161. /// The peer asked to shut down the connection.
  162. Shutdown,
  163. }
  164. /// Details of the type of an event received from a VirtIO socket.
  165. #[derive(Clone, Debug, Eq, PartialEq)]
  166. pub enum VsockEventType {
  167. /// The peer requests to establish a connection with us.
  168. ConnectionRequest,
  169. /// The connection was successfully established.
  170. Connected,
  171. /// The connection was closed.
  172. Disconnected {
  173. /// The reason for the disconnection.
  174. reason: DisconnectReason,
  175. },
  176. /// Data was received on the connection.
  177. Received {
  178. /// The length of the data in bytes.
  179. length: usize,
  180. },
  181. /// The peer requests us to send a credit update.
  182. CreditRequest,
  183. /// The peer just sent us a credit update with nothing else.
  184. CreditUpdate,
  185. }
  186. /// Driver for a VirtIO socket device.
  187. pub struct VirtIOSocket<H: Hal, T: Transport> {
  188. transport: T,
  189. /// Virtqueue to receive packets.
  190. rx: VirtQueue<H, { QUEUE_SIZE }>,
  191. tx: VirtQueue<H, { QUEUE_SIZE }>,
  192. /// Virtqueue to receive events from the device.
  193. event: VirtQueue<H, { QUEUE_SIZE }>,
  194. /// The guest_cid field contains the guest’s context ID, which uniquely identifies
  195. /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
  196. guest_cid: u64,
  197. rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
  198. }
  199. impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
  200. fn drop(&mut self) {
  201. // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
  202. // after they have been freed.
  203. self.transport.queue_unset(RX_QUEUE_IDX);
  204. self.transport.queue_unset(TX_QUEUE_IDX);
  205. self.transport.queue_unset(EVENT_QUEUE_IDX);
  206. for buffer in self.rx_queue_buffers {
  207. // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
  208. // used anywhere else after the driver is destroyed.
  209. unsafe { drop(Box::from_raw(buffer.as_ptr())) };
  210. }
  211. }
  212. }
  213. impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
  214. /// Create a new VirtIO Vsock driver.
  215. pub fn new(mut transport: T) -> Result<Self> {
  216. transport.begin_init(|features| {
  217. let features = Feature::from_bits_truncate(features);
  218. info!("Device features: {:?}", features);
  219. // negotiate these flags only
  220. let supported_features = Feature::empty();
  221. (features & supported_features).bits()
  222. });
  223. let config = transport.config_space::<VirtioVsockConfig>()?;
  224. info!("config: {:?}", config);
  225. // Safe because config is a valid pointer to the device configuration space.
  226. let guest_cid = unsafe {
  227. volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
  228. };
  229. info!("guest cid: {guest_cid:?}");
  230. let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
  231. let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
  232. let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
  233. // Allocate and add buffers for the RX queue.
  234. let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
  235. for i in 0..QUEUE_SIZE {
  236. let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
  237. // Safe because the buffer lives as long as the queue, as specified in the function
  238. // safety requirement, and we don't access it until it is popped.
  239. let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
  240. assert_eq!(i, token.into());
  241. rx_queue_buffers[i] = Box::into_raw(buffer);
  242. }
  243. let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
  244. transport.finish_init();
  245. if rx.should_notify() {
  246. transport.notify(RX_QUEUE_IDX);
  247. }
  248. Ok(Self {
  249. transport,
  250. rx,
  251. tx,
  252. event,
  253. guest_cid,
  254. rx_queue_buffers,
  255. })
  256. }
  257. /// Returns the CID which has been assigned to this guest.
  258. pub fn guest_cid(&self) -> u64 {
  259. self.guest_cid
  260. }
  261. /// Sends a request to connect to the given destination.
  262. ///
  263. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  264. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  265. /// before sending data.
  266. pub fn connect(&mut self, connection_info: &ConnectionInfo) -> Result {
  267. let header = VirtioVsockHdr {
  268. op: VirtioVsockOp::Request.into(),
  269. ..connection_info.new_header(self.guest_cid)
  270. };
  271. // Sends a header only packet to the TX queue to connect the device to the listening socket
  272. // at the given destination.
  273. self.send_packet_to_tx_queue(&header, &[])?;
  274. Ok(())
  275. }
  276. /// Requests the peer to send us a credit update for the given connection.
  277. fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
  278. let header = VirtioVsockHdr {
  279. op: VirtioVsockOp::CreditRequest.into(),
  280. ..connection_info.new_header(self.guest_cid)
  281. };
  282. self.send_packet_to_tx_queue(&header, &[])
  283. }
  284. /// Sends the buffer to the destination.
  285. pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
  286. self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
  287. let len = buffer.len() as u32;
  288. let header = VirtioVsockHdr {
  289. op: VirtioVsockOp::Rw.into(),
  290. len: len.into(),
  291. ..connection_info.new_header(self.guest_cid)
  292. };
  293. connection_info.tx_cnt += len;
  294. self.send_packet_to_tx_queue(&header, buffer)
  295. }
  296. fn check_peer_buffer_is_sufficient(
  297. &mut self,
  298. connection_info: &mut ConnectionInfo,
  299. buffer_len: usize,
  300. ) -> Result {
  301. if connection_info.peer_free() as usize >= buffer_len {
  302. Ok(())
  303. } else {
  304. // Request an update of the cached peer credit, if we haven't already done so, and tell
  305. // the caller to try again later.
  306. if !connection_info.has_pending_credit_request {
  307. self.request_credit(connection_info)?;
  308. connection_info.has_pending_credit_request = true;
  309. }
  310. Err(SocketError::InsufficientBufferSpaceInPeer.into())
  311. }
  312. }
  313. /// Tells the peer how much buffer space we have to receive data.
  314. pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
  315. let header = VirtioVsockHdr {
  316. op: VirtioVsockOp::CreditUpdate.into(),
  317. ..connection_info.new_header(self.guest_cid)
  318. };
  319. self.send_packet_to_tx_queue(&header, &[])
  320. }
  321. /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
  322. /// it.
  323. pub fn poll_recv(
  324. &mut self,
  325. handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
  326. ) -> Result<Option<VsockEvent>> {
  327. let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
  328. return Ok(None);
  329. };
  330. let result = VsockEvent::from_header(&header).and_then(|event| handler(event, body));
  331. unsafe {
  332. // TODO: What about if both handler and this give errors?
  333. self.add_buffer_to_rx_queue(token)?;
  334. }
  335. result
  336. }
  337. /// Requests to shut down the connection cleanly.
  338. ///
  339. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  340. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  341. /// shutdown.
  342. pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
  343. let header = VirtioVsockHdr {
  344. op: VirtioVsockOp::Shutdown.into(),
  345. ..connection_info.new_header(self.guest_cid)
  346. };
  347. self.send_packet_to_tx_queue(&header, &[])
  348. }
  349. /// Forcibly closes the connection without waiting for the peer.
  350. pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
  351. let header = VirtioVsockHdr {
  352. op: VirtioVsockOp::Rst.into(),
  353. ..connection_info.new_header(self.guest_cid)
  354. };
  355. self.send_packet_to_tx_queue(&header, &[])?;
  356. Ok(())
  357. }
  358. fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
  359. let _len = self.tx.add_notify_wait_pop(
  360. &[header.as_bytes(), buffer],
  361. &mut [],
  362. &mut self.transport,
  363. )?;
  364. Ok(())
  365. }
  366. /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
  367. ///
  368. /// # Safety
  369. ///
  370. /// The buffer must not currently be in the RX queue, and no other references to it must exist
  371. /// between when this method is called and when it is popped from the queue.
  372. unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
  373. // Safe because the buffer lives as long as the queue, and the caller guarantees that it's
  374. // not currently in the queue or referred to anywhere else until it is popped.
  375. unsafe {
  376. let buffer = self.rx_queue_buffers[usize::from(index)].as_mut();
  377. let new_token = self.rx.add(&[], &mut [buffer])?;
  378. // If the RX buffer somehow gets assigned a different token, then our safety assumptions
  379. // are broken and we can't safely continue to do anything with the device.
  380. assert_eq!(new_token, index);
  381. }
  382. if self.rx.should_notify() {
  383. self.transport.notify(RX_QUEUE_IDX);
  384. }
  385. Ok(())
  386. }
  387. /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
  388. /// reference to the buffer containing the body.
  389. ///
  390. /// Returns `None` if there is no pending packet.
  391. fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
  392. let Some(token) = self.rx.peek_used() else {
  393. return Ok(None);
  394. };
  395. // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
  396. // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
  397. // buffer back to the RX queue then we don't access it again until next time it is popped.
  398. let (header, body) = unsafe {
  399. let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
  400. let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
  401. // Read the header and body from the buffer. Don't check the result yet, because we need
  402. // to add the buffer back to the queue either way.
  403. let header_result = read_header_and_body(buffer);
  404. if let Err(_) = header_result {
  405. // If there was an error, add the buffer back immediately. Ignore any errors, as we
  406. // need to return the first error.
  407. let _ = self.add_buffer_to_rx_queue(token);
  408. }
  409. header_result
  410. }?;
  411. debug!("Received packet {:?}. Op {:?}", header, header.op());
  412. Ok(Some((header, body, token)))
  413. }
  414. }
  415. fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
  416. // Shouldn't panic, because we know `RX_BUFFER_SIZE > size_of::<VirtioVsockHdr>()`.
  417. let header = VirtioVsockHdr::read_from_prefix(buffer).unwrap();
  418. let body_length = header.len() as usize;
  419. // This could fail if the device returns an unreasonably long body length.
  420. let data_end = size_of::<VirtioVsockHdr>()
  421. .checked_add(body_length)
  422. .ok_or(SocketError::InvalidNumber)?;
  423. // This could fail if the device returns a body length longer than the buffer we gave it.
  424. let data = buffer
  425. .get(size_of::<VirtioVsockHdr>()..data_end)
  426. .ok_or(SocketError::BufferTooShort)?;
  427. Ok((header, data))
  428. }
  429. #[cfg(test)]
  430. mod tests {
  431. use super::*;
  432. use crate::{
  433. hal::fake::FakeHal,
  434. transport::{
  435. fake::{FakeTransport, QueueStatus, State},
  436. DeviceStatus, DeviceType,
  437. },
  438. volatile::ReadOnly,
  439. };
  440. use alloc::{sync::Arc, vec};
  441. use core::ptr::NonNull;
  442. use std::sync::Mutex;
  443. #[test]
  444. fn config() {
  445. let mut config_space = VirtioVsockConfig {
  446. guest_cid_low: ReadOnly::new(66),
  447. guest_cid_high: ReadOnly::new(0),
  448. };
  449. let state = Arc::new(Mutex::new(State {
  450. status: DeviceStatus::empty(),
  451. driver_features: 0,
  452. guest_page_size: 0,
  453. interrupt_pending: false,
  454. queues: vec![
  455. QueueStatus::default(),
  456. QueueStatus::default(),
  457. QueueStatus::default(),
  458. ],
  459. }));
  460. let transport = FakeTransport {
  461. device_type: DeviceType::Socket,
  462. max_queue_size: 32,
  463. device_features: 0,
  464. config_space: NonNull::from(&mut config_space),
  465. state: state.clone(),
  466. };
  467. let socket =
  468. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
  469. assert_eq!(socket.guest_cid(), 0x00_0000_0042);
  470. }
  471. }