ソースを参照

fix(mpsc): ensure un-received messages are dropped (#29)

This also adds loom leak checking tests. 

I also made `WaitQueue::close` into an RMW op to work around `loom`
not modeling `SeqCst` properly.
Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* fix(mpsc): ensure un-received messages are dropped

Signed-off-by: Eliza Weisman <eliza@buoyant.io>

* fix(mpsc): make `WaitQueue::close` an RMW

I *think* this only fails loom because it doesn't fully model SeqCst,
correctly...but making this a swap rather than a store ensures it's an
RMW op, which appears to fix the loom test where the close was missed by
a sender...

Signed-off-by: Eliza Weisman <eliza@buoyant.io>
Eliza Weisman 3 年 前
コミット
c444e50b8d

+ 36 - 0
.github/workflows/loom.yml

@@ -55,6 +55,42 @@ jobs:
           command: test
           args: --profile loom --lib -- mpsc_try_send_recv
 
+  async_rx_close_unconsumed:
+    name: "mpsc::rx_close_unconsumed"
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - name: Install stable toolchain
+        uses: actions-rs/toolchain@v1
+        with:
+          profile: minimal
+          toolchain: stable
+          override: true
+          components: rustfmt
+      - name: Run cargo test
+        uses: actions-rs/cargo@v1
+        with:
+          command: test
+          args: --profile loom --lib -- mpsc::rx_close_unconsumed
+
+  sync_rx_close_unconsumed:
+    name: "sync::rx_close_unconsumed"
+    runs-on: ubuntu-latest
+    steps:
+      - uses: actions/checkout@v2
+      - name: Install stable toolchain
+        uses: actions-rs/toolchain@v1
+        with:
+          profile: minimal
+          toolchain: stable
+          override: true
+          components: rustfmt
+      - name: Run cargo test
+        uses: actions-rs/cargo@v1
+        with:
+          command: test
+          args: --profile loom --lib -- mpsc_sync::rx_close_unconsumed
+
   loom_mpsc_async:
     name: "mpsc"
     runs-on: ubuntu-latest

+ 41 - 1
src/lib.rs

@@ -1,7 +1,7 @@
 #![cfg_attr(docsrs, doc = include_str!("../README.md"))]
 #![cfg_attr(not(feature = "std"), no_std)]
 #![cfg_attr(docsrs, feature(doc_cfg, doc_auto_cfg))]
-use core::{cmp, fmt, mem::MaybeUninit, ops};
+use core::{cmp, fmt, mem::MaybeUninit, ops, ptr};
 
 #[macro_use]
 mod macros;
@@ -72,6 +72,8 @@ struct Core {
     idx_mask: usize,
     closed: usize,
     capacity: usize,
+    /// Set when dropping the slots in the ring buffer, to avoid potential double-frees.
+    has_dropped_slots: bool,
 }
 
 struct Slot<T> {
@@ -94,6 +96,8 @@ impl Core {
             closed,
             idx_mask,
             capacity,
+
+            has_dropped_slots: false,
         }
     }
 
@@ -111,6 +115,9 @@ impl Core {
             gen_mask,
             idx_mask,
             capacity,
+
+            #[cfg(debug_assertions)]
+            has_dropped_slots: false,
         }
     }
 
@@ -321,6 +328,39 @@ impl Core {
             }
         }
     }
+
+    fn drop_slots<T>(&mut self, slots: &mut [Slot<T>]) {
+        debug_assert!(
+            !self.has_dropped_slots,
+            "tried to drop slots twice! core={:#?}",
+            self
+        );
+        if self.has_dropped_slots {
+            return;
+        }
+
+        let tail = self.tail.load(SeqCst);
+        let (idx, gen) = self.idx_gen(tail);
+        let num_initialized = if gen > 0 { self.capacity() } else { idx };
+        for slot in &mut slots[..num_initialized] {
+            unsafe {
+                slot.value
+                    .with_mut(|value| ptr::drop_in_place((*value).as_mut_ptr()));
+            }
+        }
+
+        self.has_dropped_slots = true;
+    }
+}
+
+impl Drop for Core {
+    fn drop(&mut self) {
+        debug_assert!(
+            self.has_dropped_slots,
+            "tried to drop Core without dropping slots! core={:#?}",
+            self
+        );
+    }
 }
 
 // === impl Ref ===

