singleconnectionmanager.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447
  1. use super::{
  2. protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
  3. VsockEventType,
  4. };
  5. use crate::{transport::Transport, Hal, Result};
  6. use core::hint::spin_loop;
  7. use log::debug;
  8. /// A higher level interface for VirtIO socket (vsock) devices.
  9. ///
  10. /// This can only keep track of a single vsock connection. If you want to support multiple
  11. /// simultaneous connections, try [`VsockConnectionManager`](super::VsockConnectionManager).
  12. pub struct SingleConnectionManager<H: Hal, T: Transport> {
  13. driver: VirtIOSocket<H, T>,
  14. connection_info: Option<ConnectionInfo>,
  15. }
  16. impl<H: Hal, T: Transport> SingleConnectionManager<H, T> {
  17. /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
  18. pub fn new(driver: VirtIOSocket<H, T>) -> Self {
  19. Self {
  20. driver,
  21. connection_info: None,
  22. }
  23. }
  24. /// Returns the CID which has been assigned to this guest.
  25. pub fn guest_cid(&self) -> u64 {
  26. self.driver.guest_cid()
  27. }
  28. /// Sends a request to connect to the given destination.
  29. ///
  30. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  31. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  32. /// before sending data.
  33. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  34. if self.connection_info.is_some() {
  35. return Err(SocketError::ConnectionExists.into());
  36. }
  37. let new_connection_info = ConnectionInfo::new(destination, src_port);
  38. self.driver.connect(&new_connection_info)?;
  39. debug!("Connection requested: {:?}", new_connection_info);
  40. self.connection_info = Some(new_connection_info);
  41. Ok(())
  42. }
  43. /// Sends the buffer to the destination.
  44. pub fn send(&mut self, buffer: &[u8]) -> Result {
  45. let connection_info = self
  46. .connection_info
  47. .as_mut()
  48. .ok_or(SocketError::NotConnected)?;
  49. connection_info.buf_alloc = 0;
  50. self.driver.send(buffer, connection_info)
  51. }
  52. /// Polls the vsock device to receive data or other updates.
  53. ///
  54. /// A buffer must be provided to put the data in if there is some to
  55. /// receive.
  56. pub fn poll_recv(&mut self, buffer: &mut [u8]) -> Result<Option<VsockEvent>> {
  57. let Some(connection_info) = &mut self.connection_info else {
  58. return Err(SocketError::NotConnected.into());
  59. };
  60. // Tell the peer that we have space to receive some data.
  61. connection_info.buf_alloc = buffer.len() as u32;
  62. self.driver.credit_update(connection_info)?;
  63. self.poll_rx_queue(buffer)
  64. }
  65. /// Blocks until we get some event from the vsock device.
  66. ///
  67. /// A buffer must be provided to put the data in if there is some to
  68. /// receive.
  69. pub fn wait_for_recv(&mut self, buffer: &mut [u8]) -> Result<VsockEvent> {
  70. loop {
  71. if let Some(event) = self.poll_recv(buffer)? {
  72. return Ok(event);
  73. } else {
  74. spin_loop();
  75. }
  76. }
  77. }
  78. fn poll_rx_queue(&mut self, body: &mut [u8]) -> Result<Option<VsockEvent>> {
  79. let guest_cid = self.driver.guest_cid();
  80. let self_connection_info = &mut self.connection_info;
  81. self.driver.poll(|event, borrowed_body| {
  82. let Some(connection_info) = self_connection_info else {
  83. return Ok(None);
  84. };
  85. // Skip packets which don't match our current connection.
  86. if !event.matches_connection(connection_info, guest_cid) {
  87. debug!(
  88. "Skipping {:?} as connection is {:?}",
  89. event, connection_info
  90. );
  91. return Ok(None);
  92. }
  93. // Update stored connection info.
  94. connection_info.update_for_event(&event);
  95. match event.event_type {
  96. VsockEventType::ConnectionRequest => {
  97. // TODO: Send Rst or handle incoming connections.
  98. }
  99. VsockEventType::Connected => {}
  100. VsockEventType::Disconnected { .. } => {
  101. *self_connection_info = None;
  102. }
  103. VsockEventType::Received { length } => {
  104. body.get_mut(0..length)
  105. .ok_or_else(|| SocketError::OutputBufferTooShort(length))?
  106. .copy_from_slice(borrowed_body);
  107. connection_info.done_forwarding(length);
  108. }
  109. VsockEventType::CreditRequest => {
  110. // No point sending a credit update until `poll_recv` is called with a buffer,
  111. // as otherwise buf_alloc would just be 0 anyway.
  112. }
  113. VsockEventType::CreditUpdate => {}
  114. }
  115. Ok(Some(event))
  116. })
  117. }
  118. /// Requests to shut down the connection cleanly.
  119. ///
  120. /// This returns as soon as the request is sent; you should wait until `poll_recv` returns a
  121. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  122. /// shutdown.
  123. pub fn shutdown(&mut self) -> Result {
  124. let connection_info = self
  125. .connection_info
  126. .as_mut()
  127. .ok_or(SocketError::NotConnected)?;
  128. connection_info.buf_alloc = 0;
  129. self.driver.shutdown(connection_info)
  130. }
  131. /// Forcibly closes the connection without waiting for the peer.
  132. pub fn force_close(&mut self) -> Result {
  133. let connection_info = self
  134. .connection_info
  135. .as_mut()
  136. .ok_or(SocketError::NotConnected)?;
  137. connection_info.buf_alloc = 0;
  138. self.driver.force_close(connection_info)?;
  139. self.connection_info = None;
  140. Ok(())
  141. }
  142. /// Blocks until the peer either accepts our connection request (with a
  143. /// `VIRTIO_VSOCK_OP_RESPONSE`) or rejects it (with a
  144. /// `VIRTIO_VSOCK_OP_RST`).
  145. pub fn wait_for_connect(&mut self) -> Result {
  146. loop {
  147. match self.wait_for_recv(&mut [])?.event_type {
  148. VsockEventType::Connected => return Ok(()),
  149. VsockEventType::Disconnected { .. } => {
  150. return Err(SocketError::ConnectionFailed.into())
  151. }
  152. VsockEventType::Received { .. } => return Err(SocketError::InvalidOperation.into()),
  153. VsockEventType::ConnectionRequest
  154. | VsockEventType::CreditRequest
  155. | VsockEventType::CreditUpdate => {}
  156. }
  157. }
  158. }
  159. }
  160. #[cfg(test)]
  161. mod tests {
  162. use super::*;
  163. use crate::{
  164. device::socket::{
  165. protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
  166. vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
  167. },
  168. hal::fake::FakeHal,
  169. transport::{
  170. fake::{FakeTransport, QueueStatus, State},
  171. DeviceStatus, DeviceType,
  172. },
  173. volatile::ReadOnly,
  174. };
  175. use alloc::{sync::Arc, vec};
  176. use core::{mem::size_of, ptr::NonNull};
  177. use std::{sync::Mutex, thread};
  178. use zerocopy::{AsBytes, FromBytes};
  179. #[test]
  180. fn send_recv() {
  181. let host_cid = 2;
  182. let guest_cid = 66;
  183. let host_port = 1234;
  184. let guest_port = 4321;
  185. let host_address = VsockAddr {
  186. cid: host_cid,
  187. port: host_port,
  188. };
  189. let hello_from_guest = "Hello from guest";
  190. let hello_from_host = "Hello from host";
  191. let mut config_space = VirtioVsockConfig {
  192. guest_cid_low: ReadOnly::new(66),
  193. guest_cid_high: ReadOnly::new(0),
  194. };
  195. let state = Arc::new(Mutex::new(State {
  196. status: DeviceStatus::empty(),
  197. driver_features: 0,
  198. guest_page_size: 0,
  199. interrupt_pending: false,
  200. queues: vec![
  201. QueueStatus::default(),
  202. QueueStatus::default(),
  203. QueueStatus::default(),
  204. ],
  205. }));
  206. let transport = FakeTransport {
  207. device_type: DeviceType::Socket,
  208. max_queue_size: 32,
  209. device_features: 0,
  210. config_space: NonNull::from(&mut config_space),
  211. state: state.clone(),
  212. };
  213. let mut socket = SingleConnectionManager::new(
  214. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  215. );
  216. // Start a thread to simulate the device.
  217. let handle = thread::spawn(move || {
  218. // Wait for connection request.
  219. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  220. assert_eq!(
  221. VirtioVsockHdr::read_from(
  222. state
  223. .lock()
  224. .unwrap()
  225. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  226. .as_slice()
  227. )
  228. .unwrap(),
  229. VirtioVsockHdr {
  230. op: VirtioVsockOp::Request.into(),
  231. src_cid: guest_cid.into(),
  232. dst_cid: host_cid.into(),
  233. src_port: guest_port.into(),
  234. dst_port: host_port.into(),
  235. len: 0.into(),
  236. socket_type: SocketType::Stream.into(),
  237. flags: 0.into(),
  238. buf_alloc: 0.into(),
  239. fwd_cnt: 0.into(),
  240. }
  241. );
  242. // Accept connection and give the peer enough credit to send the message.
  243. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  244. RX_QUEUE_IDX,
  245. VirtioVsockHdr {
  246. op: VirtioVsockOp::Response.into(),
  247. src_cid: host_cid.into(),
  248. dst_cid: guest_cid.into(),
  249. src_port: host_port.into(),
  250. dst_port: guest_port.into(),
  251. len: 0.into(),
  252. socket_type: SocketType::Stream.into(),
  253. flags: 0.into(),
  254. buf_alloc: 50.into(),
  255. fwd_cnt: 0.into(),
  256. }
  257. .as_bytes(),
  258. );
  259. // Expect a credit update.
  260. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  261. assert_eq!(
  262. VirtioVsockHdr::read_from(
  263. state
  264. .lock()
  265. .unwrap()
  266. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  267. .as_slice()
  268. )
  269. .unwrap(),
  270. VirtioVsockHdr {
  271. op: VirtioVsockOp::CreditUpdate.into(),
  272. src_cid: guest_cid.into(),
  273. dst_cid: host_cid.into(),
  274. src_port: guest_port.into(),
  275. dst_port: host_port.into(),
  276. len: 0.into(),
  277. socket_type: SocketType::Stream.into(),
  278. flags: 0.into(),
  279. buf_alloc: 0.into(),
  280. fwd_cnt: 0.into(),
  281. }
  282. );
  283. // Expect the guest to send some data.
  284. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  285. let request = state
  286. .lock()
  287. .unwrap()
  288. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
  289. assert_eq!(
  290. request.len(),
  291. size_of::<VirtioVsockHdr>() + hello_from_guest.len()
  292. );
  293. assert_eq!(
  294. VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
  295. VirtioVsockHdr {
  296. op: VirtioVsockOp::Rw.into(),
  297. src_cid: guest_cid.into(),
  298. dst_cid: host_cid.into(),
  299. src_port: guest_port.into(),
  300. dst_port: host_port.into(),
  301. len: (hello_from_guest.len() as u32).into(),
  302. socket_type: SocketType::Stream.into(),
  303. flags: 0.into(),
  304. buf_alloc: 0.into(),
  305. fwd_cnt: 0.into(),
  306. }
  307. );
  308. assert_eq!(
  309. &request[size_of::<VirtioVsockHdr>()..],
  310. hello_from_guest.as_bytes()
  311. );
  312. // Send a response.
  313. let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
  314. VirtioVsockHdr {
  315. op: VirtioVsockOp::Rw.into(),
  316. src_cid: host_cid.into(),
  317. dst_cid: guest_cid.into(),
  318. src_port: host_port.into(),
  319. dst_port: guest_port.into(),
  320. len: (hello_from_host.len() as u32).into(),
  321. socket_type: SocketType::Stream.into(),
  322. flags: 0.into(),
  323. buf_alloc: 50.into(),
  324. fwd_cnt: (hello_from_guest.len() as u32).into(),
  325. }
  326. .write_to_prefix(response.as_mut_slice());
  327. response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
  328. state
  329. .lock()
  330. .unwrap()
  331. .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
  332. // Expect a credit update.
  333. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  334. assert_eq!(
  335. VirtioVsockHdr::read_from(
  336. state
  337. .lock()
  338. .unwrap()
  339. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  340. .as_slice()
  341. )
  342. .unwrap(),
  343. VirtioVsockHdr {
  344. op: VirtioVsockOp::CreditUpdate.into(),
  345. src_cid: guest_cid.into(),
  346. dst_cid: host_cid.into(),
  347. src_port: guest_port.into(),
  348. dst_port: host_port.into(),
  349. len: 0.into(),
  350. socket_type: SocketType::Stream.into(),
  351. flags: 0.into(),
  352. buf_alloc: 64.into(),
  353. fwd_cnt: 0.into(),
  354. }
  355. );
  356. // Expect a shutdown.
  357. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  358. assert_eq!(
  359. VirtioVsockHdr::read_from(
  360. state
  361. .lock()
  362. .unwrap()
  363. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  364. .as_slice()
  365. )
  366. .unwrap(),
  367. VirtioVsockHdr {
  368. op: VirtioVsockOp::Shutdown.into(),
  369. src_cid: guest_cid.into(),
  370. dst_cid: host_cid.into(),
  371. src_port: guest_port.into(),
  372. dst_port: host_port.into(),
  373. len: 0.into(),
  374. socket_type: SocketType::Stream.into(),
  375. flags: 0.into(),
  376. buf_alloc: 0.into(),
  377. fwd_cnt: (hello_from_host.len() as u32).into(),
  378. }
  379. );
  380. });
  381. socket.connect(host_address, guest_port).unwrap();
  382. socket.wait_for_connect().unwrap();
  383. socket.send(hello_from_guest.as_bytes()).unwrap();
  384. let mut buffer = [0u8; 64];
  385. let event = socket.wait_for_recv(&mut buffer).unwrap();
  386. assert_eq!(
  387. event,
  388. VsockEvent {
  389. source: VsockAddr {
  390. cid: host_cid,
  391. port: host_port,
  392. },
  393. destination: VsockAddr {
  394. cid: guest_cid,
  395. port: guest_port,
  396. },
  397. event_type: VsockEventType::Received {
  398. length: hello_from_host.len()
  399. },
  400. buffer_status: VsockBufferStatus {
  401. buffer_allocation: 50,
  402. forward_count: hello_from_guest.len() as u32,
  403. },
  404. }
  405. );
  406. assert_eq!(
  407. &buffer[0..hello_from_host.len()],
  408. hello_from_host.as_bytes()
  409. );
  410. socket.shutdown().unwrap();
  411. handle.join().unwrap();
  412. }
  413. }