Sfoglia il codice sorgente

Avoid storing PciRoot in Pci transport.

This means we don't have to clone it.
Andrew Walbran 2 anni fa
parent
commit
cf5575ecdd
2 ha cambiato i file con 24 aggiunte e 15 eliminazioni
  1. 1 1
      examples/aarch64/src/main.rs
  2. 23 14
      src/transport/pci.rs

+ 1 - 1
examples/aarch64/src/main.rs

@@ -186,7 +186,7 @@ fn enumerate_pci(pci_node: FdtNode, cam: Cam) {
                 allocate_bars(&mut pci_root, device_function, &mut allocator);
                 dump_bar_contents(&mut pci_root, device_function, 4);
                 let mut transport =
-                    PciTransport::new::<HalImpl>(pci_root.clone(), device_function).unwrap();
+                    PciTransport::new::<HalImpl>(&mut pci_root, device_function).unwrap();
                 info!(
                     "Detected virtio PCI device with device type {:?}, features {:#018x}",
                     transport.device_type(),

+ 23 - 14
src/transport/pci.rs

@@ -79,7 +79,7 @@ pub fn virtio_device_type(device_function_info: &DeviceFunctionInfo) -> Option<D
 /// Ref: 4.1 Virtio Over PCI Bus
 #[derive(Debug)]
 pub struct PciTransport {
-    root: PciRoot,
+    device_type: DeviceType,
     /// The bus, device and function identifier for the VirtIO device.
     device_function: DeviceFunction,
     /// The common configuration structure within some BAR.
@@ -104,9 +104,17 @@ impl PciTransport {
     ///
     /// The PCI device must already have had its BARs allocated.
     pub fn new<H: Hal>(
-        mut root: PciRoot,
+        root: &mut PciRoot,
         device_function: DeviceFunction,
     ) -> Result<Self, VirtioPciError> {
+        let device_vendor = root.config_read_word(device_function, 0);
+        let device_id = (device_vendor >> 16) as u16;
+        let vendor_id = device_vendor as u16;
+        if vendor_id != VIRTIO_VENDOR_ID {
+            return Err(VirtioPciError::InvalidVendorId(vendor_id));
+        }
+        let device_type = device_type(device_id);
+
         // Find the PCI capabilities we need.
         let mut common_cfg = None;
         let mut notify_cfg = None;
@@ -153,7 +161,7 @@ impl PciTransport {
         }
 
         let common_cfg = get_bar_region::<H, _>(
-            &mut root,
+            root,
             device_function,
             &common_cfg.ok_or(VirtioPciError::MissingCommonConfig)?,
         )?;
@@ -164,10 +172,10 @@ impl PciTransport {
                 notify_off_multiplier,
             ));
         }
-        let notify_region = get_bar_region::<H, _>(&mut root, device_function, &notify_cfg)?;
+        let notify_region = get_bar_region::<H, _>(root, device_function, &notify_cfg)?;
 
         let isr_status = get_bar_region::<H, _>(
-            &mut root,
+            root,
             device_function,
             &isr_cfg.ok_or(VirtioPciError::MissingIsrConfig)?,
         )?;
@@ -175,11 +183,7 @@ impl PciTransport {
         let config_space;
         let config_space_size;
         if let Some(device_cfg) = device_cfg {
-            config_space = Some(get_bar_region::<H, _>(
-                &mut root,
-                device_function,
-                &device_cfg,
-            )?);
+            config_space = Some(get_bar_region::<H, _>(root, device_function, &device_cfg)?);
             config_space_size = device_cfg.length as usize;
         } else {
             config_space = None;
@@ -187,7 +191,7 @@ impl PciTransport {
         }
 
         Ok(Self {
-            root,
+            device_type,
             device_function,
             common_cfg,
             notify_region,
@@ -202,9 +206,7 @@ impl PciTransport {
 
 impl Transport for PciTransport {
     fn device_type(&self) -> DeviceType {
-        let header = self.root.config_read_word(self.device_function, 0);
-        let device_id = (header >> 16) as u16;
-        device_type(device_id)
+        self.device_type
     }
 
     fn read_device_features(&mut self) -> u64 {
@@ -360,6 +362,8 @@ fn get_bar_region<H: Hal, T>(
 /// An error encountered initialising a VirtIO PCI transport.
 #[derive(Clone, Debug, Eq, PartialEq)]
 pub enum VirtioPciError {
+    /// PCI device vender ID was not the VirtIO vendor ID.
+    InvalidVendorId(u16),
     /// No valid `VIRTIO_PCI_CAP_COMMON_CFG` capability was found.
     MissingCommonConfig,
     /// No valid `VIRTIO_PCI_CAP_NOTIFY_CFG` capability was found.
@@ -382,6 +386,11 @@ pub enum VirtioPciError {
 impl Display for VirtioPciError {
     fn fmt(&self, f: &mut Formatter) -> fmt::Result {
         match self {
+            Self::InvalidVendorId(vendor_id) => write!(
+                f,
+                "PCI device vender ID {:#06x} was not the VirtIO vendor ID {:#06x}.",
+                vendor_id, VIRTIO_VENDOR_ID
+            ),
             Self::MissingCommonConfig => write!(
                 f,
                 "No valid `VIRTIO_PCI_CAP_COMMON_CFG` capability was found."