Browse Source

perf(mspc): replace bad VecDeque wait queue with intrusive list (#16)

The new wait queue implementation doesn't allocate, and doesn't require
resizing a `VecDeque` inside a lock. This should improve performance and
(hopefully) make it possible to use the MPSC queue without any
allocations on no-std (once I figure out the static stuff).

Performance for blocking/yielding send is now very competitive after 
switching to the new wait queue, compare the new violin plot of the integer
MPSC benchmark with violin plot with the old wait queue that i posted
in #14.

Old'n'busted:
![image](https://user-images.githubusercontent.com/2796466/145252070-2cdda7bb-5cca-4b85-ab2d-082f3af5f990.png)

New hotness:
![image](https://user-images.githubusercontent.com/2796466/145482064-c0223f3c-feae-4f23-a1db-704a2c18eb16.png)

Big comparison benchmark:
![image](https://user-images.githubusercontent.com/2796466/145482883-a38f253c-17c8-4a0b-a798-67458c5fa27a.png)

Signed-off-by: Eliza Weisman <[email protected]>
Eliza Weisman 3 years ago
parent
commit
23f4c96fa4
14 changed files with 580 additions and 319 deletions
  1. 1 0
      Cargo.toml
  2. 2 0
      src/lib.rs
  3. 12 22
      src/mpsc.rs
  4. 103 8
      src/mpsc/async_impl.rs
  5. 16 1
      src/mpsc/sync.rs
  6. 9 7
      src/thingbuf/tests.rs
  7. 1 1
      src/util.rs
  8. 11 0
      src/util/mutex.rs
  9. 98 0
      src/util/mutex/spin_impl.rs
  10. 29 0
      src/util/mutex/std_impl.rs
  11. 0 248
      src/util/wait/wait_queue.rs
  12. 3 31
      src/wait.rs
  13. 2 1
      src/wait/cell.rs
  14. 293 0
      src/wait/queue.rs

+ 1 - 0
Cargo.toml

@@ -17,6 +17,7 @@ alloc = []
 default = ["std"]
 
 [dependencies]
+pin-project = "1"
 
 [dev-dependencies]
 tokio = { version = "1.14.0", features = ["rt", "rt-multi-thread", "macros", "sync"] }

+ 2 - 0
src/lib.rs

@@ -7,6 +7,7 @@ mod macros;
 
 mod loom;
 mod util;
+mod wait;
 
 feature! {
     #![feature = "alloc"]
@@ -266,6 +267,7 @@ impl Core {
                     return if test_dbg!(tail & self.closed != 0) {
                         Err(mpsc::TrySendError::Closed(()))
                     } else {
+                        test_println!("--> channel full!");
                         Err(mpsc::TrySendError::Full(()))
                     };
                 }

+ 12 - 22
src/mpsc.rs

@@ -12,18 +12,14 @@
 
 use crate::{
     loom::{atomic::AtomicUsize, hint},
-    util::{
-        wait::{Notify, WaitCell, WaitResult},
-        Backoff,
-    },
+    util::Backoff,
+    wait::{queue, Notify, WaitCell, WaitQueue, WaitResult},
     Ref, ThingBuf,
 };
 use core::fmt;
+use core::pin::Pin;
 use core::task::Poll;
 
-#[cfg(feature = "alloc")]
-use crate::util::wait::{NotifyOnDrop, WaitQueue};
-
 #[derive(Debug)]
 #[non_exhaustive]
 pub enum TrySendError<T = ()> {
@@ -39,8 +35,7 @@ struct Inner<T, N: Notify> {
     thingbuf: ThingBuf<T>,
     rx_wait: WaitCell<N>,
     tx_count: AtomicUsize,
-    #[cfg(feature = "alloc")]
-    tx_wait: WaitQueue<NotifyOnDrop<N>>,
+    tx_wait: WaitQueue<N>,
 }
 
 struct SendRefInner<'a, T, N: Notify> {
@@ -64,7 +59,7 @@ struct SendRefInner<'a, T, N: Notify> {
 }
 
 struct NotifyRx<'a, N: Notify>(&'a WaitCell<N>);
-struct NotifyTx<'a, N: Notify>(&'a WaitQueue<NotifyOnDrop<N>>);
+struct NotifyTx<'a, N: Notify + Unpin>(&'a WaitQueue<N>);
 
 // ==== impl TrySendError ===
 
@@ -78,13 +73,12 @@ impl TrySendError {
 }
 
 // ==== impl Inner ====
-impl<T, N: Notify> Inner<T, N> {
+impl<T, N: Notify + Unpin> Inner<T, N> {
     fn new(thingbuf: ThingBuf<T>) -> Self {
         Self {
             thingbuf,
             rx_wait: WaitCell::new(),
             tx_count: AtomicUsize::new(1),
-            #[cfg(feature = "alloc")]
             tx_wait: WaitQueue::new(),
         }
     }
@@ -93,12 +87,12 @@ impl<T, N: Notify> Inner<T, N> {
         if self.thingbuf.core.close() {
             crate::loom::hint::spin_loop();
             test_println!("draining_queue");
-            self.tx_wait.drain();
+            self.tx_wait.close();
         }
     }
 }
 
-impl<T: Default, N: Notify> Inner<T, N> {
+impl<T: Default, N: Notify + Unpin> Inner<T, N> {
     fn try_send_ref(&self) -> Result<SendRefInner<'_, T, N>, TrySendError> {
         self.thingbuf
             .core
@@ -126,7 +120,8 @@ impl<T: Default, N: Notify> Inner<T, N> {
     /// may yield, or might park the thread.
     fn poll_send_ref(
         &self,
-        mk_waiter: impl Fn() -> N,
+        mut node: Option<Pin<&mut queue::Waiter<N>>>,
+        mut register: impl FnMut(&mut Option<N>),
     ) -> Poll<Result<SendRefInner<'_, T, N>, Closed>> {
         let mut backoff = Backoff::new();
         // try to send a few times in a loop, in case the receiver notifies us
@@ -141,11 +136,7 @@ impl<T: Default, N: Notify> Inner<T, N> {
             }
 
             // try to push a waiter
-            let pushed_waiter = self.tx_wait.push_waiter(|| {
-                let current = mk_waiter();
-                test_println!("parking sender ({:?})", current);
-                NotifyOnDrop::new(current)
-            });
+            let pushed_waiter = self.tx_wait.push_waiter(&mut node, &mut register);
 
             match test_dbg!(pushed_waiter) {
                 WaitResult::TxClosed => {
@@ -196,7 +187,6 @@ impl<T: Default, N: Notify> Inner<T, N> {
                     // just in case someone sent a message while we were
                     // registering the waiter.
                     try_poll_recv!();
-                    test_dbg!(self.tx_wait.notify());
                     return Poll::Pending;
                 }
                 WaitResult::TxClosed => {
@@ -280,7 +270,7 @@ impl<N: Notify> Drop for NotifyRx<'_, N> {
     }
 }
 
-impl<N: Notify> Drop for NotifyTx<'_, N> {
+impl<N: Notify + Unpin> Drop for NotifyTx<'_, N> {
     #[inline]
     fn drop(&mut self) {
         test_println!("notifying tx ({})", core::any::type_name::<N>());

+ 103 - 8
src/mpsc/async_impl.rs

@@ -4,6 +4,7 @@ use crate::{
         atomic::{self, Ordering},
         sync::Arc,
     },
+    wait::queue,
     Ref, ThingBuf,
 };
 use core::{
@@ -73,22 +74,75 @@ impl<T: Default> Sender<T> {
     }
 
     pub async fn send_ref(&self) -> Result<SendRef<'_, T>, Closed> {
-        // This future is private because if we replace the waiter queue thing with an
-        // intrusive list, we won't want to expose the future type publicly, for safety reasons.
-        struct SendRefFuture<'sender, T>(&'sender Sender<T>);
+        #[pin_project::pin_project(PinnedDrop)]
+        struct SendRefFuture<'sender, T> {
+            tx: &'sender Sender<T>,
+            has_been_queued: bool,
+            #[pin]
+            waiter: queue::Waiter<Waker>,
+        }
+
         impl<'sender, T: Default + 'sender> Future for SendRefFuture<'sender, T> {
             type Output = Result<SendRef<'sender, T>, Closed>;
 
-            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+            fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
+                test_println!("SendRefFuture::poll({:p})", self);
                 // perform one send ref loop iteration
-                self.0
+
+                let this = self.as_mut().project();
+                let waiter = if test_dbg!(*this.has_been_queued) {
+                    None
+                } else {
+                    Some(this.waiter)
+                };
+                this.tx
                     .inner
-                    .poll_send_ref(|| cx.waker().clone())
-                    .map(|ok| ok.map(SendRef))
+                    .poll_send_ref(waiter, |waker| {
+                        // if this is called, we are definitely getting queued.
+                        *this.has_been_queued = true;
+
+                        // if the wait node does not already have a waker, or the task
+                        // has been polled with a waker that won't wake the previous
+                        // one, register a new waker.
+                        let my_waker = cx.waker();
+                        // do we need to re-register?
+                        let will_wake = waker
+                            .as_ref()
+                            .map(|waker| test_dbg!(waker.will_wake(my_waker)))
+                            .unwrap_or(false);
+
+                        if test_dbg!(will_wake) {
+                            return;
+                        }
+
+                        *waker = Some(my_waker.clone());
+                    })
+                    .map(|ok| {
+                        // avoid having to lock the list to remove a node that's
+                        // definitely not queued.
+                        *this.has_been_queued = false;
+                        ok.map(SendRef)
+                    })
             }
         }
 
-        SendRefFuture(self).await
+        #[pin_project::pinned_drop]
+        impl<T> PinnedDrop for SendRefFuture<'_, T> {
+            fn drop(self: Pin<&mut Self>) {
+                test_println!("SendRefFuture::drop({:p})", self);
+                if test_dbg!(self.has_been_queued) {
+                    let this = self.project();
+                    this.waiter.remove(&this.tx.inner.tx_wait)
+                }
+            }
+        }
+
+        SendRefFuture {
+            tx: self,
+            has_been_queued: false,
+            waiter: queue::Waiter::new(),
+        }
+        .await
     }
 
     pub async fn send(&self, val: T) -> Result<(), Closed<T>> {
@@ -205,3 +259,44 @@ impl<'a, T: Default> Future for RecvFuture<'a, T> {
         self.rx.poll_recv(cx)
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::ThingBuf;
+
+    fn _assert_sync<T: Sync>(_: T) {}
+    fn _assert_send<T: Send>(_: T) {}
+
+    #[test]
+    fn recv_ref_future_is_send() {
+        fn _compiles() {
+            let (_, rx) = channel::<usize>(ThingBuf::new(10));
+            _assert_send(rx.recv_ref());
+        }
+    }
+
+    #[test]
+    fn recv_ref_future_is_sync() {
+        fn _compiles() {
+            let (_, rx) = channel::<usize>(ThingBuf::new(10));
+            _assert_sync(rx.recv_ref());
+        }
+    }
+
+    #[test]
+    fn send_ref_future_is_send() {
+        fn _compiles() {
+            let (tx, _) = channel::<usize>(ThingBuf::new(10));
+            _assert_send(tx.send_ref());
+        }
+    }
+
+    #[test]
+    fn send_ref_future_is_sync() {
+        fn _compiles() {
+            let (tx, _) = channel::<usize>(ThingBuf::new(10));
+            _assert_sync(tx.send_ref());
+        }
+    }
+}

+ 16 - 1
src/mpsc/sync.rs

@@ -10,6 +10,7 @@ use crate::{
         sync::Arc,
         thread::{self, Thread},
     },
+    wait::queue,
     Ref, ThingBuf,
 };
 use core::fmt;
@@ -54,9 +55,23 @@ impl<T: Default> Sender<T> {
     }
 
     pub fn send_ref(&self) -> Result<SendRef<'_, T>, Closed> {
+        let mut waiter = queue::Waiter::new();
         loop {
             // perform one send ref loop iteration
-            if let Poll::Ready(result) = self.inner.poll_send_ref(thread::current) {
+
+            let waiter = unsafe {
+                // Safety: in this case, it's totally safe to pin the waiter, as
+                // it is owned uniquely by this function, and it cannot possibly
+                // be moved while this thread is parked.
+                Pin::new_unchecked(&mut waiter)
+            };
+            if let Poll::Ready(result) = self.inner.poll_send_ref(Some(waiter), |thread| {
+                if thread.is_none() {
+                    let current = thread::current();
+                    test_println!("registering {:?}", current);
+                    *thread = Some(current);
+                }
+            }) {
                 return result.map(SendRef);
             }
 

+ 9 - 7
src/thingbuf/tests.rs

@@ -97,13 +97,15 @@ fn linearizable() {
     fn thread(i: usize, q: &Arc<ThingBuf<usize>>) -> impl FnOnce() {
         let q = q.clone();
         move || {
-            while q
-                .push_ref()
-                .map(|mut val| {
-                    *val = i;
-                })
-                .is_err()
-            {}
+            let mut pushed = false;
+            while !pushed {
+                pushed = q
+                    .push_ref()
+                    .map(|mut val| {
+                        *val = i;
+                    })
+                    .is_ok();
+            }
 
             if let Some(mut val) = q.pop_ref() {
                 *val = 0;

+ 1 - 1
src/util.rs

@@ -4,8 +4,8 @@ use core::{
     ops::{Deref, DerefMut},
 };
 
+pub(crate) mod mutex;
 pub(crate) mod panic;
-pub(crate) mod wait;
 
 #[derive(Debug)]
 pub(crate) struct Backoff(u8);

+ 11 - 0
src/util/mutex.rs

@@ -0,0 +1,11 @@
+feature! {
+    #![feature = "std"]
+    pub(crate) use self::std_impl::*;
+    mod std_impl;
+}
+
+#[cfg(not(feature = "std"))]
+pub(crate) use self::spin_impl::*;
+
+#[cfg(any(not(feature = "std"), test))]
+mod spin_impl;

+ 98 - 0
src/util/mutex/spin_impl.rs

@@ -0,0 +1,98 @@
+#![cfg_attr(feature = "std", allow(dead_code))]
+use crate::{
+    loom::{
+        atomic::{AtomicBool, Ordering::*},
+        cell::{MutPtr, UnsafeCell},
+    },
+    util::Backoff,
+};
+use core::{fmt, ops};
+
+#[derive(Debug)]
+pub(crate) struct Mutex<T> {
+    locked: AtomicBool,
+    data: UnsafeCell<T>,
+}
+
+pub(crate) struct MutexGuard<'lock, T> {
+    locked: &'lock AtomicBool,
+    data: MutPtr<T>,
+}
+
+impl<T> Mutex<T> {
+    pub(crate) fn new(data: T) -> Self {
+        Self {
+            locked: AtomicBool::new(false),
+            data: UnsafeCell::new(data),
+        }
+    }
+
+    #[inline]
+    pub(crate) fn lock(&self) -> MutexGuard<'_, T> {
+        test_println!("locking {}...", core::any::type_name::<T>());
+        let mut backoff = Backoff::new();
+        while test_dbg!(self.locked.compare_exchange(false, true, AcqRel, Acquire)).is_err() {
+            while self.locked.load(Relaxed) {
+                backoff.spin_yield();
+            }
+        }
+
+        test_println!("-> locked {}!", core::any::type_name::<T>());
+        MutexGuard {
+            locked: &self.locked,
+            data: self.data.get_mut(),
+        }
+    }
+}
+
+impl<T> MutexGuard<'_, T> {
+    // this is factored out into its own function so that the debug impl is a
+    // little easier lol
+    #[inline]
+    fn get_ref(&self) -> &T {
+        unsafe {
+            // Safety: the mutex is locked, so we cannot create a concurrent
+            // mutable access, and we have a borrow on the lock's state boolean,
+            // so it will not be dropped while the guard exists.
+            &*self.data.deref()
+        }
+    }
+}
+
+impl<T> ops::Deref for MutexGuard<'_, T> {
+    type Target = T;
+    #[inline]
+    fn deref(&self) -> &T {
+        self.get_ref()
+    }
+}
+
+impl<T> ops::DerefMut for MutexGuard<'_, T> {
+    #[inline]
+    fn deref_mut(&mut self) -> &mut T {
+        unsafe {
+            // Safety: the mutex is locked, so we cannot create a concurrent
+            // mutable access, and we have a borrow on the lock's state boolean,
+            // so it will not be dropped while the guard exists.
+            &mut *self.data.deref()
+        }
+    }
+}
+
+impl<T> Drop for MutexGuard<'_, T> {
+    fn drop(&mut self) {
+        test_dbg!(self.locked.store(false, Release));
+        test_println!("unlocked!");
+    }
+}
+
+impl<T: fmt::Debug> fmt::Debug for MutexGuard<'_, T> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        self.get_ref().fmt(f)
+    }
+}
+
+unsafe impl<T: Send> Send for Mutex<T> {}
+unsafe impl<T: Send> Sync for Mutex<T> {}
+unsafe impl<T: Send> Send for MutexGuard<'_, T> {}
+unsafe impl<T: Send> Sync for MutexGuard<'_, T> {}

