Quellcode durchsuchen

spec: binary: add iterator tests over HartIds

Signed-off-by: Zhouqi Jiang <luojia@hust.edu.cn>
Zhouqi Jiang vor 4 Monaten
Ursprung
Commit
f563672f4e
2 geänderte Dateien mit 235 neuen und 39 gelöschten Zeilen
  1. 1 0
      sbi-spec/CHANGELOG.md
  2. 234 39
      sbi-spec/src/binary.rs

+ 1 - 0
sbi-spec/CHANGELOG.md

@@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
 - pmu: add config flags with bitflags in chapter 11
 - fwft: add support for FWFT extension in chapter 18
 - sse: add support for Supervisor Software Events Extension in chapter 17
+- binary: add `HartIds` structure for an iterator over `HartMask`
 
 ### Modified
 

+ 234 - 39
sbi-spec/src/binary.rs

@@ -804,22 +804,6 @@ pub struct HartMask {
     hart_mask_base: usize,
 }
 
-/// Iteration for HartMask, from low to high.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-pub struct HartIds {
-    inner: HartMask,
-    visited_mask: usize,
-}
-
-/// Error of mask modification.
-#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
-pub enum MaskError {
-    /// This mask has been ignored.
-    Ignored,
-    /// Request bit is invalid.
-    InvalidBit,
-}
-
 impl HartMask {
     /// Special value to ignore the `mask`, and consider all `bit`s as set.
     pub const IGNORE_MASK: usize = usize::MAX;
@@ -856,10 +840,11 @@ impl HartMask {
         )
     }
 
-    /// Add a hart id to this [HartMark].
-    /// Returns error when hart_id is invalid.
+    /// Insert a hart id into this [HartMask].
+    ///
+    /// Returns error when `hart_id` is invalid.
     #[inline]
