main.rs 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438
  1. #![no_std]
  2. #![no_main]
  3. mod exceptions;
  4. mod hal;
  5. mod logger;
  6. #[cfg(platform = "qemu")]
  7. mod pl011;
  8. #[cfg(platform = "qemu")]
  9. use pl011 as uart;
  10. #[cfg(platform = "crosvm")]
  11. mod uart8250;
  12. #[cfg(platform = "crosvm")]
  13. use uart8250 as uart;
  14. use buddy_system_allocator::LockedHeap;
  15. use core::{
  16. mem::size_of,
  17. panic::PanicInfo,
  18. ptr::{self, NonNull},
  19. };
  20. use fdt::{node::FdtNode, standard_nodes::Compatible, Fdt};
  21. use hal::HalImpl;
  22. use log::{debug, error, info, trace, warn, LevelFilter};
  23. use smccc::{psci::system_off, Hvc};
  24. use virtio_drivers::{
  25. device::{
  26. blk::VirtIOBlk,
  27. console::VirtIOConsole,
  28. gpu::VirtIOGpu,
  29. socket::{VirtIOSocket, VsockEventType},
  30. },
  31. transport::{
  32. mmio::{MmioTransport, VirtIOHeader},
  33. pci::{
  34. bus::{BarInfo, Cam, Command, DeviceFunction, MemoryBarType, PciRoot},
  35. virtio_device_type, PciTransport,
  36. },
  37. DeviceType, Transport,
  38. },
  39. };
  40. /// Base memory-mapped address of the primary PL011 UART device.
  41. #[cfg(platform = "qemu")]
  42. pub const UART_BASE_ADDRESS: usize = 0x900_0000;
  43. /// The base address of the first 8250 UART.
  44. #[cfg(platform = "crosvm")]
  45. pub const UART_BASE_ADDRESS: usize = 0x3f8;
  46. #[global_allocator]
  47. static HEAP_ALLOCATOR: LockedHeap<32> = LockedHeap::new();
  48. static mut HEAP: [u8; 0x1000000] = [0; 0x1000000];
  49. #[no_mangle]
  50. extern "C" fn main(x0: u64, x1: u64, x2: u64, x3: u64) {
  51. logger::init(LevelFilter::Debug).unwrap();
  52. info!("virtio-drivers example started.");
  53. debug!(
  54. "x0={:#018x}, x1={:#018x}, x2={:#018x}, x3={:#018x}",
  55. x0, x1, x2, x3
  56. );
  57. // Safe because `HEAP` is only used here and `entry` is only called once.
  58. unsafe {
  59. // Give the allocator some memory to allocate.
  60. HEAP_ALLOCATOR
  61. .lock()
  62. .init(HEAP.as_mut_ptr() as usize, HEAP.len());
  63. }
  64. info!("Loading FDT from {:#018x}", x0);
  65. // Safe because the pointer is a valid pointer to unaliased memory.
  66. let fdt = unsafe { Fdt::from_ptr(x0 as *const u8).unwrap() };
  67. for node in fdt.all_nodes() {
  68. // Dump information about the node for debugging.
  69. trace!(
  70. "{}: {:?}",
  71. node.name,
  72. node.compatible().map(Compatible::first),
  73. );
  74. if let Some(reg) = node.reg() {
  75. for range in reg {
  76. trace!(
  77. " {:#018x?}, length {:?}",
  78. range.starting_address,
  79. range.size
  80. );
  81. }
  82. }
  83. // Check whether it is a VirtIO MMIO device.
  84. if let (Some(compatible), Some(region)) =
  85. (node.compatible(), node.reg().and_then(|mut reg| reg.next()))
  86. {
  87. if compatible.all().any(|s| s == "virtio,mmio")
  88. && region.size.unwrap_or(0) > size_of::<VirtIOHeader>()
  89. {
  90. debug!("Found VirtIO MMIO device at {:?}", region);
  91. let header = NonNull::new(region.starting_address as *mut VirtIOHeader).unwrap();
  92. match unsafe { MmioTransport::new(header) } {
  93. Err(e) => warn!("Error creating VirtIO MMIO transport: {}", e),
  94. Ok(transport) => {
  95. info!(
  96. "Detected virtio MMIO device with vendor id {:#X}, device type {:?}, version {:?}",
  97. transport.vendor_id(),
  98. transport.device_type(),
  99. transport.version(),
  100. );
  101. virtio_device(transport);
  102. }
  103. }
  104. }
  105. }
  106. }
  107. if let Some(pci_node) = fdt.find_compatible(&["pci-host-cam-generic"]) {
  108. info!("Found PCI node: {}", pci_node.name);
  109. enumerate_pci(pci_node, Cam::MmioCam);
  110. }
  111. if let Some(pcie_node) = fdt.find_compatible(&["pci-host-ecam-generic"]) {
  112. info!("Found PCIe node: {}", pcie_node.name);
  113. enumerate_pci(pcie_node, Cam::Ecam);
  114. }
  115. system_off::<Hvc>().unwrap();
  116. }
  117. fn virtio_device(transport: impl Transport) {
  118. match transport.device_type() {
  119. DeviceType::Block => virtio_blk(transport),
  120. DeviceType::GPU => virtio_gpu(transport),
  121. // DeviceType::Network => virtio_net(transport), // currently is unsupported without alloc
  122. DeviceType::Console => virtio_console(transport),
  123. DeviceType::Socket => match virtio_socket(transport) {
  124. Ok(()) => info!("virtio-socket test finished successfully"),
  125. Err(e) => error!("virtio-socket test finished with error '{e:?}'"),
  126. },
  127. t => warn!("Unrecognized virtio device: {:?}", t),
  128. }
  129. }
  130. fn virtio_blk<T: Transport>(transport: T) {
  131. let mut blk = VirtIOBlk::<HalImpl, T>::new(transport).expect("failed to create blk driver");
  132. assert!(!blk.readonly());
  133. let mut input = [0xffu8; 512];
  134. let mut output = [0; 512];
  135. for i in 0..32 {
  136. for x in input.iter_mut() {
  137. *x = i as u8;
  138. }
  139. blk.write_block(i, &input).expect("failed to write");
  140. blk.read_block(i, &mut output).expect("failed to read");
  141. assert_eq!(input, output);
  142. }
  143. info!("virtio-blk test finished");
  144. }
  145. fn virtio_gpu<T: Transport>(transport: T) {
  146. let mut gpu = VirtIOGpu::<HalImpl, T>::new(transport).expect("failed to create gpu driver");
  147. let (width, height) = gpu.resolution().expect("failed to get resolution");
  148. let width = width as usize;
  149. let height = height as usize;
  150. info!("GPU resolution is {}x{}", width, height);
  151. let fb = gpu.setup_framebuffer().expect("failed to get fb");
  152. for y in 0..height {
  153. for x in 0..width {
  154. let idx = (y * width + x) * 4;
  155. fb[idx] = x as u8;
  156. fb[idx + 1] = y as u8;
  157. fb[idx + 2] = (x + y) as u8;
  158. }
  159. }
  160. gpu.flush().expect("failed to flush");
  161. //delay some time
  162. info!("virtio-gpu show graphics....");
  163. for _ in 0..1000 {
  164. for _ in 0..100000 {
  165. unsafe {
  166. core::arch::asm!("nop");
  167. }
  168. }
  169. }
  170. info!("virtio-gpu test finished");
  171. }
  172. fn virtio_console<T: Transport>(transport: T) {
  173. let mut console =
  174. VirtIOConsole::<HalImpl, T>::new(transport).expect("Failed to create console driver");
  175. let info = console.info();
  176. info!("VirtIO console {}x{}", info.rows, info.columns);
  177. for &c in b"Hello world on console!\n" {
  178. console.send(c).expect("Failed to send character");
  179. }
  180. let c = console.recv(true).expect("Failed to read from console");
  181. info!("Read {:?}", c);
  182. info!("virtio-console test finished");
  183. }
  184. fn virtio_socket<T: Transport>(transport: T) -> virtio_drivers::Result<()> {
  185. let mut socket =
  186. VirtIOSocket::<HalImpl, T>::new(transport).expect("Failed to create socket driver");
  187. let host_cid = 2;
  188. let port = 1221;
  189. info!("Connecting to host on port {port}...");
  190. socket.connect(host_cid, port, port)?;
  191. socket.wait_for_connect()?;
  192. info!("Connected to the host");
  193. const EXCHANGE_NUM: usize = 2;
  194. let messages = ["0-Ack. Hello from guest.", "1-Ack. Received again."];
  195. for k in 0..EXCHANGE_NUM {
  196. let mut buffer = [0u8; 24];
  197. let socket_event = socket.wait_for_recv(&mut buffer)?;
  198. let VsockEventType::Received {length, ..} = socket_event.event_type else {
  199. panic!("Received unexpected socket event {:?}", socket_event);
  200. };
  201. info!(
  202. "Received message: {:?}({:?}), len: {:?}",
  203. buffer,
  204. core::str::from_utf8(&buffer[..length]),
  205. length
  206. );
  207. let message = messages[k % messages.len()];
  208. socket.send(message.as_bytes())?;
  209. info!("Sent message: {:?}", message);
  210. }
  211. socket.shutdown()?;
  212. info!("Shutdown the connection");
  213. Ok(())
  214. }
  215. #[derive(Copy, Clone, Debug, Eq, PartialEq)]
  216. enum PciRangeType {
  217. ConfigurationSpace,
  218. IoSpace,
  219. Memory32,
  220. Memory64,
  221. }
  222. impl From<u8> for PciRangeType {
  223. fn from(value: u8) -> Self {
  224. match value {
  225. 0 => Self::ConfigurationSpace,
  226. 1 => Self::IoSpace,
  227. 2 => Self::Memory32,
  228. 3 => Self::Memory64,
  229. _ => panic!("Tried to convert invalid range type {}", value),
  230. }
  231. }
  232. }
  233. fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
  234. let reg = pci_node.reg().expect("PCI node missing reg property.");
  235. let mut allocator = PciMemory32Allocator::for_pci_ranges(&pci_node);
  236. for region in reg {
  237. info!(
  238. "Reg: {:?}-{:#x}",
  239. region.starting_address,
  240. region.starting_address as usize + region.size.unwrap()
  241. );
  242. assert_eq!(region.size.unwrap(), cam.size() as usize);
  243. // Safe because we know the pointer is to a valid MMIO region.
  244. let mut pci_root = unsafe { PciRoot::new(region.starting_address as *mut u8, cam) };
  245. for (device_function, info) in pci_root.enumerate_bus(0) {
  246. let (status, command) = pci_root.get_status_command(device_function);
  247. info!(
  248. "Found {} at {}, status {:?} command {:?}",
  249. info, device_function, status, command
  250. );
  251. if let Some(virtio_type) = virtio_device_type(&info) {
  252. info!(" VirtIO {:?}", virtio_type);
  253. allocate_bars(&mut pci_root, device_function, &mut allocator);
  254. dump_bar_contents(&mut pci_root, device_function, 4);
  255. let mut transport =
  256. PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
  257. info!(
  258. "Detected virtio PCI device with device type {:?}, features {:#018x}",
  259. transport.device_type(),
  260. transport.read_device_features(),
  261. );
  262. virtio_device(transport);
  263. }
  264. }
  265. }
  266. }
  267. /// Allocates 32-bit memory addresses for PCI BARs.
  268. struct PciMemory32Allocator {
  269. start: u32,
  270. end: u32,
  271. }
  272. impl PciMemory32Allocator {
  273. /// Creates a new allocator based on the ranges property of the given PCI node.
  274. pub fn for_pci_ranges(pci_node: &FdtNode) -> Self {
  275. let ranges = pci_node
  276. .property("ranges")
  277. .expect("PCI node missing ranges property.");
  278. let mut memory_32_address = 0;
  279. let mut memory_32_size = 0;
  280. for i in 0..ranges.value.len() / 28 {
  281. let range = &ranges.value[i * 28..(i + 1) * 28];
  282. let prefetchable = range[0] & 0x80 != 0;
  283. let range_type = PciRangeType::from(range[0] & 0x3);
  284. let bus_address = u64::from_be_bytes(range[4..12].try_into().unwrap());
  285. let cpu_physical = u64::from_be_bytes(range[12..20].try_into().unwrap());
  286. let size = u64::from_be_bytes(range[20..28].try_into().unwrap());
  287. info!(
  288. "range: {:?} {}prefetchable bus address: {:#018x} host physical address: {:#018x} size: {:#018x}",
  289. range_type,
  290. if prefetchable { "" } else { "non-" },
  291. bus_address,
  292. cpu_physical,
  293. size,
  294. );
  295. // Use the largest range within the 32-bit address space for 32-bit memory, even if it
  296. // is marked as a 64-bit range. This is necessary because crosvm doesn't currently
  297. // provide any 32-bit ranges.
  298. if !prefetchable
  299. && matches!(range_type, PciRangeType::Memory32 | PciRangeType::Memory64)
  300. && size > memory_32_size.into()
  301. && bus_address + size < u32::MAX.into()
  302. {
  303. assert_eq!(bus_address, cpu_physical);
  304. memory_32_address = u32::try_from(cpu_physical).unwrap();
  305. memory_32_size = u32::try_from(size).unwrap();
  306. }
  307. }
  308. if memory_32_size == 0 {
  309. panic!("No 32-bit PCI memory region found.");
  310. }
  311. Self {
  312. start: memory_32_address,
  313. end: memory_32_address + memory_32_size,
  314. }
  315. }
  316. /// Allocates a 32-bit memory address region for a PCI BAR of the given power-of-2 size.
  317. ///
  318. /// It will have alignment matching the size. The size must be a power of 2.
  319. pub fn allocate_memory_32(&mut self, size: u32) -> u32 {
  320. assert!(size.is_power_of_two());
  321. let allocated_address = align_up(self.start, size);
  322. assert!(allocated_address + size <= self.end);
  323. self.start = allocated_address + size;
  324. allocated_address
  325. }
  326. }
  327. const fn align_up(value: u32, alignment: u32) -> u32 {
  328. ((value - 1) | (alignment - 1)) + 1
  329. }
  330. fn dump_bar_contents(root: &mut PciRoot, device_function: DeviceFunction, bar_index: u8) {
  331. let bar_info = root.bar_info(device_function, bar_index).unwrap();
  332. trace!("Dumping bar {}: {:#x?}", bar_index, bar_info);
  333. if let BarInfo::Memory { address, size, .. } = bar_info {
  334. let start = address as *const u8;
  335. unsafe {
  336. let mut buf = [0u8; 32];
  337. for i in 0..size / 32 {
  338. let ptr = start.add(i as usize * 32);
  339. ptr::copy(ptr, buf.as_mut_ptr(), 32);
  340. if buf.iter().any(|b| *b != 0xff) {
  341. trace!(" {:?}: {:x?}", ptr, buf);
  342. }
  343. }
  344. }
  345. }
  346. trace!("End of dump");
  347. }
  348. /// Allocates appropriately-sized memory regions and assigns them to the device's BARs.
  349. fn allocate_bars(
  350. root: &mut PciRoot,
  351. device_function: DeviceFunction,
  352. allocator: &mut PciMemory32Allocator,
  353. ) {
  354. let mut bar_index = 0;
  355. while bar_index < 6 {
  356. let info = root.bar_info(device_function, bar_index).unwrap();
  357. debug!("BAR {}: {}", bar_index, info);
  358. // Ignore I/O bars, as they aren't required for the VirtIO driver.
  359. if let BarInfo::Memory {
  360. address_type, size, ..
  361. } = info
  362. {
  363. match address_type {
  364. MemoryBarType::Width32 => {
  365. if size > 0 {
  366. let address = allocator.allocate_memory_32(size);
  367. debug!("Allocated address {:#010x}", address);
  368. root.set_bar_32(device_function, bar_index, address);
  369. }
  370. }
  371. MemoryBarType::Width64 => {
  372. if size > 0 {
  373. let address = allocator.allocate_memory_32(size);
  374. debug!("Allocated address {:#010x}", address);
  375. root.set_bar_64(device_function, bar_index, address.into());
  376. }
  377. }
  378. _ => panic!("Memory BAR address type {:?} not supported.", address_type),
  379. }
  380. }
  381. bar_index += 1;
  382. if info.takes_two_entries() {
  383. bar_index += 1;
  384. }
  385. }
  386. // Enable the device to use its BARs.
  387. root.set_command(
  388. device_function,
  389. Command::IO_SPACE | Command::MEMORY_SPACE | Command::BUS_MASTER,
  390. );
  391. let (status, command) = root.get_status_command(device_function);
  392. debug!(
  393. "Allocated BARs and enabled device, status {:?} command {:?}",
  394. status, command
  395. );
  396. }
  397. #[panic_handler]
  398. fn panic(info: &PanicInfo) -> ! {
  399. error!("{}", info);
  400. system_off::<Hvc>().unwrap();
  401. loop {}
  402. }