vsock.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536
  1. //! Driver for VirtIO socket devices.
  2. #![deny(unsafe_op_in_unsafe_fn)]
  3. use super::error::SocketError;
  4. use super::protocol::{Feature, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp, VsockAddr};
  5. use crate::hal::Hal;
  6. use crate::queue::VirtQueue;
  7. use crate::transport::Transport;
  8. use crate::volatile::volread;
  9. use crate::Result;
  10. use alloc::boxed::Box;
  11. use core::mem::size_of;
  12. use core::ptr::{null_mut, NonNull};
  13. use log::{debug, info};
  14. use zerocopy::{AsBytes, FromBytes};
  15. pub(crate) const RX_QUEUE_IDX: u16 = 0;
  16. pub(crate) const TX_QUEUE_IDX: u16 = 1;
  17. const EVENT_QUEUE_IDX: u16 = 2;
  18. pub(crate) const QUEUE_SIZE: usize = 8;
  19. /// The size in bytes of each buffer used in the RX virtqueue.
  20. const RX_BUFFER_SIZE: usize = 512;
  21. #[derive(Clone, Debug, Default, PartialEq, Eq)]
  22. pub struct ConnectionInfo {
  23. pub dst: VsockAddr,
  24. pub src_port: u32,
  25. /// The last `buf_alloc` value the peer sent to us, indicating how much receive buffer space in
  26. /// bytes it has allocated for packet bodies.
  27. peer_buf_alloc: u32,
  28. /// The last `fwd_cnt` value the peer sent to us, indicating how many bytes of packet bodies it
  29. /// has finished processing.
  30. peer_fwd_cnt: u32,
  31. /// The number of bytes of packet bodies which we have sent to the peer.
  32. tx_cnt: u32,
  33. /// The number of bytes of packet bodies which we have received from the peer and handled.
  34. fwd_cnt: u32,
  35. /// Whether we have recently requested credit from the peer.
  36. ///
  37. /// This is set to true when we send a `VIRTIO_VSOCK_OP_CREDIT_REQUEST`, and false when we
  38. /// receive a `VIRTIO_VSOCK_OP_CREDIT_UPDATE`.
  39. has_pending_credit_request: bool,
  40. }
  41. impl ConnectionInfo {
  42. pub fn new(destination: VsockAddr, src_port: u32) -> Self {
  43. Self {
  44. dst: destination,
  45. src_port,
  46. ..Default::default()
  47. }
  48. }
  49. /// Updates this connection info with the peer buffer allocation and forwarded count from the
  50. /// given event.
  51. pub fn update_for_event(&mut self, event: &VsockEvent) {
  52. self.peer_buf_alloc = event.buffer_status.buffer_allocation;
  53. self.peer_fwd_cnt = event.buffer_status.forward_count;
  54. if let VsockEventType::CreditUpdate = event.event_type {
  55. self.has_pending_credit_request = false;
  56. }
  57. }
  58. /// Increases the forwarded count recorded for this connection by the given number of bytes.
  59. ///
  60. /// This should be called once received data has been passed to the client, so there is buffer
  61. /// space available for more.
  62. pub fn done_forwarding(&mut self, length: usize) {
  63. self.fwd_cnt += length as u32;
  64. }
  65. fn peer_free(&self) -> u32 {
  66. self.peer_buf_alloc - (self.tx_cnt - self.peer_fwd_cnt)
  67. }
  68. fn new_header(&self, src_cid: u64) -> VirtioVsockHdr {
  69. VirtioVsockHdr {
  70. src_cid: src_cid.into(),
  71. dst_cid: self.dst.cid.into(),
  72. src_port: self.src_port.into(),
  73. dst_port: self.dst.port.into(),
  74. fwd_cnt: self.fwd_cnt.into(),
  75. ..Default::default()
  76. }
  77. }
  78. }
  79. /// An event received from a VirtIO socket device.
  80. #[derive(Clone, Debug, Eq, PartialEq)]
  81. pub struct VsockEvent {
  82. /// The source of the event, i.e. the peer who sent it.
  83. pub source: VsockAddr,
  84. /// The destination of the event, i.e. the CID and port on our side.
  85. pub destination: VsockAddr,
  86. /// The peer's buffer status for the connection.
  87. pub buffer_status: VsockBufferStatus,
  88. /// The type of event.
  89. pub event_type: VsockEventType,
  90. }
  91. impl VsockEvent {
  92. /// Returns whether the event matches the given connection.
  93. pub fn matches_connection(&self, connection_info: &ConnectionInfo, guest_cid: u64) -> bool {
  94. self.source == connection_info.dst
  95. && self.destination.cid == guest_cid
  96. && self.destination.port == connection_info.src_port
  97. }
  98. fn from_header(header: &VirtioVsockHdr) -> Result<Option<Self>> {
  99. let op = header.op()?;
  100. let buffer_status = VsockBufferStatus {
  101. buffer_allocation: header.buf_alloc.into(),
  102. forward_count: header.fwd_cnt.into(),
  103. };
  104. let source = header.source();
  105. let destination = header.destination();
  106. match op {
  107. VirtioVsockOp::Request => {
  108. header.check_data_is_empty()?;
  109. // TODO: Send a Rst, or support listening.
  110. Ok(None)
  111. }
  112. VirtioVsockOp::Response => {
  113. header.check_data_is_empty()?;
  114. Ok(Some(VsockEvent {
  115. source,
  116. destination,
  117. buffer_status,
  118. event_type: VsockEventType::Connected,
  119. }))
  120. }
  121. VirtioVsockOp::CreditUpdate => {
  122. header.check_data_is_empty()?;
  123. Ok(Some(VsockEvent {
  124. source,
  125. destination,
  126. buffer_status,
  127. event_type: VsockEventType::CreditUpdate,
  128. }))
  129. }
  130. VirtioVsockOp::Rst | VirtioVsockOp::Shutdown => {
  131. header.check_data_is_empty()?;
  132. info!("Disconnected from the peer");
  133. let reason = if op == VirtioVsockOp::Rst {
  134. DisconnectReason::Reset
  135. } else {
  136. DisconnectReason::Shutdown
  137. };
  138. Ok(Some(VsockEvent {
  139. source,
  140. destination,
  141. buffer_status,
  142. event_type: VsockEventType::Disconnected { reason },
  143. }))
  144. }
  145. VirtioVsockOp::Rw => Ok(Some(VsockEvent {
  146. source,
  147. destination,
  148. buffer_status,
  149. event_type: VsockEventType::Received {
  150. length: header.len() as usize,
  151. },
  152. })),
  153. VirtioVsockOp::CreditRequest => {
  154. header.check_data_is_empty()?;
  155. Ok(Some(VsockEvent {
  156. source,
  157. destination,
  158. buffer_status,
  159. event_type: VsockEventType::CreditRequest,
  160. }))
  161. }
  162. VirtioVsockOp::Invalid => Err(SocketError::InvalidOperation.into()),
  163. }
  164. }
  165. }
  166. #[derive(Clone, Debug, Eq, PartialEq)]
  167. pub struct VsockBufferStatus {
  168. pub buffer_allocation: u32,
  169. pub forward_count: u32,
  170. }
  171. /// The reason why a vsock connection was closed.
  172. #[derive(Copy, Clone, Debug, Eq, PartialEq)]
  173. pub enum DisconnectReason {
  174. /// The peer has either closed the connection in response to our shutdown request, or forcibly
  175. /// closed it of its own accord.
  176. Reset,
  177. /// The peer asked to shut down the connection.
  178. Shutdown,
  179. }
  180. /// Details of the type of an event received from a VirtIO socket.
  181. #[derive(Clone, Debug, Eq, PartialEq)]
  182. pub enum VsockEventType {
  183. /// The connection was successfully established.
  184. Connected,
  185. /// The connection was closed.
  186. Disconnected {
  187. /// The reason for the disconnection.
  188. reason: DisconnectReason,
  189. },
  190. /// Data was received on the connection.
  191. Received {
  192. /// The length of the data in bytes.
  193. length: usize,
  194. },
  195. /// The peer requests us to send a credit update.
  196. CreditRequest,
  197. /// The peer just sent us a credit update with nothing else.
  198. CreditUpdate,
  199. }
  200. /// Driver for a VirtIO socket device.
  201. pub struct VirtIOSocket<H: Hal, T: Transport> {
  202. transport: T,
  203. /// Virtqueue to receive packets.
  204. rx: VirtQueue<H, { QUEUE_SIZE }>,
  205. tx: VirtQueue<H, { QUEUE_SIZE }>,
  206. /// Virtqueue to receive events from the device.
  207. event: VirtQueue<H, { QUEUE_SIZE }>,
  208. /// The guest_cid field contains the guest’s context ID, which uniquely identifies
  209. /// the device for its lifetime. The upper 32 bits of the CID are reserved and zeroed.
  210. guest_cid: u64,
  211. rx_queue_buffers: [NonNull<[u8; RX_BUFFER_SIZE]>; QUEUE_SIZE],
  212. }
  213. impl<H: Hal, T: Transport> Drop for VirtIOSocket<H, T> {
  214. fn drop(&mut self) {
  215. // Clear any pointers pointing to DMA regions, so the device doesn't try to access them
  216. // after they have been freed.
  217. self.transport.queue_unset(RX_QUEUE_IDX);
  218. self.transport.queue_unset(TX_QUEUE_IDX);
  219. self.transport.queue_unset(EVENT_QUEUE_IDX);
  220. for buffer in self.rx_queue_buffers {
  221. // Safe because we obtained the RX buffer pointer from Box::into_raw, and it won't be
  222. // used anywhere else after the driver is destroyed.
  223. unsafe { drop(Box::from_raw(buffer.as_ptr())) };
  224. }
  225. }
  226. }
  227. impl<H: Hal, T: Transport> VirtIOSocket<H, T> {
  228. /// Create a new VirtIO Vsock driver.
  229. pub fn new(mut transport: T) -> Result<Self> {
  230. transport.begin_init(|features| {
  231. let features = Feature::from_bits_truncate(features);
  232. info!("Device features: {:?}", features);
  233. // negotiate these flags only
  234. let supported_features = Feature::empty();
  235. (features & supported_features).bits()
  236. });
  237. let config = transport.config_space::<VirtioVsockConfig>()?;
  238. info!("config: {:?}", config);
  239. // Safe because config is a valid pointer to the device configuration space.
  240. let guest_cid = unsafe {
  241. volread!(config, guest_cid_low) as u64 | (volread!(config, guest_cid_high) as u64) << 32
  242. };
  243. info!("guest cid: {guest_cid:?}");
  244. let mut rx = VirtQueue::new(&mut transport, RX_QUEUE_IDX)?;
  245. let tx = VirtQueue::new(&mut transport, TX_QUEUE_IDX)?;
  246. let event = VirtQueue::new(&mut transport, EVENT_QUEUE_IDX)?;
  247. // Allocate and add buffers for the RX queue.
  248. let mut rx_queue_buffers = [null_mut(); QUEUE_SIZE];
  249. for i in 0..QUEUE_SIZE {
  250. let mut buffer: Box<[u8; RX_BUFFER_SIZE]> = FromBytes::new_box_zeroed();
  251. // Safe because the buffer lives as long as the queue, as specified in the function
  252. // safety requirement, and we don't access it until it is popped.
  253. let token = unsafe { rx.add(&[], &mut [buffer.as_mut_slice()]) }?;
  254. assert_eq!(i, token.into());
  255. rx_queue_buffers[i] = Box::into_raw(buffer);
  256. }
  257. let rx_queue_buffers = rx_queue_buffers.map(|ptr| NonNull::new(ptr).unwrap());
  258. transport.finish_init();
  259. if rx.should_notify() {
  260. transport.notify(RX_QUEUE_IDX);
  261. }
  262. Ok(Self {
  263. transport,
  264. rx,
  265. tx,
  266. event,
  267. guest_cid,
  268. rx_queue_buffers,
  269. })
  270. }
  271. /// Returns the CID which has been assigned to this guest.
  272. pub fn guest_cid(&self) -> u64 {
  273. self.guest_cid
  274. }
  275. /// Sends a request to connect to the given destination.
  276. ///
  277. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  278. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  279. /// before sending data.
  280. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  281. let header = VirtioVsockHdr {
  282. op: VirtioVsockOp::Request.into(),
  283. src_cid: self.guest_cid.into(),
  284. dst_cid: destination.cid.into(),
  285. src_port: src_port.into(),
  286. dst_port: destination.port.into(),
  287. ..Default::default()
  288. };
  289. // Sends a header only packet to the tx queue to connect the device to the listening
  290. // socket at the given destination.
  291. self.send_packet_to_tx_queue(&header, &[])?;
  292. Ok(())
  293. }
  294. /// Requests the peer to send us a credit update for the given connection.
  295. fn request_credit(&mut self, connection_info: &ConnectionInfo) -> Result {
  296. let header = VirtioVsockHdr {
  297. op: VirtioVsockOp::CreditRequest.into(),
  298. ..connection_info.new_header(self.guest_cid)
  299. };
  300. self.send_packet_to_tx_queue(&header, &[])
  301. }
  302. /// Sends the buffer to the destination.
  303. pub fn send(&mut self, buffer: &[u8], connection_info: &mut ConnectionInfo) -> Result {
  304. self.check_peer_buffer_is_sufficient(connection_info, buffer.len())?;
  305. let len = buffer.len() as u32;
  306. let header = VirtioVsockHdr {
  307. op: VirtioVsockOp::Rw.into(),
  308. len: len.into(),
  309. buf_alloc: 0.into(),
  310. ..connection_info.new_header(self.guest_cid)
  311. };
  312. connection_info.tx_cnt += len;
  313. self.send_packet_to_tx_queue(&header, buffer)
  314. }
  315. fn check_peer_buffer_is_sufficient(
  316. &mut self,
  317. connection_info: &mut ConnectionInfo,
  318. buffer_len: usize,
  319. ) -> Result {
  320. if connection_info.peer_free() as usize >= buffer_len {
  321. Ok(())
  322. } else {
  323. // Request an update of the cached peer credit, if we haven't already done so, and tell
  324. // the caller to try again later.
  325. if !connection_info.has_pending_credit_request {
  326. self.request_credit(connection_info)?;
  327. connection_info.has_pending_credit_request = true;
  328. }
  329. Err(SocketError::InsufficientBufferSpaceInPeer.into())
  330. }
  331. }
  332. /// Tells the peer how much buffer space we have to receive data.
  333. pub fn credit_update(&mut self, connection_info: &ConnectionInfo, buffer_size: u32) -> Result {
  334. let header = VirtioVsockHdr {
  335. op: VirtioVsockOp::CreditUpdate.into(),
  336. buf_alloc: buffer_size.into(),
  337. ..connection_info.new_header(self.guest_cid)
  338. };
  339. self.send_packet_to_tx_queue(&header, &[])
  340. }
  341. /// Polls the vsock device to receive data or other updates.
  342. ///
  343. /// A buffer must be provided to put the data in if there is some to
  344. /// receive.
  345. pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
  346. // Handle entries from the RX virtqueue until we find one that generates an event.
  347. let event = self.poll_rx_queue(buffer)?;
  348. if self.rx.should_notify() {
  349. self.transport.notify(RX_QUEUE_IDX);
  350. }
  351. Ok(event)
  352. }
  353. /// Requests to shut down the connection cleanly.
  354. ///
  355. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  356. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  357. /// shutdown.
  358. pub fn shutdown(&mut self, connection_info: &ConnectionInfo) -> Result {
  359. let header = VirtioVsockHdr {
  360. op: VirtioVsockOp::Shutdown.into(),
  361. ..connection_info.new_header(self.guest_cid)
  362. };
  363. self.send_packet_to_tx_queue(&header, &[])
  364. }
  365. /// Forcibly closes the connection without waiting for the peer.
  366. pub fn force_close(&mut self, connection_info: &ConnectionInfo) -> Result {
  367. let header = VirtioVsockHdr {
  368. op: VirtioVsockOp::Rst.into(),
  369. ..connection_info.new_header(self.guest_cid)
  370. };
  371. self.send_packet_to_tx_queue(&header, &[])?;
  372. Ok(())
  373. }
  374. fn send_packet_to_tx_queue(&mut self, header: &VirtioVsockHdr, buffer: &[u8]) -> Result {
  375. let _len = self.tx.add_notify_wait_pop(
  376. &[header.as_bytes(), buffer],
  377. &mut [],
  378. &mut self.transport,
  379. )?;
  380. Ok(())
  381. }
  382. /// Polls the RX virtqueue for the next event.
  383. ///
  384. /// Returns `Ok(None)` if the virtqueue is empty, possibly after processing some packets which
  385. /// don't result in any events to return.
  386. fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
  387. let Some(header) = self.pop_packet_from_rx_queue(body)? else {
  388. return Ok(None);
  389. };
  390. // TODO: Add buffer back immediately on error or None
  391. if let Some(event) = VsockEvent::from_header(&header)? {
  392. Ok(Some(event))
  393. } else {
  394. Ok(None)
  395. }
  396. }
  397. /// Pops one packet from the RX queue, if there is one pending. Returns the header, and copies
  398. /// the body into the given buffer.
  399. ///
  400. /// Returns `None` if there is no pending packet, or an error if the body is bigger than the
  401. /// buffer supplied.
  402. fn pop_packet_from_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VirtioVsockHdr>> {
  403. let Some(token) = self.rx.peek_used() else {
  404. return Ok(None);
  405. };
  406. // Safe because we maintain a consistent mapping of tokens to buffers, so we pass the same
  407. // buffer to `pop_used` as we previously passed to `add` for the token. Once we add the
  408. // buffer back to the RX queue then we don't access it again until next time it is popped.
  409. let header = unsafe {
  410. let buffer = self.rx_queue_buffers[usize::from(token)].as_mut();
  411. let _len = self.rx.pop_used(token, &[], &mut [buffer])?;
  412. // Read the header and body from the buffer. Don't check the result yet, because we need
  413. // to add the buffer back to the queue either way.
  414. let header_result = read_header_and_body(buffer, body);
  415. // Add the buffer back to the RX queue.
  416. let new_token = self.rx.add(&[], &mut [buffer])?;
  417. // If the RX buffer somehow gets assigned a different token, then our safety assumptions
  418. // are broken and we can't safely continue to do anything with the device.
  419. assert_eq!(new_token, token);
  420. header_result
  421. }?;
  422. debug!("Received packet {:?}. Op {:?}", header, header.op());
  423. Ok(Some(header))
  424. }
  425. }
  426. fn read_header_and_body(buffer: &[u8], body: &mut [u8]) -> Result<VirtioVsockHdr> {
  427. let header = VirtioVsockHdr::read_from_prefix(buffer).ok_or(SocketError::BufferTooShort)?;
  428. let body_length = header.len() as usize;
  429. let data_end = size_of::<VirtioVsockHdr>()
  430. .checked_add(body_length)
  431. .ok_or(SocketError::InvalidNumber)?;
  432. let data = buffer
  433. .get(size_of::<VirtioVsockHdr>()..data_end)
  434. .ok_or(SocketError::BufferTooShort)?;
  435. body.get_mut(0..body_length)
  436. .ok_or(SocketError::OutputBufferTooShort(body_length))?
  437. .copy_from_slice(data);
  438. Ok(header)
  439. }
  440. #[cfg(test)]
  441. mod tests {
  442. use super::*;
  443. use crate::{
  444. hal::fake::FakeHal,
  445. transport::{
  446. fake::{FakeTransport, QueueStatus, State},
  447. DeviceStatus, DeviceType,
  448. },
  449. volatile::ReadOnly,
  450. };
  451. use alloc::{sync::Arc, vec};
  452. use core::ptr::NonNull;
  453. use std::sync::Mutex;
  454. #[test]
  455. fn config() {
  456. let mut config_space = VirtioVsockConfig {
  457. guest_cid_low: ReadOnly::new(66),
  458. guest_cid_high: ReadOnly::new(0),
  459. };
  460. let state = Arc::new(Mutex::new(State {
  461. status: DeviceStatus::empty(),
  462. driver_features: 0,
  463. guest_page_size: 0,
  464. interrupt_pending: false,
  465. queues: vec![
  466. QueueStatus::default(),
  467. QueueStatus::default(),
  468. QueueStatus::default(),
  469. ],
  470. }));
  471. let transport = FakeTransport {
  472. device_type: DeviceType::Socket,
  473. max_queue_size: 32,
  474. device_features: 0,
  475. config_space: NonNull::from(&mut config_space),
  476. state: state.clone(),
  477. };
  478. let socket =
  479. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap();
  480. assert_eq!(socket.guest_cid(), 0x00_0000_0042);
  481. }
  482. }