multiconnectionmanager.rs 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494
  1. use super::{
  2. protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
  3. VsockEventType,
  4. };
  5. use crate::{transport::Transport, Hal, Result};
  6. use alloc::{boxed::Box, vec::Vec};
  7. use core::cmp::min;
  8. use core::convert::TryInto;
  9. use core::hint::spin_loop;
  10. use log::debug;
  11. use zerocopy::FromBytes;
  12. const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
  13. /// A higher level interface for vsock devices.
  14. ///
  15. /// This keeps track of a single vsock connection.
  16. pub struct VsockConnectionManager<H: Hal, T: Transport> {
  17. driver: VirtIOSocket<H, T>,
  18. connections: Vec<Connection>,
  19. }
  20. #[derive(Debug)]
  21. struct Connection {
  22. info: ConnectionInfo,
  23. buffer: RingBuffer,
  24. }
  25. impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
  26. /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
  27. pub fn new(driver: VirtIOSocket<H, T>) -> Self {
  28. Self {
  29. driver,
  30. connections: Vec::new(),
  31. }
  32. }
  33. /// Returns the CID which has been assigned to this guest.
  34. pub fn guest_cid(&self) -> u64 {
  35. self.driver.guest_cid()
  36. }
  37. /// Sends a request to connect to the given destination.
  38. ///
  39. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  40. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  41. /// before sending data.
  42. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  43. if self.connections.iter().any(|connection| {
  44. connection.info.dst == destination && connection.info.src_port == src_port
  45. }) {
  46. return Err(SocketError::ConnectionExists.into());
  47. }
  48. let mut new_connection_info = ConnectionInfo::new(destination, src_port);
  49. new_connection_info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
  50. self.driver.connect(&new_connection_info)?;
  51. debug!("Connection requested: {:?}", new_connection_info);
  52. self.connections.push(Connection {
  53. info: new_connection_info,
  54. buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
  55. });
  56. Ok(())
  57. }
  58. /// Sends the buffer to the destination.
  59. pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
  60. let connection = self
  61. .connections
  62. .iter_mut()
  63. .find(|connection| {
  64. connection.info.dst == destination && connection.info.src_port == src_port
  65. })
  66. .ok_or(SocketError::NotConnected)?;
  67. self.driver.send(buffer, &mut connection.info)
  68. }
  69. /// Polls the vsock device to receive data or other updates.
  70. pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
  71. let guest_cid = self.driver.guest_cid();
  72. let connections = &mut self.connections;
  73. self.driver.poll_recv(|event, body| {
  74. let connection = connections
  75. .iter_mut()
  76. .find(|connection| event.matches_connection(&connection.info, guest_cid));
  77. let Some(connection) = connection else {
  78. // Skip events which don't match any connection we know about.
  79. return Ok(None);
  80. };
  81. // Update stored connection info.
  82. connection.info.update_for_event(&event);
  83. match event.event_type {
  84. VsockEventType::ConnectionRequest => {
  85. // TODO: Send Rst or handle incoming connections.
  86. }
  87. VsockEventType::Connected => {}
  88. VsockEventType::Disconnected { .. } => {
  89. // TODO: Wait until client reads all data before removing connection.
  90. //self.connection_info = None;
  91. }
  92. VsockEventType::Received { length } => {
  93. // Copy to buffer
  94. if !connection.buffer.write(body) {
  95. return Err(SocketError::OutputBufferTooShort(length).into());
  96. }
  97. }
  98. VsockEventType::CreditRequest => {
  99. // TODO: Send a credit update.
  100. }
  101. VsockEventType::CreditUpdate => {}
  102. }
  103. Ok(Some(event))
  104. })
  105. }
  106. /// Reads data received from the given connection.
  107. pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
  108. let connection = self
  109. .connections
  110. .iter_mut()
  111. .find(|connection| connection.info.dst == peer && connection.info.src_port == src_port)
  112. .ok_or(SocketError::NotConnected)?;
  113. // Copy from ring buffer
  114. let bytes_read = connection.buffer.read(buffer);
  115. connection.info.done_forwarding(bytes_read);
  116. Ok(bytes_read)
  117. }
  118. /// Blocks until we get some event from the vsock device.
  119. pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
  120. loop {
  121. if let Some(event) = self.poll()? {
  122. return Ok(event);
  123. } else {
  124. spin_loop();
  125. }
  126. }
  127. }
  128. /// Requests to shut down the connection cleanly.
  129. ///
  130. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  131. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  132. /// shutdown.
  133. pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  134. let connection = self
  135. .connections
  136. .iter()
  137. .find(|connection| {
  138. connection.info.dst == destination && connection.info.src_port == src_port
  139. })
  140. .ok_or(SocketError::NotConnected)?;
  141. self.driver.shutdown(&connection.info)
  142. }
  143. /// Forcibly closes the connection without waiting for the peer.
  144. pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  145. let (index, connection) = self
  146. .connections
  147. .iter()
  148. .enumerate()
  149. .find(|(_, connection)| {
  150. connection.info.dst == destination && connection.info.src_port == src_port
  151. })
  152. .ok_or(SocketError::NotConnected)?;
  153. self.driver.force_close(&connection.info)?;
  154. self.connections.swap_remove(index);
  155. Ok(())
  156. }
  157. }
  158. #[derive(Debug)]
  159. struct RingBuffer {
  160. buffer: Box<[u8]>,
  161. /// The number of bytes currently in the buffer.
  162. used: usize,
  163. /// The index of the first used byte in the buffer.
  164. start: usize,
  165. }
  166. impl RingBuffer {
  167. pub fn new(capacity: usize) -> Self {
  168. Self {
  169. buffer: FromBytes::new_box_slice_zeroed(capacity),
  170. used: 0,
  171. start: 0,
  172. }
  173. }
  174. /// Returns the number of bytes currently used in the buffer.
  175. pub fn used(&self) -> usize {
  176. self.used
  177. }
  178. /// Returns the number of bytes currently free in the buffer.
  179. pub fn available(&self) -> usize {
  180. self.buffer.len() - self.used
  181. }
  182. /// Adds the given bytes to the buffer if there is enough capacity for them all.
  183. ///
  184. /// Returns true if they were added, or false if they were not.
  185. pub fn write(&mut self, bytes: &[u8]) -> bool {
  186. if bytes.len() > self.available() {
  187. return false;
  188. }
  189. let end = (self.start + self.used) % self.buffer.len();
  190. let write_before_wraparound = min(bytes.len(), self.buffer.len() - end);
  191. let write_after_wraparound = bytes
  192. .len()
  193. .checked_sub(write_before_wraparound)
  194. .unwrap_or_default();
  195. self.buffer[end..end + write_before_wraparound]
  196. .copy_from_slice(&bytes[0..write_before_wraparound]);
  197. self.buffer[0..write_after_wraparound].copy_from_slice(&bytes[write_before_wraparound..]);
  198. self.used += bytes.len();
  199. true
  200. }
  201. /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
  202. /// buffer.
  203. pub fn read(&mut self, out: &mut [u8]) -> usize {
  204. let bytes_read = min(self.used, out.len());
  205. // The number of bytes to copy out between `start` and the end of the buffer.
  206. let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
  207. // The number of bytes to copy out from the beginning of the buffer after wrapping around.
  208. let read_after_wraparound = bytes_read
  209. .checked_sub(read_before_wraparound)
  210. .unwrap_or_default();
  211. out[0..read_before_wraparound]
  212. .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
  213. out[read_before_wraparound..bytes_read]
  214. .copy_from_slice(&self.buffer[0..read_after_wraparound]);
  215. self.used -= bytes_read;
  216. self.start = (self.start + bytes_read) % self.buffer.len();
  217. bytes_read
  218. }
  219. }
  220. #[cfg(test)]
  221. mod tests {
  222. use super::*;
  223. use crate::{
  224. device::socket::{
  225. protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
  226. vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
  227. },
  228. hal::fake::FakeHal,
  229. transport::{
  230. fake::{FakeTransport, QueueStatus, State},
  231. DeviceStatus, DeviceType,
  232. },
  233. volatile::ReadOnly,
  234. };
  235. use alloc::{sync::Arc, vec};
  236. use core::{mem::size_of, ptr::NonNull};
  237. use std::{sync::Mutex, thread};
  238. use zerocopy::{AsBytes, FromBytes};
  239. #[test]
  240. fn send_recv() {
  241. let host_cid = 2;
  242. let guest_cid = 66;
  243. let host_port = 1234;
  244. let guest_port = 4321;
  245. let host_address = VsockAddr {
  246. cid: host_cid,
  247. port: host_port,
  248. };
  249. let hello_from_guest = "Hello from guest";
  250. let hello_from_host = "Hello from host";
  251. let mut config_space = VirtioVsockConfig {
  252. guest_cid_low: ReadOnly::new(66),
  253. guest_cid_high: ReadOnly::new(0),
  254. };
  255. let state = Arc::new(Mutex::new(State {
  256. status: DeviceStatus::empty(),
  257. driver_features: 0,
  258. guest_page_size: 0,
  259. interrupt_pending: false,
  260. queues: vec![
  261. QueueStatus::default(),
  262. QueueStatus::default(),
  263. QueueStatus::default(),
  264. ],
  265. }));
  266. let transport = FakeTransport {
  267. device_type: DeviceType::Socket,
  268. max_queue_size: 32,
  269. device_features: 0,
  270. config_space: NonNull::from(&mut config_space),
  271. state: state.clone(),
  272. };
  273. let mut socket = VsockConnectionManager::new(
  274. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  275. );
  276. // Start a thread to simulate the device.
  277. let handle = thread::spawn(move || {
  278. // Wait for connection request.
  279. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  280. assert_eq!(
  281. VirtioVsockHdr::read_from(
  282. state
  283. .lock()
  284. .unwrap()
  285. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  286. .as_slice()
  287. )
  288. .unwrap(),
  289. VirtioVsockHdr {
  290. op: VirtioVsockOp::Request.into(),
  291. src_cid: guest_cid.into(),
  292. dst_cid: host_cid.into(),
  293. src_port: guest_port.into(),
  294. dst_port: host_port.into(),
  295. len: 0.into(),
  296. socket_type: SocketType::Stream.into(),
  297. flags: 0.into(),
  298. buf_alloc: 1024.into(),
  299. fwd_cnt: 0.into(),
  300. }
  301. );
  302. // Accept connection and give the peer enough credit to send the message.
  303. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  304. RX_QUEUE_IDX,
  305. VirtioVsockHdr {
  306. op: VirtioVsockOp::Response.into(),
  307. src_cid: host_cid.into(),
  308. dst_cid: guest_cid.into(),
  309. src_port: host_port.into(),
  310. dst_port: guest_port.into(),
  311. len: 0.into(),
  312. socket_type: SocketType::Stream.into(),
  313. flags: 0.into(),
  314. buf_alloc: 50.into(),
  315. fwd_cnt: 0.into(),
  316. }
  317. .as_bytes(),
  318. );
  319. // Expect the guest to send some data.
  320. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  321. let request = state
  322. .lock()
  323. .unwrap()
  324. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
  325. assert_eq!(
  326. request.len(),
  327. size_of::<VirtioVsockHdr>() + hello_from_guest.len()
  328. );
  329. assert_eq!(
  330. VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
  331. VirtioVsockHdr {
  332. op: VirtioVsockOp::Rw.into(),
  333. src_cid: guest_cid.into(),
  334. dst_cid: host_cid.into(),
  335. src_port: guest_port.into(),
  336. dst_port: host_port.into(),
  337. len: (hello_from_guest.len() as u32).into(),
  338. socket_type: SocketType::Stream.into(),
  339. flags: 0.into(),
  340. buf_alloc: 1024.into(),
  341. fwd_cnt: 0.into(),
  342. }
  343. );
  344. assert_eq!(
  345. &request[size_of::<VirtioVsockHdr>()..],
  346. hello_from_guest.as_bytes()
  347. );
  348. println!("Host sending");
  349. // Send a response.
  350. let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
  351. VirtioVsockHdr {
  352. op: VirtioVsockOp::Rw.into(),
  353. src_cid: host_cid.into(),
  354. dst_cid: guest_cid.into(),
  355. src_port: host_port.into(),
  356. dst_port: guest_port.into(),
  357. len: (hello_from_host.len() as u32).into(),
  358. socket_type: SocketType::Stream.into(),
  359. flags: 0.into(),
  360. buf_alloc: 50.into(),
  361. fwd_cnt: (hello_from_guest.len() as u32).into(),
  362. }
  363. .write_to_prefix(response.as_mut_slice());
  364. response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
  365. state
  366. .lock()
  367. .unwrap()
  368. .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
  369. // Expect a shutdown.
  370. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  371. assert_eq!(
  372. VirtioVsockHdr::read_from(
  373. state
  374. .lock()
  375. .unwrap()
  376. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  377. .as_slice()
  378. )
  379. .unwrap(),
  380. VirtioVsockHdr {
  381. op: VirtioVsockOp::Shutdown.into(),
  382. src_cid: guest_cid.into(),
  383. dst_cid: host_cid.into(),
  384. src_port: guest_port.into(),
  385. dst_port: host_port.into(),
  386. len: 0.into(),
  387. socket_type: SocketType::Stream.into(),
  388. flags: 0.into(),
  389. buf_alloc: 1024.into(),
  390. fwd_cnt: (hello_from_host.len() as u32).into(),
  391. }
  392. );
  393. });
  394. socket.connect(host_address, guest_port).unwrap();
  395. assert_eq!(
  396. socket.wait_for_event().unwrap(),
  397. VsockEvent {
  398. source: host_address,
  399. destination: VsockAddr {
  400. cid: guest_cid,
  401. port: guest_port,
  402. },
  403. event_type: VsockEventType::Connected,
  404. buffer_status: VsockBufferStatus {
  405. buffer_allocation: 50,
  406. forward_count: 0,
  407. },
  408. }
  409. );
  410. println!("Guest sending");
  411. socket
  412. .send(host_address, guest_port, "Hello from guest".as_bytes())
  413. .unwrap();
  414. println!("Guest waiting to receive.");
  415. assert_eq!(
  416. socket.wait_for_event().unwrap(),
  417. VsockEvent {
  418. source: host_address,
  419. destination: VsockAddr {
  420. cid: guest_cid,
  421. port: guest_port,
  422. },
  423. event_type: VsockEventType::Received {
  424. length: hello_from_host.len()
  425. },
  426. buffer_status: VsockBufferStatus {
  427. buffer_allocation: 50,
  428. forward_count: hello_from_guest.len() as u32,
  429. },
  430. }
  431. );
  432. println!("Guest getting received data.");
  433. let mut buffer = [0u8; 64];
  434. assert_eq!(
  435. socket.recv(host_address, guest_port, &mut buffer).unwrap(),
  436. hello_from_host.len()
  437. );
  438. assert_eq!(
  439. &buffer[0..hello_from_host.len()],
  440. hello_from_host.as_bytes()
  441. );
  442. socket.shutdown(host_address, guest_port).unwrap();
  443. handle.join().unwrap();
  444. }
  445. }