Bläddra i källkod

feat(blocking::mpsc): add `Sender::send(_ref)_timeout` methods (#79)

Follow up to #75.

Co-authored-by: Eliza Weisman <eliza@buoyant.io>
Utkarsh Gupta 2 år sedan
förälder
incheckning
979ed6e8d5
3 ändrade filer med 392 tillägg och 4 borttagningar
  1. 2 4
      README.md
  2. 299 0
      src/mpsc/blocking.rs
  3. 91 0
      src/mpsc/errors.rs

+ 2 - 4
README.md

@@ -73,10 +73,8 @@ some cases where you might be better off considering other options:
   prefer a channel implementation that only allocates memory for messages as
   it's needed (such as [`tokio::sync::mpsc`]).
 
-- **You need a blocking channel with `send_timeout`** or **a blocking channel
-  with a `select` operation**. I'm probably not going to implement these things.
-  The blocking channel isn't particularly important to me compared to the async
-  channel, and I _probably_ won't add a bunch of additional APIs to it.
+- **You need a blocking channel with a `select` operation**.
+  I'm probably not going to implement it. I _may_ accept a PR if you raise it.
 
   If you need a synchronous channel with this kind of functionality,
   [`crossbeam-channel`] is probably a good choice.

+ 299 - 0
src/mpsc/blocking.rs

@@ -351,6 +351,124 @@ feature! {
             }
         }
 