+ 29 - 0
src/util/mutex/std_impl.rs

@@ -0,0 +1,29 @@
+#[cfg(all(test, loom))]
+use crate::loom::sync::Mutex as Inner;
+#[cfg(all(test, loom))]
+pub(crate) use crate::loom::sync::MutexGuard;
+
+#[cfg(not(all(test, loom)))]
+use std::sync::Mutex as Inner;
+
+#[cfg(not(all(test, loom)))]
+pub(crate) use std::sync::MutexGuard;
+
+use std::sync::PoisonError;
+
+#[derive(Debug)]
+pub(crate) struct Mutex<T>(Inner<T>);
+
+impl<T> Mutex<T> {
+    pub(crate) fn new(data: T) -> Self {
+        Self(Inner::new(data))
+    }
+
+    #[inline]
+    pub(crate) fn lock(&self) -> MutexGuard<'_, T> {
+        test_println!("locking {}...", core::any::type_name::<T>());
+        let lock = self.0.lock().unwrap_or_else(PoisonError::into_inner);
+        test_println!("-> locked {}!", core::any::type_name::<T>());
+        lock
+    }
+}

+ 0 - 248
src/util/wait/wait_queue.rs

@@ -1,248 +0,0 @@
-use super::{Notify, WaitResult};
-use crate::{
-    loom::{
-        atomic::{
-            AtomicUsize,
-            Ordering::{self, *},
-        },
-        cell::UnsafeCell,
-    },
-    util::{panic, Backoff, CachePadded},
-};
-use alloc::collections::VecDeque;
-use core::fmt;
-
-/// A mediocre wait queue, implemented as a spinlock around a `VecDeque` of
-/// waiters.
-// TODO(eliza): this can almost certainly be replaced with an intrusive list of
-// some kind, but crossbeam uses a spinlock + vec, so it's _probably_ fine...
-// XXX(eliza): the biggest downside of this is that it can't be used without
-// `liballoc`, which is sad for `no-std` async-await users...
-pub(crate) struct WaitQueue<T> {
-    locked: CachePadded<AtomicUsize>,
-    queue: UnsafeCell<VecDeque<T>>,
-}
-
-pub(crate) struct Locked<'a, T> {
-    queue: &'a WaitQueue<T>,
-    state: State,
-}
-
-#[derive(Copy, Clone)]
-struct State(usize);
-
-impl<T> WaitQueue<T> {
-    pub(crate) fn new() -> Self {
-        Self {
-            locked: CachePadded(AtomicUsize::new(State::UNLOCKED.0 | State::EMPTY.0)),
-            queue: UnsafeCell::new(VecDeque::new()),
-        }
-    }
-
-    fn compare_exchange_weak(
-        &self,
-        curr: State,
-        next: State,
-        success: Ordering,
-        failure: Ordering,
-    ) -> Result<State, State> {
-        let res = self
-            .locked
-            .compare_exchange_weak(curr.0, next.0, success, failure)
-            .map(State)
-            .map_err(State);
-        test_println!(
-            "self.state.compare_exchange_weak({:?}, {:?}, {:?}, {:?}) = {:?}",
-            curr,
-            next,
-            success,
-            failure,
-            res
-        );
-        res
-    }
-
-    fn fetch_clear(&self, state: State, order: Ordering) -> State {
-        let res = State(self.locked.fetch_and(!state.0, order));
-        test_println!(
-            "self.state.fetch_clear({:?}, {:?}) = {:?}",
-            state,
-            order,
-            res
-        );
-        res
-    }
-
-    fn lock(&self) -> Result<Locked<'_, T>, State> {
-        let mut backoff = Backoff::new();
-        let mut state = State(self.locked.load(Ordering::Relaxed));
-        loop {
-            test_dbg!(&state);
-            if test_dbg!(state.contains(State::CLOSED)) {
-                return Err(state);
-            }
-
-            if !test_dbg!(state.contains(State::LOCKED)) {
-                match self.compare_exchange_weak(
-                    state,
-                    State(state.0 | State::LOCKED.0),
-                    AcqRel,
-                    Acquire,
-                ) {
-                    Ok(_) => return Ok(Locked { queue: self, state }),
-                    Err(actual) => {
-                        state = actual;
-                        backoff.spin();
-                    }
-                }
-            } else {
-                state = State(self.locked.load(Ordering::Relaxed));
-                backoff.spin_yield();
-            }
-        }
-    }
-
-    pub(crate) fn push_waiter(&self, mk_waiter: impl FnOnce() -> T) -> WaitResult {
-        if let Ok(mut lock) = self.lock() {
-            if lock.state.queued() > 0 {
-                lock.state = lock.state.sub_queued();
-                return WaitResult::Notified;
-            }
-            lock.queue.queue.with_mut(|q| unsafe {
-                (*q).push_back(mk_waiter());
-            });
-            WaitResult::Wait
-        } else {
-            WaitResult::TxClosed
-        }
-    }
-
-    pub(crate) fn drain(&self) {
-        if let Ok(lock) = self.lock() {
-            // if test_dbg!(lock.state.contains(State::EMPTY)) {
-            //     return;
-            // }
-            lock.queue.queue.with_mut(|q| {
-                let waiters = unsafe { (*q).drain(..) };
-                for waiter in waiters {
-                    drop(waiter);
-                }
-            })
-        }
-    }
-}
-
-impl<T: Notify> WaitQueue<T> {
-    pub(crate) fn notify(&self) -> bool {
-        test_println!("notifying tx");
-
-        if let Ok(mut lock) = self.lock() {
-            return lock.notify();
-        }
-
-        false
-    }
-}
-
-impl<T> Drop for WaitQueue<T> {
-    fn drop(&mut self) {
-        self.drain();
-    }
-}
-
-impl<T> fmt::Debug for WaitQueue<T> {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.write_str("WaitQueue(..)")
-    }
-}
-
-impl<T: Notify> Locked<'_, T> {
-    fn notify(&mut self) -> bool {
-        // if test_dbg!(self.state.contains(State::EMPTY)) {
-        //     self.state = self.state.add_queued();
-        //     return false;
-        // }
-        self.queue.queue.with_mut(|q| {
-            let q = unsafe { &mut *q };
-            if let Some(waiter) = q.pop_front() {
-                waiter.notify();
-                if q.is_empty() {
-                    self.queue.fetch_clear(State::EMPTY, Release);
-                }
-                true
-            } else {
-                self.state = self.state.add_queued();
-                false
-            }
-        })
-    }
-
-    // TODO(eliza): future cancellation nonsense...
-    #[allow(dead_code)]
-    pub(crate) fn remove(&mut self, i: usize) -> Option<T> {
-        self.queue.queue.with_mut(|q| unsafe { (*q).remove(i) })
-    }
-}
-
-impl<T> Drop for Locked<'_, T> {
-    fn drop(&mut self) {
-        test_dbg!(State(self.queue.locked.swap(self.state.0, Release)));
-    }
-}
-
-impl<T: panic::UnwindSafe> panic::UnwindSafe for WaitQueue<T> {}
-impl<T: panic::RefUnwindSafe> panic::RefUnwindSafe for WaitQueue<T> {}
-unsafe impl<T: Send> Send for WaitQueue<T> {}
-unsafe impl<T: Send> Sync for WaitQueue<T> {}
-
-// === impl State ===
-
-impl State {
-    const UNLOCKED: Self = Self(0b00);
-    const LOCKED: Self = Self(0b01);
-    const EMPTY: Self = Self(0b10);
-    const CLOSED: Self = Self(0b100);
-
-    const FLAG_BITS: usize = Self::LOCKED.0 | Self::EMPTY.0 | Self::CLOSED.0;
-    const QUEUED_SHIFT: usize = Self::FLAG_BITS.trailing_ones() as usize;
-    const QUEUED_ONE: usize = 1 << Self::QUEUED_SHIFT;
-
-    fn queued(self) -> usize {
-        self.0 >> Self::QUEUED_SHIFT
-    }
-
-    fn add_queued(self) -> Self {
-        Self(self.0 + Self::QUEUED_ONE)
-    }
-
-    fn contains(self, Self(state): Self) -> bool {
-        self.0 & state == state
-    }
-
-    fn sub_queued(self) -> Self {
-        let flags = self.0 & Self::FLAG_BITS;
-        Self(self.0 & (!Self::FLAG_BITS).saturating_sub(Self::QUEUED_ONE) | flags)
-    }
-}
-
-impl fmt::Debug for State {
-    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
-        f.write_str("State(")?;
-        let mut has_flags = false;
-
-        fmt_bits!(self, f, has_flags, LOCKED, EMPTY, CLOSED);
-
-        if !has_flags {
-            f.write_str("UNLOCKED")?;
-        }
-
-        let queued = self.queued();
-        if queued > 0 {
-            write!(f, ", queued: {})", queued)?;
-        } else {
-            f.write_str(")")?;
-        }
-
-        Ok(())
-    }
-}

