|
@@ -1,11 +1,14 @@
|
|
|
use crate::{
|
|
|
loom::{
|
|
|
- atomic::{AtomicUsize, Ordering::*},
|
|
|
+ atomic::{
|
|
|
+ AtomicUsize,
|
|
|
+ Ordering::{self, *},
|
|
|
+ },
|
|
|
UnsafeCell,
|
|
|
},
|
|
|
util::panic::{self, RefUnwindSafe, UnwindSafe},
|
|
|
};
|
|
|
-use core::{fmt, task::Waker};
|
|
|
+use core::{fmt, ops, task::Waker};
|
|
|
|
|
|
#[cfg(feature = "std")]
|
|
|
use crate::loom::thread;
|
|
@@ -43,52 +46,46 @@ pub(crate) trait Notify {
|
|
|
fn notify(self);
|
|
|
}
|
|
|
|
|
|
+#[derive(Eq, PartialEq, Copy, Clone)]
|
|
|
+struct State(usize);
|
|
|
+
|
|
|
// === impl WaitCell ===
|
|
|
|
|
|
impl<T: Notify + UnwindSafe + fmt::Debug> WaitCell<T> {
|
|
|
- const WAITING: usize = 0b00;
|
|
|
- const PARKING: usize = 0b01;
|
|
|
- const NOTIFYING: usize = 0b10;
|
|
|
- const TX_CLOSED: usize = 0b100;
|
|
|
- const RX_CLOSED: usize = 0b1000;
|
|
|
-
|
|
|
pub(crate) fn new() -> Self {
|
|
|
Self {
|
|
|
- lock: AtomicUsize::new(Self::WAITING),
|
|
|
+ lock: AtomicUsize::new(State::WAITING.0),
|
|
|
waiter: UnsafeCell::new(None),
|
|
|
}
|
|
|
}
|
|
|
|
|
|
pub(crate) fn close_rx(&self) {
|
|
|
- self.lock.fetch_or(Self::RX_CLOSED, AcqRel);
|
|
|
+ test_dbg!(self.fetch_or(State::RX_CLOSED, AcqRel));
|
|
|
}
|
|
|
|
|
|
pub(crate) fn is_rx_closed(&self) -> bool {
|
|
|
- test_dbg!(self.lock.load(Acquire) & Self::RX_CLOSED == Self::RX_CLOSED)
|
|
|
+ test_dbg!(self.current_state().contains(State::RX_CLOSED))
|
|
|
}
|
|
|
|
|
|
pub(crate) fn wait_with(&self, f: impl FnOnce() -> T) -> WaitResult {
|
|
|
test_println!("registering waiter");
|
|
|
|
|
|
// this is based on tokio's AtomicWaker synchronization strategy
|
|
|
- match test_dbg!(self
|
|
|
- .lock
|
|
|
- .compare_exchange(Self::WAITING, Self::PARKING, AcqRel, Acquire,))
|
|
|
- {
|
|
|
+ match test_dbg!(self.compare_exchange(State::WAITING, State::PARKING)) {
|
|
|
// someone else is notifying the receiver, so don't park!
|
|
|
- Err(actual) if test_dbg!(actual & Self::TX_CLOSED) == Self::TX_CLOSED => {
|
|
|
- test_println!("-> state = TX_CLOSED");
|
|
|
+ Err(actual) if test_dbg!(actual.contains(State::TX_CLOSED)) => {
|
|
|
return WaitResult::TxClosed;
|
|
|
}
|
|
|
- Err(actual) if test_dbg!(actual & Self::NOTIFYING) == Self::NOTIFYING => {
|
|
|
- test_println!("-> state = NOTIFYING");
|
|
|
+ Err(actual) if test_dbg!(actual.contains(State::NOTIFYING)) => {
|
|
|
// f().notify();
|
|
|
// loom::hint::spin_loop();
|
|
|
return WaitResult::Notified;
|
|
|
}
|
|
|
|
|
|
Err(actual) => {
|
|
|
- debug_assert!(actual == Self::PARKING || actual == Self::PARKING | Self::NOTIFYING);
|
|
|
+ debug_assert!(
|
|
|
+ actual == State::PARKING || actual == State::PARKING | State::NOTIFYING
|
|
|
+ );
|
|
|
return WaitResult::Wait;
|
|
|
}
|
|
|
Ok(_) => {}
|
|
@@ -106,29 +103,22 @@ impl<T: Notify + UnwindSafe + fmt::Debug> WaitCell<T> {
|
|
|
Err(panic) => (Some(panic), None),
|
|
|
};
|
|
|
|
|
|
- let result = match test_dbg!(self.lock.compare_exchange(
|
|
|
- Self::PARKING,
|
|
|
- Self::WAITING,
|
|
|
- AcqRel,
|
|
|
- Acquire
|
|
|
- )) {
|
|
|
+ let result = match test_dbg!(self.compare_exchange(State::PARKING, State::WAITING)) {
|
|
|
Ok(_) => {
|
|
|
let _ = panic::catch_unwind(move || drop(prev_waiter));
|
|
|
|
|
|
WaitResult::Wait
|
|
|
}
|
|
|
Err(actual) => {
|
|
|
- test_println!("-> was notified; state={:#b}", actual);
|
|
|
+ test_println!("-> was notified; state={:?}", actual);
|
|
|
let waiter = self.waiter.with_mut(|waiter| unsafe { (*waiter).take() });
|
|
|
// Reset to the WAITING state by clearing everything *except*
|
|
|
// the closed bits (which must remain set).
|
|
|
- let state = test_dbg!(self
|
|
|
- .lock
|
|
|
- .fetch_and(Self::TX_CLOSED | Self::RX_CLOSED, AcqRel));
|
|
|
+ let state = test_dbg!(self.fetch_and(State::TX_CLOSED | State::RX_CLOSED, AcqRel));
|
|
|
// The only valid state transition while we were parking is to
|
|
|
// add the TX_CLOSED bit.
|
|
|
debug_assert!(
|
|
|
- state == actual || state == actual | Self::TX_CLOSED,
|
|
|
+ state == actual || state == actual | State::TX_CLOSED,
|
|
|
"state changed unexpectedly while parking!"
|
|
|
);
|
|
|
|
|
@@ -143,7 +133,7 @@ impl<T: Notify + UnwindSafe + fmt::Debug> WaitCell<T> {
|
|
|
waiter.notify();
|
|
|
}
|
|
|
|
|
|
- if state & Self::TX_CLOSED == Self::TX_CLOSED {
|
|
|
+ if test_dbg!(state.contains(State::TX_CLOSED)) {
|
|
|
WaitResult::TxClosed
|
|
|
} else {
|
|
|
WaitResult::Notified
|
|
@@ -169,15 +159,16 @@ impl<T: Notify + UnwindSafe + fmt::Debug> WaitCell<T> {
|
|
|
fn notify2(&self, close: bool) {
|
|
|
test_println!("notifying; close={:?};", close);
|
|
|
let bits = if close {
|
|
|
- Self::NOTIFYING | Self::TX_CLOSED
|
|
|
+ State::NOTIFYING | State::TX_CLOSED
|
|
|
} else {
|
|
|
- Self::NOTIFYING
|
|
|
+ State::NOTIFYING
|
|
|
};
|
|
|
- if test_dbg!(self.lock.fetch_or(bits, AcqRel)) == Self::WAITING {
|
|
|
+ test_dbg!(bits);
|
|
|
+ if test_dbg!(self.fetch_or(bits, AcqRel)) == State::WAITING {
|
|
|
// we have the lock!
|
|
|
let waiter = self.waiter.with_mut(|thread| unsafe { (*thread).take() });
|
|
|
|
|
|
- self.lock.fetch_and(!Self::NOTIFYING, Release);
|
|
|
+ test_dbg!(self.fetch_and(!State::NOTIFYING, AcqRel));
|
|
|
|
|
|
if let Some(waiter) = test_dbg!(waiter) {
|
|
|
waiter.notify();
|
|
@@ -186,6 +177,31 @@ impl<T: Notify + UnwindSafe + fmt::Debug> WaitCell<T> {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+impl<T> WaitCell<T> {
|
|
|
+ #[inline(always)]
|
|
|
+ fn compare_exchange(&self, State(curr): State, State(new): State) -> Result<State, State> {
|
|
|
+ self.lock
|
|
|
+ .compare_exchange(curr, new, AcqRel, Acquire)
|
|
|
+ .map(State)
|
|
|
+ .map_err(State)
|
|
|
+ }
|
|
|
+
|
|
|
+ #[inline(always)]
|
|
|
+ fn fetch_and(&self, State(state): State, order: Ordering) -> State {
|
|
|
+ State(self.lock.fetch_and(state, order))
|
|
|
+ }
|
|
|
+
|
|
|
+ #[inline(always)]
|
|
|
+ fn fetch_or(&self, State(state): State, order: Ordering) -> State {
|
|
|
+ State(self.lock.fetch_or(state, order))
|
|
|
+ }
|
|
|
+
|
|
|
+ #[inline(always)]
|
|
|
+ fn current_state(&self) -> State {
|
|
|
+ State(self.lock.load(Acquire))
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(feature = "std")]
|
|
|
impl Notify for thread::Thread {
|
|
|
fn notify(self) {
|
|
@@ -206,6 +222,78 @@ impl<T: RefUnwindSafe> RefUnwindSafe for WaitCell<T> {}
|
|
|
unsafe impl<T: Send> Send for WaitCell<T> {}
|
|
|
unsafe impl<T: Send> Sync for WaitCell<T> {}
|
|
|
|
|
|
+impl<T> fmt::Debug for WaitCell<T> {
|
|
|
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
|
+ f.debug_struct("WaitCell")
|
|
|
+ .field("state", &self.current_state())
|
|
|
+ .finish()
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+// === impl State ===
|
|
|
+
|
|
|
+impl State {
|
|
|
+ const WAITING: Self = Self(0b00);
|
|
|
+ const PARKING: Self = Self(0b01);
|
|
|
+ const NOTIFYING: Self = Self(0b10);
|
|
|
+ const TX_CLOSED: Self = Self(0b100);
|
|
|
+ const RX_CLOSED: Self = Self(0b1000);
|
|
|
+
|
|
|
+ fn contains(self, Self(state): Self) -> bool {
|
|
|
+ self.0 & state == state
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl ops::BitOr for State {
|
|
|
+ type Output = Self;
|
|
|
+
|
|
|
+ fn bitor(self, Self(rhs): Self) -> Self::Output {
|
|
|
+ Self(self.0 | rhs)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl ops::Not for State {
|
|
|
+ type Output = Self;
|
|
|
+
|
|
|
+ fn not(self) -> Self::Output {
|
|
|
+ Self(!self.0)
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
+impl fmt::Debug for State {
|
|
|
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
|
+ let mut has_states = false;
|
|
|
+ macro_rules! f_bits {
|
|
|
+ ($self: expr, $f: expr, $has_states: ident, $($name: ident),+) => {
|
|
|
+ $(
|
|
|
+ if $self.contains(Self::$name) {
|
|
|
+ if $has_states {
|
|
|
+ $f.write_str(" | ")?;
|
|
|
+ }
|
|
|
+ $f.write_str(stringify!($name))?;
|
|
|
+ $has_states = true;
|
|
|
+ }
|
|
|
+ )+
|
|
|
+
|
|
|
+ };
|
|
|
+ }
|
|
|
+
|
|
|
+ f_bits!(self, f, has_states, PARKING, NOTIFYING, TX_CLOSED, RX_CLOSED);
|
|
|
+
|
|
|
+ if !has_states {
|
|
|
+ if *self == Self::WAITING {
|
|
|
+ return f.write_str("WAITING");
|
|
|
+ }
|
|
|
+
|
|
|
+ f.debug_tuple("UnknownState")
|
|
|
+ .field(&format_args!("{:#b}", self.0))
|
|
|
+ .finish()?;
|
|
|
+ }
|
|
|
+
|
|
|
+ Ok(())
|
|
|
+ }
|
|
|
+}
|
|
|
+
|
|
|
#[cfg(test)]
|
|
|
mod tests {
|
|
|
use super::*;
|