+        /// Reserves a slot in the channel to mutate in place, blocking until
+        /// there is a free slot to write to, waiting for at most `timeout`.
+        ///
+        /// This is similar to the [`send_timeout`] method, but, rather than taking a
+        /// message by value to write to the channel, this method reserves a
+        /// writable slot in the channel, and returns a [`SendRef`] that allows
+        /// mutating the slot in place. If the [`StaticReceiver`] end of the
+        /// channel uses the [`StaticReceiver::recv_ref`] method for receiving
+        /// from the channel, this allows allocations for channel messages to be
+        /// reused in place.
+        ///
+        /// # Errors
+        ///
+        /// - [`Err`]`(`[`SendTimeoutError::Timeout`]`)` if the timeout has elapsed.
+        /// - [`Err`]`(`[`SendTimeoutError::Closed`]`)` if the channel has closed.
+        ///
+        /// # Examples
+        ///
+        /// Sending formatted strings by writing them directly to channel slots,
+        /// in place:
+        ///
+        /// ```
+        /// use thingbuf::mpsc::{blocking::StaticChannel, errors::SendTimeoutError};
+        /// use std::{fmt::Write, time::Duration, thread};
+        ///
+        /// static CHANNEL: StaticChannel<String, 1> = StaticChannel::new();
+        /// let (tx, rx) = CHANNEL.split();
+        ///
+        /// thread::spawn(move || {
+        ///     thread::sleep(Duration::from_millis(500));
+        ///     let msg = rx.recv_ref().unwrap();
+        ///     println!("{}", msg);
+        ///     thread::sleep(Duration::from_millis(500));
+        /// });
+        ///
+        /// thread::spawn(move || {
+        ///     let mut value = tx.send_ref_timeout(Duration::from_millis(200)).unwrap();
+        ///     write!(value, "hello").expect("writing to a `String` should never fail");
+        ///     thread::sleep(Duration::from_millis(400));
+        ///
+        ///     let mut value = tx.send_ref_timeout(Duration::from_millis(200)).unwrap();
+        ///     write!(value, "world").expect("writing to a `String` should never fail");
+        ///     thread::sleep(Duration::from_millis(400));
+        ///
+        ///     assert_eq!(
+        ///         Err(&SendTimeoutError::Timeout(())),
+        ///         tx.send_ref_timeout(Duration::from_millis(200)).as_deref().map(String::as_str)
+        ///     );
+        /// });
+        /// ```
+        ///
+        /// [`send_timeout`]: Self::send_timeout
+        #[cfg(not(all(test, loom)))]
+        pub fn send_ref_timeout(&self, timeout: Duration) -> Result<SendRef<'_, T>, SendTimeoutError> {
+            send_ref_timeout(self.core, self.slots, self.recycle, timeout)
+        }
+
+        /// Sends a message by value, blocking until there is a free slot to
+        /// write to, waiting for at most `timeout`.
+        ///
+        /// This method takes the message by value, and replaces any previous
+        /// value in the slot. This means that the channel will *not* function
+        /// as an object pool while sending messages with `send_timeout`. This method is
+        /// most appropriate when messages don't own reusable heap allocations,
+        /// or when the [`StaticReceiver`] end of the channel must receive messages
+        /// by moving them out of the channel by value (using the
+        /// [`StaticReceiver::recv`] method). When messages in the channel own
+        /// reusable heap allocations (such as `String`s or `Vec`s), and the
+        /// [`StaticReceiver`] doesn't need to receive them by value, consider using
+        /// [`send_ref_timeout`] instead, to enable allocation reuse.
+        ///
+        /// # Errors
+        ///
+        /// - [`Err`]`(`[`SendTimeoutError::Timeout`]`)` if the timeout has elapsed.
+        /// - [`Err`]`(`[`SendTimeoutError::Closed`]`)` if the channel has closed.
+        ///
+        /// # Examples
+        ///
+        /// ```
+        /// use thingbuf::mpsc::{blocking::StaticChannel, errors::SendTimeoutError};
+        /// use std::{time::Duration, thread};
+        ///
+        /// static CHANNEL: StaticChannel<i32, 1> = StaticChannel::new();
+        /// let (tx, rx) = CHANNEL.split();
+        ///
+        /// thread::spawn(move || {
+        ///     thread::sleep(Duration::from_millis(500));
+        ///     let msg = rx.recv().unwrap();
+        ///     println!("{}", msg);
+        ///     thread::sleep(Duration::from_millis(500));
+        /// });
+        ///
+        /// thread::spawn(move || {
+        ///     tx.send_timeout(1, Duration::from_millis(200)).unwrap();
+        ///     thread::sleep(Duration::from_millis(400));
+        ///
+        ///     tx.send_timeout(2, Duration::from_millis(200)).unwrap();
+        ///     thread::sleep(Duration::from_millis(400));
+        ///
+        ///     assert_eq!(
+        ///         Err(SendTimeoutError::Timeout(3)),
+        ///         tx.send_timeout(3, Duration::from_millis(200))
+        ///     );
+        /// });
+        /// ```
+        ///
+        /// [`send_ref_timeout`]: Self::send_ref_timeout
+        #[cfg(not(all(test, loom)))]
+        pub fn send_timeout(&self, val: T, timeout: Duration) -> Result<(), SendTimeoutError<T>> {
+            match self.send_ref_timeout(timeout) {
+                Err(e) => Err(e.with_value(val)),
+                Ok(mut slot) => {
+                    *slot = val;
+                    Ok(())
+                }
+            }
+        }
+
         /// Attempts to reserve a slot in the channel to mutate in place,
         /// without blocking until capacity is available.
         ///
@@ -586,6 +704,7 @@ feature! {
         /// - [`Ok`]`(`[`RecvRef`]`<T>)` if a message was received.
         /// - [`Err`]`(`[`RecvTimeoutError::Timeout`]`)` if the timeout has elapsed.
         /// - [`Err`]`(`[`RecvTimeoutError::Closed`]`)` if the channel has closed.
+        ///
         /// # Examples
         ///
         /// ```
@@ -643,6 +762,7 @@ feature! {
         /// - [`Ok`]`(<T>)` if a message was received.
         /// - [`Err`]`(`[`RecvTimeoutError::Timeout`]`)` if the timeout has elapsed.
         /// - [`Err`]`(`[`RecvTimeoutError::Closed`]`)` if the channel has closed.
+        ///
         /// # Examples
         ///
         /// ```
@@ -927,6 +1047,127 @@ where
         }
     }
 
