Procházet zdrojové kódy

spec: binary: optimize `count`, `last`, `min`, `max` and `is_sorted` to O(1) for iterator `HartIds`

Internal improvements on iterator unit test cases.

Signed-off-by: Zhouqi Jiang <luojia@hust.edu.cn>
Zhouqi Jiang před 4 měsíci
rodič
revize
4821073b56
1 změnil soubory, kde provedl 95 přidání a 24 odebrání
  1. 95 24
      sbi-spec/src/binary.rs

+ 95 - 24
sbi-spec/src/binary.rs

@@ -934,11 +934,11 @@ impl Iterator for HartIds {
                 *unvisited_mask &= !(1usize << low_bit);
                 Some(hart_id)
             }
-            UnvisitedMask::Range(from, to) => {
-                assert!(from <= to);
-                if *from < *to {
-                    let ans = *from;
-                    *from += 1;
+            UnvisitedMask::Range(start, end) => {
+                assert!(start <= end);
+                if *start < *end {
+                    let ans = *start;
+                    *start += 1;
                     Some(ans)
                 } else {
                     None
@@ -954,13 +954,42 @@ impl Iterator for HartIds {
                 let exact_popcnt = usize::try_from(unvisited_mask.count_ones()).unwrap();
                 (exact_popcnt, Some(exact_popcnt))
             }
-            UnvisitedMask::Range(from, to) => {
-                assert!(from <= to);
-                let exact_num_harts = to - from;
+            UnvisitedMask::Range(start, end) => {
+                assert!(start <= end);
+                let exact_num_harts = end - start;
                 (exact_num_harts, Some(exact_num_harts))
             }
         }
     }
+
+    #[inline]
+    fn count(self) -> usize {
+        self.size_hint().0
+    }
+
+    #[inline]
+    fn last(mut self) -> Option<Self::Item> {
+        self.next_back()
+    }
+
+    #[inline]
+    fn min(mut self) -> Option<Self::Item> {
+        self.next()
+    }
+
+    #[inline]
+    fn max(mut self) -> Option<Self::Item> {
+        self.next_back()
+    }
+
+    #[inline]
+    fn is_sorted(self) -> bool {
+        true
+    }
+
+    // TODO: implement fn advance_by once it's stablized: https://github.com/rust-lang/rust/issues/77404
+    // #[inline]
+    // fn advance_by(&mut self, n: usize) -> Result<(), core::num::NonZero<usize>> { ... }
 }
 
 impl DoubleEndedIterator for HartIds {
@@ -974,11 +1003,11 @@ impl DoubleEndedIterator for HartIds {
                 *unvisited_mask &= !(1usize << (usize::BITS - high_bit - 1));
                 Some(hart_id)
             }
-            UnvisitedMask::Range(from, to) => {
-                assert!(from <= to);
-                if *from < *to {
-                    let ans = *to;
-                    *to -= 1;
+            UnvisitedMask::Range(start, end) => {
+                assert!(start <= end);
+                if *start < *end {
+                    let ans = *end;
+                    *end -= 1;
                     Some(ans)
                 } else {
                     None
@@ -986,6 +1015,10 @@ impl DoubleEndedIterator for HartIds {
             }
         }
     }
+
+    // TODO: implement advance_back_by once stablized.
+    // #[inline]
+    // fn advance_back_by(&mut self, n: usize) -> Result<(), core::num::NonZero<usize>> { ... }
 }
 
 impl ExactSizeIterator for HartIds {}
@@ -1402,7 +1435,8 @@ mod tests {
         assert_eq!(iter.size_hint(), (usize::MAX - 1, Some(usize::MAX - 1)));
         assert_eq!(iter.next(), Some(1));
         assert_eq!(iter.next(), Some(2));
-        let mut iter = iter.skip(500);
+        // skip 500 elements
+        let _ = iter.nth(500 - 1);
         assert_eq!(iter.next(), Some(503));
         assert_eq!(iter.size_hint(), (usize::MAX - 504, Some(usize::MAX - 504)));
         assert_eq!(iter.next_back(), Some(usize::MAX));
@@ -1411,20 +1445,57 @@ mod tests {
 
         // A common usage of `HartMask::all`, we assume that this platform filters out hart 0..=3.
         let environment_available_hart_ids = 4..128;
-        // `iter` contains 64..=usize::MAX.
+        // `hart_mask_iter` contains 64..=usize::MAX.
         let hart_mask_iter = all_harts.iter().skip(64);
-        let mut iter_peekable = hart_mask_iter.peekable();
         let filtered_iter = environment_available_hart_ids.filter(|&x| {
-            while let Some(&y) = iter_peekable.peek() {
-                match y.cmp(&x) {
-                    core::cmp::Ordering::Equal => return true,
-                    core::cmp::Ordering::Greater => break,
-                    core::cmp::Ordering::Less => iter_peekable.next(),
-                };
-            }
-            false
+            hart_mask_iter
+                .clone()
+                .find(|&y| y >= x)
+                .map_or(false, |y| y == x)
         });
         assert!(filtered_iter.eq(64..128));
+
+        // The following operations should have O(1) complexity.
+        let all_harts = HartMask::all();
+        assert_eq!(all_harts.iter().count(), usize::MAX);
+        assert_eq!(all_harts.iter().last(), Some(usize::MAX));
+        assert_eq!(all_harts.iter().min(), Some(0));
+        assert_eq!(all_harts.iter().max(), Some(usize::MAX));
+        assert!(all_harts.iter().is_sorted());
+
+        let partial_all_harts = {
+            let mut ans = HartMask::all().iter();
+            let _ = ans.nth(65536 - 1);
+            let _ = ans.nth_back(4096 - 1);
+            ans
+        };
+        assert_eq!(partial_all_harts.clone().count(), usize::MAX - 65536 - 4096);
+        assert_eq!(partial_all_harts.clone().last(), Some(usize::MAX - 4096));
+        assert_eq!(partial_all_harts.clone().min(), Some(65536));
+        assert_eq!(partial_all_harts.clone().max(), Some(usize::MAX - 4096));
+        assert!(partial_all_harts.is_sorted());
+
+        let nothing = HartMask::from_mask_base(0, 1000);
+        assert_eq!(nothing.iter().count(), 0);
+        assert_eq!(nothing.iter().last(), None);
+        assert_eq!(nothing.iter().min(), None);
+        assert_eq!(nothing.iter().max(), None);
+        assert!(nothing.iter().is_sorted());
+
+        let mask = HartMask::from_mask_base(0b101011, 1);
+        assert_eq!(mask.iter().count(), 4);
+        assert_eq!(mask.iter().last(), Some(6));
+        assert_eq!(mask.iter().min(), Some(1));
+        assert_eq!(mask.iter().max(), Some(6));
+        assert!(mask.iter().is_sorted());
+
+        let all_mask_bits_set = HartMask::from_mask_base(usize::MAX, 1000);
+        let last = 1000 + usize::BITS as usize - 1;
+        assert_eq!(all_mask_bits_set.iter().count(), usize::BITS as usize);
+        assert_eq!(all_mask_bits_set.iter().last(), Some(last));
+        assert_eq!(all_mask_bits_set.iter().min(), Some(1000));
+        assert_eq!(all_mask_bits_set.iter().max(), Some(last));
+        assert!(all_mask_bits_set.iter().is_sorted());
     }
 
     #[test]