main.rs 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451
  1. #![no_std]
  2. #![no_main]
  3. extern crate alloc;
  4. mod exceptions;
  5. mod hal;
  6. mod logger;
  7. #[cfg(platform = "qemu")]
  8. mod pl011;
  9. #[cfg(platform = "qemu")]
  10. use pl011 as uart;
  11. #[cfg(platform = "crosvm")]
  12. mod uart8250;
  13. #[cfg(platform = "crosvm")]
  14. use uart8250 as uart;
  15. use buddy_system_allocator::LockedHeap;
  16. use core::{
  17. mem::size_of,
  18. panic::PanicInfo,
  19. ptr::{self, NonNull},
  20. };
  21. use fdt::{node::FdtNode, standard_nodes::Compatible, Fdt};
  22. use hal::HalImpl;
  23. use log::{debug, error, info, trace, warn, LevelFilter};
  24. use smccc::{psci::system_off, Hvc};
  25. use virtio_drivers::{
  26. device::{
  27. blk::VirtIOBlk,
  28. console::VirtIOConsole,
  29. gpu::VirtIOGpu,
  30. socket::{
  31. VirtIOSocket, VsockAddr, VsockConnectionManager, VsockEventType, VMADDR_CID_HOST,
  32. },
  33. },
  34. transport::{
  35. mmio::{MmioTransport, VirtIOHeader},
  36. pci::{
  37. bus::{BarInfo, Cam, Command, DeviceFunction, MemoryBarType, PciRoot},
  38. virtio_device_type, PciTransport,
  39. },
  40. DeviceType, Transport,
  41. },
  42. };
  43. /// Base memory-mapped address of the primary PL011 UART device.
  44. #[cfg(platform = "qemu")]
  45. pub const UART_BASE_ADDRESS: usize = 0x900_0000;
  46. /// The base address of the first 8250 UART.
  47. #[cfg(platform = "crosvm")]
  48. pub const UART_BASE_ADDRESS: usize = 0x3f8;
  49. #[global_allocator]
  50. static HEAP_ALLOCATOR: LockedHeap<32> = LockedHeap::new();
  51. static mut HEAP: [u8; 0x1000000] = [0; 0x1000000];
  52. #[no_mangle]
  53. extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) {
  54. logger::init(LevelFilter::Debug).unwrap();
  55. info!("virtio-drivers example started.");
  56. debug!(
  57. "x0={:#018x}, x1={:#018x}, x2={:#018x}, x3={:#018x}",
  58. x0, x1, x2, x3
  59. );
  60. // Safe because `HEAP` is only used here and `entry` is only called once.
  61. unsafe {
  62. // Give the allocator some memory to allocate.
  63. HEAP_ALLOCATOR
  64. .lock()
  65. .init(HEAP.as_mut_ptr() as usize, HEAP.len());
  66. }
  67. info!("Loading FDT from {:#018x}", x0);
  68. // Safe because the pointer is a valid pointer to unaliased memory.
  69. let fdt = unsafe { Fdt::from_ptr(x0 as *const u8).unwrap() };
  70. for node in fdt.all_nodes() {
  71. // Dump information about the node for debugging.
  72. trace!(
  73. "{}: {:?}",
  74. node.name,
  75. node.compatible().map(Compatible::first),
  76. );
  77. if let Some(reg) = node.reg() {
  78. for range in reg {
  79. trace!(
  80. " {:#018x?}, length {:?}",
  81. range.starting_address,
  82. range.size
  83. );
  84. }
  85. }
  86. // Check whether it is a VirtIO MMIO device.
  87. if let (Some(compatible), Some(region)) =
  88. (node.compatible(), node.reg().and_then(|mut reg| reg.next()))
  89. {
  90. if compatible.all().any(|s| s == "virtio,mmio")
  91. && region.size.unwrap_or(0) > size_of::<VirtIOHeader>()
  92. {
  93. debug!("Found VirtIO MMIO device at {:?}", region);
  94. let header = NonNull::new(region.starting_address as *mut VirtIOHeader).unwrap();
  95. match unsafe { MmioTransport::new(header) } {
  96. Err(e) => warn!("Error creating VirtIO MMIO transport: {}", e),
  97. Ok(transport) => {
  98. info!(
  99. "Detected virtio MMIO device with vendor id {:#X}, device type {:?}, version {:?}",
  100. transport.vendor_id(),
  101. transport.device_type(),
  102. transport.version(),
  103. );
  104. virtio_device(transport);
  105. }
  106. }
  107. }
  108. }
  109. }
  110. if let Some(pci_node) = fdt.find_compatible(&["pci-host-cam-generic"]) {
  111. info!("Found PCI node: {}", pci_node.name);
  112. enumerate_pci(pci_node, Cam::MmioCam);
  113. }
  114. if let Some(pcie_node) = fdt.find_compatible(&["pci-host-ecam-generic"]) {
  115. info!("Found PCIe node: {}", pcie_node.name);
  116. enumerate_pci(pcie_node, Cam::Ecam);
  117. }
  118. system_off::<Hvc>().unwrap();
  119. }
  120. fn virtio_device(transport: impl Transport) {
  121. match transport.device_type() {
  122. DeviceType::Block => virtio_blk(transport),
  123. DeviceType::GPU => virtio_gpu(transport),
  124. // DeviceType::Network => virtio_net(transport), // currently is unsupported without alloc
  125. DeviceType::Console => virtio_console(transport),
  126. DeviceType::Socket => match virtio_socket(transport) {
  127. Ok(()) => info!("virtio-socket test finished successfully"),
  128. Err(e) => error!("virtio-socket test finished with error '{e:?}'"),
  129. },
  130. t => warn!("Unrecognized virtio device: {:?}", t),
  131. }
  132. }
  133. fn virtio_blk<T: Transport>(transport: T) {
  134. let mut blk = VirtIOBlk::<HalImpl, T>::new(transport).expect("failed to create blk driver");
  135. assert!(!blk.readonly());
  136. let mut input = [0xffu8; 512];
  137. let mut output = [0; 512];
  138. for i in 0..32 {
  139. for x in input.iter_mut() {
  140. *x = i as u8;
  141. }
  142. blk.write_block(i, &input).expect("failed to write");
  143. blk.read_block(i, &mut output).expect("failed to read");
  144. assert_eq!(input, output);
  145. }
  146. info!("virtio-blk test finished");
  147. }
  148. fn virtio_gpu<T: Transport>(transport: T) {
  149. let mut gpu = VirtIOGpu::<HalImpl, T>::new(transport).expect("failed to create gpu driver");
  150. let (width, height) = gpu.resolution().expect("failed to get resolution");
  151. let width = width as usize;
  152. let height = height as usize;
  153. info!("GPU resolution is {}x{}", width, height);
  154. let fb = gpu.setup_framebuffer().expect("failed to get fb");
  155. for y in 0..height {
  156. for x in 0..width {
  157. let idx = (y * width + x) * 4;
  158. fb[idx] = x as u8;
  159. fb[idx + 1] = y as u8;
  160. fb[idx + 2] = (x + y) as u8;
  161. }
  162. }
  163. gpu.flush().expect("failed to flush");
  164. //delay some time
  165. info!("virtio-gpu show graphics....");
  166. for _ in 0..1000 {
  167. for _ in 0..100000 {
  168. unsafe {
  169. core::arch::asm!("nop");
  170. }
  171. }
  172. }
  173. info!("virtio-gpu test finished");
  174. }
  175. fn virtio_console<T: Transport>(transport: T) {
  176. let mut console =
  177. VirtIOConsole::<HalImpl, T>::new(transport).expect("Failed to create console driver");
  178. let info = console.info();
  179. info!("VirtIO console {}x{}", info.rows, info.columns);
  180. for &c in b"Hello world on console!\n" {
  181. console.send(c).expect("Failed to send character");
  182. }
  183. let c = console.recv(true).expect("Failed to read from console");
  184. info!("Read {:?}", c);
  185. info!("virtio-console test finished");
  186. }
  187. fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
  188. let mut socket = VsockConnectionManager::new(
  189. VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver"),
  190. );
  191. let port = 1221;
  192. let host_address = VsockAddr {
  193. cid: VMADDR_CID_HOST,
  194. port,
  195. };
  196. info!("Connecting to host on port {port}...");
  197. socket.connect(host_address, port)?;
  198. let event = socket.wait_for_event()?;
  199. assert_eq!(event.source, host_address);
  200. assert_eq!(event.destination.port, port);
  201. assert_eq!(event.event_type, VsockEventType::Connected);
  202. info!("Connected to the host");
  203. const EXCHANGE_NUM: usize = 2;
  204. let messages = ["0-Ack. Hello from guest.", "1-Ack. Received again."];
  205. for k in 0..EXCHANGE_NUM {
  206. let mut buffer = [0u8; 24];
  207. let socket_event = socket.wait_for_event()?;
  208. let VsockEventType::Received {length, ..} = socket_event.event_type else {
  209. panic!("Received unexpected socket event {:?}", socket_event);
  210. };
  211. let read_length = socket.recv(host_address, port, &mut buffer)?;
  212. assert_eq!(length, read_length);
  213. info!(
  214. "Received message: {:?}({:?}), len: {:?}",
  215. buffer,
  216. core::str::from_utf8(&buffer[..length]),
  217. length
  218. );
  219. let message = messages[k % messages.len()];
  220. socket.send(host_address, port, message.as_bytes())?;
  221. info!("Sent message: {:?}", message);
  222. }
  223. socket.shutdown(host_address, port)?;
  224. info!("Shutdown the connection");
  225. Ok(())
  226. }
  227. #[derive(Copy, Clone, Debug, Eq, PartialEq)]
  228. enum PciRangeType {
  229. ConfigurationSpace,
  230. IoSpace,
  231. Memory32,
  232. Memory64,
  233. }
  234. impl From<u8> for PciRangeType {
  235. fn from(value: u8) -> Self {
  236. match value {
  237. 0 => Self::ConfigurationSpace,
  238. 1 => Self::IoSpace,
  239. 2 => Self::Memory32,
  240. 3 => Self::Memory64,
  241. _ => panic!("Tried to convert invalid range type {}", value),
  242. }
  243. }
  244. }
  245. fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
  246. let reg = pci_node.reg().expect("PCI node missing reg property.");
  247. let mut allocator = PciMemory32Allocator::for_pci_ranges(&pci_node);
  248. for region in reg {
  249. info!(
  250. "Reg: {:?}-{:#x}",
  251. region.starting_address,
  252. region.starting_address as usize + region.size.unwrap()
  253. );
  254. assert_eq!(region.size.unwrap(), cam.size() as usize);
  255. // Safe because we know the pointer is to a valid MMIO region.
  256. let mut pci_root = unsafe { PciRoot::new(region.starting_address as *mut u8, cam) };
  257. for (device_function, info) in pci_root.enumerate_bus(0) {
  258. let (status, command) = pci_root.get_status_command(device_function);
  259. info!(
  260. "Found {} at {}, status {:?} command {:?}",
  261. info, device_function, status, command
  262. );
  263. if let Some(virtio_type) = virtio_device_type(&info) {
  264. info!(" VirtIO {:?}", virtio_type);
  265. allocate_bars(&mut pci_root, device_function, &mut allocator);
  266. dump_bar_contents(&mut pci_root, device_function, 4);
  267. let mut transport =
  268. PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
  269. info!(
  270. "Detected virtio PCI device with device type {:?}, features {:#018x}",
  271. transport.device_type(),
  272. transport.read_device_features(),
  273. );
  274. virtio_device(transport);
  275. }
  276. }
  277. }
  278. }
  279. /// Allocates 32-bit memory addresses for PCI BARs.
  280. struct PciMemory32Allocator {
  281. start: u32,
  282. end: u32,
  283. }
  284. impl PciMemory32Allocator {
  285. /// Creates a new allocator based on the ranges property of the given PCI node.
  286. pub fn for_pci_ranges(pci_node: &FdtNode) -> Self {
  287. let ranges = pci_node
  288. .property("ranges")
  289. .expect("PCI node missing ranges property.");
  290. let mut memory_32_address = 0;
  291. let mut memory_32_size = 0;
  292. for i in 0..ranges.value.len() / 28 {
  293. let range = &ranges.value[i * 28..(i + 1) * 28];
  294. let prefetchable = range[0] & 0x80 != 0;
  295. let range_type = PciRangeType::from(range[0] & 0x3);
  296. let bus_address = u64::from_be_bytes(range[4..12].try_into().unwrap());
  297. let cpu_physical = u64::from_be_bytes(range[12..20].try_into().unwrap());
  298. let size = u64::from_be_bytes(range[20..28].try_into().unwrap());
  299. info!(
  300. "range: {:?} {}prefetchable bus address: {:#018x} host physical address: {:#018x} size: {:#018x}",
  301. range_type,
  302. if prefetchable { "" } else { "non-" },
  303. bus_address,
  304. cpu_physical,
  305. size,
  306. );
  307. // Use the largest range within the 32-bit address space for 32-bit memory, even if it
  308. // is marked as a 64-bit range. This is necessary because crosvm doesn't currently
  309. // provide any 32-bit ranges.
  310. if !prefetchable
  311. && matches!(range_type, PciRangeType::Memory32 | PciRangeType::Memory64)
  312. && size > memory_32_size.into()
  313. && bus_address + size < u32::MAX.into()
  314. {
  315. assert_eq!(bus_address, cpu_physical);
  316. memory_32_address = u32::try_from(cpu_physical).unwrap();
  317. memory_32_size = u32::try_from(size).unwrap();
  318. }
  319. }
  320. if memory_32_size == 0 {
  321. panic!("No 32-bit PCI memory region found.");
  322. }
  323. Self {
  324. start: memory_32_address,
  325. end: memory_32_address + memory_32_size,
  326. }
  327. }
  328. /// Allocates a 32-bit memory address region for a PCI BAR of the given power-of-2 size.
  329. ///
  330. /// It will have alignment matching the size. The size must be a power of 2.
  331. pub fn allocate_memory_32(&mut self, size: u32) -> u32 {
  332. assert!(size.is_power_of_two());
  333. let allocated_address = align_up(self.start, size);
  334. assert!(allocated_address + size <= self.end);
  335. self.start = allocated_address + size;
  336. allocated_address
  337. }
  338. }
  339. const fn align_up(value: u32, alignment: u32) -> u32 {
  340. ((value - 1) | (alignment - 1)) + 1
  341. }
  342. fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_index: u8) {
  343. let bar_info = root.bar_info(device_function, bar_index).unwrap();
  344. trace!("Dumping bar {}: {:#x?}", bar_index, bar_info);
  345. if let BarInfo::Memory { address, size, .. } = bar_info {
  346. let start = address as *const u8;
  347. unsafe {
  348. let mut buf = [0u8; 32];
  349. for i in 0..size / 32 {
  350. let ptr = start.add(i as usize * 32);
  351. ptr::copy(ptr, buf.as_mut_ptr(), 32);
  352. if buf.iter().any(|b| *b != 0xff) {
  353. trace!(" {:?}: {:x?}", ptr, buf);
  354. }
  355. }
  356. }
  357. }
  358. trace!("End of dump");
  359. }
  360. /// Allocates appropriately-sized memory regions and assigns them to the device's BARs.
  361. fn allocate_bars(
  362. root: &mut PciRoot,
  363. device_function: DeviceFunction,
  364. allocator: &mut PciMemory32Allocator,
  365. ) {
  366. let mut bar_index = 0;
  367. while bar_index < 6 {
  368. let info = root.bar_info(device_function, bar_index).unwrap();
  369. debug!("BAR {}: {}", bar_index, info);
  370. // Ignore I/O bars, as they aren't required for the VirtIO driver.
  371. if let BarInfo::Memory {
  372. address_type, size, ..
  373. } = info
  374. {
  375. match address_type {
  376. MemoryBarType::Width32 => {
  377. if size > 0 {
  378. let address = allocator.allocate_memory_32(size);
  379. debug!("Allocated address {:#010x}", address);
  380. root.set_bar_32(device_function, bar_index, address);
  381. }
  382. }
  383. MemoryBarType::Width64 => {
  384. if size > 0 {
  385. let address = allocator.allocate_memory_32(size);
  386. debug!("Allocated address {:#010x}", address);
  387. root.set_bar_64(device_function, bar_index, address.into());
  388. }
  389. }
  390. _ => panic!("Memory BAR address type {:?} not supported.", address_type),
  391. }
  392. }
  393. bar_index += 1;
  394. if info.takes_two_entries() {
  395. bar_index += 1;
  396. }
  397. }
  398. // Enable the device to use its BARs.
  399. root.set_command(
  400. device_function,
  401. Command::IO_SPACE | Command::MEMORY_SPACE | Command::BUS_MASTER,
  402. );
  403. let (status, command) = root.get_status_command(device_function);
  404. debug!(
  405. "Allocated BARs and enabled device, status {:?} command {:?}",
  406. status, command
  407. );
  408. }
  409. #[panic_handler]
  410. fn panic(info: &PanicInfo) -> ! {
  411. error!("{}", info);
  412. system_off::<Hvc>().unwrap();
  413. loop {}
  414. }