123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389 |
- use crate::{
- loom::{
- atomic::{
- AtomicUsize,
- Ordering::{self, *},
- },
- cell::UnsafeCell,
- },
- util::panic::{self, RefUnwindSafe, UnwindSafe},
- wait::{Notify, WaitResult},
- };
- use core::{fmt, ops};
- /// An atomically registered waiter ([`Waker`] or [`Thread`]).
- ///
- /// This is inspired by the [`AtomicWaker` type] used in Tokio's
- /// synchronization primitives, with the following modifications:
- ///
- /// - Unlike [`AtomicWaker`], a `WaitCell` is generic over the type of the
- /// waiting value. This means it can be used in both asynchronous code (with
- /// [`Waker`]s), or in synchronous, multi-threaded code (with a [`Thread`]).
- /// - An additional bit of state is added to allow setting a "close" bit. This
- /// is so that closing a channel can be tracked in the same atomic as the
- /// receiver's notification state, reducing the number of separate atomic RMW
- /// ops that have to be synchronized between.
- /// - A `WaitCell` is always woken by value. This is just because I didn't
- /// actually need separate "take waiter" and "wake" steps for any of the uses
- /// in `ThingBuf`...
- ///
- /// [`AtomicWaker`]: https://github.com/tokio-rs/tokio/blob/09b770c5db31a1f35631600e1d239679354da2dd/tokio/src/sync/task/atomic_waker.rs
- pub(crate) struct WaitCell<T> {
- lock: AtomicUsize,
- waiter: UnsafeCell<Option<T>>,
- }
- #[derive(Eq, PartialEq, Copy, Clone)]
- struct State(usize);
- // === impl WaitCell ===
- impl<T> WaitCell<T> {
- #[cfg(not(all(loom, test)))]
- pub(crate) const fn new() -> Self {
- Self {
- lock: AtomicUsize::new(State::WAITING.0),
- waiter: UnsafeCell::new(None),
- }
- }
- #[cfg(all(loom, test))]
- pub(crate) fn new() -> Self {
- Self {
- lock: AtomicUsize::new(State::WAITING.0),
- waiter: UnsafeCell::new(None),
- }
- }
- }
- impl<T: Notify> WaitCell<T> {
- pub(crate) fn wait_with(&self, f: impl FnOnce() -> T) -> WaitResult {
- test_println!("registering waiter");
- // this is based on tokio's AtomicWaker synchronization strategy
- match test_dbg!(self.compare_exchange(State::WAITING, State::PARKING, Acquire)) {
- // someone else is notifying the receiver, so don't park!
- Err(actual) if test_dbg!(actual.contains(State::TX_CLOSED)) => {
- return WaitResult::Closed;
- }
- Err(actual) if test_dbg!(actual.contains(State::NOTIFYING)) => {
- f().notify();
- crate::loom::hint::spin_loop();
- return WaitResult::Notified;
- }
- Err(actual) => {
- debug_assert!(
- actual == State::PARKING || actual == State::PARKING | State::NOTIFYING
- );
- return WaitResult::Wait;
- }
- Ok(_) => {}
- }
- test_println!("-> locked!");
- let (panicked, prev_waiter) = match panic::catch_unwind(panic::AssertUnwindSafe(f)) {
- Ok(new_waiter) => {
- let new_waiter = test_dbg!(new_waiter);
- let prev_waiter = self
- .waiter
- .with_mut(|waiter| unsafe { (*waiter).replace(new_waiter) });
- (None, test_dbg!(prev_waiter))
- }
- Err(panic) => (Some(panic), None),
- };
- let result = match test_dbg!(self.compare_exchange(State::PARKING, State::WAITING, AcqRel))
- {
- Ok(_) => {
- let _ = panic::catch_unwind(move || drop(prev_waiter));
- WaitResult::Wait
- }
- Err(actual) => {
- test_println!("-> was notified; state={:?}", actual);
- let waiter = self.waiter.with_mut(|waiter| unsafe { (*waiter).take() });
- // Reset to the WAITING state by clearing everything *except*
- // the closed bits (which must remain set).
- let state = test_dbg!(self.fetch_and(State::TX_CLOSED, AcqRel));
- // The only valid state transition while we were parking is to
- // add the TX_CLOSED bit.
- debug_assert!(
- state == actual || state == actual | State::TX_CLOSED,
- "state changed unexpectedly while parking!"
- );
- if let Some(prev_waiter) = prev_waiter {
- let _ = panic::catch_unwind(move || {
- prev_waiter.notify();
- });
- }
- if let Some(waiter) = waiter {
- debug_assert!(panicked.is_none());
- waiter.notify();
- }
- if test_dbg!(state.contains(State::TX_CLOSED)) {
- WaitResult::Closed
- } else {
- WaitResult::Notified
- }
- }
- };
- if let Some(panic) = panicked {
- panic::resume_unwind(panic);
- }
- result
- }
- pub(crate) fn notify(&self) -> bool {
- self.notify2(false)
- }
- pub(crate) fn close_tx(&self) {
- self.notify2(true);
- }
- fn notify2(&self, close: bool) -> bool {
- test_println!("notifying; close={:?};", close);
- let bits = if close {
- State::NOTIFYING | State::TX_CLOSED
- } else {
- State::NOTIFYING
- };
- test_dbg!(bits);
- if test_dbg!(self.fetch_or(bits, AcqRel)) == State::WAITING {
- // we have the lock!
- let waiter = self.waiter.with_mut(|thread| unsafe { (*thread).take() });
- test_dbg!(self.fetch_and(!State::NOTIFYING, AcqRel));
- if let Some(waiter) = test_dbg!(waiter) {
- waiter.notify();
- return true;
- }
- }
- false
- }
- }
- impl<T> WaitCell<T> {
- #[inline(always)]
- fn compare_exchange(
- &self,
- State(curr): State,
- State(new): State,
- success: Ordering,
- ) -> Result<State, State> {
- self.lock
- .compare_exchange(curr, new, success, Acquire)
- .map(State)
- .map_err(State)
- }
- #[inline(always)]
- fn fetch_and(&self, State(state): State, order: Ordering) -> State {
- State(self.lock.fetch_and(state, order))
- }
- #[inline(always)]
- fn fetch_or(&self, State(state): State, order: Ordering) -> State {
- State(self.lock.fetch_or(state, order))
- }
- #[inline(always)]
- fn current_state(&self) -> State {
- State(self.lock.load(Acquire))
- }
- }
- impl<T: UnwindSafe> UnwindSafe for WaitCell<T> {}
- impl<T: RefUnwindSafe> RefUnwindSafe for WaitCell<T> {}
- unsafe impl<T: Send> Send for WaitCell<T> {}
- unsafe impl<T: Send> Sync for WaitCell<T> {}
- impl<T> fmt::Debug for WaitCell<T> {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- f.debug_struct("WaitCell")
- .field("state", &self.current_state())
- .finish()
- }
- }
- // === impl State ===
- impl State {
- const WAITING: Self = Self(0b00);
- const PARKING: Self = Self(0b01);
- const NOTIFYING: Self = Self(0b10);
- const TX_CLOSED: Self = Self(0b100);
- fn contains(self, Self(state): Self) -> bool {
- self.0 & state == state
- }
- }
- impl ops::BitOr for State {
- type Output = Self;
- fn bitor(self, Self(rhs): Self) -> Self::Output {
- Self(self.0 | rhs)
- }
- }
- impl ops::Not for State {
- type Output = Self;
- fn not(self) -> Self::Output {
- Self(!self.0)
- }
- }
- impl fmt::Debug for State {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
- let mut has_states = false;
- fmt_bits!(self, f, has_states, PARKING, NOTIFYING, TX_CLOSED);
- if !has_states {
- if *self == Self::WAITING {
- return f.write_str("WAITING");
- }
- f.debug_tuple("UnknownState")
- .field(&format_args!("{:#b}", self.0))
- .finish()?;
- }
- Ok(())
- }
- }
- #[cfg(all(loom, test))]
- mod tests {
- use super::*;
- use crate::loom::{
- future,
- sync::atomic::{AtomicUsize, Ordering::Relaxed},
- thread,
- };
- #[cfg(feature = "alloc")]
- use alloc::sync::Arc;
- use core::task::{Poll, Waker};
- struct Chan {
- num: AtomicUsize,
- task: WaitCell<Waker>,
- }
- const NUM_NOTIFY: usize = 2;
- async fn wait_on(chan: Arc<Chan>) {
- futures_util::future::poll_fn(move |cx| {
- let res = test_dbg!(chan.task.wait_with(|| cx.waker().clone()));
- if NUM_NOTIFY == chan.num.load(Relaxed) {
- return Poll::Ready(());
- }
- if res == WaitResult::Notified || res == WaitResult::Closed {
- return Poll::Ready(());
- }
- Poll::Pending
- })
- .await
- }
- #[test]
- #[cfg(feature = "alloc")]
- fn basic_notification() {
- crate::loom::model(|| {
- let chan = Arc::new(Chan {
- num: AtomicUsize::new(0),
- task: WaitCell::new(),
- });
- for _ in 0..NUM_NOTIFY {
- let chan = chan.clone();
- thread::spawn(move || {
- chan.num.fetch_add(1, Relaxed);
- chan.task.notify();
- });
- }
- future::block_on(wait_on(chan));
- });
- }
- #[test]
- #[cfg(feature = "alloc")]
- fn tx_close() {
- crate::loom::model(|| {
- let chan = Arc::new(Chan {
- num: AtomicUsize::new(0),
- task: WaitCell::new(),
- });
- thread::spawn({
- let chan = chan.clone();
- move || {
- chan.num.fetch_add(1, Relaxed);
- chan.task.notify();
- }
- });
- thread::spawn({
- let chan = chan.clone();
- move || {
- chan.num.fetch_add(1, Relaxed);
- chan.task.close_tx();
- }
- });
- future::block_on(wait_on(chan));
- });
- }
- // #[test]
- // #[cfg(feature = "std")]
- // fn test_panicky_waker() {
- // use std::panic;
- // use std::ptr;
- // use std::task::{RawWaker, RawWakerVTable, Waker};
- // static PANICKING_VTABLE: RawWakerVTable =
- // RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ());
- // let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
- // loom::model(move || {
- // let chan = Arc::new(Chan {
- // num: AtomicUsize::new(0),
- // task: WaitCell::new(),
- // });
- // for _ in 0..NUM_NOTIFY {
- // let chan = chan.clone();
- // thread::spawn(move || {
- // chan.num.fetch_add(1, Relaxed);
- // chan.task.notify();
- // });
- // }
- // // Note: this panic should have no effect on the overall state of the
- // // waker and it should proceed as normal.
- // //
- // // A thread above might race to flag a wakeup, and a WAKING state will
- // // be preserved if this expected panic races with that so the below
- // // procedure should be allowed to continue uninterrupted.
- // let _ = panic::catch_unwind(|| chan.task.wait_with(|| panicking.clone()));
- // future::block_on(wait_on(chan));
- // });
- // }
- }
|