multiconnectionmanager.rs 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546
  1. use super::{
  2. protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
  3. VsockEventType,
  4. };
  5. use crate::{transport::Transport, Hal, Result};
  6. use alloc::{boxed::Box, vec::Vec};
  7. use core::cmp::min;
  8. use core::convert::TryInto;
  9. use core::hint::spin_loop;
  10. use log::debug;
  11. use zerocopy::FromBytes;
  12. const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
  13. /// A higher level interface for VirtIO socket (vsock) devices.
  14. ///
  15. /// This keeps track of multiple vsock connections.
  16. ///
  17. /// # Example
  18. ///
  19. /// ```
  20. /// # use virtio_drivers::{Error, Hal};
  21. /// # use virtio_drivers::transport::Transport;
  22. /// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
  23. ///
  24. /// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
  25. /// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
  26. ///
  27. /// // Start a thread to call `socket.poll()` and handle events.
  28. ///
  29. /// let remote_address = VsockAddr { cid: 2, port: 42 };
  30. /// let local_port = 1234;
  31. /// socket.connect(remote_address, local_port)?;
  32. ///
  33. /// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
  34. ///
  35. /// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
  36. ///
  37. /// socket.shutdown(remote_address, local_port)?;
  38. /// # Ok(())
  39. /// # }
  40. /// ```
  41. pub struct VsockConnectionManager<H: Hal, T: Transport> {
  42. driver: VirtIOSocket<H, T>,
  43. connections: Vec<Connection>,
  44. }
  45. #[derive(Debug)]
  46. struct Connection {
  47. info: ConnectionInfo,
  48. buffer: RingBuffer,
  49. }
  50. impl Connection {
  51. fn new(peer: VsockAddr, local_port: u32) -> Self {
  52. let mut info = ConnectionInfo::new(peer, local_port);
  53. info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
  54. Self {
  55. info,
  56. buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
  57. }
  58. }
  59. }
  60. impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
  61. /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
  62. pub fn new(driver: VirtIOSocket<H, T>) -> Self {
  63. Self {
  64. driver,
  65. connections: Vec::new(),
  66. }
  67. }
  68. /// Returns the CID which has been assigned to this guest.
  69. pub fn guest_cid(&self) -> u64 {
  70. self.driver.guest_cid()
  71. }
  72. /// Sends a request to connect to the given destination.
  73. ///
  74. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  75. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  76. /// before sending data.
  77. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  78. if self.connections.iter().any(|connection| {
  79. connection.info.dst == destination && connection.info.src_port == src_port
  80. }) {
  81. return Err(SocketError::ConnectionExists.into());
  82. }
  83. let new_connection = Connection::new(destination, src_port);
  84. self.driver.connect(&new_connection.info)?;
  85. debug!("Connection requested: {:?}", new_connection.info);
  86. self.connections.push(new_connection);
  87. Ok(())
  88. }
  89. /// Sends the buffer to the destination.
  90. pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
  91. let connection = self
  92. .connections
  93. .iter_mut()
  94. .find(|connection| {
  95. connection.info.dst == destination && connection.info.src_port == src_port
  96. })
  97. .ok_or(SocketError::NotConnected)?;
  98. self.driver.send(buffer, &mut connection.info)
  99. }
  100. /// Polls the vsock device to receive data or other updates.
  101. pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
  102. let guest_cid = self.driver.guest_cid();
  103. let connections = &mut self.connections;
  104. let result = self.driver.poll(|event, body| {
  105. let connection = connections
  106. .iter_mut()
  107. .find(|connection| event.matches_connection(&connection.info, guest_cid));
  108. let Some(connection) = connection else {
  109. // Skip events which don't match any connection we know about.
  110. return Ok(None);
  111. };
  112. // Update stored connection info.
  113. connection.info.update_for_event(&event);
  114. match event.event_type {
  115. VsockEventType::ConnectionRequest => {
  116. // TODO: Send Rst or handle incoming connections.
  117. }
  118. VsockEventType::Connected => {}
  119. VsockEventType::Disconnected { .. } => {
  120. // TODO: Wait until client reads all data before removing connection.
  121. //self.connection_info = None;
  122. }
  123. VsockEventType::Received { length } => {
  124. // Copy to buffer
  125. if !connection.buffer.write(body) {
  126. return Err(SocketError::OutputBufferTooShort(length).into());
  127. }
  128. }
  129. VsockEventType::CreditRequest => {}
  130. VsockEventType::CreditUpdate => {}
  131. }
  132. Ok(Some(event))
  133. })?;
  134. // If the peer requested credit, send an update.
  135. if let Some(VsockEvent {
  136. source,
  137. destination,
  138. event_type: VsockEventType::CreditRequest,
  139. ..
  140. }) = result
  141. {
  142. let connection = self
  143. .connections
  144. .iter()
  145. .find(|connection| {
  146. connection.info.dst == source && connection.info.src_port == destination.port
  147. })
  148. .unwrap();
  149. self.driver.credit_update(&connection.info)?;
  150. // No need to pass the request on to the client, we've already handled it.
  151. Ok(None)
  152. } else {
  153. Ok(result)
  154. }
  155. }
  156. /// Reads data received from the given connection.
  157. pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
  158. let connection = self
  159. .connections
  160. .iter_mut()
  161. .find(|connection| connection.info.dst == peer && connection.info.src_port == src_port)
  162. .ok_or(SocketError::NotConnected)?;
  163. // Copy from ring buffer
  164. let bytes_read = connection.buffer.read(buffer);
  165. connection.info.done_forwarding(bytes_read);
  166. Ok(bytes_read)
  167. }
  168. /// Blocks until we get some event from the vsock device.
  169. pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
  170. loop {
  171. if let Some(event) = self.poll()? {
  172. return Ok(event);
  173. } else {
  174. spin_loop();
  175. }
  176. }
  177. }
  178. /// Requests to shut down the connection cleanly.
  179. ///
  180. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  181. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  182. /// shutdown.
  183. pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  184. let connection = self
  185. .connections
  186. .iter()
  187. .find(|connection| {
  188. connection.info.dst == destination && connection.info.src_port == src_port
  189. })
  190. .ok_or(SocketError::NotConnected)?;
  191. self.driver.shutdown(&connection.info)
  192. }
  193. /// Forcibly closes the connection without waiting for the peer.
  194. pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  195. let (index, connection) = self
  196. .connections
  197. .iter()
  198. .enumerate()
  199. .find(|(_, connection)| {
  200. connection.info.dst == destination && connection.info.src_port == src_port
  201. })
  202. .ok_or(SocketError::NotConnected)?;
  203. self.driver.force_close(&connection.info)?;
  204. self.connections.swap_remove(index);
  205. Ok(())
  206. }
  207. }
  208. #[derive(Debug)]
  209. struct RingBuffer {
  210. buffer: Box<[u8]>,
  211. /// The number of bytes currently in the buffer.
  212. used: usize,
  213. /// The index of the first used byte in the buffer.
  214. start: usize,
  215. }
  216. impl RingBuffer {
  217. pub fn new(capacity: usize) -> Self {
  218. Self {
  219. buffer: FromBytes::new_box_slice_zeroed(capacity),
  220. used: 0,
  221. start: 0,
  222. }
  223. }
  224. /// Returns the number of bytes currently used in the buffer.
  225. pub fn used(&self) -> usize {
  226. self.used
  227. }
  228. /// Returns the number of bytes currently free in the buffer.
  229. pub fn available(&self) -> usize {
  230. self.buffer.len() - self.used
  231. }
  232. /// Adds the given bytes to the buffer if there is enough capacity for them all.
  233. ///
  234. /// Returns true if they were added, or false if they were not.
  235. pub fn write(&mut self, bytes: &[u8]) -> bool {
  236. if bytes.len() > self.available() {
  237. return false;
  238. }
  239. let end = (self.start + self.used) % self.buffer.len();
  240. let write_before_wraparound = min(bytes.len(), self.buffer.len() - end);
  241. let write_after_wraparound = bytes
  242. .len()
  243. .checked_sub(write_before_wraparound)
  244. .unwrap_or_default();
  245. self.buffer[end..end + write_before_wraparound]
  246. .copy_from_slice(&bytes[0..write_before_wraparound]);
  247. self.buffer[0..write_after_wraparound].copy_from_slice(&bytes[write_before_wraparound..]);
  248. self.used += bytes.len();
  249. true
  250. }
  251. /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
  252. /// buffer.
  253. pub fn read(&mut self, out: &mut [u8]) -> usize {
  254. let bytes_read = min(self.used, out.len());
  255. // The number of bytes to copy out between `start` and the end of the buffer.
  256. let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
  257. // The number of bytes to copy out from the beginning of the buffer after wrapping around.
  258. let read_after_wraparound = bytes_read
  259. .checked_sub(read_before_wraparound)
  260. .unwrap_or_default();
  261. out[0..read_before_wraparound]
  262. .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
  263. out[read_before_wraparound..bytes_read]
  264. .copy_from_slice(&self.buffer[0..read_after_wraparound]);
  265. self.used -= bytes_read;
  266. self.start = (self.start + bytes_read) % self.buffer.len();
  267. bytes_read
  268. }
  269. }
  270. #[cfg(test)]
  271. mod tests {
  272. use super::*;
  273. use crate::{
  274. device::socket::{
  275. protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
  276. vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
  277. },
  278. hal::fake::FakeHal,
  279. transport::{
  280. fake::{FakeTransport, QueueStatus, State},
  281. DeviceStatus, DeviceType,
  282. },
  283. volatile::ReadOnly,
  284. };
  285. use alloc::{sync::Arc, vec};
  286. use core::{mem::size_of, ptr::NonNull};
  287. use std::{sync::Mutex, thread};
  288. use zerocopy::{AsBytes, FromBytes};
  289. #[test]
  290. fn send_recv() {
  291. let host_cid = 2;
  292. let guest_cid = 66;
  293. let host_port = 1234;
  294. let guest_port = 4321;
  295. let host_address = VsockAddr {
  296. cid: host_cid,
  297. port: host_port,
  298. };
  299. let hello_from_guest = "Hello from guest";
  300. let hello_from_host = "Hello from host";
  301. let mut config_space = VirtioVsockConfig {
  302. guest_cid_low: ReadOnly::new(66),
  303. guest_cid_high: ReadOnly::new(0),
  304. };
  305. let state = Arc::new(Mutex::new(State {
  306. status: DeviceStatus::empty(),
  307. driver_features: 0,
  308. guest_page_size: 0,
  309. interrupt_pending: false,
  310. queues: vec![
  311. QueueStatus::default(),
  312. QueueStatus::default(),
  313. QueueStatus::default(),
  314. ],
  315. }));
  316. let transport = FakeTransport {
  317. device_type: DeviceType::Socket,
  318. max_queue_size: 32,
  319. device_features: 0,
  320. config_space: NonNull::from(&mut config_space),
  321. state: state.clone(),
  322. };
  323. let mut socket = VsockConnectionManager::new(
  324. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  325. );
  326. // Start a thread to simulate the device.
  327. let handle = thread::spawn(move || {
  328. // Wait for connection request.
  329. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  330. assert_eq!(
  331. VirtioVsockHdr::read_from(
  332. state
  333. .lock()
  334. .unwrap()
  335. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  336. .as_slice()
  337. )
  338. .unwrap(),
  339. VirtioVsockHdr {
  340. op: VirtioVsockOp::Request.into(),
  341. src_cid: guest_cid.into(),
  342. dst_cid: host_cid.into(),
  343. src_port: guest_port.into(),
  344. dst_port: host_port.into(),
  345. len: 0.into(),
  346. socket_type: SocketType::Stream.into(),
  347. flags: 0.into(),
  348. buf_alloc: 1024.into(),
  349. fwd_cnt: 0.into(),
  350. }
  351. );
  352. // Accept connection and give the peer enough credit to send the message.
  353. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  354. RX_QUEUE_IDX,
  355. VirtioVsockHdr {
  356. op: VirtioVsockOp::Response.into(),
  357. src_cid: host_cid.into(),
  358. dst_cid: guest_cid.into(),
  359. src_port: host_port.into(),
  360. dst_port: guest_port.into(),
  361. len: 0.into(),
  362. socket_type: SocketType::Stream.into(),
  363. flags: 0.into(),
  364. buf_alloc: 50.into(),
  365. fwd_cnt: 0.into(),
  366. }
  367. .as_bytes(),
  368. );
  369. // Expect the guest to send some data.
  370. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  371. let request = state
  372. .lock()
  373. .unwrap()
  374. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
  375. assert_eq!(
  376. request.len(),
  377. size_of::<VirtioVsockHdr>() + hello_from_guest.len()
  378. );
  379. assert_eq!(
  380. VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
  381. VirtioVsockHdr {
  382. op: VirtioVsockOp::Rw.into(),
  383. src_cid: guest_cid.into(),
  384. dst_cid: host_cid.into(),
  385. src_port: guest_port.into(),
  386. dst_port: host_port.into(),
  387. len: (hello_from_guest.len() as u32).into(),
  388. socket_type: SocketType::Stream.into(),
  389. flags: 0.into(),
  390. buf_alloc: 1024.into(),
  391. fwd_cnt: 0.into(),
  392. }
  393. );
  394. assert_eq!(
  395. &request[size_of::<VirtioVsockHdr>()..],
  396. hello_from_guest.as_bytes()
  397. );
  398. println!("Host sending");
  399. // Send a response.
  400. let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
  401. VirtioVsockHdr {
  402. op: VirtioVsockOp::Rw.into(),
  403. src_cid: host_cid.into(),
  404. dst_cid: guest_cid.into(),
  405. src_port: host_port.into(),
  406. dst_port: guest_port.into(),
  407. len: (hello_from_host.len() as u32).into(),
  408. socket_type: SocketType::Stream.into(),
  409. flags: 0.into(),
  410. buf_alloc: 50.into(),
  411. fwd_cnt: (hello_from_guest.len() as u32).into(),
  412. }
  413. .write_to_prefix(response.as_mut_slice());
  414. response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
  415. state
  416. .lock()
  417. .unwrap()
  418. .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
  419. // Expect a shutdown.
  420. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  421. assert_eq!(
  422. VirtioVsockHdr::read_from(
  423. state
  424. .lock()
  425. .unwrap()
  426. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  427. .as_slice()
  428. )
  429. .unwrap(),
  430. VirtioVsockHdr {
  431. op: VirtioVsockOp::Shutdown.into(),
  432. src_cid: guest_cid.into(),
  433. dst_cid: host_cid.into(),
  434. src_port: guest_port.into(),
  435. dst_port: host_port.into(),
  436. len: 0.into(),
  437. socket_type: SocketType::Stream.into(),
  438. flags: 0.into(),
  439. buf_alloc: 1024.into(),
  440. fwd_cnt: (hello_from_host.len() as u32).into(),
  441. }
  442. );
  443. });
  444. socket.connect(host_address, guest_port).unwrap();
  445. assert_eq!(
  446. socket.wait_for_event().unwrap(),
  447. VsockEvent {
  448. source: host_address,
  449. destination: VsockAddr {
  450. cid: guest_cid,
  451. port: guest_port,
  452. },
  453. event_type: VsockEventType::Connected,
  454. buffer_status: VsockBufferStatus {
  455. buffer_allocation: 50,
  456. forward_count: 0,
  457. },
  458. }
  459. );
  460. println!("Guest sending");
  461. socket
  462. .send(host_address, guest_port, "Hello from guest".as_bytes())
  463. .unwrap();
  464. println!("Guest waiting to receive.");
  465. assert_eq!(
  466. socket.wait_for_event().unwrap(),
  467. VsockEvent {
  468. source: host_address,
  469. destination: VsockAddr {
  470. cid: guest_cid,
  471. port: guest_port,
  472. },
  473. event_type: VsockEventType::Received {
  474. length: hello_from_host.len()
  475. },
  476. buffer_status: VsockBufferStatus {
  477. buffer_allocation: 50,
  478. forward_count: hello_from_guest.len() as u32,
  479. },
  480. }
  481. );
  482. println!("Guest getting received data.");
  483. let mut buffer = [0u8; 64];
  484. assert_eq!(
  485. socket.recv(host_address, guest_port, &mut buffer).unwrap(),
  486. hello_from_host.len()
  487. );
  488. assert_eq!(
  489. &buffer[0..hello_from_host.len()],
  490. hello_from_host.as_bytes()
  491. );
  492. socket.shutdown(host_address, guest_port).unwrap();
  493. handle.join().unwrap();
  494. }
  495. }