+ 16 - 7
src/mpsc/async_impl.rs

@@ -595,16 +595,25 @@ impl<T> PinnedDrop for SendRefFuture<'_, T> {
     }
 }
 
-#[cfg(feature = "alloc")]
-impl<T> fmt::Debug for Inner<T> {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.debug_struct("Inner")
-            .field("core", &self.core)
-            .field("slots", &format_args!("Box<[..]>"))
-            .finish()
+feature! {
+    #![feature = "alloc"]
+    impl<T> fmt::Debug for Inner<T> {
+        fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+            f.debug_struct("Inner")
+                .field("core", &self.core)
+                .field("slots", &format_args!("Box<[..]>"))
+                .finish()
+        }
+    }
+
+    impl<T> Drop for Inner<T> {
+        fn drop(&mut self) {
+            self.core.core.drop_slots(&mut self.slots[..])
+        }
     }
 }
 
+#[cfg(feature = "alloc")]
 #[cfg(test)]
 mod tests {
     use super::*;

+ 16 - 5
src/mpsc/sync.rs

@@ -10,6 +10,7 @@ use crate::{
         sync::Arc,
         thread::{self, Thread},
     },
+    util::Backoff,
     wait::queue,
     Ref,
 };
@@ -385,6 +386,12 @@ impl<T> fmt::Debug for Inner<T> {
     }
 }
 
+impl<T> Drop for Inner<T> {
+    fn drop(&mut self) {
+        self.core.core.drop_slots(&mut self.slots[..])
+    }
+}
+
 #[inline]
 fn recv_ref<'a, T: Default>(
     core: &'a ChannelCore<Thread>,
@@ -422,6 +429,7 @@ fn send_ref<'a, T: Default>(
     let mut waiter = queue::Waiter::new();
     let mut unqueued = true;
     let thread = thread::current();