+ 3 - 31
src/util/wait.rs → src/wait.rs

@@ -1,14 +1,9 @@
 use crate::util::panic::UnwindSafe;
 use core::{fmt, task::Waker};
 
-mod wait_cell;
-pub(crate) use self::wait_cell::WaitCell;
-
-feature! {
-    #![feature = "alloc"]
-    pub(crate) mod wait_queue;
-    pub(crate) use self::wait_queue::WaitQueue;
-}
+mod cell;
+pub(crate) mod queue;
+pub(crate) use self::{cell::WaitCell, queue::WaitQueue};
 
 #[cfg(feature = "std")]
 use crate::loom::thread;
@@ -20,9 +15,6 @@ pub(crate) enum WaitResult {
     Notified,
 }
 
-#[derive(Debug)]
-pub(crate) struct NotifyOnDrop<T: Notify>(Option<T>);
-
 pub(crate) trait Notify: UnwindSafe + fmt::Debug {
     fn notify(self);
 }
@@ -41,23 +33,3 @@ impl Notify for Waker {
         self.wake();
     }
 }
-
-impl<T: Notify> NotifyOnDrop<T> {
-    pub(crate) fn new(notify: T) -> Self {
-        Self(Some(notify))
-    }
-}
-
-impl<T: Notify> Notify for NotifyOnDrop<T> {
-    fn notify(self) {
-        drop(self)
-    }
-}
-
-impl<T: Notify> Drop for NotifyOnDrop<T> {
-    fn drop(&mut self) {
-        if let Some(notify) = self.0.take() {
-            notify.notify();
-        }
-    }
-}