+    /// Reserves a slot in the channel to mutate in place, blocking until
+    /// there is a free slot to write to, waiting for at most `timeout`.
+    ///
+    /// This is similar to the [`send_timeout`] method, but, rather than taking a
+    /// message by value to write to the channel, this method reserves a
+    /// writable slot in the channel, and returns a [`SendRef`] that allows
+    /// mutating the slot in place. If the [`Receiver`] end of the channel
+    /// uses the [`Receiver::recv_ref`] method for receiving from the channel,
+    /// this allows allocations for channel messages to be reused in place.
+    ///
+    /// # Errors
+    ///
+    /// - [`Err`]`(`[`SendTimeoutError::Timeout`]`)` if the timeout has elapsed.
+    /// - [`Err`]`(`[`SendTimeoutError::Closed`]`)` if the channel has closed.
+    ///
+    /// # Examples
+    ///
+    /// Sending formatted strings by writing them directly to channel slots,
+    /// in place:
+    ///
+    /// ```
+    /// use thingbuf::mpsc::{blocking, errors::SendTimeoutError};
+    /// use std::{fmt::Write, time::Duration, thread};
+    ///
+    /// let (tx, rx) = blocking::channel::<String>(1);
+    ///
+    /// thread::spawn(move || {
+    ///     thread::sleep(Duration::from_millis(500));
+    ///     let msg = rx.recv_ref().unwrap();
+    ///     println!("{}", msg);
+    ///     thread::sleep(Duration::from_millis(500));
+    /// });
+    ///
+    /// thread::spawn(move || {
+    ///     let mut value = tx.send_ref_timeout(Duration::from_millis(200)).unwrap();
+    ///     write!(value, "hello").expect("writing to a `String` should never fail");
+    ///     thread::sleep(Duration::from_millis(400));
+    ///
+    ///     let mut value = tx.send_ref_timeout(Duration::from_millis(200)).unwrap();
+    ///     write!(value, "world").expect("writing to a `String` should never fail");
+    ///     thread::sleep(Duration::from_millis(400));
+    ///
+    ///     assert_eq!(
+    ///         Err(&SendTimeoutError::Timeout(())),
+    ///         tx.send_ref_timeout(Duration::from_millis(200)).as_deref().map(String::as_str)
+    ///     );
+    /// });
+    /// ```
+    ///
+    /// [`send_timeout`]: Self::send_timeout
+    #[cfg(not(all(test, loom)))]
+    pub fn send_ref_timeout(&self, timeout: Duration) -> Result<SendRef<'_, T>, SendTimeoutError> {
+        send_ref_timeout(
+            &self.inner.core,
+            self.inner.slots.as_ref(),
+            &self.inner.recycle,
+            timeout,
+        )
+    }
+
+    /// Sends a message by value, blocking until there is a free slot to
+    /// write to, for at most `timeout`.
+    ///
+    /// This method takes the message by value, and replaces any previous
+    /// value in the slot. This means that the channel will *not* function
+    /// as an object pool while sending messages with `send_timeout`. This method is
+    /// most appropriate when messages don't own reusable heap allocations,
+    /// or when the [`Receiver`] end of the channel must receive messages by
+    /// moving them out of the channel by value (using the
+    /// [`Receiver::recv`] method). When messages in the channel own
+    /// reusable heap allocations (such as `String`s or `Vec`s), and the
+    /// [`Receiver`] doesn't need to receive them by value, consider using
+    /// [`send_ref_timeout`] instead, to enable allocation reuse.
+    ///
+    ///
+    /// # Errors
+    ///
+    /// - [`Err`]`(`[`SendTimeoutError::Timeout`]`)` if the timeout has elapsed.
+    /// - [`Err`]`(`[`SendTimeoutError::Closed`]`)` if the channel has closed.
+    ///
+    /// # Examples
+    ///
+    /// ```
+    /// use thingbuf::mpsc::{blocking, errors::SendTimeoutError};
+    /// use std::{time::Duration, thread};
+    ///
+    /// let (tx, rx) = blocking::channel(1);
+    ///
+    /// thread::spawn(move || {
+    ///     thread::sleep(Duration::from_millis(500));
+    ///     let msg = rx.recv().unwrap();
+    ///     println!("{}", msg);
+    ///     thread::sleep(Duration::from_millis(500));
+    /// });
+    ///
+    /// thread::spawn(move || {
+    ///     tx.send_timeout(1, Duration::from_millis(200)).unwrap();
+    ///     thread::sleep(Duration::from_millis(400));
+    ///
+    ///     tx.send_timeout(2, Duration::from_millis(200)).unwrap();
+    ///     thread::sleep(Duration::from_millis(400));
+    ///
+    ///     assert_eq!(
+    ///         Err(SendTimeoutError::Timeout(3)),
+    ///         tx.send_timeout(3, Duration::from_millis(200))
+    ///     );
+    /// });
+    /// ```
+    ///
+    /// [`send_ref_timeout`]: Self::send_ref_timeout
+    #[cfg(not(all(test, loom)))]
+    pub fn send_timeout(&self, val: T, timeout: Duration) -> Result<(), SendTimeoutError<T>> {
+        match self.send_ref_timeout(timeout) {
+            Err(e) => Err(e.with_value(val)),
+            Ok(mut slot) => {
+                *slot = val;
+                Ok(())
+            }
+        }
+    }
+
     /// Attempts to reserve a slot in the channel to mutate in place,
     /// without blocking until capacity is available.
     ///
