Parcourir la source

Use AtomicBool for fake queue notification status.

This prevents races between the thread notifying the queue and the
thread waiting for it to be notified.
Andrew Walbran il y a 2 ans
Parent
commit
2b545f5eb0
4 fichiers modifiés avec 30 ajouts et 13 suppressions
  1. 3 3
      src/device/blk.rs
  2. 2 2
      src/device/console.rs
  3. 10 2
      src/device/socket/vsock.rs
  4. 15 6
      src/transport/fake.rs

+ 3 - 3
src/device/blk.rs

@@ -501,7 +501,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -537,7 +537,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -610,7 +610,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 2 - 2
src/device/console.rs

@@ -261,7 +261,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -309,7 +309,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 10 - 2
src/device/socket/vsock.rs

@@ -583,7 +583,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,
@@ -615,7 +619,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,

+ 15 - 6
src/transport/fake.rs

@@ -4,7 +4,12 @@ use crate::{
     PhysAddr, Result,
 };
 use alloc::{sync::Arc, vec::Vec};
-use core::{any::TypeId, ptr::NonNull, time::Duration};
+use core::{
+    any::TypeId,
+    ptr::NonNull,
+    sync::atomic::{AtomicBool, Ordering},
+    time::Duration,
+};
 use std::{sync::Mutex, thread};
 
 /// A fake implementation of [`Transport`] for unit tests.
@@ -35,7 +40,9 @@ impl<C> Transport for FakeTransport<C> {
     }
 
     fn notify(&mut self, queue: u16) {
-        self.state.lock().unwrap().queues[queue as usize].notified = true;
+        self.state.lock().unwrap().queues[queue as usize]
+            .notified
+            .store(true, Ordering::SeqCst);
     }
 
     fn get_status(&self) -> DeviceStatus {
@@ -171,18 +178,20 @@ impl State {
 
     /// Waits until the given queue is notified.
     pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) {
-        while !state.lock().unwrap().queues[usize::from(queue_index)].notified {
+        while !state.lock().unwrap().queues[usize::from(queue_index)]
+            .notified
+            .swap(false, Ordering::SeqCst)
+        {
             thread::sleep(Duration::from_millis(10));
         }
-        state.lock().unwrap().queues[usize::from(queue_index)].notified = false;
     }
 }
 
-#[derive(Clone, Debug, Default, Eq, PartialEq)]
+#[derive(Debug, Default)]
 pub struct QueueStatus {
     pub size: u32,
     pub descriptors: PhysAddr,
     pub driver_area: PhysAddr,
     pub device_area: PhysAddr,
-    pub notified: bool,
+    pub notified: AtomicBool,
 }