Browse Source

feat: add `Deref` and `DerefMut` impls to `Ref` types (#13)

Now that tokio-rs/loom#219 has merged, we can add
`Deref`/`DerefMut` for the `Ref` types without breaking
Loom's concurrent access checking! :D

This makes the `Ref` APIs much easier to work with.
Eliza Weisman 3 years ago
parent
commit
6ebfe7b8fd

+ 1 - 1
Cargo.toml

@@ -24,7 +24,7 @@ tokio = { version = "1.14.0", features = ["rt", "rt-multi-thread", "macros", "sy
 futures-util = { version = "0.3", default-features = false }
 
 [target.'cfg(loom)'.dev-dependencies]
-loom = { version = "0.5.3", features = ["checkpoint", "futures"] }
+loom = { version = "0.5.4", features = ["checkpoint", "futures"] }
 tracing-subscriber = { version = "0.3", default-features = false, features = ["std", "fmt"] }
 tracing = { version = "0.1", default-features = false, features = ["std"] }
 

+ 1 - 5
README.md

@@ -30,8 +30,4 @@
   **A:** Originally, I imagined it as a kind of ring buffer, so (as a pun on
   "ringbuf"), I called it "stringbuf". Then, I realized you could do this with
   more than just strings. In fact, it can be generalized to arbitrary...things.
-  So, "thingbuf".
-
-- **Q: Why don't the `Ref` types implement `Deref` and `DerefMut`?**
-
-  **A:** [Blame `loom` for this.](https://github.com/tokio-rs/loom/pull/219)
+  So, "thingbuf".

+ 39 - 19
src/lib.rs

@@ -1,6 +1,6 @@
 #![cfg_attr(not(feature = "std"), no_std)]
 #![cfg_attr(docsrs, feature(doc_cfg))]
-use core::{cmp, fmt, mem::MaybeUninit, ops::Index};
+use core::{cmp, fmt, mem::MaybeUninit, ops};
 
 #[macro_use]
 mod macros;
@@ -27,12 +27,13 @@ pub use self::static_thingbuf::StaticThingBuf;
 use crate::{
     loom::{
         atomic::{AtomicUsize, Ordering::*},
-        UnsafeCell,
+        cell::{MutPtr, UnsafeCell},
     },
     util::{Backoff, CachePadded},
 };
 
 pub struct Ref<'slot, T> {
+    ptr: MutPtr<MaybeUninit<T>>,
     slot: &'slot Slot<T>,
     new_state: usize,
 }
@@ -127,7 +128,7 @@ impl Core {
     ) -> Result<Ref<'slots, T>, mpsc::TrySendError<()>>
     where
         T: Default,
-        S: Index<usize, Output = Slot<T>> + ?Sized,
+        S: ops::Index<usize, Output = Slot<T>> + ?Sized,
     {
         test_println!("push_ref");
         let mut backoff = Backoff::new();
@@ -158,16 +159,20 @@ impl Core {
                     Ok(_) => {
                         // We got the slot! It's now okay to write to it
                         test_println!("claimed tail slot [{}]", idx);
+                        // Claim exclusive ownership over the slot
+                        let ptr = slot.value.get_mut();
+
                         if gen == 0 {
-                            slot.value.with_mut(|value| unsafe {
+                            unsafe {
                                 // Safety: we have just claimed exclusive ownership over
                                 // this slot.
-                                (*value).write(T::default());
-                            });
+                                ptr.deref().write(T::default());
+                            };
                             test_println!("-> initialized");
                         }
 
                         return Ok(Ref {
+                            ptr,
                             new_state: tail + 1,
                             slot,
                         });
@@ -207,7 +212,7 @@ impl Core {
 
     fn pop_ref<'slots, T, S>(&self, slots: &'slots S) -> Result<Ref<'slots, T>, mpsc::TrySendError>
     where
-        S: Index<usize, Output = Slot<T>> + ?Sized,
+        S: ops::Index<usize, Output = Slot<T>> + ?Sized,
     {
         test_println!("pop_ref");
         let mut backoff = Backoff::new();
@@ -234,6 +239,7 @@ impl Core {
                         test_println!("claimed head slot [{}]", idx);
                         return Ok(Ref {
                             new_state: head.wrapping_add(self.gen),
+                            ptr: slot.value.get_mut(),
                             slot,
                         });
                     }
@@ -299,11 +305,9 @@ impl Core {
 // === impl Ref ===
 
 impl<T> Ref<'_, T> {
-    const RELEASED: usize = usize::MAX;
-
     #[inline]
     pub fn with<U>(&self, f: impl FnOnce(&T) -> U) -> U {
-        self.slot.value.with(|value| unsafe {
+        self.ptr.with(|value| unsafe {
             // Safety: if a `Ref` exists, we have exclusive ownership of the
             // slot. A `Ref` is only created if the slot has already been
             // initialized.
@@ -315,7 +319,7 @@ impl<T> Ref<'_, T> {
 
     #[inline]
     pub fn with_mut<U>(&mut self, f: impl FnOnce(&mut T) -> U) -> U {
-        self.slot.value.with_mut(|value| unsafe {
+        self.ptr.with(|value| unsafe {
             // Safety: if a `Ref` exists, we have exclusive ownership of the
             // slot.
             // TODO(eliza): use `MaybeUninit::assume_init_mut` here once it's
@@ -323,23 +327,39 @@ impl<T> Ref<'_, T> {
             f(&mut *(&mut *value).as_mut_ptr())
         })
     }
+}
 
-    pub(crate) fn release(&mut self) {
-        if self.new_state == Self::RELEASED {
-            test_println!("release_ref; already released");
-            return;
+impl<T> ops::Deref for Ref<'_, T> {
+    type Target = T;
+
+    #[inline]
+    fn deref(&self) -> &Self::Target {
+        unsafe {
+            // Safety: if a `Ref` exists, we have exclusive ownership of the
+            // slot. A `Ref` is only created if the slot has already been
+            // initialized.
+            &*self.ptr.deref().as_ptr()
         }
+    }
+}
 
-        test_println!("release_ref");
-        test_dbg!(self.slot.state.store(test_dbg!(self.new_state), Release));
-        self.new_state = Self::RELEASED;
+impl<T> ops::DerefMut for Ref<'_, T> {
+    #[inline]
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        unsafe {
+            // Safety: if a `Ref` exists, we have exclusive ownership of the
+            // slot. A `Ref` is only created if the slot has already been
+            // initialized.
+            &mut *self.ptr.deref().as_mut_ptr()
+        }
     }
 }
 
 impl<T> Drop for Ref<'_, T> {
     #[inline]
     fn drop(&mut self) {
-        self.release();
+        test_println!("drop Ref<{}>", core::any::type_name::<T>());
+        test_dbg!(self.slot.state.store(test_dbg!(self.new_state), Release));
     }
 }
 

+ 48 - 20
src/loom.rs

@@ -8,7 +8,7 @@ mod inner {
         pub use std::sync::atomic::Ordering;
     }
 
-    pub(crate) use loom::{cell::UnsafeCell, future, hint, sync, thread};
+    pub(crate) use loom::{cell, future, hint, sync, thread};
     use std::{cell::RefCell, fmt::Write};
 
     pub(crate) mod model {
@@ -213,31 +213,59 @@ mod inner {
         }
     }
 
-    #[derive(Debug)]
-    pub(crate) struct UnsafeCell<T>(core::cell::UnsafeCell<T>);
+    pub(crate) mod cell {
+        #[derive(Debug)]
+        pub(crate) struct UnsafeCell<T>(core::cell::UnsafeCell<T>);
 
-    impl<T> UnsafeCell<T> {
-        pub const fn new(data: T) -> UnsafeCell<T> {
-            UnsafeCell(core::cell::UnsafeCell::new(data))
-        }
+        impl<T> UnsafeCell<T> {
+            pub const fn new(data: T) -> UnsafeCell<T> {
+                UnsafeCell(core::cell::UnsafeCell::new(data))
+            }
 
-        #[inline(always)]
-        pub fn with<F, R>(&self, f: F) -> R
-        where
-            F: FnOnce(*const T) -> R,
-        {
-            f(self.0.get())
+            #[inline(always)]
+            pub fn with<F, R>(&self, f: F) -> R
+            where
+                F: FnOnce(*const T) -> R,
+            {
+                f(self.0.get())
+            }
+
+            #[inline(always)]
+            pub fn with_mut<F, R>(&self, f: F) -> R
+            where
+                F: FnOnce(*mut T) -> R,
+            {
+                f(self.0.get())
+            }
+
+            #[inline(always)]
+            pub(crate) fn get_mut(&self) -> MutPtr<T> {
+                MutPtr(self.0.get())
+            }
         }
 
-        #[inline(always)]
-        pub fn with_mut<F, R>(&self, f: F) -> R
-        where
-            F: FnOnce(*mut T) -> R,
-        {
-            f(self.0.get())
+        #[derive(Debug)]
+        pub(crate) struct MutPtr<T: ?Sized>(*mut T);
+
+        impl<T: ?Sized> MutPtr<T> {
+            // Clippy knows that it's Bad and Wrong to construct a mutable reference
+            // from an immutable one...but this function is intended to simulate a raw
+            // pointer, so we have to do that here.
+            #[allow(clippy::mut_from_ref)]
+            #[inline(always)]
+            pub(crate) unsafe fn deref(&self) -> &mut T {
+                &mut *self.0
+            }
+
+            #[inline(always)]
+            pub fn with<F, R>(&self, f: F) -> R
+            where
+                F: FnOnce(*mut T) -> R,
+            {
+                f(self.0)
+            }
         }
     }
-
     pub(crate) mod alloc {
         /// Track allocations, detecting leaks
         #[derive(Debug, Default)]

+ 103 - 20
src/mpsc.rs

@@ -44,10 +44,28 @@ struct Inner<T, N: Notify> {
 }
 
 struct SendRefInner<'a, T, N: Notify> {
-    inner: &'a Inner<T, N>,
+    // /!\ LOAD BEARING STRUCT DROP ORDER /!\
+    //
+    // The `Ref` field *must* be dropped before the `NotifyInner` field, or else
+    // loom tests will fail. This ensures that the mutable access to the slot is
+    // considered to have ended *before* the receiver thread/task is notified.
+    //
+    // The alternatives to a load-bearing drop order would be:
+    // (a) put one field inside an `Option` so it can be dropped before the
+    //     other (not great, as it adds a little extra overhead even outside
+    //     of Loom tests),
+    // (b) use `core::mem::ManuallyDrop` (also not great, requires additional
+    //     unsafe code that in this case we can avoid)
+    //
+    // So, given that, relying on struct field drop order seemed like the least
+    // bad option here. Just don't reorder these fields. :)
     slot: Ref<'a, T>,
+    _notify: NotifyRx<'a, N>,
 }
 
+struct NotifyRx<'a, N: Notify>(&'a WaitCell<N>);
+struct NotifyTx<'a, N: Notify>(&'a WaitQueue<NotifyOnDrop<N>>);
+
 // ==== impl TrySendError ===
 
 impl TrySendError {
@@ -85,7 +103,10 @@ impl<T: Default, N: Notify> Inner<T, N> {
         self.thingbuf
             .core
             .push_ref(self.thingbuf.slots.as_ref())
-            .map(|slot| SendRefInner { inner: self, slot })
+            .map(|slot| SendRefInner {
+                _notify: NotifyRx(&self.rx_wait),
+                slot,
+            })
     }
 
     fn try_send(&self, val: T) -> Result<(), TrySendError<T>> {
@@ -194,6 +215,22 @@ impl<T: Default, N: Notify> Inner<T, N> {
     }
 }
 
+impl<T, N: Notify> core::ops::Deref for SendRefInner<'_, T, N> {
+    type Target = T;
+
+    #[inline]
+    fn deref(&self) -> &Self::Target {
+        self.slot.deref()
+    }
+}
+
+impl<T, N: Notify> core::ops::DerefMut for SendRefInner<'_, T, N> {
+    #[inline]
+    fn deref_mut(&mut self) -> &mut Self::Target {
+        self.slot.deref_mut()
+    }
+}
+
 impl<T, N: Notify> SendRefInner<'_, T, N> {
     #[inline]
     pub fn with<U>(&self, f: impl FnOnce(&T) -> U) -> U {
@@ -206,15 +243,6 @@ impl<T, N: Notify> SendRefInner<'_, T, N> {
     }
 }
 
-impl<T, N: Notify> Drop for SendRefInner<'_, T, N> {
-    #[inline]
-    fn drop(&mut self) {
-        test_println!("drop SendRef<T, {}>", std::any::type_name::<N>());
-        self.slot.release();
-        self.inner.rx_wait.notify();
-    }
-}
-
 impl<T: fmt::Debug, N: Notify> fmt::Debug for SendRefInner<'_, T, N> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         self.with(|val| fmt::Debug::fmt(val, f))
@@ -244,6 +272,22 @@ impl<T: fmt::Write, N: Notify> fmt::Write for SendRefInner<'_, T, N> {
     }
 }
 
+impl<N: Notify> Drop for NotifyRx<'_, N> {
+    #[inline]
+    fn drop(&mut self) {
+        test_println!("notifying rx ({})", core::any::type_name::<N>());
+        self.0.notify();
+    }
+}
+
+impl<N: Notify> Drop for NotifyTx<'_, N> {
+    #[inline]
+    fn drop(&mut self) {
+        test_println!("notifying tx ({})", core::any::type_name::<N>());
+        self.0.notify();
+    }
+}
+
 macro_rules! impl_send_ref {
     (pub struct $name:ident<$notify:ty>;) => {
         pub struct $name<'sender, T>(SendRefInner<'sender, T, $notify>);
@@ -260,6 +304,22 @@ macro_rules! impl_send_ref {
             }
         }
 
+        impl<T> core::ops::Deref for $name<'_, T> {
+            type Target = T;
+
+            #[inline]
+            fn deref(&self) -> &Self::Target {
+                self.0.deref()
+            }
+        }
+
+        impl<T> core::ops::DerefMut for $name<'_, T> {
+            #[inline]
+            fn deref_mut(&mut self) -> &mut Self::Target {
+                self.0.deref_mut()
+            }
+        }
+
         impl<T: fmt::Debug> fmt::Debug for $name<'_, T> {
             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                 self.0.fmt(f)
@@ -294,8 +354,23 @@ macro_rules! impl_send_ref {
 macro_rules! impl_recv_ref {
     (pub struct $name:ident<$notify:ty>;) => {
         pub struct $name<'recv, T> {
+            // /!\ LOAD BEARING STRUCT DROP ORDER /!\
+            //
+            // The `Ref` field *must* be dropped before the `NotifyTx` field, or else
+            // loom tests will fail. This ensures that the mutable access to the slot is
+            // considered to have ended *before* the receiver thread/task is notified.
+            //
+            // The alternatives to a load-bearing drop order would be:
+            // (a) put one field inside an `Option` so it can be dropped before the
+            //     other (not great, as it adds a little extra overhead even outside
+            //     of Loom tests),
+            // (b) use `core::mem::ManuallyDrop` (also not great, requires additional
+            //     unsafe code that in this case we can avoid)
+            //
+            // So, given that, relying on struct field drop order seemed like the least
+            // bad option here. Just don't reorder these fields. :)
             slot: Ref<'recv, T>,
-            inner: &'recv Inner<T, $notify>,
+            _notify: crate::mpsc::NotifyTx<'recv, $notify>,
         }
 
         impl<T> $name<'_, T> {
@@ -310,6 +385,22 @@ macro_rules! impl_recv_ref {
             }
         }
 
+        impl<T> core::ops::Deref for $name<'_, T> {
+            type Target = T;
+
+            #[inline]
+            fn deref(&self) -> &Self::Target {
+                self.slot.deref()
+            }
+        }
+
+        impl<T> core::ops::DerefMut for $name<'_, T> {
+            #[inline]
+            fn deref_mut(&mut self) -> &mut Self::Target {
+                self.slot.deref_mut()
+            }
+        }
+
         impl<T: fmt::Debug> fmt::Debug for $name<'_, T> {
             fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                 self.slot.fmt(f)
@@ -338,14 +429,6 @@ macro_rules! impl_recv_ref {
                 self.slot.write_fmt(f)
             }
         }
-
-        impl<T> Drop for RecvRef<'_, T> {
-            fn drop(&mut self) {
-                test_println!("drop RecvRef<T, {}>", stringify!($notify));
-                self.slot.release();
-                self.inner.tx_wait.notify();
-            }
-        }
     };
 }
 

