Explorar el Código

Use AtomicBool for fake queue notification status.

This prevents races between the thread notifying the queue and the
thread waiting for it to be notified.
Andrew Walbran hace 2 años
padre
commit
2b545f5eb0
Se han modificado 4 ficheros con 30 adiciones y 13 borrados
  1. 3 3
      src/device/blk.rs
  2. 2 2
      src/device/console.rs
  3. 10 2
      src/device/socket/vsock.rs
  4. 15 6
      src/transport/fake.rs

+ 3 - 3
src/device/blk.rs

@@ -501,7 +501,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -537,7 +537,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -610,7 +610,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 2 - 2
src/device/console.rs

@@ -261,7 +261,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -309,7 +309,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 10 - 2
src/device/socket/vsock.rs

@@ -583,7 +583,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,
@@ -615,7 +619,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,

+ 15 - 6
src/transport/fake.rs

@@ -4,7 +4,12 @@ use crate::{
     PhysAddr, Result,
 };
 use alloc::{sync::Arc, vec::Vec};
-use core::{any::TypeId, ptr::NonNull, time::Duration};
+use core::{
+    any::TypeId,
+    ptr::NonNull,
+    sync::atomic::{AtomicBool, Ordering},
+    time::Duration,
+};
 use std::{sync::Mutex, thread};
 
 /// A fake implementation of [`Transport`] for unit tests.
@@ -35,7 +40,9 @@ impl<C> Transport for FakeTransport<C> {
     }
 
     fn notify(&mut self, queue: u16) {
-        self.state.lock().unwrap().queues[queue as usize].notified = true;
+        self.state.lock().unwrap().queues[queue as usize]
+            .notified
+            .store(true, Ordering::SeqCst);
     }
 
     fn get_status(&self) -> DeviceStatus {
@@ -171,18 +178,20 @@ impl State {
 
     /// Waits until the given queue is notified.
     pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) {
-        while !state.lock().unwrap().queues[usize::from(queue_index)].notified {
+        while !state.lock().unwrap().queues[usize::from(queue_index)]
+            .notified
+            .swap(false, Ordering::SeqCst)
+        {
             thread::sleep(Duration::from_millis(10));
         }
-        state.lock().unwrap().queues[usize::from(queue_index)].notified = false;
     }
 }
 
-#[derive(Clone, Debug, Default, Eq, PartialEq)]
+#[derive(Debug, Default)]
 pub struct QueueStatus {
     pub size: u32,
     pub descriptors: PhysAddr,
     pub driver_area: PhysAddr,
     pub device_area: PhysAddr,
-    pub notified: bool,
+    pub notified: AtomicBool,
 }