cell.rs 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389
  1. use crate::{
  2. loom::{
  3. atomic::{
  4. AtomicUsize,
  5. Ordering::{self, *},
  6. },
  7. cell::UnsafeCell,
  8. },
  9. util::panic::{self, RefUnwindSafe, UnwindSafe},
  10. wait::{Notify, WaitResult},
  11. };
  12. use core::{fmt, ops};
  13. /// An atomically registered waiter ([`Waker`] or [`Thread`]).
  14. ///
  15. /// This is inspired by the [`AtomicWaker` type] used in Tokio's
  16. /// synchronization primitives, with the following modifications:
  17. ///
  18. /// - Unlike [`AtomicWaker`], a `WaitCell` is generic over the type of the
  19. /// waiting value. This means it can be used in both asynchronous code (with
  20. /// [`Waker`]s), or in synchronous, multi-threaded code (with a [`Thread`]).
  21. /// - An additional bit of state is added to allow setting a "close" bit. This
  22. /// is so that closing a channel can be tracked in the same atomic as the
  23. /// receiver's notification state, reducing the number of separate atomic RMW
  24. /// ops that have to be synchronized between.
  25. /// - A `WaitCell` is always woken by value. This is just because I didn't
  26. /// actually need separate "take waiter" and "wake" steps for any of the uses
  27. /// in `ThingBuf`...
  28. ///
  29. /// [`AtomicWaker`]: https://github.com/tokio-rs/tokio/blob/09b770c5db31a1f35631600e1d239679354da2dd/tokio/src/sync/task/atomic_waker.rs
  30. pub(crate) struct WaitCell<T> {
  31. lock: AtomicUsize,
  32. waiter: UnsafeCell<Option<T>>,
  33. }
  34. #[derive(Eq, PartialEq, Copy, Clone)]
  35. struct State(usize);
  36. // === impl WaitCell ===
  37. impl<T> WaitCell<T> {
  38. #[cfg(not(all(loom, test)))]
  39. pub(crate) const fn new() -> Self {
  40. Self {
  41. lock: AtomicUsize::new(State::WAITING.0),
  42. waiter: UnsafeCell::new(None),
  43. }
  44. }
  45. #[cfg(all(loom, test))]
  46. pub(crate) fn new() -> Self {
  47. Self {
  48. lock: AtomicUsize::new(State::WAITING.0),
  49. waiter: UnsafeCell::new(None),
  50. }
  51. }
  52. }
  53. impl<T: Notify> WaitCell<T> {
  54. pub(crate) fn wait_with(&self, f: impl FnOnce() -> T) -> WaitResult {
  55. test_println!("registering waiter");
  56. // this is based on tokio's AtomicWaker synchronization strategy
  57. match test_dbg!(self.compare_exchange(State::WAITING, State::PARKING, Acquire)) {
  58. // someone else is notifying the receiver, so don't park!
  59. Err(actual) if test_dbg!(actual.contains(State::TX_CLOSED)) => {
  60. return WaitResult::Closed;
  61. }
  62. Err(actual) if test_dbg!(actual.contains(State::NOTIFYING)) => {
  63. f().notify();
  64. crate::loom::hint::spin_loop();
  65. return WaitResult::Notified;
  66. }
  67. Err(actual) => {
  68. debug_assert!(
  69. actual == State::PARKING || actual == State::PARKING | State::NOTIFYING
  70. );
  71. return WaitResult::Wait;
  72. }
  73. Ok(_) => {}
  74. }
  75. test_println!("-> locked!");
  76. let (panicked, prev_waiter) = match panic::catch_unwind(panic::AssertUnwindSafe(f)) {
  77. Ok(new_waiter) => {
  78. let new_waiter = test_dbg!(new_waiter);
  79. let prev_waiter = self
  80. .waiter
  81. .with_mut(|waiter| unsafe { (*waiter).replace(new_waiter) });
  82. (None, test_dbg!(prev_waiter))
  83. }
  84. Err(panic) => (Some(panic), None),
  85. };
  86. let result = match test_dbg!(self.compare_exchange(State::PARKING, State::WAITING, AcqRel))
  87. {
  88. Ok(_) => {
  89. let _ = panic::catch_unwind(move || drop(prev_waiter));
  90. WaitResult::Wait
  91. }
  92. Err(actual) => {
  93. test_println!("-> was notified; state={:?}", actual);
  94. let waiter = self.waiter.with_mut(|waiter| unsafe { (*waiter).take() });
  95. // Reset to the WAITING state by clearing everything *except*
  96. // the closed bits (which must remain set).
  97. let state = test_dbg!(self.fetch_and(State::TX_CLOSED, AcqRel));
  98. // The only valid state transition while we were parking is to
  99. // add the TX_CLOSED bit.
  100. debug_assert!(
  101. state == actual || state == actual | State::TX_CLOSED,
  102. "state changed unexpectedly while parking!"
  103. );
  104. if let Some(prev_waiter) = prev_waiter {
  105. let _ = panic::catch_unwind(move || {
  106. prev_waiter.notify();
  107. });
  108. }
  109. if let Some(waiter) = waiter {
  110. debug_assert!(panicked.is_none());
  111. waiter.notify();
  112. }
  113. if test_dbg!(state.contains(State::TX_CLOSED)) {
  114. WaitResult::Closed
  115. } else {
  116. WaitResult::Notified
  117. }
  118. }
  119. };
  120. if let Some(panic) = panicked {
  121. panic::resume_unwind(panic);
  122. }
  123. result
  124. }
  125. pub(crate) fn notify(&self) -> bool {
  126. self.notify2(false)
  127. }
  128. pub(crate) fn close_tx(&self) {
  129. self.notify2(true);
  130. }
  131. fn notify2(&self, close: bool) -> bool {
  132. test_println!("notifying; close={:?};", close);
  133. let bits = if close {
  134. State::NOTIFYING | State::TX_CLOSED
  135. } else {
  136. State::NOTIFYING
  137. };
  138. test_dbg!(bits);
  139. if test_dbg!(self.fetch_or(bits, AcqRel)) == State::WAITING {
  140. // we have the lock!
  141. let waiter = self.waiter.with_mut(|thread| unsafe { (*thread).take() });
  142. test_dbg!(self.fetch_and(!State::NOTIFYING, AcqRel));
  143. if let Some(waiter) = test_dbg!(waiter) {
  144. waiter.notify();
  145. return true;
  146. }
  147. }
  148. false
  149. }
  150. }
  151. impl<T> WaitCell<T> {
  152. #[inline(always)]
  153. fn compare_exchange(
  154. &self,
  155. State(curr): State,
  156. State(new): State,
  157. success: Ordering,
  158. ) -> Result<State, State> {
  159. self.lock
  160. .compare_exchange(curr, new, success, Acquire)
  161. .map(State)
  162. .map_err(State)
  163. }
  164. #[inline(always)]
  165. fn fetch_and(&self, State(state): State, order: Ordering) -> State {
  166. State(self.lock.fetch_and(state, order))
  167. }
  168. #[inline(always)]
  169. fn fetch_or(&self, State(state): State, order: Ordering) -> State {
  170. State(self.lock.fetch_or(state, order))
  171. }
  172. #[inline(always)]
  173. fn current_state(&self) -> State {
  174. State(self.lock.load(Acquire))
  175. }
  176. }
  177. impl<T: UnwindSafe> UnwindSafe for WaitCell<T> {}
  178. impl<T: RefUnwindSafe> RefUnwindSafe for WaitCell<T> {}
  179. unsafe impl<T: Send> Send for WaitCell<T> {}
  180. unsafe impl<T: Send> Sync for WaitCell<T> {}
  181. impl<T> fmt::Debug for WaitCell<T> {
  182. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  183. f.debug_struct("WaitCell")
  184. .field("state", &self.current_state())
  185. .finish()
  186. }
  187. }
  188. // === impl State ===
  189. impl State {
  190. const WAITING: Self = Self(0b00);
  191. const PARKING: Self = Self(0b01);
  192. const NOTIFYING: Self = Self(0b10);
  193. const TX_CLOSED: Self = Self(0b100);
  194. fn contains(self, Self(state): Self) -> bool {
  195. self.0 & state == state
  196. }
  197. }
  198. impl ops::BitOr for State {
  199. type Output = Self;
  200. fn bitor(self, Self(rhs): Self) -> Self::Output {
  201. Self(self.0 | rhs)
  202. }
  203. }
  204. impl ops::Not for State {
  205. type Output = Self;
  206. fn not(self) -> Self::Output {
  207. Self(!self.0)
  208. }
  209. }
  210. impl fmt::Debug for State {
  211. fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
  212. let mut has_states = false;
  213. fmt_bits!(self, f, has_states, PARKING, NOTIFYING, TX_CLOSED);
  214. if !has_states {
  215. if *self == Self::WAITING {
  216. return f.write_str("WAITING");
  217. }
  218. f.debug_tuple("UnknownState")
  219. .field(&format_args!("{:#b}", self.0))
  220. .finish()?;
  221. }
  222. Ok(())
  223. }
  224. }
  225. #[cfg(all(loom, test))]
  226. mod tests {
  227. use super::*;
  228. use crate::loom::{
  229. future,
  230. sync::atomic::{AtomicUsize, Ordering::Relaxed},
  231. thread,
  232. };
  233. #[cfg(feature = "alloc")]
  234. use alloc::sync::Arc;
  235. use core::task::{Poll, Waker};
  236. struct Chan {
  237. num: AtomicUsize,
  238. task: WaitCell<Waker>,
  239. }
  240. const NUM_NOTIFY: usize = 2;
  241. async fn wait_on(chan: Arc<Chan>) {
  242. futures_util::future::poll_fn(move |cx| {
  243. let res = test_dbg!(chan.task.wait_with(|| cx.waker().clone()));
  244. if NUM_NOTIFY == chan.num.load(Relaxed) {
  245. return Poll::Ready(());
  246. }
  247. if res == WaitResult::Notified || res == WaitResult::Closed {
  248. return Poll::Ready(());
  249. }
  250. Poll::Pending
  251. })
  252. .await
  253. }
  254. #[test]
  255. #[cfg(feature = "alloc")]
  256. fn basic_notification() {
  257. crate::loom::model(|| {
  258. let chan = Arc::new(Chan {
  259. num: AtomicUsize::new(0),
  260. task: WaitCell::new(),
  261. });
  262. for _ in 0..NUM_NOTIFY {
  263. let chan = chan.clone();
  264. thread::spawn(move || {
  265. chan.num.fetch_add(1, Relaxed);
  266. chan.task.notify();
  267. });
  268. }
  269. future::block_on(wait_on(chan));
  270. });
  271. }
  272. #[test]
  273. #[cfg(feature = "alloc")]
  274. fn tx_close() {
  275. crate::loom::model(|| {
  276. let chan = Arc::new(Chan {
  277. num: AtomicUsize::new(0),
  278. task: WaitCell::new(),
  279. });
  280. thread::spawn({
  281. let chan = chan.clone();
  282. move || {
  283. chan.num.fetch_add(1, Relaxed);
  284. chan.task.notify();
  285. }
  286. });
  287. thread::spawn({
  288. let chan = chan.clone();
  289. move || {
  290. chan.num.fetch_add(1, Relaxed);
  291. chan.task.close_tx();
  292. }
  293. });
  294. future::block_on(wait_on(chan));
  295. });
  296. }
  297. // #[test]
  298. // #[cfg(feature = "std")]
  299. // fn test_panicky_waker() {
  300. // use std::panic;
  301. // use std::ptr;
  302. // use std::task::{RawWaker, RawWakerVTable, Waker};
  303. // static PANICKING_VTABLE: RawWakerVTable =
  304. // RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ());
  305. // let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
  306. // loom::model(move || {
  307. // let chan = Arc::new(Chan {
  308. // num: AtomicUsize::new(0),
  309. // task: WaitCell::new(),
  310. // });
  311. // for _ in 0..NUM_NOTIFY {
  312. // let chan = chan.clone();
  313. // thread::spawn(move || {
  314. // chan.num.fetch_add(1, Relaxed);
  315. // chan.task.notify();
  316. // });
  317. // }
  318. // // Note: this panic should have no effect on the overall state of the
  319. // // waker and it should proceed as normal.
  320. // //
  321. // // A thread above might race to flag a wakeup, and a WAKING state will
  322. // // be preserved if this expected panic races with that so the below
  323. // // procedure should be allowed to continue uninterrupted.
  324. // let _ = panic::catch_unwind(|| chan.task.wait_with(|| panicking.clone()));
  325. // future::block_on(wait_on(chan));
  326. // });
  327. // }
  328. }