Browse Source

spec: binary: implement full-range constructor (`HartMask::all`) and iterator for `HartMask`

Signed-off-by: Zhouqi Jiang <luojia@hust.edu.cn>
Zhouqi Jiang 4 months ago
parent
commit
9ec3a8d2aa
1 changed files with 104 additions and 23 deletions
  1. 104 23
      sbi-spec/src/binary.rs

+ 104 - 23
sbi-spec/src/binary.rs

@@ -817,6 +817,20 @@ impl HartMask {
         }
     }
 
+    /// Construct a [HartMask] that selects all available harts on the current environment.
+    ///
+    /// According to the RISC-V SBI Specification, `hart_mask_base` can be set to `-1` (i.e. `usize::MAX`)
+    /// to indicate that `hart_mask` shall be ignored and all available harts must be considered.
+    /// In case of this function in the `sbi-spec` crate, we fill in `usize::MAX` in `hart_mask_base`
+    /// parameter to match the RISC-V SBI standard, while choosing 0 as the ignored `hart_mask` value.
+    #[inline]
+    pub const fn all() -> Self {
+        Self {
+            hart_mask: 0,
+            hart_mask_base: usize::MAX,
+        }
+    }
+
     /// Gets the special value for ignoring the `mask` parameter.
     #[inline]
     pub const fn ignore_mask(&self) -> usize {
@@ -874,8 +888,10 @@ impl HartMask {
     #[inline]
     pub const fn iter(&self) -> HartIds {
         HartIds {
-            unvisited_mask: self.hart_mask,
-            base: self.hart_mask_base,
+            inner: match self.hart_mask_base {
+                Self::IGNORE_MASK => UnvisitedMask::Range(0, usize::MAX),
+                _ => UnvisitedMask::MaskBase(self.hart_mask, self.hart_mask_base),
+            },
         }
     }
 }
@@ -896,8 +912,13 @@ impl IntoIterator for HartMask {
 /// It will iterate hart id from low to high.
 #[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
 pub struct HartIds {
-    unvisited_mask: usize,
-    base: usize,
+    inner: UnvisitedMask,
+}
+
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+enum UnvisitedMask {
+    MaskBase(usize, usize),
+    Range(usize, usize),
 }
 
 impl Iterator for HartIds {
@@ -905,35 +926,64 @@ impl Iterator for HartIds {
 
     #[inline]
     fn next(&mut self) -> Option<Self::Item> {
-        if self.unvisited_mask == 0 {
-            None
-        } else {
-            let low_bit = self.unvisited_mask.trailing_zeros();
-            let hart_id = usize::try_from(low_bit).unwrap() + self.base;
-            self.unvisited_mask &= !(1usize << low_bit);
-
-            Some(hart_id)
+        match &mut self.inner {
+            UnvisitedMask::MaskBase(0, _base) => None,
+            UnvisitedMask::MaskBase(unvisited_mask, base) => {
+                let low_bit = unvisited_mask.trailing_zeros();
+                let hart_id = usize::try_from(low_bit).unwrap() + *base;
+                *unvisited_mask &= !(1usize << low_bit);
+                Some(hart_id)
+            }
+            UnvisitedMask::Range(from, to) => {
+                assert!(from <= to);
+                if *from < *to {
+                    let ans = *from;
+                    *from += 1;
+                    Some(ans)
+                } else {
+                    None
+                }
+            }
         }
     }
 
     #[inline]
     fn size_hint(&self) -> (usize, Option<usize>) {
-        let exact_popcnt = usize::try_from(self.unvisited_mask.count_ones()).unwrap();
-        (exact_popcnt, Some(exact_popcnt))
+        match self.inner {
+            UnvisitedMask::MaskBase(unvisited_mask, _base) => {
+                let exact_popcnt = usize::try_from(unvisited_mask.count_ones()).unwrap();
+                (exact_popcnt, Some(exact_popcnt))
+            }
+            UnvisitedMask::Range(from, to) => {
+                assert!(from <= to);
+                let exact_num_harts = to - from;
+                (exact_num_harts, Some(exact_num_harts))
+            }
+        }
     }
 }
 
 impl DoubleEndedIterator for HartIds {
     #[inline]
     fn next_back(&mut self) -> Option<Self::Item> {
-        if self.unvisited_mask == 0 {
-            None
-        } else {
-            let high_bit = self.unvisited_mask.leading_zeros();
-            let hart_id = usize::try_from(usize::BITS - high_bit - 1).unwrap() + self.base;
-            self.unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
-
-            Some(hart_id)
+        match &mut self.inner {
+            UnvisitedMask::MaskBase(0, _base) => None,
+            UnvisitedMask::MaskBase(unvisited_mask, base) => {
+                let high_bit = unvisited_mask.leading_zeros();
+                let hart_id = usize::try_from(usize::BITS - high_bit - 1).unwrap() + *base;
+                *unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
+                Some(hart_id)
+            }
+            UnvisitedMask::Range(from, to) => {
+                assert!(from <= to);
+                if *from < *to {
+                    let ans = *to;
+                    *to -= 1;
+                    Some(ans)
+                } else {
+                    None
+                }
+            }
         }
     }
 }
@@ -1343,7 +1393,38 @@ mod tests {
         let range = 1000..(1000 + usize::BITS as usize);
         assert!(all_mask_bits_set.iter().eq(range));
 
-        // TODO: full-range hart mask
+        let all_harts = HartMask::all();
+        let mut iter = all_harts.iter();
+        assert_eq!(iter.size_hint(), (usize::MAX, Some(usize::MAX)));
+        // Don't use `Iterator::eq` here; it would literally run `Iterator::try_for_each` from 0 to usize::MAX
+        // which will cost us forever to run the test.
+        assert_eq!(iter.next(), Some(0));
+        assert_eq!(iter.size_hint(), (usize::MAX - 1, Some(usize::MAX - 1)));
+        assert_eq!(iter.next(), Some(1));
+        assert_eq!(iter.next(), Some(2));
+        let mut iter = iter.skip(500);
+        assert_eq!(iter.next(), Some(503));
+        assert_eq!(iter.size_hint(), (usize::MAX - 504, Some(usize::MAX - 504)));
+        assert_eq!(iter.next_back(), Some(usize::MAX));
+        assert_eq!(iter.next_back(), Some(usize::MAX - 1));
+        assert_eq!(iter.size_hint(), (usize::MAX - 506, Some(usize::MAX - 506)));
+
+        // A common usage of `HartMask::all`, we assume that this platform filters out hart 0..=3.
+        let environment_available_hart_ids = 4..128;
+        // `iter` contains 64..=usize::MAX.
+        let hart_mask_iter = all_harts.iter().skip(64);
+        let mut iter_peekable = hart_mask_iter.peekable();
+        let filtered_iter = environment_available_hart_ids.filter(|&x| {
+            while let Some(&y) = iter_peekable.peek() {
+                match y.cmp(&x) {
+                    core::cmp::Ordering::Equal => return true,
+                    core::cmp::Ordering::Greater => break,
+                    core::cmp::Ordering::Less => iter_peekable.next(),
+                };
+            }
+            false
+        });
+        assert!(filtered_iter.eq(64..128));
     }
 
     #[test]