@@ -1149,6 +1390,7 @@ impl<T, R> Receiver<T, R> {
     /// - [`Ok`]`(`[`RecvRef`]`<T>)` if a message was received.
     /// - [`Err`]`(`[`RecvTimeoutError::Timeout`]`)` if the timeout has elapsed.
     /// - [`Err`]`(`[`RecvTimeoutError::Closed`]`)` if the channel has closed.
+    ///
     /// # Examples
     ///
     /// ```
@@ -1454,3 +1696,60 @@ fn send_ref<'a, T, R: Recycle<T>>(
         }
     }
 }
+
+#[cfg(not(all(test, loom)))]
+#[inline]
+fn send_ref_timeout<'a, T, R: Recycle<T>>(
+    core: &'a ChannelCore<Thread>,
+    slots: &'a [Slot<T>],
+    recycle: &'a R,
+    timeout: Duration,
+) -> Result<SendRef<'a, T>, SendTimeoutError> {
+    // fast path: avoid getting the thread and constructing the node if the
+    // slot is immediately ready.
+    match core.try_send_ref(slots, recycle) {
+        Ok(slot) => return Ok(SendRef(slot)),
+        Err(TrySendError::Closed(_)) => return Err(SendTimeoutError::Closed(())),
+        _ => {}
+    }
+
+    let mut waiter = queue::Waiter::new();
+    let mut unqueued = true;
+    let thread = thread::current();
+    let mut boff = Backoff::new();
+    let beginning_park = Instant::now();
+    loop {
+        let node = unsafe {
+            // Safety: in this case, it's totally safe to pin the waiter, as
+            // it is owned uniquely by this function, and it cannot possibly
+            // be moved while this thread is parked.
+            Pin::new_unchecked(&mut waiter)
+        };
+
+        let wait = if unqueued {
+            test_dbg!(core.tx_wait.start_wait(node, &thread))
+        } else {
+            test_dbg!(core.tx_wait.continue_wait(node, &thread))
+        };
+
+        match wait {
+            WaitResult::Closed => return Err(SendTimeoutError::Closed(())),
+            WaitResult::Notified => {
+                boff.spin_yield();
+                match core.try_send_ref(slots.as_ref(), recycle) {
+                    Ok(slot) => return Ok(SendRef(slot)),
+                    Err(TrySendError::Closed(_)) => return Err(SendTimeoutError::Closed(())),
+                    _ => {}
+                }
+            }
+            WaitResult::Wait => {
+                unqueued = false;
+                thread::park_timeout(timeout);
+                let elapsed = beginning_park.elapsed();
+                if elapsed >= timeout {
+                    return Err(SendTimeoutError::Timeout(()));
+                }
+            }
+        }
+    }
+}

