Ver código fonte

Add a newtype for non-bottom levels.

This allows us to express the invariant that l≠0 through the type
system.
ticki 8 anos atrás
pai
commit
8a50d08cc9
2 arquivos alterados com 53 adições e 23 exclusões
  1. 48 18
      src/bk/lv.rs
  2. 5 5
      src/bk/node.rs

+ 48 - 18
src/bk/lv.rs

@@ -47,12 +47,12 @@ impl Level {
     }
 
     #[inline]
-    pub fn above(self) -> Option<Level> {
+    pub fn above(self) -> Option<NonBottomLevel> {
         // TODO: Find a way to eliminate this branch.
         if self == Level::max() {
             None
         } else {
-            Some(Level(self.0 + 1))
+            Some(NonBottomLevel(self.0 + 1))
         }
     }
 
@@ -75,11 +75,32 @@ impl Level {
 }
 
 impl Into<usize> for Level {
+    #[inline]
     fn into(self) -> usize {
         self.0
     }
 }
 
+/// A non-bottom level.
+///
+/// This newtype holds the invariant that the contained level is greater than zero (bottom level).
+pub struct NonBottomLevel(Level);
+
+impl NonBottomLevel {
+    #[inline]
+    pub fn below(self) -> Level {
+        // We can safely do this, because `self.0.0` is never zero.
+        Level(self.0.0 - 1)
+    }
+}
+
+impl From<NonBottomLevel> for Level {
+    #[inline]
+    fn from(from: NonBottomLevel) -> Level {
+        Level(from.0)
+    }
+}
+
 pub struct Iter {
     lv: usize,
     to: usize,
@@ -114,6 +135,7 @@ impl Iter {
 impl Iterator for LevelIter {
     type Item = Level;
 
+    #[inline]
     fn next(&mut self) -> Option<Level> {
         if self.lv <= self.to {
             let ret = self.n;
@@ -154,7 +176,7 @@ impl<T> ops::IndexMut<Level> for Array {
 
 #[cfg(test)]
 mod test {
-    use super::{self, Level};
+    use super;
 
     #[test]
     fn level_generation_dist() {
@@ -164,7 +186,7 @@ mod test {
         let mut occ = lv::Array::default();
         // Simulate tousand level generations.
         for _ in 0..1000 {
-            if let Some(lv) = Level::generate() {
+            if let Some(lv) = lv::Level::generate() {
                 // Increment the occurence counter.
                 occ[lv] += 1;
             } else {
@@ -177,7 +199,7 @@ mod test {
         assert!((490..510).contains(nones));
 
         let mut expected = 250;
-        for lv in Iter::all() {
+        for lv in lv::Iter::all() {
             // Ensure that the occurences of `lv` is within the expected margin.
             assert!((expected - 10..expected + 10).contains(occ[lv]));
         }
@@ -185,44 +207,52 @@ mod test {
 
     #[test]
     fn above() {
-        assert_eq!(Level::max().above(), None);
-        assert_eq!(Level::min().above().unwrap() as usize, 1);
+        assert_eq!(lv::Level::max().above(), None);
+        assert_eq!(lv::Level::min().above().unwrap() as usize, 1);
     }
 
     #[test]
     fn iter() {
-        assert!(Iter::all().eq(0..Level::max() as usize));
-        assert!(Iter::non_bottom().eq(1..Level::max() as usize));
+        assert!(lv::Iter::all().eq(0..lv::Level::max() as usize));
+        assert!(lv::Iter::non_bottom().eq(1..lv::Level::max() as usize));
     }
 
     #[test]
     fn array_max_index() {
-        assert_eq!(lv::Array::<&str>::default()[Level::max()], "");
-        assert_eq!(lv::Array::<u32>::default()[Level::max()], 0);
-        assert_eq!(&mut lv::Array::<&str>::default()[Level::max()], &mut "");
-        assert_eq!(&mut lv::Array::<u32>::default()[Level::max()], &mut 0);
+        assert_eq!(lv::Array::<&str>::default()[lv::Level::max()], "");
+        assert_eq!(lv::Array::<u32>::default()[lv::Level::max()], 0);
+        assert_eq!(&mut lv::Array::<&str>::default()[lv::Level::max()], &mut "");
+        assert_eq!(&mut lv::Array::<u32>::default()[lv::Level::max()], &mut 0);
     }
 
     #[test]
     fn array_iter() {
         let mut arr = lv::Array::default();
-        for lv in Iter::all() {
+        for lv in lv::Iter::all() {
             arr[lv] = lv as usize;
         }
 
-        for lv in Iter::all() {
+        for lv in lv::Iter::all() {
             assert_eq!(arr[lv], lv as usize);
 
-            for lv in Iter::start_at(lv) {
+            for lv in lv::Iter::start_at(lv) {
                 assert_eq!(arr[lv], lv as usize);
             }
-            for lv in Iter::all().to(lv) {
+            for lv in lv::Iter::all().to(lv) {
                 assert_eq!(arr[lv], lv as usize);
             }
         }
 
-        for lv in Iter::non_bottom() {
+        for lv in lv::Iter::non_bottom() {
             assert_eq!(arr[lv], lv as usize);
         }
     }
+
+    #[test]
+    fn non_bottom_below() {
+        let above: lv::NonBottomLevel = lv::Level::min().above().unwrap();
+        let lv: lv::Level = above.below();
+
+        assert_eq!(lv, lv::Level::min());
+    }
 }

+ 5 - 5
src/bk/node.rs

@@ -113,10 +113,10 @@ impl Node {
         new_fat
     }
 
-    /// Calculate the fat value of a non bottom layer (i.e. level is greater than or equal to one).
-    pub fn calculate_fat_value_non_bottom(&self, lv: lv::Level) -> block::Size {
+    /// Calculate the fat value of a non bottom layer.
+    pub fn calculate_fat_value_non_bottom(&self, lv: lv::NonBottomLevel) -> block::Size {
         // Since `lv != 0` decrementing will not underflow.
-        self.calculate_fat_value(lv, self.shortcuts[lv - 1].follow_shortcut(lv - 1))
+        self.calculate_fat_value(lv, self.shortcuts[lv.below()].follow_shortcut(lv.below()))
     }
 
     /// Calculate the fat value of the lowest level.
@@ -159,7 +159,7 @@ impl Node {
                     fat value does not match the calculated fat value.");
 
             // Check the fat values of the non bottom level.
-            for lv in lv::level_iter().skip(1) {
+            for lv in lv::Iter::non_bottom() {
                 assert!(self.shortcuts[lv.into()].fat == self.calculate_fat_value_non_bottom(lv), "The \
                         bottom layer's fat value does not match the calculated fat value.");
             }
@@ -167,7 +167,7 @@ impl Node {
             // Check that the shortcut refers to a node with appropriate (equal to or greater)
             // height.
             // FIXME: Fold this loop to the one above.
-            for lv in lv::level_iter() {
+            for lv in lv::Iter::all() {
                 assert!(!self.shortcuts[lv.into()].next.shortcuts[lv.into()].is_null(), "Shortcut \
                         points to a node with a lower height. Is this a dangling pointer?");
             }