Browse Source

more flexible async implementation

Román Cárdenas 1 year ago
parent
commit
5b012c0978

+ 1 - 0
riscv-peripheral/Cargo.toml

@@ -15,6 +15,7 @@ riscv-pac = { path = "../riscv-pac", version = "0.1.0" }
 aclint-hal-async = ["embedded-hal-async"]
 
 [package.metadata.docs.rs]
+all-features = true
 default-target = "riscv64imac-unknown-none-elf"
 targets = [
     "riscv32i-unknown-none-elf", "riscv32imc-unknown-none-elf", "riscv32imac-unknown-none-elf",

+ 126 - 105
riscv-peripheral/src/hal_async/aclint.rs

@@ -1,4 +1,22 @@
 //! Asynchronous delay implementation for the (A)CLINT peripheral.
+//!
+//! # Note
+//!
+//! The asynchronous delay implementation for the (A)CLINT peripheral relies on the machine-level timer interrupts.
+//! Therefore, it needs to schedule the machine-level timer interrupts via the [`MTIMECMP`] register assigned to the current HART.
+//! Thus, the [`Delay`] instance must be created on the same HART that is used to call the asynchronous delay methods.
+//!
+//! # Requirements
+//!
+//! The following `extern "Rust"` functions must be implemented:
+//!
+//! - `fn _riscv_peripheral_aclint_mtimer(hart_id: usize) -> MTIMER`: This function returns the `MTIMER` register for the given HART ID.
+//! - `fn _riscv_peripheral_aclint_push_timer(t: Timer) -> Result<(), Timer>`: This function pushes a new timer to a timer queue assigned to the given HART ID.
+//! If it fails (e.g., the timer queue is full), it returns back the timer that failed to be pushed.
+//! The logic of timer queues are application-specific and are not provided by this crate.
+//! - `fn _riscv_peripheral_aclint_wake_timers(hart_id: usize, current_tick: u64) -> Option<u64>`:
+//! This function pops all the expired timers from a timer queue assigned to the given HART ID and wakes their associated wakers.
+//! The function returns the next [`MTIME`] tick at which the next timer expires. If the queue is empty, it returns `None`.
 
 use crate::aclint::mtimer::{MTIME, MTIMECMP, MTIMER};
 pub use crate::hal_async::delay::DelayNs;
@@ -21,101 +39,76 @@ extern "Rust" {
     /// Tries to push a new timer to the timer queue assigned to the given HART ID.
     /// If it fails (e.g., the timer queue is full), it returns back the timer that failed to be pushed.
     ///
-    /// # Note
-    ///
-    /// the [`Delay`] reference allows to access the `MTIME` and `MTIMECMP` registers,
-    /// as well as handy information such as the HART ID or the clock frequency of the `MTIMER` peripheral.
-    ///
     /// # Safety
     ///
     /// Do not call this function directly. It is only meant to be called by [`DelayAsync`].
-    fn _riscv_peripheral_push_timer(hart_id: usize, delay: &Delay, t: Timer) -> Result<(), Timer>;
+    fn _riscv_peripheral_aclint_push_timer(t: Timer) -> Result<(), Timer>;
 
-    /// Pops a expired timer from the timer queue assigned to the given HART ID.
-    /// If the queue is empty, it returns `Err(None)`.
+    /// Pops all the expired timers from the timer queue assigned to the given HART ID and wakes their associated wakers.
+    /// Once it is done, if the queue is empty, it returns `None`.
     /// Alternatively, if the queue is not empty but the earliest timer has not expired yet,
-    /// it returns `Err(Some(next_expires))` where `next_expires` is the tick at which this timer expires.
+    /// it returns `Some(next_expires)` where `next_expires` is the tick at which this timer expires.
     ///
     /// # Safety
     ///
-    /// It is extremely important that this function only returns a timer that has expired.
-    /// Otherwise, the timer will be lost and the waker will never be called.
-    ///
     /// Do not call this function directly. It is only meant to be called by [`MachineExternal`] and [`DelayAsync`].
-    fn _riscv_peripheral_pop_timer(hart_id: usize, current_tick: u64)
-        -> Result<Timer, Option<u64>>;
+    fn _riscv_peripheral_aclint_wake_timers(hart_id: usize, current_tick: u64) -> Option<u64>;
 }
 
-/// Machine-level timer interrupt handler.
-/// This handler is triggered whenever the `MTIME` register reaches the value of the `MTIMECMP` register.
+/// Machine-level timer interrupt handler. This handler is triggered whenever the `MTIME`
+/// register reaches the value of the `MTIMECMP` register of the current HART.
 #[no_mangle]
 #[allow(non_snake_case)]
 fn MachineExternal() {
+    // recover the MTIME and MTIMECMP registers for the current HART
     let hart_id = riscv::register::mhartid::read();
     let mtimer = unsafe { _riscv_peripheral_aclint_mtimer(hart_id) };
     let (mtime, mtimercmp) = (mtimer.mtime, mtimer.mtimecmp_mhartid());
-    schedule_machine_external(hart_id, mtime, mtimercmp);
+    // schedule the next machine timer interrupt
+    schedule_machine_timer(hart_id, mtime, mtimercmp);
 }
 
-fn schedule_machine_external(hart_id: usize, mtime: MTIME, mtimercmp: MTIMECMP) {
+/// Schedules the next machine timer interrupt for the given HART ID according to the timer queue.
+fn schedule_machine_timer(hart_id: usize, mtime: MTIME, mtimercmp: MTIMECMP) {
     unsafe { riscv::register::mie::clear_mtimer() }; // disable machine timer interrupts to avoid reentrancy
-    loop {
-        let current_tick = mtime.read();
-        let timer = unsafe { _riscv_peripheral_pop_timer(hart_id, current_tick) };
-        match timer {
-            Ok(timer) => {
-                debug_assert!(timer.expires() <= current_tick);
-                timer.wake();
-            }
-            Err(e) => {
-                if let Some(next_expires) = e {
-                    debug_assert!(next_expires > current_tick);
-                    mtimercmp.write(next_expires); // schedule next interrupt at next_expires
-                    unsafe { riscv::register::mie::set_mtimer() }; // enable machine timer interrupts again
-                } else {
-                    mtimercmp.write(u64::MAX); // write max to clear and "disable" the interrupt
-                }
-                break;
-            }
-        }
+    let current_tick = mtime.read();
+    if let Some(next_expires) =
+        unsafe { _riscv_peripheral_aclint_wake_timers(hart_id, current_tick) }
+    {
+        debug_assert!(next_expires > current_tick);
+        mtimercmp.write(next_expires); // schedule next interrupt at next_expires
+        unsafe { riscv::register::mie::set_mtimer() }; // enable machine timer interrupts again if necessary
     }
 }
 
 /// Asynchronous delay implementation for (A)CLINT peripherals.
+///
+/// # Note
+///
+/// The asynchronous delay implementation for (A)CLINT peripherals relies on the machine-level timer interrupts.
+/// Therefore, it needs to schedule the machine-level timer interrupts via the [`MTIMECMP`] register assigned to the current HART.
+/// Thus, the [`Delay`] instance must be created on the same HART that is used to call the asynchronous delay methods.
+/// Additionally, the rest of the application must not modify the [`MTIMER`] register assigned to the current HART.
 #[derive(Clone)]
 pub struct Delay {
-    mtime: MTIME,
     hart_id: usize,
-    mtimecmp: MTIMECMP,
     freq: usize,
+    mtime: MTIME,
+    mtimecmp: MTIMECMP,
 }
 
 impl Delay {
-    /// Creates a new `Delay` instance.
-    #[inline]
-    pub fn new<H: riscv_pac::HartIdNumber>(mtimer: MTIMER, hart_id: H, freq: usize) -> Self {
-        Self {
-            mtime: mtimer.mtime,
-            hart_id: hart_id.number() as _,
-            mtimecmp: mtimer.mtimecmp(hart_id),
-            freq,
-        }
-    }
-
     /// Creates a new `Delay` instance for the current HART.
-    /// This function determines the current HART ID by reading the [`riscv::register::mhartid`] CSR.
-    ///
-    /// # Note
-    ///
-    /// This function can only be used in M-mode. For S-mode, use [`Delay::new_mhartid`] instead.
     #[inline]
-    pub fn new_mhartid(mtimer: MTIMER, freq: usize) -> Self {
+    pub fn new(freq: usize) -> Self {
         let hart_id = riscv::register::mhartid::read();
+        let mtimer = unsafe { _riscv_peripheral_aclint_mtimer(hart_id) };
+        let (mtime, mtimecmp) = (mtimer.mtime, mtimer.mtimecmp_mhartid());
         Self {
-            mtime: mtimer.mtime,
             hart_id,
-            mtimecmp: mtimer.mtimecmp_mhartid(),
             freq,
+            mtime,
+            mtimecmp,
         }
     }
 
@@ -130,29 +123,37 @@ impl Delay {
     pub fn set_freq(&mut self, freq: usize) {
         self.freq = freq;
     }
+}
 
-    /// Returns the `MTIME` register.
+impl DelayNs for Delay {
     #[inline]
-    pub const fn get_mtime(&self) -> MTIME {
-        self.mtime
+    async fn delay_ns(&mut self, ns: u32) {
+        let n_ticks = ns as u64 * self.get_freq() as u64 / 1_000_000_000;
+        DelayAsync::new(self, n_ticks).await;
     }
 
-    /// Returns the `MTIMECMP` register.
     #[inline]
-    pub const fn get_mtimecmp(&self) -> MTIMECMP {
-        self.mtimecmp
+    async fn delay_us(&mut self, us: u32) {
+        let n_ticks = us as u64 * self.get_freq() as u64 / 1_000_000;
+        DelayAsync::new(self, n_ticks).await;
     }
 
-    /// Returns the hart ID.
     #[inline]
-    pub const fn get_hart_id(&self) -> usize {
-        self.hart_id
+    async fn delay_ms(&mut self, ms: u32) {
+        let n_ticks = ms as u64 * self.get_freq() as u64 / 1_000;
+        DelayAsync::new(self, n_ticks).await;
     }
 }
 
 /// Timer queue entry.
+/// When pushed to the timer queue via the `_riscv_peripheral_aclint_push_timer` function,
+/// this entry provides the necessary information to adapt it to the timer queue implementation.
 #[derive(Debug)]
 pub struct Timer {
+    hart_id: usize,
+    freq: usize,
+    mtime: MTIME,
+    mtimecmp: MTIMECMP,
     expires: u64,
     waker: Waker,
 }
@@ -160,8 +161,46 @@ pub struct Timer {
 impl Timer {
     /// Creates a new timer queue entry.
     #[inline]
-    pub fn new(expires: u64, waker: Waker) -> Self {
-        Self { expires, waker }
+    const fn new(
+        hart_id: usize,
+        freq: usize,
+        mtime: MTIME,
+        mtimecmp: MTIMECMP,
+        expires: u64,
+        waker: Waker,
+    ) -> Self {
+        Self {
+            hart_id,
+            freq,
+            mtime,
+            mtimecmp,
+            expires,
+            waker,
+        }
+    }
+
+    /// Returns the HART ID associated with this timer.
+    #[inline]
+    pub const fn hart_id(&self) -> usize {
+        self.hart_id
+    }
+
+    /// Returns the frequency of the [`MTIME`] register associated with this timer.
+    #[inline]
+    pub const fn freq(&self) -> usize {
+        self.freq
+    }
+
+    /// Returns the [`MTIME`] register associated with this timer.
+    #[inline]
+    pub const fn mtime(&self) -> MTIME {
+        self.mtime
+    }
+
+    /// Returns the [`MTIMECMP`] register associated with this timer.
+    #[inline]
+    pub const fn mtimecmp(&self) -> MTIMECMP {
+        self.mtimecmp
     }
 
     /// Returns the tick at which the timer expires.
@@ -170,16 +209,16 @@ impl Timer {
         self.expires
     }
 
-    /// Wakes the waker associated with this timer.
+    /// Returns the waker associated with this timer.
     #[inline]
-    pub fn wake(&self) {
-        self.waker.wake_by_ref();
+    pub fn waker(&self) -> Waker {
+        self.waker.clone()
     }
 }
 
 impl PartialEq for Timer {
     fn eq(&self, other: &Self) -> bool {
-        self.expires == other.expires
+        self.hart_id == other.hart_id && self.freq == other.freq && self.expires == other.expires
     }
 }
 
@@ -197,14 +236,14 @@ impl PartialOrd for Timer {
     }
 }
 
-struct DelayAsync {
-    delay: Delay,
+struct DelayAsync<'a> {
+    delay: &'a Delay,
     expires: u64,
     pushed: bool,
 }
 
-impl DelayAsync {
-    pub fn new(delay: Delay, n_ticks: u64) -> Self {
+impl<'a> DelayAsync<'a> {
+    pub fn new(delay: &'a Delay, n_ticks: u64) -> Self {
         let t0 = delay.mtime.read();
         let expires = t0.wrapping_add(n_ticks);
         Self {
@@ -215,7 +254,7 @@ impl DelayAsync {
     }
 }
 
-impl Future for DelayAsync {
+impl<'a> Future for DelayAsync<'a> {
     type Output = ();
 
     #[inline]
@@ -224,17 +263,19 @@ impl Future for DelayAsync {
             if !self.pushed {
                 // we only push the timer to the queue the first time we poll
                 self.pushed = true;
-                let timer = Timer::new(self.expires, cx.waker().clone());
-                unsafe {
-                    _riscv_peripheral_push_timer(self.delay.hart_id, &self.delay, timer)
-                        .expect("timer queue is full");
-                };
-                // we also need to schedule the interrupt if the timer we just pushed is the earliest one
-                schedule_machine_external(
+                let timer = Timer::new(
                     self.delay.hart_id,
+                    self.delay.freq,
                     self.delay.mtime,
                     self.delay.mtimecmp,
+                    self.expires,
+                    cx.waker().clone(),
                 );
+                unsafe {
+                    _riscv_peripheral_aclint_push_timer(timer).expect("timer queue is full");
+                };
+                // we also need to reschedule the machine timer interrupt
+                schedule_machine_timer(self.delay.hart_id, self.delay.mtime, self.delay.mtimecmp);
             }
             Poll::Pending
         } else {
@@ -242,23 +283,3 @@ impl Future for DelayAsync {
         }
     }
 }
-
-impl DelayNs for Delay {
-    #[inline]
-    async fn delay_ns(&mut self, ns: u32) {
-        let n_ticks = ns as u64 * self.get_freq() as u64 / 1_000_000_000;
-        DelayAsync::new(self.clone(), n_ticks).await;
-    }
-
-    #[inline]
-    async fn delay_us(&mut self, us: u32) {
-        let n_ticks = us as u64 * self.get_freq() as u64 / 1_000_000;
-        DelayAsync::new(self.clone(), n_ticks).await;
-    }
-
-    #[inline]
-    async fn delay_ms(&mut self, ms: u32) {
-        let n_ticks = ms as u64 * self.get_freq() as u64 / 1_000;
-        DelayAsync::new(self.clone(), n_ticks).await;
-    }
-}

+ 6 - 1
riscv-peripheral/src/lib.rs

@@ -1,4 +1,9 @@
-//! Standard RISC-V peripherals for embedded systems written in Rust
+//! Standard RISC-V peripherals for embedded systems written in Rust.
+//!
+//! ## Features
+//!
+//! - `aclint-hal-async`: enables the [`hal_async::delay::DelayNs`] implementation for the ACLINT peripheral.
+//! This feature relies on external functions that must be provided by the user. See [`hal_async::aclint`] for more information.
 
 #![deny(missing_docs)]
 #![no_std]

+ 6 - 15
riscv-peripheral/src/macros.rs

@@ -216,27 +216,18 @@ macro_rules! clint_codegen {
     (async_delay, $($tail:tt)*) => {
         impl CLINT {
             /// Asynchronous delay implementation for CLINT peripherals.
-            /// You must specify which HART ID you want to use for the delay.
-            ///
-            /// # Note
-            ///
-            /// You must export the `riscv_peripheral::hal_async::delay::DelayNs` trait in order to use delay methods.
-            #[inline]
-            pub fn async_delay<H: $crate::plic::HartIdNumber>(hart_id: H) -> $crate::hal_async::aclint::Delay {
-                $crate::hal_async::aclint::Delay::new(Self::mtimer(), hart_id, Self::freq())
-            }
-
-            /// Asynchronous delay implementation for CLINT peripherals.
-            /// This function determines the current HART ID by reading the [`riscv::register::mhartid`] CSR.
             ///
             /// # Note
             ///
             /// You must export the `riscv_peripheral::hal_async::delay::DelayNs` trait in order to use delay methods.
             ///
-            /// This function can only be used in M-mode. For S-mode, use [`CLINT::async_delay`] instead.
+            /// This implementation relies on the machine-level timer interrupts to wake futures.
+            /// Therefore, it needs to schedule the machine-level timer interrupts via the `MTIMECMP` register assigned to the current HART.
+            /// Thus, the `Delay` instance must be created on the same HART that is used to call the asynchronous delay methods.
+            /// Additionally, the rest of the application must not modify the `MTIMER` register assigned to the current HART.
             #[inline]
-            pub fn async_delay_mhartid() -> $crate::hal_async::aclint::Delay {
-                $crate::hal_async::aclint::Delay::new_mhartid(Self::mtimer(), Self::freq())
+            pub fn async_delay() -> $crate::hal_async::aclint::Delay {
+                $crate::hal_async::aclint::Delay::new(Self::freq())
             }
         }
         $crate::clint_codegen!($($tail)*);