+ 91 - 0
src/mpsc/errors.rs

@@ -1,6 +1,28 @@
 //! Errors returned by channels.
 use core::fmt;
 
+/// Error returned by the [`Sender::send_timeout`] or [`Sender::send_ref_timeout`]
+/// (and [`StaticSender::send_timeout`]/[`StaticSender::send_ref_timeout`]) methods
+/// (blocking only).
+///
+/// [`Sender::send_timeout`]: super::blocking::Sender::send_timeout
+/// [`Sender::send_ref_timeout`]: super::blocking::Sender::send_ref_timeout
+/// [`StaticSender::send_timeout`]: super::blocking::StaticSender::send_timeout
+/// [`StaticSender::send_ref_timeout`]: super::blocking::StaticSender::send_ref_timeout
+#[cfg(feature = "std")]
+#[non_exhaustive]
+#[derive(PartialEq, Eq)]
+pub enum SendTimeoutError<T = ()> {
+    /// The data could not be sent on the channel because the channel is
+    /// currently full and sending would require waiting for capacity.
+    Timeout(T),
+    /// The data could not be sent because the [`Receiver`] half of the channel
+    /// has been dropped.
+    ///
+    /// [`Receiver`]: super::Receiver
+    Closed(T),
+}
+
 /// Error returned by the [`Sender::try_send`] or [`Sender::try_send_ref`] (and
 /// [`StaticSender::try_send`]/[`StaticSender::try_send_ref`]) methods.
 ///
@@ -88,6 +110,75 @@ impl<T> fmt::Display for Closed<T> {
 #[cfg(feature = "std")]
 impl<T> std::error::Error for Closed<T> {}
 
+// === impl SendTimeoutError ===
+
+#[cfg(feature = "std")]
+impl SendTimeoutError {
+    pub(crate) fn with_value<T>(self, value: T) -> SendTimeoutError<T> {
+        match self {
+            Self::Timeout(()) => SendTimeoutError::Timeout(value),
+            Self::Closed(()) => SendTimeoutError::Closed(value),
+        }
+    }
+}
+
+#[cfg(feature = "std")]
+impl<T> SendTimeoutError<T> {
+    /// Returns `true` if this error was returned because the channel is still
+    /// full after the timeout has elapsed.
+    pub fn is_timeout(&self) -> bool {
+        matches!(self, Self::Timeout(_))
+    }
+
+    /// Returns `true` if this error was returned because the channel has closed
+    /// (e.g. the [`Receiver`] end has been dropped).
+    ///
+    /// If this returns `true`, no future [`try_send`] or [`send`] operation on
+    /// this channel will succeed.
+    ///
+    /// [`Receiver`]: super::blocking::Receiver
+    /// [`try_send`]: super::blocking::Sender::try_send
+    /// [`send`]: super::blocking::Sender::send
+    /// [`Receiver`]: super::blocking::Receiver
+    pub fn is_closed(&self) -> bool {
+        matches!(self, Self::Timeout(_))
+    }
+
+    /// Unwraps the inner `T` value held by this error.
+    ///
+    /// This method allows recovering the original message when sending to a
+    /// channel has failed.
+    pub fn into_inner(self) -> T {
+        match self {
+            Self::Timeout(val) => val,
+            Self::Closed(val) => val,
+        }
+    }
+}
+
+#[cfg(feature = "std")]
+impl<T> fmt::Debug for SendTimeoutError<T> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str(match self {
+            Self::Timeout(_) => "SendTimeoutError::Timeout(..)",
+            Self::Closed(_) => "SendTimeoutError::Closed(..)",
+        })
+    }
+}
+
+#[cfg(feature = "std")]
+impl<T> fmt::Display for SendTimeoutError<T> {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str(match self {
+            Self::Timeout(_) => "timed out waiting for channel capacity",
+            Self::Closed(_) => "channel closed",
+        })
+    }
+}
+
+#[cfg(feature = "std")]
+impl<T> std::error::Error for SendTimeoutError<T> {}
+
 // === impl TrySendError ===
 
 impl TrySendError {