Преглед изворни кода

Use AtomicBool for fake queue notification status.

This prevents races between the thread notifying the queue and the
thread waiting for it to be notified.
Andrew Walbran пре 2 година
родитељ
комит
2b545f5eb0
4 измењених фајлова са 30 додато и 13 уклоњено
  1. 3 3
      src/device/blk.rs
  2. 2 2
      src/device/console.rs
  3. 10 2
      src/device/socket/vsock.rs
  4. 15 6
      src/transport/fake.rs

+ 3 - 3
src/device/blk.rs

@@ -501,7 +501,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -537,7 +537,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -610,7 +610,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 2 - 2
src/device/console.rs

@@ -261,7 +261,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -309,7 +309,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 10 - 2
src/device/socket/vsock.rs

@@ -583,7 +583,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,
@@ -615,7 +619,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,

+ 15 - 6
src/transport/fake.rs

@@ -4,7 +4,12 @@ use crate::{
     PhysAddr, Result,
 };
 use alloc::{sync::Arc, vec::Vec};
-use core::{any::TypeId, ptr::NonNull, time::Duration};
+use core::{
+    any::TypeId,
+    ptr::NonNull,
+    sync::atomic::{AtomicBool, Ordering},
+    time::Duration,
+};
 use std::{sync::Mutex, thread};
 
 /// A fake implementation of [`Transport`] for unit tests.
@@ -35,7 +40,9 @@ impl<C> Transport for FakeTransport<C> {
     }
 
     fn notify(&mut self, queue: u16) {
-        self.state.lock().unwrap().queues[queue as usize].notified = true;
+        self.state.lock().unwrap().queues[queue as usize]
+            .notified
+            .store(true, Ordering::SeqCst);
     }
 
     fn get_status(&self) -> DeviceStatus {
@@ -171,18 +178,20 @@ impl State {
 
     /// Waits until the given queue is notified.
     pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) {
-        while !state.lock().unwrap().queues[usize::from(queue_index)].notified {
+        while !state.lock().unwrap().queues[usize::from(queue_index)]
+            .notified
+            .swap(false, Ordering::SeqCst)
+        {
             thread::sleep(Duration::from_millis(10));
         }
-        state.lock().unwrap().queues[usize::from(queue_index)].notified = false;
     }
 }
 
-#[derive(Clone, Debug, Default, Eq, PartialEq)]
+#[derive(Debug, Default)]
 pub struct QueueStatus {
     pub size: u32,
     pub descriptors: PhysAddr,
     pub driver_area: PhysAddr,
     pub device_area: PhysAddr,
-    pub notified: bool,
+    pub notified: AtomicBool,
 }