浏览代码

Use AtomicBool for fake queue notification status.

This prevents races between the thread notifying the queue and the
thread waiting for it to be notified.
Andrew Walbran 2 年之前
父节点
当前提交
2b545f5eb0
共有 4 个文件被更改,包括 30 次插入13 次删除
  1. 3 3
      src/device/blk.rs
  2. 2 2
      src/device/console.rs
  3. 10 2
      src/device/socket/vsock.rs
  4. 15 6
      src/transport/fake.rs

+ 3 - 3
src/device/blk.rs

@@ -501,7 +501,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -537,7 +537,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -610,7 +610,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 1],
+            queues: vec![QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 2 - 2
src/device/console.rs

@@ -261,7 +261,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,
@@ -309,7 +309,7 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 2],
+            queues: vec![QueueStatus::default(), QueueStatus::default()],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Console,

+ 10 - 2
src/device/socket/vsock.rs

@@ -583,7 +583,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,
@@ -615,7 +619,11 @@ mod tests {
             driver_features: 0,
             guest_page_size: 0,
             interrupt_pending: false,
-            queues: vec![QueueStatus::default(); 3],
+            queues: vec![
+                QueueStatus::default(),
+                QueueStatus::default(),
+                QueueStatus::default(),
+            ],
         }));
         let transport = FakeTransport {
             device_type: DeviceType::Socket,

+ 15 - 6
src/transport/fake.rs

@@ -4,7 +4,12 @@ use crate::{
     PhysAddr, Result,
 };
 use alloc::{sync::Arc, vec::Vec};
-use core::{any::TypeId, ptr::NonNull, time::Duration};
+use core::{
+    any::TypeId,
+    ptr::NonNull,
+    sync::atomic::{AtomicBool, Ordering},
+    time::Duration,
+};
 use std::{sync::Mutex, thread};
 
 /// A fake implementation of [`Transport`] for unit tests.
@@ -35,7 +40,9 @@ impl<C> Transport for FakeTransport<C> {
     }
 
     fn notify(&mut self, queue: u16) {
-        self.state.lock().unwrap().queues[queue as usize].notified = true;
+        self.state.lock().unwrap().queues[queue as usize]
+            .notified
+            .store(true, Ordering::SeqCst);
     }
 
     fn get_status(&self) -> DeviceStatus {
@@ -171,18 +178,20 @@ impl State {
 
     /// Waits until the given queue is notified.
     pub fn wait_until_queue_notified(state: &Mutex<Self>, queue_index: u16) {
-        while !state.lock().unwrap().queues[usize::from(queue_index)].notified {
+        while !state.lock().unwrap().queues[usize::from(queue_index)]
+            .notified
+            .swap(false, Ordering::SeqCst)
+        {
             thread::sleep(Duration::from_millis(10));
         }
-        state.lock().unwrap().queues[usize::from(queue_index)].notified = false;
     }
 }
 
-#[derive(Clone, Debug, Default, Eq, PartialEq)]
+#[derive(Debug, Default)]
 pub struct QueueStatus {
     pub size: u32,
     pub descriptors: PhysAddr,
     pub driver_area: PhysAddr,
     pub device_area: PhysAddr,
-    pub notified: bool,
+    pub notified: AtomicBool,
 }