|
@@ -817,6 +817,20 @@ impl HartMask {
|
|
|
}
|
|
|
}
|
|
|
|
|
|
+ /// Construct a [HartMask] that selects all available harts on the current environment.
|
|
|
+ ///
|
|
|
+ /// According to the RISC-V SBI Specification, `hart_mask_base` can be set to `-1` (i.e. `usize::MAX`)
|
|
|
+ /// to indicate that `hart_mask` shall be ignored and all available harts must be considered.
|
|
|
+ /// In case of this function in the `sbi-spec` crate, we fill in `usize::MAX` in `hart_mask_base`
|
|
|
+ /// parameter to match the RISC-V SBI standard, while choosing 0 as the ignored `hart_mask` value.
|
|
|
+ #[inline]
|
|
|
+ pub const fn all() -> Self {
|
|
|
+ Self {
|
|
|
+ hart_mask: 0,
|
|
|
+ hart_mask_base: usize::MAX,
|
|
|
+ }
|
|
|
+ }
|
|
|
+
|
|
|
/// Gets the special value for ignoring the `mask` parameter.
|
|
|
#[inline]
|
|
|
pub const fn ignore_mask(&self) -> usize {
|
|
@@ -874,8 +888,10 @@ impl HartMask {
|
|
|
#[inline]
|
|
|
pub const fn iter(&self) -> HartIds {
|
|
|
HartIds {
|
|
|
- unvisited_mask: self.hart_mask,
|
|
|
- base: self.hart_mask_base,
|
|
|
+ inner: match self.hart_mask_base {
|
|
|
+ Self::IGNORE_MASK => UnvisitedMask::Range(0, usize::MAX),
|
|
|
+ _ => UnvisitedMask::MaskBase(self.hart_mask, self.hart_mask_base),
|
|
|
+ },
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -896,8 +912,13 @@ impl IntoIterator for HartMask {
|
|
|
/// It will iterate hart id from low to high.
|
|
|
#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
|
|
|
pub struct HartIds {
|
|
|
- unvisited_mask: usize,
|
|
|
- base: usize,
|
|
|
+ inner: UnvisitedMask,
|
|
|
+}
|
|
|
+
|
|
|
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
|
|
|
+enum UnvisitedMask {
|
|
|
+ MaskBase(usize, usize),
|
|
|
+ Range(usize, usize),
|
|
|
}
|
|
|
|
|
|
impl Iterator for HartIds {
|
|
@@ -905,35 +926,64 @@ impl Iterator for HartIds {
|
|
|
|
|
|
#[inline]
|
|
|
fn next(&mut self) -> Option<Self::Item> {
|
|
|
- if self.unvisited_mask == 0 {
|
|
|
- None
|
|
|
- } else {
|
|
|
- let low_bit = self.unvisited_mask.trailing_zeros();
|
|
|
- let hart_id = usize::try_from(low_bit).unwrap() + self.base;
|
|
|
- self.unvisited_mask &= !(1usize << low_bit);
|
|
|
-
|
|
|
- Some(hart_id)
|
|
|
+ match &mut self.inner {
|
|
|
+ UnvisitedMask::MaskBase(0, _base) => None,
|
|
|
+ UnvisitedMask::MaskBase(unvisited_mask, base) => {
|
|
|
+ let low_bit = unvisited_mask.trailing_zeros();
|
|
|
+ let hart_id = usize::try_from(low_bit).unwrap() + *base;
|
|
|
+ *unvisited_mask &= !(1usize << low_bit);
|
|
|
+ Some(hart_id)
|
|
|
+ }
|
|
|
+ UnvisitedMask::Range(from, to) => {
|
|
|
+ assert!(from <= to);
|
|
|
+ if *from < *to {
|
|
|
+ let ans = *from;
|
|
|
+ *from += 1;
|
|
|
+ Some(ans)
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
#[inline]
|
|
|
fn size_hint(&self) -> (usize, Option<usize>) {
|
|
|
- let exact_popcnt = usize::try_from(self.unvisited_mask.count_ones()).unwrap();
|
|
|
- (exact_popcnt, Some(exact_popcnt))
|
|
|
+ match self.inner {
|
|
|
+ UnvisitedMask::MaskBase(unvisited_mask, _base) => {
|
|
|
+ let exact_popcnt = usize::try_from(unvisited_mask.count_ones()).unwrap();
|
|
|
+ (exact_popcnt, Some(exact_popcnt))
|
|
|
+ }
|
|
|
+ UnvisitedMask::Range(from, to) => {
|
|
|
+ assert!(from <= to);
|
|
|
+ let exact_num_harts = to - from;
|
|
|
+ (exact_num_harts, Some(exact_num_harts))
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
|
|
|
impl DoubleEndedIterator for HartIds {
|
|
|
#[inline]
|
|
|
fn next_back(&mut self) -> Option<Self::Item> {
|
|
|
- if self.unvisited_mask == 0 {
|
|
|
- None
|
|
|
- } else {
|
|
|
- let high_bit = self.unvisited_mask.leading_zeros();
|
|
|
- let hart_id = usize::try_from(usize::BITS - high_bit - 1).unwrap() + self.base;
|
|
|
- self.unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
|
|
|
-
|
|
|
- Some(hart_id)
|
|
|
+ match &mut self.inner {
|
|
|
+ UnvisitedMask::MaskBase(0, _base) => None,
|
|
|
+ UnvisitedMask::MaskBase(unvisited_mask, base) => {
|
|
|
+ let high_bit = unvisited_mask.leading_zeros();
|
|
|
+ let hart_id = usize::try_from(usize::BITS - high_bit - 1).unwrap() + *base;
|
|
|
+ *unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
|
|
|
+ Some(hart_id)
|
|
|
+ }
|
|
|
+ UnvisitedMask::Range(from, to) => {
|
|
|
+ assert!(from <= to);
|
|
|
+ if *from < *to {
|
|
|
+ let ans = *to;
|
|
|
+ *to -= 1;
|
|
|
+ Some(ans)
|
|
|
+ } else {
|
|
|
+ None
|
|
|
+ }
|
|
|
+ }
|
|
|
}
|
|
|
}
|
|
|
}
|
|
@@ -1343,7 +1393,38 @@ mod tests {
|
|
|
let range = 1000..(1000 + usize::BITS as usize);
|
|
|
assert!(all_mask_bits_set.iter().eq(range));
|
|
|
|
|
|
- // TODO: full-range hart mask
|
|
|
+ let all_harts = HartMask::all();
|
|
|
+ let mut iter = all_harts.iter();
|
|
|
+ assert_eq!(iter.size_hint(), (usize::MAX, Some(usize::MAX)));
|
|
|
+ // Don't use `Iterator::eq` here; it would literally run `Iterator::try_for_each` from 0 to usize::MAX
|
|
|
+ // which will cost us forever to run the test.
|
|
|
+ assert_eq!(iter.next(), Some(0));
|
|
|
+ assert_eq!(iter.size_hint(), (usize::MAX - 1, Some(usize::MAX - 1)));
|
|
|
+ assert_eq!(iter.next(), Some(1));
|
|
|
+ assert_eq!(iter.next(), Some(2));
|
|
|
+ let mut iter = iter.skip(500);
|
|
|
+ assert_eq!(iter.next(), Some(503));
|
|
|
+ assert_eq!(iter.size_hint(), (usize::MAX - 504, Some(usize::MAX - 504)));
|
|
|
+ assert_eq!(iter.next_back(), Some(usize::MAX));
|
|
|
+ assert_eq!(iter.next_back(), Some(usize::MAX - 1));
|
|
|
+ assert_eq!(iter.size_hint(), (usize::MAX - 506, Some(usize::MAX - 506)));
|
|
|
+
|
|
|
+ // A common usage of `HartMask::all`, we assume that this platform filters out hart 0..=3.
|
|
|
+ let environment_available_hart_ids = 4..128;
|
|
|
+ // `iter` contains 64..=usize::MAX.
|
|
|
+ let hart_mask_iter = all_harts.iter().skip(64);
|
|
|
+ let mut iter_peekable = hart_mask_iter.peekable();
|
|
|
+ let filtered_iter = environment_available_hart_ids.filter(|&x| {
|
|
|
+ while let Some(&y) = iter_peekable.peek() {
|
|
|
+ match y.cmp(&x) {
|
|
|
+ core::cmp::Ordering::Equal => return true,
|
|
|
+ core::cmp::Ordering::Greater => break,
|
|
|
+ core::cmp::Ordering::Less => iter_peekable.next(),
|
|
|
+ };
|
|
|
+ }
|
|
|
+ false
|
|
|
+ });
|
|
|
+ assert!(filtered_iter.eq(64..128));
|
|
|
}
|
|
|
|
|
|
#[test]
|