+ 2 - 1
src/util/wait/wait_cell.rs → src/wait/cell.rs

@@ -1,4 +1,3 @@
-use super::{Notify, WaitResult};
 use crate::{
     loom::{
         atomic::{
@@ -8,6 +7,7 @@ use crate::{
         cell::UnsafeCell,
     },
     util::panic::{self, RefUnwindSafe, UnwindSafe},
+    wait::{Notify, WaitResult},
 };
 use core::{fmt, ops};
 
@@ -37,6 +37,7 @@ pub(crate) struct WaitCell<T> {
 struct State(usize);
 
 // === impl WaitCell ===
+
 impl<T> WaitCell<T> {
     #[cfg(not(all(loom, test)))]
     pub(crate) const fn new() -> Self {

+ 293 - 0
src/wait/queue.rs

@@ -0,0 +1,293 @@
+use crate::{
+    loom::{
+        atomic::{AtomicUsize, Ordering::*},
+        cell::UnsafeCell,
+    },
+    util::{mutex::Mutex, CachePadded},
+    wait::{Notify, WaitResult},
+};
+
+use core::{fmt, marker::PhantomPinned, pin::Pin, ptr::NonNull};
+
+#[derive(Debug)]
+pub(crate) struct WaitQueue<T> {
+    state: CachePadded<AtomicUsize>,
+    list: Mutex<List<T>>,
+}
+
+#[derive(Debug)]
+pub(crate) struct Waiter<T> {
+    node: UnsafeCell<Node<T>>,
+}
+
+#[derive(Debug)]
+#[pin_project::pin_project]
+struct Node<T> {
+    next: Link<Waiter<T>>,
+    prev: Link<Waiter<T>>,
+    waiter: Option<T>,
+
+    // This type is !Unpin due to the heuristic from:
+    // <https://github.com/rust-lang/rust/pull/82834>
+    #[pin]
+    _pin: PhantomPinned,
+}
+
+type Link<T> = Option<NonNull<T>>;
+
+struct List<T> {
+    head: Link<Waiter<T>>,
+    tail: Link<Waiter<T>>,
+}
+
+const CLOSED: usize = 1 << 0;
+const ONE_QUEUED: usize = 1 << 1;
+
+impl<T: Notify + Unpin> WaitQueue<T> {
+    pub(crate) fn new() -> Self {
+        Self {
+            state: CachePadded(AtomicUsize::new(0)),
+            list: Mutex::new(List::new()),
+        }
+    }
+
+    pub(crate) fn push_waiter(
+        &self,
+        waiter: &mut Option<Pin<&mut Waiter<T>>>,
+        register: impl FnOnce(&mut Option<T>),
+    ) -> WaitResult {
+        test_println!("WaitQueue::push_waiter()");
+        let mut state = test_dbg!(self.state.load(Acquire));
+        if test_dbg!(state & CLOSED != 0) {
+            return WaitResult::TxClosed;
+        }
+
+        if test_dbg!(waiter.is_some()) {
+            while test_dbg!(state > CLOSED) {
+                match test_dbg!(self.state.compare_exchange_weak(
+                    state,
+                    state.saturating_sub(ONE_QUEUED),
+                    AcqRel,
+                    Acquire
+                )) {
+                    Ok(_) => return WaitResult::Notified,
+                    Err(actual) => state = test_dbg!(actual),
+                }
+            }
+
+            if test_dbg!(state & CLOSED != 0) {
+                return WaitResult::TxClosed;
+            }
+
+            let mut list = self.list.lock();
+            // Reload the state inside the lock.
+            state = test_dbg!(self.state.load(Acquire));
+            while test_dbg!(state >= ONE_QUEUED) {
+                match test_dbg!(self.state.compare_exchange(
+                    state,
+                    state.saturating_sub(ONE_QUEUED),
+                    AcqRel,
+                    Acquire
+                )) {
+                    Ok(_) => return WaitResult::Notified,
+                    Err(actual) => state = actual,
+                }
+            }
+
+            if let Some(waiter) = waiter.take() {
+                test_println!("WaitQueue::push_waiter -> pushing {:p}", waiter);
+
+                unsafe {
+                    // Safety: the waker can only be registered while holding
+                    // the wait queue lock. We are holding the lock, so no one
+                    // else will try to touch the waker until we're done.
+                    waiter.with_node(|node| register(&mut node.waiter));
+                }
+                list.push_front(waiter);
+            } else {
+                unreachable!("this could be unchecked...")
+            }
+        }
+
+        WaitResult::Wait
+    }
+
+    pub(crate) fn notify(&self) -> bool {
+        test_println!("WaitQueue::notify()");
+        if let Some(node) = test_dbg!(self.list.lock().pop_back()) {
+            node.notify();
+            true
+        } else {
+            test_dbg!(self.state.fetch_add(ONE_QUEUED, Release));
+            false
+        }
+    }
+
+    pub(crate) fn close(&self) {
+        test_println!("WaitQueue::close()");
+        test_dbg!(self.state.fetch_or(CLOSED, Release));
+        let mut list = self.list.lock();
+        while let Some(node) = list.pop_back() {
+            node.notify();
+        }
+    }
+}
+
+// === impl Waiter ===
+
+impl<T: Notify> Waiter<T> {
+    pub(crate) fn new() -> Self {
+        Self {
+            node: UnsafeCell::new(Node {
+                next: None,
+                prev: None,
+                waiter: None,
+                _pin: PhantomPinned,
+            }),
+        }
+    }
+
+    #[inline]
+    fn notify(self: Pin<&mut Self>) -> bool {
+        let waker = unsafe { self.with_node(|node| node.waiter.take()) };
+        if let Some(waker) = waker {
+            waker.notify();
+            return true;
+        }
+        false
+    }
+}
+
+impl<T> Waiter<T> {
+    unsafe fn with_node<U>(&self, f: impl FnOnce(&mut Node<T>) -> U) -> U {
+        self.node.with_mut(|node| f(&mut *node))
+    }
+
+    unsafe fn set_prev(&mut self, prev: Option<NonNull<Waiter<T>>>) {
+        self.node.with_mut(|node| (*node).prev = prev);
+    }
+
+    // unsafe fn set_next(&mut self, next: Option<NonNull<Waiter<T>>>) {
+    //     self.node.with_mut(|node| (*node).next = next);
+    // }
+
+    unsafe fn take_prev(&mut self) -> Option<NonNull<Waiter<T>>> {
+        self.node.with_mut(|node| (*node).prev.take())
+    }
+
+    unsafe fn take_next(&mut self) -> Option<NonNull<Waiter<T>>> {
+        self.node.with_mut(|node| (*node).next.take())
+    }
+}
+
+impl<T: Notify> Waiter<T> {
+    pub(crate) fn remove(self: Pin<&mut Self>, q: &WaitQueue<T>) {
+        test_println!("Waiter::remove({:p})", self);
+        unsafe {
+            // Safety: removing a node is unsafe even when the list is locked,
+            // because there's no way to guarantee that the node is part of
+            // *this* list. However, the potential callers of this method will
+            // never have access to any other linked lists, so we can just kind
+            // of assume that this is safe.
+            q.list.lock().remove(self);
+        }
+    }
+}
+
+unsafe impl<T: Send> Send for Waiter<T> {}
+unsafe impl<T: Send> Sync for Waiter<T> {}
+
+// === impl List ===
+
+impl<T> List<T> {
+    fn new() -> Self {
+        Self {
+            head: None,
+            tail: None,
+        }
+    }
+
+    fn push_front(&mut self, waiter: Pin<&mut Waiter<T>>) {
+        unsafe {
+            waiter.with_node(|node| {
+                node.next = self.head;
+                node.prev = None;
+            })
+        }
+
+        let ptr = unsafe { NonNull::from(Pin::into_inner_unchecked(waiter)) };
+
+        debug_assert_ne!(self.head, Some(ptr), "tried to push the same waiter twice!");
+        if let Some(mut head) = self.head.replace(ptr) {
+            unsafe {
+                head.as_mut().set_prev(Some(ptr));
+            }
+        }
+
+        if self.tail.is_none() {
+            self.tail = Some(ptr);
+        }
+    }
+
+    fn pop_back(&mut self) -> Option<Pin<&mut Waiter<T>>> {
+        let mut last = self.tail?;
+        test_println!("List::pop_back() -> {:p}", last);
+
+        unsafe {
+            let last = last.as_mut();
+            let prev = last.take_prev();
+
+            match prev {
+                Some(mut prev) => {
+                    let _ = prev.as_mut().take_next();
+                }
+                None => self.head = None,
+            }
+
+            self.tail = prev;
+            last.take_next();
+
+            Some(Pin::new_unchecked(last))
+        }
+    }
+
+    unsafe fn remove(&mut self, node: Pin<&mut Waiter<T>>) {
+        let node_ref = node.get_unchecked_mut();
+        let prev = node_ref.take_prev();
+        let next = node_ref.take_next();
+        let ptr = NonNull::from(node_ref);
+
+        match prev {
+            Some(mut prev) => prev.as_mut().with_node(|prev| {
+                debug_assert_eq!(prev.next, Some(ptr));
+                prev.next = next;
+            }),
+            None => {
+                debug_assert_eq!(self.head, Some(ptr));
+                self.head = next;
+            }
+        }
+
+        match next {
+            Some(mut next) => next.as_mut().with_node(|next| {
+                debug_assert_eq!(next.prev, Some(ptr));
+                next.prev = prev;
+            }),
+            None => {
+                debug_assert_eq!(self.tail, Some(ptr));
+                self.tail = prev;
+            }
+        }
+    }
+}
+
+impl<T> fmt::Debug for List<T> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.debug_struct("List")
+            .field("head", &self.head)
+            .field("tail", &self.tail)
+            .finish()
+    }
+}
+
+unsafe impl<T: Send> Send for List<T> {}