瀏覽代碼

fix(mpsc): fix a deadlock in async send_ref (#20)

This fixes a deadlock issue in the async MPSC's `send_ref` method. The
deadlock occurs when a new waker needs to be registered for a task whose
wait node is already in the wait queue. Previously, the new waker would
not be registered because the waker registering closure was only called
when the node was being enqueued. If the node was already in the queue,
polling the future would never touch the waker. This means that if the
task was polled with a new waker, it would leave its old waker in the
queue, and might never be notified again.

This branch fixes that by separating pushing the task and registering
the waker. We check if the node already has a waker prior to registering,
and if it did, we don't push it again.

Signed-off-by: Eliza Weisman <[email protected]>
Eliza Weisman 3 年之前
父節點
當前提交
c58c620096
共有 6 個文件被更改,包括 116 次插入117 次删除
  1. 1 0
      src/lib.rs
  2. 1 1
      src/macros.rs
  3. 4 2
      src/mpsc.rs
  4. 45 31
      src/mpsc/async_impl.rs
  5. 1 1
      src/mpsc/sync.rs
  6. 64 82
      src/wait/queue.rs

+ 1 - 0
src/lib.rs

@@ -131,6 +131,7 @@ impl Core {
     }
 
     fn close(&self) -> bool {
+        test_println!("Core::close");
         if std::thread::panicking() {
             return false;
         }

+ 1 - 1
src/macros.rs

@@ -21,7 +21,7 @@ macro_rules! test_dbg {
     ($e:expr) => {
         match $e {
             e => {
-                #[cfg(test)]
+                #[cfg(any(test, all(thingbuf_trace, feature = "std")))]
                 test_println!("{} = {:?}", stringify!($e), &e);
                 e
             }

+ 4 - 2
src/mpsc.rs

@@ -120,10 +120,11 @@ impl<T: Default, N: Notify + Unpin> Inner<T, N> {
     /// may yield, or might park the thread.
     fn poll_send_ref(
         &self,
-        mut node: Option<Pin<&mut queue::Waiter<N>>>,
+        node: Pin<&mut queue::Waiter<N>>,
         mut register: impl FnMut(&mut Option<N>),
     ) -> Poll<Result<SendRefInner<'_, T, N>, Closed>> {
         let mut backoff = Backoff::new();
+        let mut node = Some(node);
         // try to send a few times in a loop, in case the receiver notifies us
         // right before we park.
         loop {
@@ -136,7 +137,7 @@ impl<T: Default, N: Notify + Unpin> Inner<T, N> {
             }
 
             // try to push a waiter
-            let pushed_waiter = self.tx_wait.push_waiter(&mut node, &mut register);
+            let pushed_waiter = self.tx_wait.wait(&mut node, &mut register);
 
             match test_dbg!(pushed_waiter) {
                 WaitResult::Closed => {
@@ -187,6 +188,7 @@ impl<T: Default, N: Notify + Unpin> Inner<T, N> {
                     // just in case someone sent a message while we were
                     // registering the waiter.
                     try_poll_recv!();
+                    test_println!("-> yield");
                     return Poll::Pending;
                 }
                 WaitResult::Closed => {

+ 45 - 31
src/mpsc/async_impl.rs

@@ -77,7 +77,7 @@ impl<T: Default> Sender<T> {
         #[pin_project::pin_project(PinnedDrop)]
         struct SendRefFuture<'sender, T> {
             tx: &'sender Sender<T>,
-            has_been_queued: bool,
+            queued: bool,
             #[pin]
             waiter: queue::Waiter<Waker>,
         }
@@ -88,41 +88,55 @@ impl<T: Default> Sender<T> {
             fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
                 test_println!("SendRefFuture::poll({:p})", self);
                 // perform one send ref loop iteration
-
-                let this = self.as_mut().project();
-                let waiter = if test_dbg!(*this.has_been_queued) {
-                    None
-                } else {
-                    Some(this.waiter)
-                };
-                this.tx
-                    .inner
-                    .poll_send_ref(waiter, |waker| {
-                        // if this is called, we are definitely getting queued.
-                        *this.has_been_queued = true;
-
-                        // if the wait node does not already have a waker, or the task
-                        // has been polled with a waker that won't wake the previous
-                        // one, register a new waker.
+                let res = {
+                    let this = self.as_mut().project();
+                    this.tx.inner.poll_send_ref(this.waiter, |waker| {
                         let my_waker = cx.waker();
-                        // do we need to re-register?
-                        let will_wake = waker
-                            .as_ref()
-                            .map(|waker| test_dbg!(waker.will_wake(my_waker)))
-                            .unwrap_or(false);
 
-                        if test_dbg!(will_wake) {
+                        // If there's already a waker in the node, we might have
+                        // been woken spuriously for some reason. In that case,
+                        // make sure that the waker in the node will wake the
+                        // waker that was passed in on *this* poll --- the
+                        // future may have moved to another task or something!
+                        if let Some(waker) = waker.as_mut() {
+                            if test_dbg!(!waker.will_wake(my_waker)) {
+                                test_println!(
+                                    "poll_send_ref -> re-registering waker {:?}",
+                                    my_waker
+                                );
+                                *waker = my_waker.clone();
+                            }
                             return;
                         }
 
+                        // Otherwise, we are registering this task for the first
+                        // time.
+                        test_println!("poll_send_ref -> registering initial waker {:?}", my_waker);
                         *waker = Some(my_waker.clone());
+                        *this.queued = true;
                     })
-                    .map(|ok| {
-                        // avoid having to lock the list to remove a node that's
-                        // definitely not queued.
-                        *this.has_been_queued = false;
-                        ok.map(SendRef)
-                    })
+                };
+                res.map(|ready| {
+                    let this = self.as_mut().project();
+                    if test_dbg!(*this.queued) {
+                        // If the node was ever in the queue, we have to make
+                        // sure we're *absolutely certain* it isn't still in the
+                        // queue before we say it's okay to drop the node
+                        // without removing it from the linked list. Check to
+                        // make sure we were woken by the queue, and not by a
+                        // spurious wakeup.
+                        //
+                        // This means we *may* be a little bit aggressive about
+                        // locking the wait queue to make sure the node is
+                        // removed, but that's better than leaving dangling
+                        // pointers in the queue...
+                        *this.queued = test_dbg!(!this
+                            .waiter
+                            .was_woken_from_queue
+                            .swap(false, Ordering::AcqRel));
+                    }
+                    ready.map(SendRef)
+                })
             }
         }
 
@@ -130,7 +144,7 @@ impl<T: Default> Sender<T> {
         impl<T> PinnedDrop for SendRefFuture<'_, T> {
             fn drop(self: Pin<&mut Self>) {
                 test_println!("SendRefFuture::drop({:p})", self);
-                if test_dbg!(self.has_been_queued) {
+                if test_dbg!(self.queued) {
                     let this = self.project();
                     this.waiter.remove(&this.tx.inner.tx_wait)
                 }
@@ -139,7 +153,7 @@ impl<T: Default> Sender<T> {
 
         SendRefFuture {
             tx: self,
-            has_been_queued: false,
+            queued: false,
             waiter: queue::Waiter::new(),
         }
         .await

+ 1 - 1
src/mpsc/sync.rs

@@ -65,7 +65,7 @@ impl<T: Default> Sender<T> {
                 // be moved while this thread is parked.
                 Pin::new_unchecked(&mut waiter)
             };
-            if let Poll::Ready(result) = self.inner.poll_send_ref(Some(waiter), |thread| {
+            if let Poll::Ready(result) = self.inner.poll_send_ref(waiter, |thread| {
                 if thread.is_none() {
                     let current = thread::current();
                     test_println!("registering {:?}", current);

+ 64 - 82
src/wait/queue.rs

@@ -1,6 +1,6 @@
 use crate::{
     loom::{
-        atomic::{AtomicUsize, Ordering::*},
+        atomic::{AtomicBool, AtomicUsize, Ordering::*},
         cell::UnsafeCell,
     },
     util::{mutex::Mutex, CachePadded},
@@ -86,6 +86,7 @@ pub(crate) struct WaitQueue<T> {
 #[derive(Debug)]
 pub(crate) struct Waiter<T> {
     node: UnsafeCell<Node<T>>,
+    pub(crate) was_woken_from_queue: AtomicBool,
 }
 
 #[derive(Debug)]
@@ -108,63 +109,42 @@ struct List<T> {
     tail: Link<Waiter<T>>,
 }
 
-const CLOSED: usize = 1 << 0;
-const ONE_QUEUED: usize = 1 << 1;
+const CLOSED: usize = 1;
+const ONE_QUEUED: usize = 2;
+const EMPTY: usize = 0;
 
 impl<T: Notify + Unpin> WaitQueue<T> {
     pub(crate) fn new() -> Self {
         Self {
-            state: CachePadded(AtomicUsize::new(0)),
+            state: CachePadded(AtomicUsize::new(EMPTY)),
             list: Mutex::new(List::new()),
         }
     }
 
-    pub(crate) fn push_waiter(
+    pub(crate) fn wait(
         &self,
         waiter: &mut Option<Pin<&mut Waiter<T>>>,
         register: impl FnOnce(&mut Option<T>),
     ) -> WaitResult {
         test_println!("WaitQueue::push_waiter()");
 
-        let mut state = test_dbg!(self.state.load(Acquire));
-
         // First, go ahead and check if the queue has been closed. This is
         // necessary even if `waiter` is `None`, as the waiter may already be
         // queued, and just checking if the list was closed.
         // TODO(eliza): that actually kind of sucks lol...
-        if test_dbg!(state & CLOSED != 0) {
-            return WaitResult::Closed;
+        // Is there at least one queued notification assigned to the wait
+        // queue? If so, try to consume that now, rather than waiting.
+        match test_dbg!(self
+            .state
+            .compare_exchange(ONE_QUEUED, EMPTY, AcqRel, Acquire))
+        {
+            Ok(_) => return WaitResult::Notified,
+            Err(CLOSED) => return WaitResult::Closed,
+            Err(_state) => debug_assert_eq!(_state, EMPTY),
         }
 
         // If we were actually called with a real waiter, try to queue the node.
         if test_dbg!(waiter.is_some()) {
-            // Is there at least one queued notification assigned to the wait
-            // queue? If so, try to consume that now, rather than waiting.
-            while test_dbg!(state >= ONE_QUEUED) {
-                match test_dbg!(self.state.compare_exchange_weak(
-                    state,
-                    // Subtract one queued notification from the current state.
-                    state.saturating_sub(ONE_QUEUED),
-                    AcqRel,
-                    Acquire
-                )) {
-                    // We consumed a queued notification! Return `Notified`
-                    // now, so that we'll try our operation again, instead
-                    // of waiting.
-                    Ok(_) => return WaitResult::Notified,
-                    // Someone else was updating the state variable. Try again
-                    // --- but they may have closed the queue, or consumed the last
-                    // queued notification!
-                    Err(actual) => state = test_dbg!(actual),
-                }
-            }
-
-            // Okay, did the queue close while we were trying to consume a
-            // queued notification?
-            if test_dbg!(state & CLOSED != 0) {
-                return WaitResult::Closed;
-            }
-
             // There are no queued notifications to consume, and the queue is
             // still open. Therefore, it's time to actually push the waiter to
             // the queue...finally lol :)
@@ -175,36 +155,46 @@ impl<T: Notify + Unpin> WaitQueue<T> {
             // Okay, we have the lock...but what if someone changed the state
             // WHILE we were waiting to acquire the lock? isn't concurrent
             // programming great? :) :) :) :) :)
-            state = test_dbg!(self.state.load(Acquire));
             // Try to consume a queued notification *again* in case any were
             // assigned to the queue while we were waiting to acquire the lock.
-            while test_dbg!(state >= ONE_QUEUED) {
-                match test_dbg!(self.state.compare_exchange(
-                    state,
-                    state.saturating_sub(ONE_QUEUED),
-                    AcqRel,
-                    Acquire
-                )) {
-                    Ok(_) => return WaitResult::Notified,
-                    Err(actual) => state = actual,
-                }
+            match test_dbg!(self
+                .state
+                .compare_exchange(ONE_QUEUED, EMPTY, AcqRel, Acquire))
+            {
+                Ok(_) => return WaitResult::Notified,
+                Err(CLOSED) => return WaitResult::Closed,
+                Err(_state) => debug_assert_eq!(_state, EMPTY),
             }
 
             // We didn't consume a queued notification. it is now, finally, time
             // to actually put the waiter in the linked list. wasn't that fun?
 
             if let Some(waiter) = waiter.take() {
-                test_println!("WaitQueue::push_waiter -> pushing {:p}", waiter);
-
                 // Now that we have the lock, register the `Waker` or `Thread`
                 // to
-                unsafe {
+                let should_queue = unsafe {
+                    test_println!("WaitQueue::push_waiter -> registering {:p}", waiter);
                     // Safety: the waker can only be registered while holding
                     // the wait queue lock. We are holding the lock, so no one
                     // else will try to touch the waker until we're done.
-                    waiter.with_node(|node| register(&mut node.waiter));
+                    waiter.with_node(|node| {
+                        // Does the node need to be added to the wait queue? If
+                        // it currently has a waiter (prior to registering),
+                        // then we know it's already in the queue. Otherwise, if
+                        // it doesn't have a waiter, it is either waiting for
+                        // the first time, or it is re-registering after a
+                        // notification that it wasn't able to consume (for some
+                        // reason).
+                        let should_queue = node.waiter.is_none();
+                        register(&mut node.waiter);
+                        should_queue
+                    })
+                };
+                if test_dbg!(should_queue) {
+                    test_println!("WaitQueue::push_waiter -> pushing {:p}", waiter);
+                    test_dbg!(waiter.was_woken_from_queue.swap(false, AcqRel));
+                    list.push_front(waiter);
                 }
-                list.push_front(waiter);
             } else {
                 // XXX(eliza): in practice we can't ever get here because of the
                 // `if` above. this should probably be `unreachable_unchecked`
@@ -223,11 +213,16 @@ impl<T: Notify + Unpin> WaitQueue<T> {
     /// notification was assigned to the queue, returns `false`.
     pub(crate) fn notify(&self) -> bool {
         test_println!("WaitQueue::notify()");
-        if let Some(node) = test_dbg!(self.list.lock().pop_back()) {
+        let mut list = self.list.lock();
+        if let Some(node) = list.pop_back() {
+            drop(list);
+            test_println!("notifying {:?}", node);
             node.notify();
             true
         } else {
-            test_dbg!(self.state.fetch_add(ONE_QUEUED, Release));
+            test_println!("no waiters to notify...");
+            // This can be relaxed because we're holding the lock.
+            test_dbg!(self.state.store(ONE_QUEUED, Relaxed));
             false
         }
     }
@@ -235,7 +230,7 @@ impl<T: Notify + Unpin> WaitQueue<T> {
     /// Close the queue, notifying all waiting tasks.
     pub(crate) fn close(&self) {
         test_println!("WaitQueue::close()");
-        test_dbg!(self.state.fetch_or(CLOSED, Release));
+        test_dbg!(self.state.store(CLOSED, Release));
         let mut list = self.list.lock();
         while let Some(node) = list.pop_back() {
             node.notify();
@@ -254,18 +249,9 @@ impl<T: Notify> Waiter<T> {
                 waiter: None,
                 _pin: PhantomPinned,
             }),
+            was_woken_from_queue: AtomicBool::new(false),
         }
     }
-
-    #[inline]
-    fn notify(self: Pin<&mut Self>) -> bool {
-        let waker = unsafe { self.with_node(|node| node.waiter.take()) };
-        if let Some(waker) = waker {
-            waker.notify();
-            return true;
-        }
-        false
-    }
 }
 
 impl<T> Waiter<T> {
@@ -335,7 +321,7 @@ impl<T> List<T> {
         }
     }
 
-    fn pop_back(&mut self) -> Option<Pin<&mut Waiter<T>>> {
+    fn pop_back(&mut self) -> Option<T> {
         let mut last = self.tail?;
         test_println!("List::pop_back() -> {:p}", last);
 
@@ -352,8 +338,8 @@ impl<T> List<T> {
 
             self.tail = prev;
             last.take_next();
-
-            Some(Pin::new_unchecked(last))
+            last.was_woken_from_queue.store(true, Relaxed);
+            last.with_node(|node| node.waiter.take())
         }
     }
 
@@ -363,26 +349,22 @@ impl<T> List<T> {
         let next = node_ref.take_next();
         let ptr = NonNull::from(node_ref);
 
-        match prev {
-            Some(mut prev) => prev.as_mut().with_node(|prev| {
+        if let Some(mut prev) = prev {
+            prev.as_mut().with_node(|prev| {
                 debug_assert_eq!(prev.next, Some(ptr));
                 prev.next = next;
-            }),
-            None => {
-                debug_assert_eq!(self.head, Some(ptr));
-                self.head = next;
-            }
+            });
+        } else if self.head == Some(ptr) {
+            self.head = next;
         }
 
-        match next {
-            Some(mut next) => next.as_mut().with_node(|next| {
+        if let Some(mut next) = next {
+            next.as_mut().with_node(|next| {
                 debug_assert_eq!(next.prev, Some(ptr));
                 next.prev = prev;
-            }),
-            None => {
-                debug_assert_eq!(self.tail, Some(ptr));
-                self.tail = prev;
-            }
+            });
+        } else if self.tail == Some(ptr) {
+            self.tail = prev;
         }
     }
 }