-    pub fn insert(&mut self, hart_id: usize) -> Result<(), MaskError> {
+    pub const fn insert(&mut self, hart_id: usize) -> Result<(), MaskError> {
         if self.hart_mask_base == Self::IGNORE_MASK {
             Ok(())
         } else if valid_bit(self.hart_mask_base, hart_id) {
@@ -870,10 +855,11 @@ impl HartMask {
         }
     }
 
-    /// Remove a hart id from this [HartMark].
-    /// Returns error when hart_id is invalid, or it has been ignored.
+    /// Remove a hart id from this [HartMask].
+    ///
+    /// Returns error when `hart_id` is invalid, or it has been ignored.
     #[inline]
-    pub fn remove(&mut self, hart_id: usize) -> Result<(), MaskError> {
+    pub const fn remove(&mut self, hart_id: usize) -> Result<(), MaskError> {
         if self.hart_mask_base == Self::IGNORE_MASK {
             Err(MaskError::Ignored)
         } else if valid_bit(self.hart_mask_base, hart_id) {
@@ -888,31 +874,81 @@ impl HartMask {
     #[inline]
     pub const fn iter(&self) -> HartIds {
         HartIds {
-            inner: HartMask {
-                hart_mask: self.hart_mask,
-                hart_mask_base: self.hart_mask_base,
-            },
-            visited_mask: 0,
+            unvisited_mask: self.hart_mask,
+            base: self.hart_mask_base,
         }
     }
 }
 
+impl IntoIterator for HartMask {
+    type Item = usize;
+
+    type IntoIter = HartIds;
+
+    #[inline]
+    fn into_iter(self) -> Self::IntoIter {
+        self.iter()
+    }
+}
+
+/// Iterator structure for `HartMask`.
+///
+/// It will iterate hart id from low to high.
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+pub struct HartIds {
+    unvisited_mask: usize,
+    base: usize,
+}
+
 impl Iterator for HartIds {
     type Item = usize;
 
     #[inline]
     fn next(&mut self) -> Option<Self::Item> {
-        let non_visited_mask = (!self.visited_mask) & (self.inner.hart_mask);
-        if non_visited_mask == 0 {
+        if self.unvisited_mask == 0 {
             None
         } else {
-            let low_bit = non_visited_mask.trailing_zeros();
-            let hart_id = usize::try_from(low_bit).unwrap() + self.inner.hart_mask_base;
-            self.visited_mask |= 1usize << low_bit;
+            let low_bit = self.unvisited_mask.trailing_zeros();
+            let hart_id = usize::try_from(low_bit).unwrap() + self.base;
+            self.unvisited_mask &= !(1usize << low_bit);
 
             Some(hart_id)
         }
     }
+
+    #[inline]
+    fn size_hint(&self) -> (usize, Option<usize>) {
+        let exact_popcnt = usize::try_from(self.unvisited_mask.count_ones()).unwrap();
+        (exact_popcnt, Some(exact_popcnt))
+    }
+}
+
+impl DoubleEndedIterator for HartIds {
+    #[inline]
+    fn next_back(&mut self) -> Option<Self::Item> {
+        if self.unvisited_mask == 0 {
+            None
+        } else {
+            let high_bit = self.unvisited_mask.leading_zeros();
+            let hart_id = usize::try_from(usize::BITS - high_bit - 1).unwrap() + self.base;
+            self.unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
+
+            Some(hart_id)
+        }
+    }
+}
+
+impl ExactSizeIterator for HartIds {}
+
+impl core::iter::FusedIterator for HartIds {}
+
+/// Error of mask modification.
+#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash)]
+pub enum MaskError {
+    /// This mask has been ignored.
+    Ignored,
+    /// Request bit is invalid.
+    InvalidBit,
 }
 
 /// Counter index mask structure in SBI function calls for the `PMU` extension §11.
@@ -1136,13 +1172,7 @@ mod tests {
             assert!(mask.has_bit(i));
         }
         assert!(mask.has_bit(usize::MAX));
-        // Test HartIds
-        let mut mask_iter = HartMask::from_mask_base(0b101011, 1).iter();
-        assert_eq!(mask_iter.next(), Some(1usize));
-        assert_eq!(mask_iter.next(), Some(2usize));
-        assert_eq!(mask_iter.next(), Some(4usize));
-        assert_eq!(mask_iter.next(), Some(6usize));
-        assert_eq!(mask_iter.next(), None);
+
         let mut mask = HartMask::from_mask_base(0, 1);
         assert!(!mask.has_bit(1));
         assert!(mask.insert(1).is_ok());
@@ -1151,6 +1181,171 @@ mod tests {
         assert!(!mask.has_bit(1));
     }
 
+    #[test]
+    fn rustsbi_hart_ids_iterator() {
+        let mask = HartMask::from_mask_base(0b101011, 1);
+        // Test the `next` method of `HartIds` structure.
+        let mut hart_ids = mask.iter();
+        assert_eq!(hart_ids.next(), Some(1));
+        assert_eq!(hart_ids.next(), Some(2));
+        assert_eq!(hart_ids.next(), Some(4));
+        assert_eq!(hart_ids.next(), Some(6));
+        assert_eq!(hart_ids.next(), None);
+        // `HartIds` structures are fused, meaning they return `None` forever once iteration finished.
+        assert_eq!(hart_ids.next(), None);
+
+        // Test `for` loop on mask (`HartMask`) as `IntoIterator`.
+        let mut ans = [0; 4];
+        let mut idx = 0;
+        for hart_id in mask {
+            ans[idx] = hart_id;
+            idx += 1;
+        }
+        assert_eq!(ans, [1, 2, 4, 6]);
+
+        // Test `Iterator` methods on `HartIds`.
+        let mut hart_ids = mask.iter();
+        assert_eq!(hart_ids.size_hint(), (4, Some(4)));
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.size_hint(), (3, Some(3)));
+        let _ = hart_ids.next();
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.size_hint(), (1, Some(1)));
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.size_hint(), (0, Some(0)));
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.size_hint(), (0, Some(0)));
+
+        let mut hart_ids = mask.iter();
+        assert_eq!(hart_ids.count(), 4);
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.count(), 3);
+        let _ = hart_ids.next();
+        let _ = hart_ids.next();
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.count(), 0);
+        let _ = hart_ids.next();
+        assert_eq!(hart_ids.count(), 0);
+
+        let hart_ids = mask.iter();
+        assert_eq!(hart_ids.last(), Some(6));
+
+        let mut hart_ids = mask.iter();
+        assert_eq!(hart_ids.nth(2), Some(4));
+        let mut hart_ids = mask.iter();
+        assert_eq!(hart_ids.nth(0), Some(1));
+
+        let mut iter = mask.iter().step_by(2);
+        assert_eq!(iter.next(), Some(1));
+        assert_eq!(iter.next(), Some(4));
+        assert_eq!(iter.next(), None);
+
+        let mask_2 = HartMask::from_mask_base(0b1001101, 64);
+        let mut iter = mask.iter().chain(mask_2);
+        assert_eq!(iter.next(), Some(1));
+        assert_eq!(iter.next(), Some(2));
+        assert_eq!(iter.next(), Some(4));
+        assert_eq!(iter.next(), Some(6));
+        assert_eq!(iter.next(), Some(64));
+        assert_eq!(iter.next(), Some(66));
+        assert_eq!(iter.next(), Some(67));
+        assert_eq!(iter.next(), Some(70));
+        assert_eq!(iter.next(), None);
+
+        let mut iter = mask.iter().zip(mask_2);
+        assert_eq!(iter.next(), Some((1, 64)));
+        assert_eq!(iter.next(), Some((2, 66)));
+        assert_eq!(iter.next(), Some((4, 67)));
+        assert_eq!(iter.next(), Some((6, 70)));
+        assert_eq!(iter.next(), None);
+
+        fn to_plic_context_id(hart_id_machine: usize) -> usize {
+            hart_id_machine * 2
+        }
+        let mut iter = mask.iter().map(to_plic_context_id);
+        assert_eq!(iter.next(), Some(2));
+        assert_eq!(iter.next(), Some(4));
+        assert_eq!(iter.next(), Some(8));
+        assert_eq!(iter.next(), Some(12));
+        assert_eq!(iter.next(), None);
+
+        let mut channel_received = [0; 4];
+        let mut idx = 0;
+        let mut channel_send = |hart_id| {
+            channel_received[idx] = hart_id;
+            idx += 1;
+        };
+        mask.iter().for_each(|value| channel_send(value));
+        assert_eq!(channel_received, [1, 2, 4, 6]);
+
+        let is_in_cluster_1 = |hart_id: &usize| *hart_id >= 4 && *hart_id < 7;
+        let mut iter = mask.iter().filter(is_in_cluster_1);
+        assert_eq!(iter.next(), Some(4));
+        assert_eq!(iter.next(), Some(6));
+        assert_eq!(iter.next(), None);
+
+        let if_in_cluster_1_get_plic_context_id = |hart_id: usize| {
+            if hart_id >= 4 && hart_id < 7 {
+                Some(hart_id * 2)
+            } else {
+                None
+            }
+        };
+        let mut iter = mask.iter().filter_map(if_in_cluster_1_get_plic_context_id);
+        assert_eq!(iter.next(), Some(8));
+        assert_eq!(iter.next(), Some(12));
+        assert_eq!(iter.next(), None);
+
+        let mut iter = mask.iter().enumerate();
+        assert_eq!(iter.next(), Some((0, 1)));
+        assert_eq!(iter.next(), Some((1, 2)));
+        assert_eq!(iter.next(), Some((2, 4)));
+        assert_eq!(iter.next(), Some((3, 6)));
+        assert_eq!(iter.next(), None);
+        let mut ans = [(0, 0); 4];
+        let mut idx = 0;
+        for (i, hart_id) in mask.iter().enumerate() {
+            ans[idx] = (i, hart_id);
+            idx += 1;
+        }
+        assert_eq!(ans, [(0, 1), (1, 2), (2, 4), (3, 6)]);
+
+        let mut iter = mask.iter().peekable();
+        assert_eq!(iter.peek(), Some(&1));
+        assert_eq!(iter.next(), Some(1));
+        assert_eq!(iter.peek(), Some(&2));
+        assert_eq!(iter.next(), Some(2));
+        assert_eq!(iter.peek(), Some(&4));
+        assert_eq!(iter.next(), Some(4));
+        assert_eq!(iter.peek(), Some(&6));
+        assert_eq!(iter.next(), Some(6));
+        assert_eq!(iter.peek(), None);
+        assert_eq!(iter.next(), None);
+
+        // TODO: other iterator tests.
+
+        assert!(mask.iter().is_sorted());
+        assert!(mask.iter().is_sorted_by(|a, b| a <= b));
+
+        // Reverse iterator as `DoubleEndedIterator`.
+        let mut iter = mask.iter().rev();
+        assert_eq!(iter.next(), Some(6));
+        assert_eq!(iter.next(), Some(4));
+        assert_eq!(iter.next(), Some(2));
+        assert_eq!(iter.next(), Some(1));
+        assert_eq!(iter.next(), None);
+
+        // Special iterator values.
+        let nothing = HartMask::from_mask_base(0, 1000);
+        assert!(nothing.iter().eq([]));
+
+        let all_mask_bits_set = HartMask::from_mask_base(usize::MAX, 1000);
+        let range = 1000..(1000 + usize::BITS as usize);
+        assert!(all_mask_bits_set.iter().eq(range));
+
+        // TODO: full-range hart mask
+    }
+
     #[test]
     fn rustsbi_counter_index_mask() {
         let mask = CounterMask::from_mask_base(0b1, 400);