multiconnectionmanager.rs 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736
  1. use super::{
  2. protocol::VsockAddr, vsock::ConnectionInfo, SocketError, VirtIOSocket, VsockEvent,
  3. VsockEventType,
  4. };
  5. use crate::{transport::Transport, Hal, Result};
  6. use alloc::{boxed::Box, vec::Vec};
  7. use core::cmp::min;
  8. use core::convert::TryInto;
  9. use core::hint::spin_loop;
  10. use log::debug;
  11. use zerocopy::FromBytes;
  12. const PER_CONNECTION_BUFFER_CAPACITY: usize = 1024;
  13. /// A higher level interface for VirtIO socket (vsock) devices.
  14. ///
  15. /// This keeps track of multiple vsock connections.
  16. ///
  17. /// # Example
  18. ///
  19. /// ```
  20. /// # use virtio_drivers::{Error, Hal};
  21. /// # use virtio_drivers::transport::Transport;
  22. /// use virtio_drivers::device::socket::{VirtIOSocket, VsockAddr, VsockConnectionManager};
  23. ///
  24. /// # fn example<HalImpl: Hal, T: Transport>(transport: T) -> Result<(), Error> {
  25. /// let mut socket = VsockConnectionManager::new(VirtIOSocket::<HalImpl, _>::new(transport)?);
  26. ///
  27. /// // Start a thread to call `socket.poll()` and handle events.
  28. ///
  29. /// let remote_address = VsockAddr { cid: 2, port: 42 };
  30. /// let local_port = 1234;
  31. /// socket.connect(remote_address, local_port)?;
  32. ///
  33. /// // Wait until `socket.poll()` returns an event indicating that the socket is connected.
  34. ///
  35. /// socket.send(remote_address, local_port, "Hello world".as_bytes())?;
  36. ///
  37. /// socket.shutdown(remote_address, local_port)?;
  38. /// # Ok(())
  39. /// # }
  40. /// ```
  41. pub struct VsockConnectionManager<H: Hal, T: Transport> {
  42. driver: VirtIOSocket<H, T>,
  43. connections: Vec<Connection>,
  44. listening_ports: Vec<u32>,
  45. }
  46. #[derive(Debug)]
  47. struct Connection {
  48. info: ConnectionInfo,
  49. buffer: RingBuffer,
  50. }
  51. impl Connection {
  52. fn new(peer: VsockAddr, local_port: u32) -> Self {
  53. let mut info = ConnectionInfo::new(peer, local_port);
  54. info.buf_alloc = PER_CONNECTION_BUFFER_CAPACITY.try_into().unwrap();
  55. Self {
  56. info,
  57. buffer: RingBuffer::new(PER_CONNECTION_BUFFER_CAPACITY),
  58. }
  59. }
  60. }
  61. impl<H: Hal, T: Transport> VsockConnectionManager<H, T> {
  62. /// Construct a new connection manager wrapping the given low-level VirtIO socket driver.
  63. pub fn new(driver: VirtIOSocket<H, T>) -> Self {
  64. Self {
  65. driver,
  66. connections: Vec::new(),
  67. listening_ports: Vec::new(),
  68. }
  69. }
  70. /// Returns the CID which has been assigned to this guest.
  71. pub fn guest_cid(&self) -> u64 {
  72. self.driver.guest_cid()
  73. }
  74. /// Allows incoming connections on the given port number.
  75. pub fn listen(&mut self, port: u32) {
  76. if !self.listening_ports.contains(&port) {
  77. self.listening_ports.push(port);
  78. }
  79. }
  80. /// Stops allowing incoming connections on the given port number.
  81. pub fn unlisten(&mut self, port: u32) {
  82. self.listening_ports.retain(|p| *p != port);
  83. }
  84. /// Sends a request to connect to the given destination.
  85. ///
  86. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  87. /// `VsockEventType::Connected` event indicating that the peer has accepted the connection
  88. /// before sending data.
  89. pub fn connect(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  90. if self.connections.iter().any(|connection| {
  91. connection.info.dst == destination && connection.info.src_port == src_port
  92. }) {
  93. return Err(SocketError::ConnectionExists.into());
  94. }
  95. let new_connection = Connection::new(destination, src_port);
  96. self.driver.connect(&new_connection.info)?;
  97. debug!("Connection requested: {:?}", new_connection.info);
  98. self.connections.push(new_connection);
  99. Ok(())
  100. }
  101. /// Sends the buffer to the destination.
  102. pub fn send(&mut self, destination: VsockAddr, src_port: u32, buffer: &[u8]) -> Result {
  103. let connection = self
  104. .connections
  105. .iter_mut()
  106. .find(|connection| {
  107. connection.info.dst == destination && connection.info.src_port == src_port
  108. })
  109. .ok_or(SocketError::NotConnected)?;
  110. self.driver.send(buffer, &mut connection.info)
  111. }
  112. /// Polls the vsock device to receive data or other updates.
  113. pub fn poll(&mut self) -> Result<Option<VsockEvent>> {
  114. let guest_cid = self.driver.guest_cid();
  115. let connections = &mut self.connections;
  116. let result = self.driver.poll(|event, body| {
  117. let connection = connections
  118. .iter_mut()
  119. .find(|connection| event.matches_connection(&connection.info, guest_cid));
  120. // Skip events which don't match any connection we know about, unless they are a
  121. // connection request.
  122. let connection = if let Some(connection) = connection {
  123. connection
  124. } else if let VsockEventType::ConnectionRequest = event.event_type {
  125. // If the requested connection already exists or the CID isn't ours, ignore it.
  126. if connection.is_some() || event.destination.cid != guest_cid {
  127. return Ok(None);
  128. }
  129. // Add the new connection to our list, at least for now. It will be removed again
  130. // below if we weren't listening on the port.
  131. connections.push(Connection::new(event.source, event.destination.port));
  132. connections.last_mut().unwrap()
  133. } else {
  134. return Ok(None);
  135. };
  136. // Update stored connection info.
  137. connection.info.update_for_event(&event);
  138. if let VsockEventType::Received { length } = event.event_type {
  139. // Copy to buffer
  140. if !connection.buffer.write(body) {
  141. return Err(SocketError::OutputBufferTooShort(length).into());
  142. }
  143. }
  144. Ok(Some(event))
  145. })?;
  146. let Some(event) = result else {
  147. return Ok(None);
  148. };
  149. // The connection must exist because we found it above in the callback.
  150. let (connection_index, connection) = connections
  151. .iter_mut()
  152. .enumerate()
  153. .find(|(_, connection)| event.matches_connection(&connection.info, guest_cid))
  154. .unwrap();
  155. match event.event_type {
  156. VsockEventType::ConnectionRequest => {
  157. if self.listening_ports.contains(&event.destination.port) {
  158. self.driver.accept(&connection.info)?;
  159. } else {
  160. // Reject the connection request and remove it from our list.
  161. self.driver.force_close(&connection.info)?;
  162. self.connections.swap_remove(connection_index);
  163. // No need to pass the request on to the client, as we've already rejected it.
  164. return Ok(None);
  165. }
  166. }
  167. VsockEventType::Connected => {}
  168. VsockEventType::Disconnected { .. } => {
  169. // TODO: Wait until client reads all data before removing connection.
  170. }
  171. VsockEventType::Received { .. } => {
  172. // Already copied the buffer in the callback above.
  173. }
  174. VsockEventType::CreditRequest => {
  175. // If the peer requested credit, send an update.
  176. self.driver.credit_update(&connection.info)?;
  177. // No need to pass the request on to the client, we've already handled it.
  178. return Ok(None);
  179. }
  180. VsockEventType::CreditUpdate => {}
  181. }
  182. Ok(Some(event))
  183. }
  184. /// Reads data received from the given connection.
  185. pub fn recv(&mut self, peer: VsockAddr, src_port: u32, buffer: &mut [u8]) -> Result<usize> {
  186. let connection = self
  187. .connections
  188. .iter_mut()
  189. .find(|connection| connection.info.dst == peer && connection.info.src_port == src_port)
  190. .ok_or(SocketError::NotConnected)?;
  191. // Copy from ring buffer
  192. let bytes_read = connection.buffer.read(buffer);
  193. connection.info.done_forwarding(bytes_read);
  194. Ok(bytes_read)
  195. }
  196. /// Blocks until we get some event from the vsock device.
  197. pub fn wait_for_event(&mut self) -> Result<VsockEvent> {
  198. loop {
  199. if let Some(event) = self.poll()? {
  200. return Ok(event);
  201. } else {
  202. spin_loop();
  203. }
  204. }
  205. }
  206. /// Requests to shut down the connection cleanly.
  207. ///
  208. /// This returns as soon as the request is sent; you should wait until `poll` returns a
  209. /// `VsockEventType::Disconnected` event if you want to know that the peer has acknowledged the
  210. /// shutdown.
  211. pub fn shutdown(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  212. let connection = self
  213. .connections
  214. .iter()
  215. .find(|connection| {
  216. connection.info.dst == destination && connection.info.src_port == src_port
  217. })
  218. .ok_or(SocketError::NotConnected)?;
  219. self.driver.shutdown(&connection.info)
  220. }
  221. /// Forcibly closes the connection without waiting for the peer.
  222. pub fn force_close(&mut self, destination: VsockAddr, src_port: u32) -> Result {
  223. let (index, connection) = self
  224. .connections
  225. .iter()
  226. .enumerate()
  227. .find(|(_, connection)| {
  228. connection.info.dst == destination && connection.info.src_port == src_port
  229. })
  230. .ok_or(SocketError::NotConnected)?;
  231. self.driver.force_close(&connection.info)?;
  232. self.connections.swap_remove(index);
  233. Ok(())
  234. }
  235. }
  236. #[derive(Debug)]
  237. struct RingBuffer {
  238. buffer: Box<[u8]>,
  239. /// The number of bytes currently in the buffer.
  240. used: usize,
  241. /// The index of the first used byte in the buffer.
  242. start: usize,
  243. }
  244. impl RingBuffer {
  245. pub fn new(capacity: usize) -> Self {
  246. Self {
  247. buffer: FromBytes::new_box_slice_zeroed(capacity),
  248. used: 0,
  249. start: 0,
  250. }
  251. }
  252. /// Returns the number of bytes currently used in the buffer.
  253. pub fn used(&self) -> usize {
  254. self.used
  255. }
  256. /// Returns the number of bytes currently free in the buffer.
  257. pub fn available(&self) -> usize {
  258. self.buffer.len() - self.used
  259. }
  260. /// Adds the given bytes to the buffer if there is enough capacity for them all.
  261. ///
  262. /// Returns true if they were added, or false if they were not.
  263. pub fn write(&mut self, bytes: &[u8]) -> bool {
  264. if bytes.len() > self.available() {
  265. return false;
  266. }
  267. let end = (self.start + self.used) % self.buffer.len();
  268. let write_before_wraparound = min(bytes.len(), self.buffer.len() - end);
  269. let write_after_wraparound = bytes
  270. .len()
  271. .checked_sub(write_before_wraparound)
  272. .unwrap_or_default();
  273. self.buffer[end..end + write_before_wraparound]
  274. .copy_from_slice(&bytes[0..write_before_wraparound]);
  275. self.buffer[0..write_after_wraparound].copy_from_slice(&bytes[write_before_wraparound..]);
  276. self.used += bytes.len();
  277. true
  278. }
  279. /// Reads and removes as many bytes as possible from the buffer, up to the length of the given
  280. /// buffer.
  281. pub fn read(&mut self, out: &mut [u8]) -> usize {
  282. let bytes_read = min(self.used, out.len());
  283. // The number of bytes to copy out between `start` and the end of the buffer.
  284. let read_before_wraparound = min(bytes_read, self.buffer.len() - self.start);
  285. // The number of bytes to copy out from the beginning of the buffer after wrapping around.
  286. let read_after_wraparound = bytes_read
  287. .checked_sub(read_before_wraparound)
  288. .unwrap_or_default();
  289. out[0..read_before_wraparound]
  290. .copy_from_slice(&self.buffer[self.start..self.start + read_before_wraparound]);
  291. out[read_before_wraparound..bytes_read]
  292. .copy_from_slice(&self.buffer[0..read_after_wraparound]);
  293. self.used -= bytes_read;
  294. self.start = (self.start + bytes_read) % self.buffer.len();
  295. bytes_read
  296. }
  297. }
  298. #[cfg(test)]
  299. mod tests {
  300. use super::*;
  301. use crate::{
  302. device::socket::{
  303. protocol::{SocketType, VirtioVsockConfig, VirtioVsockHdr, VirtioVsockOp},
  304. vsock::{VsockBufferStatus, QUEUE_SIZE, RX_QUEUE_IDX, TX_QUEUE_IDX},
  305. },
  306. hal::fake::FakeHal,
  307. transport::{
  308. fake::{FakeTransport, QueueStatus, State},
  309. DeviceStatus, DeviceType,
  310. },
  311. volatile::ReadOnly,
  312. };
  313. use alloc::{sync::Arc, vec};
  314. use core::{mem::size_of, ptr::NonNull};
  315. use std::{sync::Mutex, thread};
  316. use zerocopy::{AsBytes, FromBytes};
  317. #[test]
  318. fn send_recv() {
  319. let host_cid = 2;
  320. let guest_cid = 66;
  321. let host_port = 1234;
  322. let guest_port = 4321;
  323. let host_address = VsockAddr {
  324. cid: host_cid,
  325. port: host_port,
  326. };
  327. let hello_from_guest = "Hello from guest";
  328. let hello_from_host = "Hello from host";
  329. let mut config_space = VirtioVsockConfig {
  330. guest_cid_low: ReadOnly::new(66),
  331. guest_cid_high: ReadOnly::new(0),
  332. };
  333. let state = Arc::new(Mutex::new(State {
  334. status: DeviceStatus::empty(),
  335. driver_features: 0,
  336. guest_page_size: 0,
  337. interrupt_pending: false,
  338. queues: vec![
  339. QueueStatus::default(),
  340. QueueStatus::default(),
  341. QueueStatus::default(),
  342. ],
  343. }));
  344. let transport = FakeTransport {
  345. device_type: DeviceType::Socket,
  346. max_queue_size: 32,
  347. device_features: 0,
  348. config_space: NonNull::from(&mut config_space),
  349. state: state.clone(),
  350. };
  351. let mut socket = VsockConnectionManager::new(
  352. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  353. );
  354. // Start a thread to simulate the device.
  355. let handle = thread::spawn(move || {
  356. // Wait for connection request.
  357. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  358. assert_eq!(
  359. VirtioVsockHdr::read_from(
  360. state
  361. .lock()
  362. .unwrap()
  363. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  364. .as_slice()
  365. )
  366. .unwrap(),
  367. VirtioVsockHdr {
  368. op: VirtioVsockOp::Request.into(),
  369. src_cid: guest_cid.into(),
  370. dst_cid: host_cid.into(),
  371. src_port: guest_port.into(),
  372. dst_port: host_port.into(),
  373. len: 0.into(),
  374. socket_type: SocketType::Stream.into(),
  375. flags: 0.into(),
  376. buf_alloc: 1024.into(),
  377. fwd_cnt: 0.into(),
  378. }
  379. );
  380. // Accept connection and give the peer enough credit to send the message.
  381. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  382. RX_QUEUE_IDX,
  383. VirtioVsockHdr {
  384. op: VirtioVsockOp::Response.into(),
  385. src_cid: host_cid.into(),
  386. dst_cid: guest_cid.into(),
  387. src_port: host_port.into(),
  388. dst_port: guest_port.into(),
  389. len: 0.into(),
  390. socket_type: SocketType::Stream.into(),
  391. flags: 0.into(),
  392. buf_alloc: 50.into(),
  393. fwd_cnt: 0.into(),
  394. }
  395. .as_bytes(),
  396. );
  397. // Expect the guest to send some data.
  398. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  399. let request = state
  400. .lock()
  401. .unwrap()
  402. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX);
  403. assert_eq!(
  404. request.len(),
  405. size_of::<VirtioVsockHdr>() + hello_from_guest.len()
  406. );
  407. assert_eq!(
  408. VirtioVsockHdr::read_from_prefix(request.as_slice()).unwrap(),
  409. VirtioVsockHdr {
  410. op: VirtioVsockOp::Rw.into(),
  411. src_cid: guest_cid.into(),
  412. dst_cid: host_cid.into(),
  413. src_port: guest_port.into(),
  414. dst_port: host_port.into(),
  415. len: (hello_from_guest.len() as u32).into(),
  416. socket_type: SocketType::Stream.into(),
  417. flags: 0.into(),
  418. buf_alloc: 1024.into(),
  419. fwd_cnt: 0.into(),
  420. }
  421. );
  422. assert_eq!(
  423. &request[size_of::<VirtioVsockHdr>()..],
  424. hello_from_guest.as_bytes()
  425. );
  426. println!("Host sending");
  427. // Send a response.
  428. let mut response = vec![0; size_of::<VirtioVsockHdr>() + hello_from_host.len()];
  429. VirtioVsockHdr {
  430. op: VirtioVsockOp::Rw.into(),
  431. src_cid: host_cid.into(),
  432. dst_cid: guest_cid.into(),
  433. src_port: host_port.into(),
  434. dst_port: guest_port.into(),
  435. len: (hello_from_host.len() as u32).into(),
  436. socket_type: SocketType::Stream.into(),
  437. flags: 0.into(),
  438. buf_alloc: 50.into(),
  439. fwd_cnt: (hello_from_guest.len() as u32).into(),
  440. }
  441. .write_to_prefix(response.as_mut_slice());
  442. response[size_of::<VirtioVsockHdr>()..].copy_from_slice(hello_from_host.as_bytes());
  443. state
  444. .lock()
  445. .unwrap()
  446. .write_to_queue::<QUEUE_SIZE>(RX_QUEUE_IDX, &response);
  447. // Expect a shutdown.
  448. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  449. assert_eq!(
  450. VirtioVsockHdr::read_from(
  451. state
  452. .lock()
  453. .unwrap()
  454. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  455. .as_slice()
  456. )
  457. .unwrap(),
  458. VirtioVsockHdr {
  459. op: VirtioVsockOp::Shutdown.into(),
  460. src_cid: guest_cid.into(),
  461. dst_cid: host_cid.into(),
  462. src_port: guest_port.into(),
  463. dst_port: host_port.into(),
  464. len: 0.into(),
  465. socket_type: SocketType::Stream.into(),
  466. flags: 0.into(),
  467. buf_alloc: 1024.into(),
  468. fwd_cnt: (hello_from_host.len() as u32).into(),
  469. }
  470. );
  471. });
  472. socket.connect(host_address, guest_port).unwrap();
  473. assert_eq!(
  474. socket.wait_for_event().unwrap(),
  475. VsockEvent {
  476. source: host_address,
  477. destination: VsockAddr {
  478. cid: guest_cid,
  479. port: guest_port,
  480. },
  481. event_type: VsockEventType::Connected,
  482. buffer_status: VsockBufferStatus {
  483. buffer_allocation: 50,
  484. forward_count: 0,
  485. },
  486. }
  487. );
  488. println!("Guest sending");
  489. socket
  490. .send(host_address, guest_port, "Hello from guest".as_bytes())
  491. .unwrap();
  492. println!("Guest waiting to receive.");
  493. assert_eq!(
  494. socket.wait_for_event().unwrap(),
  495. VsockEvent {
  496. source: host_address,
  497. destination: VsockAddr {
  498. cid: guest_cid,
  499. port: guest_port,
  500. },
  501. event_type: VsockEventType::Received {
  502. length: hello_from_host.len()
  503. },
  504. buffer_status: VsockBufferStatus {
  505. buffer_allocation: 50,
  506. forward_count: hello_from_guest.len() as u32,
  507. },
  508. }
  509. );
  510. println!("Guest getting received data.");
  511. let mut buffer = [0u8; 64];
  512. assert_eq!(
  513. socket.recv(host_address, guest_port, &mut buffer).unwrap(),
  514. hello_from_host.len()
  515. );
  516. assert_eq!(
  517. &buffer[0..hello_from_host.len()],
  518. hello_from_host.as_bytes()
  519. );
  520. socket.shutdown(host_address, guest_port).unwrap();
  521. handle.join().unwrap();
  522. }
  523. #[test]
  524. fn incoming_connection() {
  525. let host_cid = 2;
  526. let guest_cid = 66;
  527. let host_port = 1234;
  528. let guest_port = 4321;
  529. let wrong_guest_port = 4444;
  530. let host_address = VsockAddr {
  531. cid: host_cid,
  532. port: host_port,
  533. };
  534. let mut config_space = VirtioVsockConfig {
  535. guest_cid_low: ReadOnly::new(66),
  536. guest_cid_high: ReadOnly::new(0),
  537. };
  538. let state = Arc::new(Mutex::new(State {
  539. status: DeviceStatus::empty(),
  540. driver_features: 0,
  541. guest_page_size: 0,
  542. interrupt_pending: false,
  543. queues: vec![
  544. QueueStatus::default(),
  545. QueueStatus::default(),
  546. QueueStatus::default(),
  547. ],
  548. }));
  549. let transport = FakeTransport {
  550. device_type: DeviceType::Socket,
  551. max_queue_size: 32,
  552. device_features: 0,
  553. config_space: NonNull::from(&mut config_space),
  554. state: state.clone(),
  555. };
  556. let mut socket = VsockConnectionManager::new(
  557. VirtIOSocket::<FakeHal, FakeTransport<VirtioVsockConfig>>::new(transport).unwrap(),
  558. );
  559. socket.listen(guest_port);
  560. // Start a thread to simulate the device.
  561. let handle = thread::spawn(move || {
  562. // Send a connection request for a port the guest isn't listening on.
  563. println!("Host sending connection request to wrong port");
  564. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  565. RX_QUEUE_IDX,
  566. VirtioVsockHdr {
  567. op: VirtioVsockOp::Request.into(),
  568. src_cid: host_cid.into(),
  569. dst_cid: guest_cid.into(),
  570. src_port: host_port.into(),
  571. dst_port: wrong_guest_port.into(),
  572. len: 0.into(),
  573. socket_type: SocketType::Stream.into(),
  574. flags: 0.into(),
  575. buf_alloc: 50.into(),
  576. fwd_cnt: 0.into(),
  577. }
  578. .as_bytes(),
  579. );
  580. // Expect a rejection.
  581. println!("Host waiting for rejection");
  582. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  583. assert_eq!(
  584. VirtioVsockHdr::read_from(
  585. state
  586. .lock()
  587. .unwrap()
  588. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  589. .as_slice()
  590. )
  591. .unwrap(),
  592. VirtioVsockHdr {
  593. op: VirtioVsockOp::Rst.into(),
  594. src_cid: guest_cid.into(),
  595. dst_cid: host_cid.into(),
  596. src_port: wrong_guest_port.into(),
  597. dst_port: host_port.into(),
  598. len: 0.into(),
  599. socket_type: SocketType::Stream.into(),
  600. flags: 0.into(),
  601. buf_alloc: 1024.into(),
  602. fwd_cnt: 0.into(),
  603. }
  604. );
  605. // Send a connection request for a port the guest is listening on.
  606. println!("Host sending connection request to right port");
  607. state.lock().unwrap().write_to_queue::<QUEUE_SIZE>(
  608. RX_QUEUE_IDX,
  609. VirtioVsockHdr {
  610. op: VirtioVsockOp::Request.into(),
  611. src_cid: host_cid.into(),
  612. dst_cid: guest_cid.into(),
  613. src_port: host_port.into(),
  614. dst_port: guest_port.into(),
  615. len: 0.into(),
  616. socket_type: SocketType::Stream.into(),
  617. flags: 0.into(),
  618. buf_alloc: 50.into(),
  619. fwd_cnt: 0.into(),
  620. }
  621. .as_bytes(),
  622. );
  623. // Expect a response.
  624. println!("Host waiting for response");
  625. State::wait_until_queue_notified(&state, TX_QUEUE_IDX);
  626. assert_eq!(
  627. VirtioVsockHdr::read_from(
  628. state
  629. .lock()
  630. .unwrap()
  631. .read_from_queue::<QUEUE_SIZE>(TX_QUEUE_IDX)
  632. .as_slice()
  633. )
  634. .unwrap(),
  635. VirtioVsockHdr {
  636. op: VirtioVsockOp::Response.into(),
  637. src_cid: guest_cid.into(),
  638. dst_cid: host_cid.into(),
  639. src_port: guest_port.into(),
  640. dst_port: host_port.into(),
  641. len: 0.into(),
  642. socket_type: SocketType::Stream.into(),
  643. flags: 0.into(),
  644. buf_alloc: 1024.into(),
  645. fwd_cnt: 0.into(),
  646. }
  647. );
  648. println!("Host finished");
  649. });
  650. // Expect an incoming connection.
  651. println!("Guest expecting incoming connection.");
  652. assert_eq!(
  653. socket.wait_for_event().unwrap(),
  654. VsockEvent {
  655. source: host_address,
  656. destination: VsockAddr {
  657. cid: guest_cid,
  658. port: guest_port,
  659. },
  660. event_type: VsockEventType::ConnectionRequest,
  661. buffer_status: VsockBufferStatus {
  662. buffer_allocation: 50,
  663. forward_count: 0,
  664. },
  665. }
  666. );
  667. handle.join().unwrap();
  668. }
  669. }