wait.rs 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. use crate::{
  2. loom::{
  3. atomic::{AtomicUsize, Ordering::*},
  4. UnsafeCell,
  5. },
  6. util::panic::{self, RefUnwindSafe, UnwindSafe},
  7. };
  8. use core::{fmt, task::Waker};
  9. #[cfg(feature = "std")]
  10. use crate::loom::thread;
  11. /// An atomically registered waiter ([`Waker`] or [`Thread`]).
  12. ///
  13. /// This is inspired by the [`AtomicWaker` type] used in Tokio's
  14. /// synchronization primitives, with the following modifications:
  15. ///
  16. /// - Unlike [`AtomicWaker`], a `WaitCell` is generic over the type of the
  17. /// waiting value. This means it can be used in both asynchronous code (with
  18. /// [`Waker`]s), or in synchronous, multi-threaded code (with a [`Thread`]).
  19. /// - An additional bit of state is added to allow setting a "close" bit. This
  20. /// is so that closing a channel can be tracked in the same atomic as the
  21. /// receiver's notification state, reducing the number of separate atomic RMW
  22. /// ops that have to be synchronized between.
  23. /// - A `WaitCell` is always woken by value. This is just because I didn't
  24. /// actually need separate "take waiter" and "wake" steps for any of the uses
  25. /// in `ThingBuf`...
  26. ///
  27. /// [`AtomicWaker`]: https://github.com/tokio-rs/tokio/blob/09b770c5db31a1f35631600e1d239679354da2dd/tokio/src/sync/task/atomic_waker.rs
  28. pub(crate) struct WaitCell<T> {
  29. lock: AtomicUsize,
  30. waiter: UnsafeCell<Option<T>>,
  31. }
  32. #[derive(Debug, Eq, PartialEq)]
  33. pub(crate) enum WaitResult {
  34. Wait,
  35. Notified,
  36. TxClosed,
  37. }
  38. pub(crate) trait Notify {
  39. fn notify(self);
  40. }
  41. // === impl WaitCell ===
  42. impl<T: Notify + UnwindSafe + fmt::Debug> WaitCell<T> {
  43. const WAITING: usize = 0b00;
  44. const PARKING: usize = 0b01;
  45. const NOTIFYING: usize = 0b10;
  46. const TX_CLOSED: usize = 0b100;
  47. const RX_CLOSED: usize = 0b1000;
  48. pub(crate) fn new() -> Self {
  49. Self {
  50. lock: AtomicUsize::new(Self::WAITING),
  51. waiter: UnsafeCell::new(None),
  52. }
  53. }
  54. pub(crate) fn close_rx(&self) {
  55. self.lock.fetch_or(Self::RX_CLOSED, AcqRel);
  56. }
  57. pub(crate) fn is_rx_closed(&self) -> bool {
  58. test_dbg!(self.lock.load(Acquire) & Self::RX_CLOSED == Self::RX_CLOSED)
  59. }
  60. pub(crate) fn wait_with(&self, f: impl FnOnce() -> T) -> WaitResult {
  61. test_println!("registering waiter");
  62. // this is based on tokio's AtomicWaker synchronization strategy
  63. match test_dbg!(self
  64. .lock
  65. .compare_exchange(Self::WAITING, Self::PARKING, AcqRel, Acquire,))
  66. {
  67. // someone else is notifying the receiver, so don't park!
  68. Err(actual) if test_dbg!(actual & Self::TX_CLOSED) == Self::TX_CLOSED => {
  69. test_println!("-> state = TX_CLOSED");
  70. return WaitResult::TxClosed;
  71. }
  72. Err(actual) if test_dbg!(actual & Self::NOTIFYING) == Self::NOTIFYING => {
  73. test_println!("-> state = NOTIFYING");
  74. // f().notify();
  75. // loom::hint::spin_loop();
  76. return WaitResult::Notified;
  77. }
  78. Err(actual) => {
  79. debug_assert!(actual == Self::PARKING || actual == Self::PARKING | Self::NOTIFYING);
  80. return WaitResult::Wait;
  81. }
  82. Ok(_) => {}
  83. }
  84. test_println!("-> locked!");
  85. let (panicked, prev_waiter) = match panic::catch_unwind(panic::AssertUnwindSafe(f)) {
  86. Ok(new_waiter) => {
  87. let new_waiter = test_dbg!(new_waiter);
  88. let prev_waiter = self
  89. .waiter
  90. .with_mut(|waiter| unsafe { (*waiter).replace(new_waiter) });
  91. (None, test_dbg!(prev_waiter))
  92. }
  93. Err(panic) => (Some(panic), None),
  94. };
  95. let result = match test_dbg!(self.lock.compare_exchange(
  96. Self::PARKING,
  97. Self::WAITING,
  98. AcqRel,
  99. Acquire
  100. )) {
  101. Ok(_) => {
  102. let _ = panic::catch_unwind(move || drop(prev_waiter));
  103. WaitResult::Wait
  104. }
  105. Err(actual) => {
  106. test_println!("-> was notified; state={:#b}", actual);
  107. let waiter = self.waiter.with_mut(|waiter| unsafe { (*waiter).take() });
  108. // Reset to the WAITING state by clearing everything *except*
  109. // the closed bits (which must remain set).
  110. let state = test_dbg!(self
  111. .lock
  112. .fetch_and(Self::TX_CLOSED | Self::RX_CLOSED, AcqRel));
  113. // The only valid state transition while we were parking is to
  114. // add the TX_CLOSED bit.
  115. debug_assert!(
  116. state == actual || state == actual | Self::TX_CLOSED,
  117. "state changed unexpectedly while parking!"
  118. );
  119. if let Some(prev_waiter) = prev_waiter {
  120. let _ = panic::catch_unwind(move || {
  121. prev_waiter.notify();
  122. });
  123. }
  124. if let Some(waiter) = waiter {
  125. debug_assert!(panicked.is_none());
  126. waiter.notify();
  127. }
  128. if state & Self::TX_CLOSED == Self::TX_CLOSED {
  129. WaitResult::TxClosed
  130. } else {
  131. WaitResult::Notified
  132. }
  133. }
  134. };
  135. if let Some(panic) = panicked {
  136. panic::resume_unwind(panic);
  137. }
  138. result
  139. }
  140. pub(crate) fn notify(&self) {
  141. self.notify2(false)
  142. }
  143. pub(crate) fn close_tx(&self) {
  144. self.notify2(true)
  145. }
  146. fn notify2(&self, close: bool) {
  147. test_println!("notifying; close={:?};", close);
  148. let bits = if close {
  149. Self::NOTIFYING | Self::TX_CLOSED
  150. } else {
  151. Self::NOTIFYING
  152. };
  153. if test_dbg!(self.lock.fetch_or(bits, AcqRel)) == Self::WAITING {
  154. // we have the lock!
  155. let waiter = self.waiter.with_mut(|thread| unsafe { (*thread).take() });
  156. self.lock.fetch_and(!Self::NOTIFYING, Release);
  157. if let Some(waiter) = test_dbg!(waiter) {
  158. waiter.notify();
  159. }
  160. }
  161. }
  162. }
  163. #[cfg(feature = "std")]
  164. impl Notify for thread::Thread {
  165. fn notify(self) {
  166. test_println!("NOTIFYING {:?} (from {:?})", self, thread::current());
  167. self.unpark();
  168. }
  169. }
  170. impl Notify for Waker {
  171. fn notify(self) {
  172. test_println!("WAKING TASK {:?} (from {:?})", self, thread::current());
  173. self.wake();
  174. }
  175. }
  176. impl<T: UnwindSafe> UnwindSafe for WaitCell<T> {}
  177. impl<T: RefUnwindSafe> RefUnwindSafe for WaitCell<T> {}
  178. unsafe impl<T: Send> Send for WaitCell<T> {}
  179. unsafe impl<T: Send> Sync for WaitCell<T> {}
  180. #[cfg(test)]
  181. mod tests {
  182. use super::*;
  183. use crate::loom::{
  184. self, future,
  185. sync::atomic::{AtomicUsize, Ordering::Relaxed},
  186. thread,
  187. };
  188. #[cfg(feature = "alloc")]
  189. use alloc::sync::Arc;
  190. use core::task::{Poll, Waker};
  191. struct Chan {
  192. num: AtomicUsize,
  193. task: WaitCell<Waker>,
  194. }
  195. const NUM_NOTIFY: usize = 2;
  196. async fn wait_on(chan: Arc<Chan>) {
  197. futures_util::future::poll_fn(move |cx| {
  198. let res = test_dbg!(chan.task.wait_with(|| cx.waker().clone()));
  199. if NUM_NOTIFY == chan.num.load(Relaxed) {
  200. return Poll::Ready(());
  201. }
  202. if res == WaitResult::Notified || res == WaitResult::TxClosed {
  203. return Poll::Ready(());
  204. }
  205. Poll::Pending
  206. })
  207. .await
  208. }
  209. #[test]
  210. #[cfg(feature = "alloc")]
  211. fn basic_notification() {
  212. loom::model(|| {
  213. let chan = Arc::new(Chan {
  214. num: AtomicUsize::new(0),
  215. task: WaitCell::new(),
  216. });
  217. for _ in 0..NUM_NOTIFY {
  218. let chan = chan.clone();
  219. thread::spawn(move || {
  220. chan.num.fetch_add(1, Relaxed);
  221. chan.task.notify();
  222. });
  223. }
  224. future::block_on(wait_on(chan));
  225. });
  226. }
  227. #[test]
  228. #[cfg(feature = "alloc")]
  229. fn tx_close() {
  230. loom::model(|| {
  231. let chan = Arc::new(Chan {
  232. num: AtomicUsize::new(0),
  233. task: WaitCell::new(),
  234. });
  235. thread::spawn({
  236. let chan = chan.clone();
  237. move || {
  238. chan.num.fetch_add(1, Relaxed);
  239. chan.task.notify();
  240. }
  241. });
  242. thread::spawn({
  243. let chan = chan.clone();
  244. move || {
  245. chan.num.fetch_add(1, Relaxed);
  246. chan.task.close_tx();
  247. }
  248. });
  249. future::block_on(wait_on(chan));
  250. });
  251. }
  252. #[test]
  253. #[cfg(feature = "std")]
  254. fn test_panicky_waker() {
  255. use std::panic;
  256. use std::ptr;
  257. use std::task::{RawWaker, RawWakerVTable, Waker};
  258. static PANICKING_VTABLE: RawWakerVTable =
  259. RawWakerVTable::new(|_| panic!("clone"), |_| (), |_| (), |_| ());
  260. let panicking = unsafe { Waker::from_raw(RawWaker::new(ptr::null(), &PANICKING_VTABLE)) };
  261. loom::model(move || {
  262. let chan = Arc::new(Chan {
  263. num: AtomicUsize::new(0),
  264. task: WaitCell::new(),
  265. });
  266. for _ in 0..NUM_NOTIFY {
  267. let chan = chan.clone();
  268. thread::spawn(move || {
  269. chan.num.fetch_add(1, Relaxed);
  270. chan.task.notify();
  271. });
  272. }
  273. // Note: this panic should have no effect on the overall state of the
  274. // waker and it should proceed as normal.
  275. //
  276. // A thread above might race to flag a wakeup, and a WAKING state will
  277. // be preserved if this expected panic races with that so the below
  278. // procedure should be allowed to continue uninterrupted.
  279. let _ = panic::catch_unwind(|| chan.task.wait_with(|| panicking.clone()));
  280. future::block_on(wait_on(chan));
  281. });
  282. }
  283. }