multiconnectionmanager.rs 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757
  1. use super::{
  2. protocol::VsockAddr, vsock::ConnectionInfo, DisconnectReason, SocketError, VirtIOSocket,
  3. VsockEvent, VsockEventType,
  4. };
  5. use crate::{transport::Transport, Hal, Result};
  6. use alloc::{boxed::Box, vec::Vec};
  7. use core::cmp::min;
  8. use core::convert::TryInto;
  9. use core::hint::spin_loop;
  10. use log::debug;
  11. use zerocopy::FromBytes;
  12. const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
  13. /// A higher level interface for VirtIO socket (vsock) devices.
  14. ///
  15. /// This keeps track of multiple vsock connections.
  16. ///
  17. /// # Example
  18. ///
  19. /// ```
  20. /// # use virtio_drivers::{Error, Hal};
  21. /// # use virtio_drivers::transport::Transport;
  22. /// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
  23. ///
  24. /// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
  25. /// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
  26. ///
  27. /// // Start a thread to call `socket.poll()` and handle events.
  28. ///
  29. /// let remote_address = VsockAddr { cid: 2, port: 42 };
  30. /// let local_port = 1234;
  31. /// socket.connect(remote_address, local_port)?;
  32. ///
  33. /// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
  34. ///
  35. /// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
  36. ///
  37. /// socket.shutdown(remote_address, local_port)?;
  38. /// # Ok(())
  39. /// # }
  40. /// ```
  41. pub struct VsockConnectionManager<H: Hal, T: Transport> {
  42. driver: VirtIOSocket<H, T>,
  43. connections: Vec<Connection>,
  44. listening_ports: Vec<u32>,
  45. }
  46. #[derive(Debug)]
  47. struct Connection {
  48. info: ConnectionInfo,
  49. buffer: RingBuffer,
  50. /// The peer sent a SHUTDOWN request, but we haven't yet responded with a RST because there is
  51. /// still data in the buffer.
  52. peer_requested_shutdown: bool,
  53. }
  54. impl Connection {
  55. fn new(peer: VsockAddr, local_port: u32) -> Self {
  56. let mut info = ConnectionInfo::new(peer, local_port);
  57. info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
  58. Self {
  59. info,
  60. buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
  61. peer_requested_shutdown: false,
  62. }
  63. }
  64. }
  65. impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
  66. /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
  67. pub fn new(driver: VirtIOSocket<H, T>) -> Self {
  68. Self {
  69. driver,
  70. connections: Vec::new(),
  71. listening_ports: Vec::new(),
  72. }
  73. }
  74. /// Returns the CID which has been assigned to this guest.
  75. pub fn guest_cid(&self) -> u64 {
  76. self.driver.guest_cid()
  77. }
  78. /// Allows incoming connections on the given port number.
  79. pub fn listen(&mut self, port: u32) {
  80. if !self.listening_ports.contains(&port) {
  81. self.listening_ports.push(port);
  82. }
  83. }
  84. /// Stops allowing incoming connections on the given port number.
  85. pub fn unlisten(&mut self, port: u32) {
  86. self.listening_ports.retain(|p| *p != port);
  87. }
  88. /// Sends a request to connect to the given destination.
  89. ///
  90. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  91. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  92. /// before sending data.
  93. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  94. if self.connections.iter().any(|connection| {
  95. connection.info.dst == destination && connection.info.src_port == src_port
  96. }) {
  97. return Err(SocketError::ConnectionExists.into());
  98. }
  99. let new_connection = Connection::new(destination, src_port);
  100. self.driver.connect(&new_connection.info)?;
  101. debug!("Connection requested: {:?}", new_connection.info);
  102. self.connections.push(new_connection);
  103. Ok(())
  104. }
  105. /// Sends the buffer to the destination.
  106. pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
  107. let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
  108. self.driver.send(buffer, &mut connection.info)
  109. }
  110. /// Polls the vsock device to receive data or other updates.
  111. pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
  112. let guest_cid = self.driver.guest_cid();
  113. let connections = &mut self.connections;
  114. let result = self.driver.poll(|event, body| {
  115. let connection = get_connection_for_event(connections, &event, guest_cid);
  116. // Skip events which don't match any connection we know about, unless they are a
  117. // connection request.
  118. let connection = if let Some((_, connection)) = connection {
  119. connection
  120. } else if let VsockEventType::ConnectionRequest = event.event_type {
  121. // If the requested connection already exists or the CID isn't ours, ignore it.
  122. if connection.is_some() || event.destination.cid != guest_cid {
  123. return Ok(None);
  124. }
  125. // Add the new connection to our list, at least for now. It will be removed again
  126. // below if we weren't listening on the port.
  127. connections.push(Connection::new(event.source, event.destination.port));
  128. connections.last_mut().unwrap()
  129. } else {
  130. return Ok(None);
  131. };
  132. // Update stored connection info.
  133. connection.info.update_for_event(&event);
  134. if let VsockEventType::Received { length } = event.event_type {
  135. // Copy to buffer
  136. if !connection.buffer.add(body) {
  137. return Err(SocketError::OutputBufferTooShort(length).into());
  138. }
  139. }
  140. Ok(Some(event))
  141. })?;
  142. let Some(event) = result else {
  143. return Ok(None);
  144. };
  145. // The connection must exist because we found it above in the callback.
  146. let (connection_index, connection) =
  147. get_connection_for_event(connections, &event, guest_cid).unwrap();
  148. match event.event_type {
  149. VsockEventType::ConnectionRequest => {
  150. if self.listening_ports.contains(&event.destination.port) {
  151. self.driver.accept(&connection.info)?;
  152. } else {
  153. // Reject the connection request and remove it from our list.
  154. self.driver.force_close(&connection.info)?;
  155. self.connections.swap_remove(connection_index);
  156. // No need to pass the request on to the client, as we've already rejected it.
  157. return Ok(None);
  158. }
  159. }
  160. VsockEventType::Connected => {}
  161. VsockEventType::Disconnected { reason } => {
  162. // Wait until client reads all data before removing connection.
  163. if connection.buffer.is_empty() {
  164. if reason == DisconnectReason::Shutdown {
  165. self.driver.force_close(&connection.info)?;
  166. }
  167. self.connections.swap_remove(connection_index);
  168. } else {
  169. connection.peer_requested_shutdown = true;
  170. }
  171. }
  172. VsockEventType::Received { .. } => {
  173. // Already copied the buffer in the callback above.
  174. }
  175. VsockEventType::CreditRequest => {
  176. // If the peer requested credit, send an update.
  177. self.driver.credit_update(&connection.info)?;
  178. // No need to pass the request on to the client, we've already handled it.
  179. return Ok(None);
  180. }
  181. VsockEventType::CreditUpdate => {}
  182. }
  183. Ok(Some(event))
  184. }
  185. /// Reads data received from the given connection.
  186. pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
  187. let (connection_index, connection) = get_connection(&mut self.connections, peer, src_port)?;
  188. // Copy from ring buffer
  189. let bytes_read = connection.buffer.drain(buffer);
  190. connection.info.done_forwarding(bytes_read);
  191. // If buffer is now empty and the peer requested shutdown, finish shutting down the
  192. // connection.
  193. if connection.peer_requested_shutdown && connection.buffer.is_empty() {
  194. self.driver.force_close(&connection.info)?;
  195. self.connections.swap_remove(connection_index);
  196. }
  197. Ok(bytes_read)
  198. }
  199. /// Blocks until we get some event from the vsock device.
  200. pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
  201. loop {
  202. if let Some(event) = self.poll()? {
  203. return Ok(event);
  204. } else {
  205. spin_loop();
  206. }
  207. }
  208. }
  209. /// Requests to shut down the connection cleanly.
  210. ///
  211. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  212. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  213. /// shutdown.
  214. pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  215. let (_, connection) = get_connection(&mut self.connections, destination, src_port)?;
  216. self.driver.shutdown(&connection.info)
  217. }
  218. /// Forcibly closes the connection without waiting for the peer.
  219. pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  220. let (index, connection) = get_connection(&mut self.connections, destination, src_port)?;
  221. self.driver.force_close(&connection.info)?;
  222. self.connections.swap_remove(index);
  223. Ok(())
  224. }
  225. }
  226. /// Returns the connection from the given list matching the given peer address and local port, and
  227. /// its index.
  228. ///
  229. /// Returns `Err(SocketError::NotConnected)` if there is no matching connection in the list.
  230. fn get_connection(
  231. connections: &mut [Connection],
  232. peer: VsockAddr,
  233. local_port: u32,
  234. ) -> core::result::Result<(usize, &mut Connection), SocketError> {
  235. connections
  236. .iter_mut()
  237. .enumerate()
  238. .find(|(_, connection)| {
  239. connection.info.dst == peer && connection.info.src_port == local_port
  240. })
  241. .ok_or(SocketError::NotConnected)
  242. }
  243. /// Returns the connection from the given list matching the event, if any, and its index.
  244. fn get_connection_for_event<'a>(
  245. connections: &'a mut [Connection],
  246. event: &VsockEvent,
  247. local_cid: u64,
  248. ) -> Option<(usize, &'a mut Connection)> {
  249. connections
  250. .iter_mut()
  251. .enumerate()
  252. .find(|(_, connection)| event.matches_connection(&connection.info, local_cid))
  253. }
  254. #[derive(Debug)]
  255. struct RingBuffer {
  256. buffer: Box<[u8]>,
  257. /// The number of bytes currently in the buffer.
  258. used: usize,
  259. /// The index of the first used byte in the buffer.
  260. start: usize,
  261. }
  262. impl RingBuffer {
  263. pub fn new(capacity: usize) -> Self {
  264. Self {
  265. buffer: FromBytes::new_box_slice_zeroed(capacity),
  266. used: 0,
  267. start: 0,
  268. }
  269. }
  270. /// Returns the number of bytes currently used in the buffer.
  271. pub fn used(&self) -> usize {
  272. self.used
  273. }
  274. /// Returns true iff there are currently no bytes in the buffer.
  275. pub fn is_empty(&self) -> bool {
  276. self.used == 0
  277. }
  278. /// Returns the number of bytes currently free in the buffer.
  279. pub fn available(&self) -> usize {
  280. self.buffer.len() - self.used
  281. }
  282. /// Adds the given bytes to the buffer if there is enough capacity for them all.
  283. ///
  284. /// Returns true if they were added, or false if they were not.
  285. pub fn add(&mut self, bytes: &[u8]) -> bool {
  286. if bytes.len() > self.available() {
  287. return false;
  288. }
  289. // The index of the first available position in the buffer.
  290. let first_available = (self.start + self.used) % self.buffer.len();
  291. // The number of bytes to copy from `bytes` to `buffer` between `first_available` and
  292. // `buffer.len()`.
  293. let copy_length_before_wraparound = min(bytes.len(), self.buffer.len() - first_available);
  294. self.buffer[first_available..first_available + copy_length_before_wraparound]
  295. .copy_from_slice(&bytes[0..copy_length_before_wraparound]);
  296. if let Some(bytes_after_wraparound) = bytes.get(copy_length_before_wraparound..) {
  297. self.buffer[0..bytes_after_wraparound.len()].copy_from_slice(bytes_after_wraparound);
  298. }
  299. self.used += bytes.len();
  300. true
  301. }
  302. /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
  303. /// buffer.
  304. pub fn drain(&mut self, out: &mut [u8]) -> usize {
  305. let bytes_read = min(self.used, out.len());
  306. // The number of bytes to copy out between `start` and the end of the buffer.
  307. let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
  308. // The number of bytes to copy out from the beginning of the buffer after wrapping around.
  309. let read_after_wraparound = bytes_read
  310. .checked_sub(read_before_wraparound)
  311. .unwrap_or_default();
  312. out[0..read_before_wraparound]
  313. .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
  314. out[read_before_wraparound..bytes_read]
  315. .copy_from_slice(&self.buffer[0..read_after_wraparound]);
  316. self.used -= bytes_read;
  317. self.start = (self.start + bytes_read) % self.buffer.len();
  318. bytes_read
  319. }
  320. }
  321. #[cfg(test)]
  322. mod tests {
  323. use super::*;
  324. use crate::{
  325. device::socket::{
  326. protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
  327. vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
  328. },
  329. hal::fake::FakeHal,
  330. transport::{
  331. fake::{FakeTransport, QueueStatus, State},
  332. DeviceType,
  333. },
  334. volatile::ReadOnly,
  335. };
  336. use alloc::{sync::Arc, vec};
  337. use core::{mem::size_of, ptr::NonNull};
  338. use std::{sync::Mutex, thread};
  339. use zerocopy::{AsBytes, FromBytes};
  340. #[test]
  341. fn send_recv() {
  342. let host_cid = 2;
  343. let guest_cid = 66;
  344. let host_port = 1234;
  345. let guest_port = 4321;
  346. let host_address = VsockAddr {
  347. cid: host_cid,
  348. port: host_port,
  349. };
  350. let hello_from_guest = "Hello from guest";
  351. let hello_from_host = "Hello from host";
  352. let mut config_space = VirtioVsockConfig {
  353. guest_cid_low: ReadOnly::new(66),
  354. guest_cid_high: ReadOnly::new(0),
  355. };
  356. let state = Arc::new(Mutex::new(State {
  357. queues: vec![
  358. QueueStatus::default(),
  359. QueueStatus::default(),
  360. QueueStatus::default(),
  361. ],
  362. ..Default::default()
  363. }));
  364. let transport = FakeTransport {
  365. device_type: DeviceType::Socket,
  366. max_queue_size: 32,
  367. device_features: 0,
  368. config_space: NonNull::from(&mut config_space),
  369. state: state.clone(),
  370. };
  371. let mut socket = VsockConnectionManager::new(
  372. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  373. );
  374. // Start a thread to simulate the device.
  375. let handle = thread::spawn(move || {
  376. // Wait for connection request.
  377. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  378. assert_eq!(
  379. VirtioVsockHdr::read_from(
  380. state
  381. .lock()
  382. .unwrap()
  383. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  384. .as_slice()
  385. )
  386. .unwrap(),
  387. VirtioVsockHdr {
  388. op: VirtioVsockOp::Request.into(),
  389. src_cid: guest_cid.into(),
  390. dst_cid: host_cid.into(),
  391. src_port: guest_port.into(),
  392. dst_port: host_port.into(),
  393. len: 0.into(),
  394. socket_type: SocketType::Stream.into(),
  395. flags: 0.into(),
  396. buf_alloc: 1024.into(),
  397. fwd_cnt: 0.into(),
  398. }
  399. );
  400. // Accept connection and give the peer enough credit to send the message.
  401. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  402. RX_QUEUE_IDX,
  403. VirtioVsockHdr {
  404. op: VirtioVsockOp::Response.into(),
  405. src_cid: host_cid.into(),
  406. dst_cid: guest_cid.into(),
  407. src_port: host_port.into(),
  408. dst_port: guest_port.into(),
  409. len: 0.into(),
  410. socket_type: SocketType::Stream.into(),
  411. flags: 0.into(),
  412. buf_alloc: 50.into(),
  413. fwd_cnt: 0.into(),
  414. }
  415. .as_bytes(),
  416. );
  417. // Expect the guest to send some data.
  418. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  419. let request = state
  420. .lock()
  421. .unwrap()
  422. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
  423. assert_eq!(
  424. request.len(),
  425. size_of::<VirtioVsockHdr>() + hello_from_guest.len()
  426. );
  427. assert_eq!(
  428. VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
  429. VirtioVsockHdr {
  430. op: VirtioVsockOp::Rw.into(),
  431. src_cid: guest_cid.into(),
  432. dst_cid: host_cid.into(),
  433. src_port: guest_port.into(),
  434. dst_port: host_port.into(),
  435. len: (hello_from_guest.len() as u32).into(),
  436. socket_type: SocketType::Stream.into(),
  437. flags: 0.into(),
  438. buf_alloc: 1024.into(),
  439. fwd_cnt: 0.into(),
  440. }
  441. );
  442. assert_eq!(
  443. &request[size_of::<VirtioVsockHdr>()..],
  444. hello_from_guest.as_bytes()
  445. );
  446. println!("Host sending");
  447. // Send a response.
  448. let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
  449. VirtioVsockHdr {
  450. op: VirtioVsockOp::Rw.into(),
  451. src_cid: host_cid.into(),
  452. dst_cid: guest_cid.into(),
  453. src_port: host_port.into(),
  454. dst_port: guest_port.into(),
  455. len: (hello_from_host.len() as u32).into(),
  456. socket_type: SocketType::Stream.into(),
  457. flags: 0.into(),
  458. buf_alloc: 50.into(),
  459. fwd_cnt: (hello_from_guest.len() as u32).into(),
  460. }
  461. .write_to_prefix(response.as_mut_slice());
  462. response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
  463. state
  464. .lock()
  465. .unwrap()
  466. .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
  467. // Expect a shutdown.
  468. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  469. assert_eq!(
  470. VirtioVsockHdr::read_from(
  471. state
  472. .lock()
  473. .unwrap()
  474. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  475. .as_slice()
  476. )
  477. .unwrap(),
  478. VirtioVsockHdr {
  479. op: VirtioVsockOp::Shutdown.into(),
  480. src_cid: guest_cid.into(),
  481. dst_cid: host_cid.into(),
  482. src_port: guest_port.into(),
  483. dst_port: host_port.into(),
  484. len: 0.into(),
  485. socket_type: SocketType::Stream.into(),
  486. flags: 0.into(),
  487. buf_alloc: 1024.into(),
  488. fwd_cnt: (hello_from_host.len() as u32).into(),
  489. }
  490. );
  491. });
  492. socket.connect(host_address, guest_port).unwrap();
  493. assert_eq!(
  494. socket.wait_for_event().unwrap(),
  495. VsockEvent {
  496. source: host_address,
  497. destination: VsockAddr {
  498. cid: guest_cid,
  499. port: guest_port,
  500. },
  501. event_type: VsockEventType::Connected,
  502. buffer_status: VsockBufferStatus {
  503. buffer_allocation: 50,
  504. forward_count: 0,
  505. },
  506. }
  507. );
  508. println!("Guest sending");
  509. socket
  510. .send(host_address, guest_port, "Hello from guest".as_bytes())
  511. .unwrap();
  512. println!("Guest waiting to receive.");
  513. assert_eq!(
  514. socket.wait_for_event().unwrap(),
  515. VsockEvent {
  516. source: host_address,
  517. destination: VsockAddr {
  518. cid: guest_cid,
  519. port: guest_port,
  520. },
  521. event_type: VsockEventType::Received {
  522. length: hello_from_host.len()
  523. },
  524. buffer_status: VsockBufferStatus {
  525. buffer_allocation: 50,
  526. forward_count: hello_from_guest.len() as u32,
  527. },
  528. }
  529. );
  530. println!("Guest getting received data.");
  531. let mut buffer = [0u8; 64];
  532. assert_eq!(
  533. socket.recv(host_address, guest_port, &mut buffer).unwrap(),
  534. hello_from_host.len()
  535. );
  536. assert_eq!(
  537. &buffer[0..hello_from_host.len()],
  538. hello_from_host.as_bytes()
  539. );
  540. socket.shutdown(host_address, guest_port).unwrap();
  541. handle.join().unwrap();
  542. }
  543. #[test]
  544. fn incoming_connection() {
  545. let host_cid = 2;
  546. let guest_cid = 66;
  547. let host_port = 1234;
  548. let guest_port = 4321;
  549. let wrong_guest_port = 4444;
  550. let host_address = VsockAddr {
  551. cid: host_cid,
  552. port: host_port,
  553. };
  554. let mut config_space = VirtioVsockConfig {
  555. guest_cid_low: ReadOnly::new(66),
  556. guest_cid_high: ReadOnly::new(0),
  557. };
  558. let state = Arc::new(Mutex::new(State {
  559. queues: vec![
  560. QueueStatus::default(),
  561. QueueStatus::default(),
  562. QueueStatus::default(),
  563. ],
  564. ..Default::default()
  565. }));
  566. let transport = FakeTransport {
  567. device_type: DeviceType::Socket,
  568. max_queue_size: 32,
  569. device_features: 0,
  570. config_space: NonNull::from(&mut config_space),
  571. state: state.clone(),
  572. };
  573. let mut socket = VsockConnectionManager::new(
  574. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  575. );
  576. socket.listen(guest_port);
  577. // Start a thread to simulate the device.
  578. let handle = thread::spawn(move || {
  579. // Send a connection request for a port the guest isn't listening on.
  580. println!("Host sending connection request to wrong port");
  581. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  582. RX_QUEUE_IDX,
  583. VirtioVsockHdr {
  584. op: VirtioVsockOp::Request.into(),
  585. src_cid: host_cid.into(),
  586. dst_cid: guest_cid.into(),
  587. src_port: host_port.into(),
  588. dst_port: wrong_guest_port.into(),
  589. len: 0.into(),
  590. socket_type: SocketType::Stream.into(),
  591. flags: 0.into(),
  592. buf_alloc: 50.into(),
  593. fwd_cnt: 0.into(),
  594. }
  595. .as_bytes(),
  596. );
  597. // Expect a rejection.
  598. println!("Host waiting for rejection");
  599. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  600. assert_eq!(
  601. VirtioVsockHdr::read_from(
  602. state
  603. .lock()
  604. .unwrap()
  605. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  606. .as_slice()
  607. )
  608. .unwrap(),
  609. VirtioVsockHdr {
  610. op: VirtioVsockOp::Rst.into(),
  611. src_cid: guest_cid.into(),
  612. dst_cid: host_cid.into(),
  613. src_port: wrong_guest_port.into(),
  614. dst_port: host_port.into(),
  615. len: 0.into(),
  616. socket_type: SocketType::Stream.into(),
  617. flags: 0.into(),
  618. buf_alloc: 1024.into(),
  619. fwd_cnt: 0.into(),
  620. }
  621. );
  622. // Send a connection request for a port the guest is listening on.
  623. println!("Host sending connection request to right port");
  624. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  625. RX_QUEUE_IDX,
  626. VirtioVsockHdr {
  627. op: VirtioVsockOp::Request.into(),
  628. src_cid: host_cid.into(),
  629. dst_cid: guest_cid.into(),
  630. src_port: host_port.into(),
  631. dst_port: guest_port.into(),
  632. len: 0.into(),
  633. socket_type: SocketType::Stream.into(),
  634. flags: 0.into(),
  635. buf_alloc: 50.into(),
  636. fwd_cnt: 0.into(),
  637. }
  638. .as_bytes(),
  639. );
  640. // Expect a response.
  641. println!("Host waiting for response");
  642. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  643. assert_eq!(
  644. VirtioVsockHdr::read_from(
  645. state
  646. .lock()
  647. .unwrap()
  648. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  649. .as_slice()
  650. )
  651. .unwrap(),
  652. VirtioVsockHdr {
  653. op: VirtioVsockOp::Response.into(),
  654. src_cid: guest_cid.into(),
  655. dst_cid: host_cid.into(),
  656. src_port: guest_port.into(),
  657. dst_port: host_port.into(),
  658. len: 0.into(),
  659. socket_type: SocketType::Stream.into(),
  660. flags: 0.into(),
  661. buf_alloc: 1024.into(),
  662. fwd_cnt: 0.into(),
  663. }
  664. );
  665. println!("Host finished");
  666. });
  667. // Expect an incoming connection.
  668. println!("Guest expecting incoming connection.");
  669. assert_eq!(
  670. socket.wait_for_event().unwrap(),
  671. VsockEvent {
  672. source: host_address,
  673. destination: VsockAddr {
  674. cid: guest_cid,
  675. port: guest_port,
  676. },
  677. event_type: VsockEventType::ConnectionRequest,
  678. buffer_status: VsockBufferStatus {
  679. buffer_allocation: 50,
  680. forward_count: 0,
  681. },
  682. }
  683. );
  684. handle.join().unwrap();
  685. }
  686. }