+ 1 - 1
src/mpsc/async_impl.rs

@@ -151,8 +151,8 @@ impl<T: Default> Receiver<T> {
     pub fn poll_recv_ref(&self, cx: &mut Context<'_>) -> Poll<Option<RecvRef<'_, T>>> {
         self.inner.poll_recv_ref(|| cx.waker().clone()).map(|some| {
             some.map(|slot| RecvRef {
+                _notify: super::NotifyTx(&self.inner.tx_wait),
                 slot,
-                inner: &*self.inner,
             })
         })
     }

+ 1 - 1
src/mpsc/sync.rs

@@ -107,8 +107,8 @@ impl<T: Default> Receiver<T> {
             match self.inner.poll_recv_ref(thread::current) {
                 Poll::Ready(r) => {
                     return r.map(|slot| RecvRef {
+                        _notify: super::NotifyTx(&self.inner.tx_wait),
                         slot,
-                        inner: &*self.inner,
                     })
                 }
                 Poll::Pending => {

+ 9 - 6
src/mpsc/tests/mpsc_async.rs

@@ -12,18 +12,18 @@ fn mpsc_try_send_recv() {
         let p1 = {
             let tx = tx.clone();
             thread::spawn(move || {
-                tx.try_send_ref().unwrap().with_mut(|val| *val = 1);
+                *tx.try_send_ref().unwrap() = 1;
             })
         };
         let p2 = thread::spawn(move || {
-            tx.try_send(2).unwrap();
-            tx.try_send(3).unwrap();
+            *tx.try_send_ref().unwrap() = 2;
+            *tx.try_send_ref().unwrap() = 3;
         });
 
         let mut vals = future::block_on(async move {
             let mut vals = Vec::new();
             while let Some(val) = rx.recv_ref().await {
-                val.with(|val| vals.push(*val));
+                vals.push(*val);
             }
             vals
         });
@@ -46,8 +46,11 @@ fn rx_closes() {
             'iters: for i in 0..=ITERATIONS {
                 test_println!("sending {}...", i);
                 'send: loop {
-                    match tx.try_send(i) {
-                        Ok(_) => break 'send,
+                    match tx.try_send_ref() {
+                        Ok(mut slot) => {
+                            *slot = i;
+                            break 'send;
+                        }
                         Err(TrySendError::Full(_)) => thread::yield_now(),
                         Err(TrySendError::Closed(_)) => break 'iters,
                     }

+ 8 - 5
src/mpsc/tests/mpsc_sync.rs

@@ -21,12 +21,12 @@ fn mpsc_try_send_recv() {
         let p1 = {
             let tx = tx.clone();
             thread::spawn(move || {
-                tx.try_send_ref().unwrap().with_mut(|val| *val = 1);
+                *tx.try_send_ref().unwrap() = 1;
             })
         };
         let p2 = thread::spawn(move || {
-            tx.try_send(2).unwrap();
-            tx.try_send(3).unwrap();
+            *tx.try_send_ref().unwrap() = 2;
+            *tx.try_send_ref().unwrap() = 3;
         });
 
         let mut vals = Vec::new();
@@ -53,8 +53,11 @@ fn rx_closes() {
             'iters: for i in 0..=ITERATIONS {
                 test_println!("sending {}", i);
                 'send: loop {
-                    match tx.try_send(i) {
-                        Ok(_) => break 'send,
+                    match tx.try_send_ref() {
+                        Ok(mut slot) => {
+                            *slot = i;
+                            break 'send;
+                        }
                         Err(TrySendError::Full(_)) => thread::yield_now(),
                         Err(TrySendError::Closed(_)) => break 'iters,
                     }

+ 15 - 13
src/thingbuf/tests.rs

@@ -14,8 +14,8 @@ fn push_many_mpsc() {
         let q = q.clone();
         move || {
             for &val in vals {
-                if let Ok(mut r) = test_dbg!(q.push_ref()) {
-                    r.with_mut(|r| r.get_mut().push_str(val));
+                if let Ok(mut slot) = test_dbg!(q.push_ref()) {
+                    slot.get_mut().push_str(val);
                 } else {
                     return;
                 }
@@ -32,8 +32,8 @@ fn push_many_mpsc() {
         let mut all_vals = Vec::new();
 
         while Arc::strong_count(&q) > 1 {
-            if let Some(r) = q.pop_ref() {
-                r.with(|val| all_vals.push(val.get_ref().to_string()));
+            if let Some(val) = q.pop_ref() {
+                all_vals.push(val.get_ref().to_string());
             }
             thread::yield_now();
         }
@@ -41,8 +41,8 @@ fn push_many_mpsc() {
         t1.join().unwrap();
         t2.join().unwrap();
 
-        while let Some(r) = test_dbg!(q.pop_ref()) {
-            r.with(|val| all_vals.push(val.get_ref().to_string()));
+        while let Some(val) = test_dbg!(q.pop_ref()) {
+            all_vals.push(val.get_ref().to_string());
         }
 
         test_dbg!(&all_vals);
@@ -65,8 +65,8 @@ fn spsc() {
             thread::spawn(move || {
                 for i in 0..COUNT {
                     loop {
-                        if let Ok(mut guard) = q.push_ref() {
-                            guard.with_mut(|val| *val = i);
+                        if let Ok(mut val) = q.push_ref() {
+                            *val = i;
                             break;
                         }
                         thread::yield_now();
@@ -77,8 +77,8 @@ fn spsc() {
 
         for i in 0..COUNT {
             loop {
-                if let Some(guard) = q.pop_ref() {
-                    guard.with(|val| assert_eq!(*val, i));
+                if let Some(val) = q.pop_ref() {
+                    assert_eq!(*val, i);
                     break;
                 }
                 thread::yield_now();
@@ -99,12 +99,14 @@ fn linearizable() {
         move || {
             while q
                 .push_ref()
-                .map(|mut r| r.with_mut(|val| *val = i))
+                .map(|mut val| {
+                    *val = i;
+                })
                 .is_err()
             {}
 
-            if let Some(mut r) = q.pop_ref() {
-                r.with_mut(|val| *val = 0);
+            if let Some(mut val) = q.pop_ref() {
+                *val = 0;
             }
         }
     }

+ 1 - 1
src/util/wait/wait_cell.rs

@@ -5,7 +5,7 @@ use crate::{
             AtomicUsize,
             Ordering::{self, *},
         },
-        UnsafeCell,
+        cell::UnsafeCell,
     },
     util::panic::{self, RefUnwindSafe, UnwindSafe},
 };

+ 1 - 1
src/util/wait/wait_queue.rs

@@ -5,7 +5,7 @@ use crate::{
             AtomicUsize,
             Ordering::{self, *},
         },
-        UnsafeCell,
+        cell::UnsafeCell,
     },
     util::{panic, Backoff, CachePadded},
 };