+    let mut boff = Backoff::new();
     loop {
         let node = unsafe {
             // Safety: in this case, it's totally safe to pin the waiter, as
@@ -438,11 +446,14 @@ fn send_ref<'a, T: Default>(
 
         match wait {
             WaitResult::Closed => return Err(Closed(())),
-            WaitResult::Notified => match core.try_send_ref(slots.as_ref()) {
-                Ok(slot) => return Ok(SendRef(slot)),
-                Err(TrySendError::Closed(_)) => return Err(Closed(())),
-                _ => {}
-            },
+            WaitResult::Notified => {
+                boff.spin_yield();
+                match core.try_send_ref(slots.as_ref()) {
+                    Ok(slot) => return Ok(SendRef(slot)),
+                    Err(TrySendError::Closed(_)) => return Err(Closed(())),
+                    _ => {}
+                }
+            }
             WaitResult::Wait => {
                 unqueued = false;
                 thread::park();

+ 76 - 5
src/mpsc/tests/mpsc_async.rs

@@ -1,6 +1,6 @@
 use super::*;
 use crate::{
-    loom::{self, future, thread},
+    loom::{self, alloc::Track, future, thread},
     ThingBuf,
 };
 
@@ -13,18 +13,18 @@ fn mpsc_try_send_recv() {
         let p1 = {
             let tx = tx.clone();
             thread::spawn(move || {
-                *tx.try_send_ref().unwrap() = 1;
+                *tx.try_send_ref().unwrap() = Track::new(1);
             })
         };
         let p2 = thread::spawn(move || {
-            *tx.try_send_ref().unwrap() = 2;
-            *tx.try_send_ref().unwrap() = 3;
+            *tx.try_send_ref().unwrap() = Track::new(2);
+            *tx.try_send_ref().unwrap() = Track::new(3);
         });
 
         let mut vals = future::block_on(async move {
             let mut vals = Vec::new();
             while let Some(val) = rx.recv_ref().await {
-                vals.push(*val);
+                vals.push(*val.get_ref());
             }
             vals
         });
@@ -74,6 +74,77 @@ fn rx_closes() {
     })
 }
 
+#[test]
+fn rx_close_unconsumed_spsc() {
+    // Tests that messages that have not been consumed by the receiver are
+    // dropped when dropping the channel.
+    const MESSAGES: usize = 4;
+
+    loom::model(|| {
+        let (tx, rx) = channel(MESSAGES);
+
+        let consumer = thread::spawn(move || {
+            future::block_on(async move {
+                // recieve one message
+                let msg = rx.recv().await;
+                test_println!("recv {:?}", msg);
+                assert!(msg.is_some());
+                // drop the receiver...
+            })
+        });
+
+        future::block_on(async move {
+            let mut i = 1;
+            while let Ok(mut slot) = tx.send_ref().await {
+                test_println!("producer sending {}...", i);
+                *slot = Track::new(i);
+                i += 1;
+            }
+        });
+
+        consumer.join().unwrap();
+    })
+}
+
+#[test]
+#[ignore] // This is marked as `ignore` because it takes over an hour to run.
+fn rx_close_unconsumed_mpsc() {
+    const MESSAGES: usize = 2;
+
+    async fn do_producer(tx: Sender<Track<i32>>, n: usize) {
+        let mut i = 1;
+        while let Ok(mut slot) = tx.send_ref().await {
+            test_println!("producer {} sending {}...", n, i);
+            *slot = Track::new(i);
+            i += 1;
+        }
+    }
+
+    loom::model(|| {
+        let (tx, rx) = channel(MESSAGES);
+
+        let consumer = thread::spawn(move || {
+            future::block_on(async move {
+                // recieve one message
+                let msg = rx.recv().await;
+                test_println!("recv {:?}", msg);
+                assert!(msg.is_some());
+                // drop the receiver...
+            })
+        });
+
+        let producer = {
+            let tx = tx.clone();
+            thread::spawn(move || future::block_on(do_producer(tx, 1)))
+        };
+
+        future::block_on(do_producer(tx, 2));
+
+        producer.join().unwrap();
+        consumer.join().unwrap();
+    })
+}
+
 #[test]
 fn spsc_recv_then_send() {
     loom::model(|| {

+ 66 - 1
src/mpsc/tests/mpsc_sync.rs

@@ -1,6 +1,6 @@
 use super::*;
 use crate::{
-    loom::{self, thread},
+    loom::{self, alloc::Track, thread},
     ThingBuf,
 };
 
@@ -78,6 +78,71 @@ fn rx_closes() {
     })
 }
 
+#[test]
+fn rx_close_unconsumed_spsc() {
+    // Tests that messages that have not been consumed by the receiver are
+    // dropped when dropping the channel.
+    const MESSAGES: usize = 4;
+
+    loom::model(|| {
+        let (tx, rx) = sync::channel(MESSAGES);
+
+        let consumer = thread::spawn(move || {
+            // recieve one message
+            let msg = rx.recv();
+            test_println!("recv {:?}", msg);
+            assert!(msg.is_some());
+            // drop the receiver...
+        });
+
+        let mut i = 1;
+        while let Ok(mut slot) = tx.send_ref() {
+            test_println!("producer sending {}...", i);
+            *slot = Track::new(i);
+            i += 1;
+        }
+
+        consumer.join().unwrap();
+        drop(tx);
+    })
+}
+
+#[test]
+#[ignore] // This is marked as `ignore` because it takes over an hour to run.
+fn rx_close_unconsumed_mpsc() {
+    const MESSAGES: usize = 2;
+
+    fn do_producer(tx: sync::Sender<Track<i32>>, n: usize) -> impl FnOnce() + Send + Sync {
+        move || {
+            let mut i = 1;
+            while let Ok(mut slot) = tx.send_ref() {
+                test_println!("producer {} sending {}...", n, i);
+                *slot = Track::new(i);
+                i += 1;
+            }
+        }
+    }
+
+    loom::model(|| {
+        let (tx, rx) = sync::channel(MESSAGES);
+
+        let consumer = thread::spawn(move || {
+            // recieve one message
+            let msg = rx.recv();
+            test_println!("recv {:?}", msg);
+            assert!(msg.is_some());
+            // drop the receiver...
+        });
+
+        let producer = thread::spawn(do_producer(tx.clone(), 1));
+
+        do_producer(tx, 2)();
+
+        producer.join().unwrap();
+        consumer.join().unwrap();
+    })
+}
+
 #[test]
 fn spsc_recv_then_try_send() {
     loom::model(|| {

+ 7 - 16
src/static_thingbuf.rs

@@ -1,6 +1,5 @@
-use crate::loom::atomic::Ordering;
 use crate::{Core, Full, Ref, Slot};
-use core::{fmt, mem, ptr};
+use core::{fmt, mem};
 
 /// A statically allocated, fixed-size lock-free multi-producer multi-consumer
 /// queue.
@@ -486,20 +485,6 @@ impl<T: Default, const CAP: usize> StaticThingBuf<T, CAP> {
     }
 }
 
-impl<T, const CAP: usize> Drop for StaticThingBuf<T, CAP> {
-    fn drop(&mut self) {
-        let tail = self.core.tail.load(Ordering::SeqCst);
-        let (idx, gen) = self.core.idx_gen(tail);
-        let num_initialized = if gen > 0 { self.capacity() } else { idx };
-        for slot in &self.slots[..num_initialized] {
-            unsafe {
-                slot.value
-                    .with_mut(|value| ptr::drop_in_place((*value).as_mut_ptr()));
-            }
-        }
-    }
-}
-
 impl<T, const CAP: usize> fmt::Debug for StaticThingBuf<T, CAP> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         f.debug_struct("StaticThingBuf")
@@ -509,3 +494,9 @@ impl<T, const CAP: usize> fmt::Debug for StaticThingBuf<T, CAP> {
             .finish()
     }
 }
+
+impl<T, const CAP: usize> Drop for StaticThingBuf<T, CAP> {
+    fn drop(&mut self) {
+        self.core.drop_slots(&mut self.slots[..]);
+    }
+}

+ 2 - 11
src/thingbuf.rs

@@ -1,7 +1,6 @@
-use crate::loom::atomic::Ordering;
 use crate::{Core, Full, Ref, Slot};
 use alloc::boxed::Box;
-use core::{fmt, mem, ptr};
+use core::{fmt, mem};
 
 #[cfg(all(loom, test))]
 mod tests;
@@ -474,15 +473,7 @@ impl<T> ThingBuf<T> {
 
 impl<T> Drop for ThingBuf<T> {
     fn drop(&mut self) {
-        let tail = self.core.tail.load(Ordering::SeqCst);
-        let (idx, gen) = self.core.idx_gen(tail);
-        let num_initialized = if gen > 0 { self.capacity() } else { idx };
-        for slot in &self.slots[..num_initialized] {
-            unsafe {
-                slot.value
-                    .with_mut(|value| ptr::drop_in_place((*value).as_mut_ptr()));
-            }
-        }
+        self.core.drop_slots(&mut self.slots[..]);
     }
 }
 

+ 1 - 1
src/wait/queue.rs

@@ -437,7 +437,7 @@ impl<T: Notify + Unpin> WaitQueue<T> {
     pub(crate) fn close(&self) {
         test_println!("WaitQueue::close()");
 
-        test_dbg!(self.state.store(CLOSED, SeqCst));
+        test_dbg!(self.state.swap(CLOSED, SeqCst));
         let mut list = self.list.lock();
         while !list.is_empty() {
             if let Some(waiter) = list.dequeue(CLOSED) {