vsock.rs 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. //! Driver for VirtIO socket devices.
  2. #![deny(unsafe_op_in_unsafe_fn)]
  3. use super::error::SocketError;
  4. use super::protocol::{Feature, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
  5. use crate::hal::Hal;
  6. use crate::queue::VirtQueue;
  7. use crate::transport::Transport;
  8. use crate::volatile::volread;
  9. use crate::Result;
  10. use alloc::boxed::Box;
  11. use core::mem::size_of;
  12. use core::ptr::{null_mut, NonNull};
  13. use log::{debug, info};
  14. use zerocopy::{AsBytes, FromBytes};
  15. pub(crate) const RX_QUEUE_IDX: u16 = 0;
  16. pub(crate) const TX_QUEUE_IDX: u16 = 1;
  17. const EVENT_QUEUE_IDX: u16 = 2;
  18. pub(crate) const QUEUE_SIZE: usize = 8;
  19. /// The size in bytes of each buffer used in the RX virtqueue.
  20. const RX_BUFFER_SIZE: usize = 512;
  21. #[derive(Clone, Debug, Default, PartialEq, Eq)]
  22. pub struct ConnectionInfo {
  23. pub dst: VsockAddr,
  24. pub src_port: u32,
  25. /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
  26. /// bytes it has allocated for packet bodies.
  27. peer_buf_alloc: u32,
  28. /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
  29. /// has finished processing.
  30. peer_fwd_cnt: u32,
  31. /// The number of bytes of packet bodies which we have sent to the peer.
  32. tx_cnt: u32,
  33. /// The number of bytes of buffer space we have allocated to receive packet bodies from the
  34. /// peer.
  35. pub buf_alloc: u32,
  36. /// The number of bytes of packet bodies which we have received from the peer and handled.
  37. fwd_cnt: u32,
  38. /// Whether we have recently requested credit from the peer.
  39. ///
  40. /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
  41. /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
  42. has_pending_credit_request: bool,
  43. }
  44. impl ConnectionInfo {
  45. pub fn new(destination: VsockAddr, src_port: u32) -> Self {
  46. Self {
  47. dst: destination,
  48. src_port,
  49. ..Default::default()
  50. }
  51. }
  52. /// Updates this connection info with the peer buffer allocation and forwarded count from the
  53. /// given event.
  54. pub fn update_for_event(&mut self, event: &VsockEvent) {
  55. self.peer_buf_alloc = event.buffer_status.buffer_allocation;
  56. self.peer_fwd_cnt = event.buffer_status.forward_count;
  57. if let VsockEventType::CreditUpdate = event.event_type {
  58. self.has_pending_credit_request = false;
  59. }
  60. }
  61. /// Increases the forwarded count recorded for this connection by the given number of bytes.
  62. ///
  63. /// This should be called once received data has been passed to the client, so there is buffer
  64. /// space available for more.
  65. pub fn done_forwarding(&mut self, length: usize) {
  66. self.fwd_cnt += length as u32;
  67. }
  68. fn peer_free(&self) -> u32 {
  69. self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
  70. }
  71. fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
  72. VirtioVsockHdr {
  73. src_cid: src_cid.into(),
  74. dst_cid: self.dst.cid.into(),
  75. src_port: self.src_port.into(),
  76. dst_port: self.dst.port.into(),
  77. buf_alloc: self.buf_alloc.into(),
  78. fwd_cnt: self.fwd_cnt.into(),
  79. ..Default::default()
  80. }
  81. }
  82. }
  83. /// An event received from a VirtIO socket device.
  84. #[derive(Clone, Debug, Eq, PartialEq)]
  85. pub struct VsockEvent {
  86. /// The source of the event, i.e. the peer who sent it.
  87. pub source: VsockAddr,
  88. /// The destination of the event, i.e. the CID and port on our side.
  89. pub destination: VsockAddr,
  90. /// The peer's buffer status for the connection.
  91. pub buffer_status: VsockBufferStatus,
  92. /// The type of event.
  93. pub event_type: VsockEventType,
  94. }
  95. impl VsockEvent {
  96. /// Returns whether the event matches the given connection.
  97. pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
  98. self.source == connection_info.dst
  99. && self.destination.cid == guest_cid
  100. && self.destination.port == connection_info.src_port
  101. }
  102. fn from_header(header: &VirtioVsockHdr) -> Result<Option<Self>> {
  103. let op = header.op()?;
  104. let buffer_status = VsockBufferStatus {
  105. buffer_allocation: header.buf_alloc.into(),
  106. forward_count: header.fwd_cnt.into(),
  107. };
  108. let source = header.source();
  109. let destination = header.destination();
  110. match op {
  111. VirtioVsockOp::Request => {
  112. header.check_data_is_empty()?;
  113. // TODO: Send a Rst, or support listening.
  114. Ok(None)
  115. }
  116. VirtioVsockOp::Response => {
  117. header.check_data_is_empty()?;
  118. Ok(Some(VsockEvent {
  119. source,
  120. destination,
  121. buffer_status,
  122. event_type: VsockEventType::Connected,
  123. }))
  124. }
  125. VirtioVsockOp::CreditUpdate => {
  126. header.check_data_is_empty()?;
  127. Ok(Some(VsockEvent {
  128. source,
  129. destination,
  130. buffer_status,
  131. event_type: VsockEventType::CreditUpdate,
  132. }))
  133. }
  134. VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
  135. header.check_data_is_empty()?;
  136. info!("Disconnected from the peer");
  137. let reason = if op == VirtioVsockOp::Rst {
  138. DisconnectReason::Reset
  139. } else {
  140. DisconnectReason::Shutdown
  141. };
  142. Ok(Some(VsockEvent {
  143. source,
  144. destination,
  145. buffer_status,
  146. event_type: VsockEventType::Disconnected { reason },
  147. }))
  148. }
  149. VirtioVsockOp::Rw => Ok(Some(VsockEvent {
  150. source,
  151. destination,
  152. buffer_status,
  153. event_type: VsockEventType::Received {
  154. length: header.len() as usize,
  155. },
  156. })),
  157. VirtioVsockOp::CreditRequest => {
  158. header.check_data_is_empty()?;
  159. Ok(Some(VsockEvent {
  160. source,
  161. destination,
  162. buffer_status,
  163. event_type: VsockEventType::CreditRequest,
  164. }))
  165. }
  166. VirtioVsockOp::Invalid => Err(SocketError::InvalidOperation.into()),
  167. }
  168. }
  169. }
  170. #[derive(Clone, Debug, Eq, PartialEq)]
  171. pub struct VsockBufferStatus {
  172. pub buffer_allocation: u32,
  173. pub forward_count: u32,
  174. }
  175. /// The reason why a vsock connection was closed.
  176. #[derive(Copy, Clone, Debug, Eq, PartialEq)]
  177. pub enum DisconnectReason {
  178. /// The peer has either closed the connection in response to our shutdown request, or forcibly
  179. /// closed it of its own accord.
  180. Reset,
  181. /// The peer asked to shut down the connection.
  182. Shutdown,
  183. }
  184. /// Details of the type of an event received from a VirtIO socket.
  185. #[derive(Clone, Debug, Eq, PartialEq)]
  186. pub enum VsockEventType {
  187. /// The connection was successfully established.
  188. Connected,
  189. /// The connection was closed.
  190. Disconnected {
  191. /// The reason for the disconnection.
  192. reason: DisconnectReason,
  193. },
  194. /// Data was received on the connection.
  195. Received {
  196. /// The length of the data in bytes.
  197. length: usize,
  198. },
  199. /// The peer requests us to send a credit update.
  200. CreditRequest,
  201. /// The peer just sent us a credit update with nothing else.
  202. CreditUpdate,
  203. }
  204. /// Driver for a VirtIO socket device.
  205. pub struct VirtIOSocket<H: Hal, T: Transport> {
  206. transport: T,
  207. /// Virtqueue to receive packets.
  208. rx: VirtQueue<H, { QUEUE_SIZE }>,
  209. tx: VirtQueue<H, { QUEUE_SIZE }>,
  210. /// Virtqueue to receive events from the device.
  211. event: VirtQueue<H, { QUEUE_SIZE }>,
  212. /// The guest_cid field contains the guest’s context ID, which uniquely identifies
  213. /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
  214. guest_cid: u64,
  215. rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
  216. }
  217. impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
  218. fn drop(&mut self) {
  219. // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
  220. // after they have been freed.
  221. self.transport.queue_unset(RX_QUEUE_IDX);
  222. self.transport.queue_unset(TX_QUEUE_IDX);
  223. self.transport.queue_unset(EVENT_QUEUE_IDX);
  224. for buffer in self.rx_queue_buffers {
  225. // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
  226. // used anywhere else after the driver is destroyed.
  227. unsafe { drop(Box::from_raw(buffer.as_ptr())) };
  228. }
  229. }
  230. }
  231. impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
  232. /// Create a new VirtIO Vsock driver.
  233. pub fn new(mut transport: T) -> Result<Self> {
  234. transport.begin_init(|features| {
  235. let features = Feature::from_bits_truncate(features);
  236. info!("Device features: {:?}", features);
  237. // negotiate these flags only
  238. let supported_features = Feature::empty();
  239. (features & supported_features).bits()
  240. });
  241. let config = transport.config_space::<VirtioVsockConfig>()?;
  242. info!("config: {:?}", config);
  243. // Safe because config is a valid pointer to the device configuration space.
  244. let guest_cid = unsafe {
  245. volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
  246. };
  247. info!("guest cid: {guest_cid:?}");
  248. let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
  249. let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
  250. let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
  251. // Allocate and add buffers for the RX queue.
  252. let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
  253. for i in 0..QUEUE_SIZE {
  254. let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
  255. // Safe because the buffer lives as long as the queue, as specified in the function
  256. // safety requirement, and we don't access it until it is popped.
  257. let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
  258. assert_eq!(i, token.into());
  259. rx_queue_buffers[i] = Box::into_raw(buffer);
  260. }
  261. let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
  262. transport.finish_init();
  263. if rx.should_notify() {
  264. transport.notify(RX_QUEUE_IDX);
  265. }
  266. Ok(Self {
  267. transport,
  268. rx,
  269. tx,
  270. event,
  271. guest_cid,
  272. rx_queue_buffers,
  273. })
  274. }
  275. /// Returns the CID which has been assigned to this guest.
  276. pub fn guest_cid(&self) -> u64 {
  277. self.guest_cid
  278. }
  279. /// Sends a request to connect to the given destination.
  280. ///
  281. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  282. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  283. /// before sending data.
  284. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  285. let header = VirtioVsockHdr {
  286. op: VirtioVsockOp::Request.into(),
  287. src_cid: self.guest_cid.into(),
  288. dst_cid: destination.cid.into(),
  289. src_port: src_port.into(),
  290. dst_port: destination.port.into(),
  291. ..Default::default()
  292. };
  293. // Sends a header only packet to the tx queue to connect the device to the listening
  294. // socket at the given destination.
  295. self.send_packet_to_tx_queue(&header, &[])?;
  296. Ok(())
  297. }
  298. /// Requests the peer to send us a credit update for the given connection.
  299. fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
  300. let header = VirtioVsockHdr {
  301. op: VirtioVsockOp::CreditRequest.into(),
  302. ..connection_info.new_header(self.guest_cid)
  303. };
  304. self.send_packet_to_tx_queue(&header, &[])
  305. }
  306. /// Sends the buffer to the destination.
  307. pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
  308. self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
  309. let len = buffer.len() as u32;
  310. let header = VirtioVsockHdr {
  311. op: VirtioVsockOp::Rw.into(),
  312. len: len.into(),
  313. ..connection_info.new_header(self.guest_cid)
  314. };
  315. connection_info.tx_cnt += len;
  316. self.send_packet_to_tx_queue(&header, buffer)
  317. }
  318. fn check_peer_buffer_is_sufficient(
  319. &mut self,
  320. connection_info: &mut ConnectionInfo,
  321. buffer_len: usize,
  322. ) -> Result {
  323. if connection_info.peer_free() as usize >= buffer_len {
  324. Ok(())
  325. } else {
  326. // Request an update of the cached peer credit, if we haven't already done so, and tell
  327. // the caller to try again later.
  328. if !connection_info.has_pending_credit_request {
  329. self.request_credit(connection_info)?;
  330. connection_info.has_pending_credit_request = true;
  331. }
  332. Err(SocketError::InsufficientBufferSpaceInPeer.into())
  333. }
  334. }
  335. /// Tells the peer how much buffer space we have to receive data.
  336. pub fn credit_update(&mut self, connection_info: &ConnectionInfo) -> Result {
  337. let header = VirtioVsockHdr {
  338. op: VirtioVsockOp::CreditUpdate.into(),
  339. ..connection_info.new_header(self.guest_cid)
  340. };
  341. self.send_packet_to_tx_queue(&header, &[])
  342. }
  343. /// Polls the RX virtqueue for the next event, and calls the given handler function to handle
  344. /// it.
  345. pub fn poll_recv(
  346. &mut self,
  347. handler: impl FnOnce(VsockEvent, &[u8]) -> Result<Option<VsockEvent>>,
  348. ) -> Result<Option<VsockEvent>> {
  349. let Some((header, body, token)) = self.pop_packet_from_rx_queue()? else {
  350. return Ok(None);
  351. };
  352. let result = match VsockEvent::from_header(&header) {
  353. Ok(Some(event)) => handler(event, body),
  354. other => other,
  355. };
  356. unsafe {
  357. // TODO: What about if both handler and this give errors?
  358. self.add_buffer_to_rx_queue(token)?;
  359. }
  360. result
  361. }
  362. /// Requests to shut down the connection cleanly.
  363. ///
  364. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  365. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  366. /// shutdown.
  367. pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
  368. let header = VirtioVsockHdr {
  369. op: VirtioVsockOp::Shutdown.into(),
  370. ..connection_info.new_header(self.guest_cid)
  371. };
  372. self.send_packet_to_tx_queue(&header, &[])
  373. }
  374. /// Forcibly closes the connection without waiting for the peer.
  375. pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
  376. let header = VirtioVsockHdr {
  377. op: VirtioVsockOp::Rst.into(),
  378. ..connection_info.new_header(self.guest_cid)
  379. };
  380. self.send_packet_to_tx_queue(&header, &[])?;
  381. Ok(())
  382. }
  383. fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
  384. let _len = self.tx.add_notify_wait_pop(
  385. &[header.as_bytes(), buffer],
  386. &mut [],
  387. &mut self.transport,
  388. )?;
  389. Ok(())
  390. }
  391. /// Adds the buffer at the given index in `rx_queue_buffers` back to the RX queue.
  392. ///
  393. /// # Safety
  394. ///
  395. /// The buffer must not currently be in the RX queue, and no other references to it must exist
  396. /// between when this method is called and when it is popped from the queue.
  397. unsafe fn add_buffer_to_rx_queue(&mut self, index: u16) -> Result {
  398. // Safe because the buffer lives as long as the queue, and the caller guarantees that it's
  399. // not currently in the queue or referred to anywhere else until it is popped.
  400. unsafe {
  401. let buffer = self.rx_queue_buffers[usize::from(index)].as_mut();
  402. let new_token = self.rx.add(&[], &mut [buffer])?;
  403. // If the RX buffer somehow gets assigned a different token, then our safety assumptions
  404. // are broken and we can't safely continue to do anything with the device.
  405. assert_eq!(new_token, index);
  406. }
  407. if self.rx.should_notify() {
  408. self.transport.notify(RX_QUEUE_IDX);
  409. }
  410. Ok(())
  411. }
  412. /// Pops one packet from the RX queue, if there is one pending. Returns the header, and a
  413. /// reference to the buffer containing the body.
  414. ///
  415. /// Returns `None` if there is no pending packet.
  416. fn pop_packet_from_rx_queue(&mut self) -> Result<Option<(VirtioVsockHdr, &[u8], u16)>> {
  417. let Some(token) = self.rx.peek_used() else {
  418. return Ok(None);
  419. };
  420. // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
  421. // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
  422. // buffer back to the RX queue then we don't access it again until next time it is popped.
  423. let (header, body) = unsafe {
  424. let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
  425. let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
  426. // Read the header and body from the buffer. Don't check the result yet, because we need
  427. // to add the buffer back to the queue either way.
  428. let header_result = read_header_and_body(buffer);
  429. if let Err(_) = header_result {
  430. // If there was an error, add the buffer back immediately. Ignore any errors, as we
  431. // need to return the first error.
  432. let _ = self.add_buffer_to_rx_queue(token);
  433. }
  434. header_result
  435. }?;
  436. debug!("Received packet {:?}. Op {:?}", header, header.op());
  437. Ok(Some((header, body, token)))
  438. }
  439. }
  440. fn read_header_and_body(buffer: &[u8]) -> Result<(VirtioVsockHdr, &[u8])> {
  441. let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
  442. let body_length = header.len() as usize;
  443. let data_end = size_of::<VirtioVsockHdr>()
  444. .checked_add(body_length)
  445. .ok_or(SocketError::InvalidNumber)?;
  446. let data = buffer
  447. .get(size_of::<VirtioVsockHdr>()..data_end)
  448. .ok_or(SocketError::BufferTooShort)?;
  449. Ok((header, data))
  450. }
  451. #[cfg(test)]
  452. mod tests {
  453. use super::*;
  454. use crate::{
  455. hal::fake::FakeHal,
  456. transport::{
  457. fake::{FakeTransport, QueueStatus, State},
  458. DeviceStatus, DeviceType,
  459. },
  460. volatile::ReadOnly,
  461. };
  462. use alloc::{sync::Arc, vec};
  463. use core::ptr::NonNull;
  464. use std::sync::Mutex;
  465. #[test]
  466. fn config() {
  467. let mut config_space = VirtioVsockConfig {
  468. guest_cid_low: ReadOnly::new(66),
  469. guest_cid_high: ReadOnly::new(0),
  470. };
  471. let state = Arc::new(Mutex::new(State {
  472. status: DeviceStatus::empty(),
  473. driver_features: 0,
  474. guest_page_size: 0,
  475. interrupt_pending: false,
  476. queues: vec![
  477. QueueStatus::default(),
  478. QueueStatus::default(),
  479. QueueStatus::default(),
  480. ],
  481. }));
  482. let transport = FakeTransport {
  483. device_type: DeviceType::Socket,
  484. max_queue_size: 32,
  485. device_features: 0,
  486. config_space: NonNull::from(&mut config_space),
  487. state: state.clone(),
  488. };
  489. let socket =
  490. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
  491. assert_eq!(socket.guest_cid(), 0x00_0000_0042);
  492. }
  493. }