Răsfoiți Sursa

test and fix more scalar add cases

Josh Stone 7 ani în urmă
părinte
comite
6afac825d9
3 a modificat fișierele cu 94 adăugiri și 27 ștergeri
  1. 20 17
      bigint/src/biguint.rs
  2. 47 5
      bigint/src/tests/bigint.rs
  3. 27 5
      bigint/src/tests/biguint.rs

+ 20 - 17
bigint/src/biguint.rs

@@ -405,13 +405,15 @@ impl Add<BigDigit> for BigUint {
 
     #[inline]
     fn add(mut self, other: BigDigit) -> BigUint {
-        if self.data.len() == 0 && other != 0 {
-            self.data.push(0);
-        }
+        if other != 0 {
+            if self.data.len() == 0 {
+                self.data.push(0);
+            }
 
-        let carry = __add2(&mut self.data, &[other]);
-        if carry != 0 {
-            self.data.push(carry);
+            let carry = __add2(&mut self.data, &[other]);
+            if carry != 0 {
+                self.data.push(carry);
+            }
         }
         self
     }
@@ -422,19 +424,20 @@ impl Add<DoubleBigDigit> for BigUint {
 
     #[inline]
     fn add(mut self, other: DoubleBigDigit) -> BigUint {
-        if self.data.len() == 0 && other != 0 {
-            self.data.push(0);
-        }
-        if self.data.len() == 1 && other > BigDigit::max_value() as DoubleBigDigit {
-            self.data.push(0);
-        }
-
         let (hi, lo) = big_digit::from_doublebigdigit(other);
-        let carry = __add2(&mut self.data, &[lo, hi]);
-        if carry != 0 {
-            self.data.push(carry);
+        if hi == 0 {
+            self + lo
+        } else {
+            while self.data.len() < 2 {
+                self.data.push(0);
+            }
+
+            let carry = __add2(&mut self.data, &[lo, hi]);
+            if carry != 0 {
+                self.data.push(carry);
+            }
+            self
         }
-        self
     }
 }
 

+ 47 - 5
bigint/src/tests/bigint.rs

@@ -1,4 +1,4 @@
-use {BigDigit, BigUint, big_digit};
+use {BigDigit, DoubleBigDigit, BigUint, big_digit};
 use {Sign, BigInt, RandBigInt, ToBigInt};
 use Sign::{Minus, NoSign, Plus};
 
@@ -532,6 +532,16 @@ const SUM_TRIPLES: &'static [(&'static [BigDigit],
                                      (&[1, 1, 1], &[N1, N1], &[0, 1, 2]),
                                      (&[2, 2, 1], &[N1, N2], &[1, 1, 2])];
 
+fn get_scalar(vec: &[BigDigit]) -> BigDigit {
+    vec.get(0).map_or(0, BigDigit::clone)
+}
+
+fn get_scalar_double(vec: &[BigDigit]) -> DoubleBigDigit {
+    let lo = vec.get(0).map_or(0, BigDigit::clone);
+    let hi = vec.get(1).map_or(0, BigDigit::clone);
+    big_digit::to_doublebigdigit(hi, lo)
+}
+
 #[test]
 fn test_add() {
     for elm in SUM_TRIPLES.iter() {
@@ -561,8 +571,8 @@ fn test_scalar_add() {
         let c = BigInt::from_slice(Plus, c_vec);
         let (na, nb, nc) = (-&a, -&b, -&c);
 
-        if a_vec.len() == 1 {
-            let a = a_vec[0];
+        if a_vec.len() <= 1 {
+            let a = get_scalar(a_vec);
             assert_op!(a + b == c);
             assert_op!(b + a == c);
             assert_op!(a + nc == nb);
@@ -577,8 +587,24 @@ fn test_scalar_add() {
             }
         }
 
-        if b_vec.len() == 1 {
-            let b = b_vec[0];
+        if a_vec.len() <= 2 {
+            let a = get_scalar_double(a_vec);
+            assert_op!(a + b == c);
+            assert_op!(b + a == c);
+            assert_op!(a + nc == nb);
+            assert_op!(nc + a == nb);
+
+            if a <= i64::max_value() as u64 {
+                let na = -(a as i64);
+                assert_op!(na + nb == nc);
+                assert_op!(nb + na == nc);
+                assert_op!(na + c == b);
+                assert_op!(c + na == b);
+            }
+        }
+
+        if b_vec.len() <= 1 {
+            let b = get_scalar(b_vec);
             assert_op!(a + b == c);
             assert_op!(b + a == c);
             assert_op!(b + nc == na);
@@ -592,6 +618,22 @@ fn test_scalar_add() {
                 assert_op!(c + nb == a);
             }
         }
+
+        if b_vec.len() <= 2 {
+            let b = get_scalar_double(b_vec);
+            assert_op!(a + b == c);
+            assert_op!(b + a == c);
+            assert_op!(b + nc == na);
+            assert_op!(nc + b == na);
+
+            if b <= i64::max_value() as u64 {
+                let nb = -(b as i64);
+                assert_op!(na + nb == nc);
+                assert_op!(nb + na == nc);
+                assert_op!(nb + c == a);
+                assert_op!(c + nb == a);
+            }
+        }
     }
 }
 

+ 27 - 5
bigint/src/tests/biguint.rs

@@ -1,5 +1,5 @@
 use integer::Integer;
-use {BigDigit, BigUint, ToBigUint, big_digit};
+use {BigDigit, DoubleBigDigit, BigUint, ToBigUint, big_digit};
 use {BigInt, RandBigInt, ToBigInt};
 use Sign::Plus;
 
@@ -677,6 +677,16 @@ const SUM_TRIPLES: &'static [(&'static [BigDigit],
                                      (&[1, 1, 1], &[N1, N1], &[0, 1, 2]),
                                      (&[2, 2, 1], &[N1, N2], &[1, 1, 2])];
 
+fn get_scalar(vec: &[BigDigit]) -> BigDigit {
+    vec.get(0).map_or(0, BigDigit::clone)
+}
+
+fn get_scalar_double(vec: &[BigDigit]) -> DoubleBigDigit {
+    let lo = vec.get(0).map_or(0, BigDigit::clone);
+    let hi = vec.get(1).map_or(0, BigDigit::clone);
+    big_digit::to_doublebigdigit(hi, lo)
+}
+
 #[test]
 fn test_add() {
     for elm in SUM_TRIPLES.iter() {
@@ -698,14 +708,26 @@ fn test_scalar_add() {
         let b = BigUint::from_slice(b_vec);
         let c = BigUint::from_slice(c_vec);
 
-        if a_vec.len() == 1 {
-            let a = a_vec[0];
+        if a_vec.len() <= 1 {
+            let a = get_scalar(a_vec);
             assert_op!(a + b == c);
             assert_op!(b + a == c);
         }
 
-        if b_vec.len() == 1 {
-            let b = b_vec[0];
+        if a_vec.len() <= 2 {
+            let a = get_scalar_double(a_vec);
+            assert_op!(a + b == c);
+            assert_op!(b + a == c);
+        }
+
+        if b_vec.len() <= 1 {
+            let b = get_scalar(b_vec);
+            assert_op!(a + b == c);
+            assert_op!(b + a == c);
+        }
+
+        if b_vec.len() <= 2 {
+            let b = get_scalar_double(b_vec);
             assert_op!(a + b == c);
             assert_op!(b + a == c);
         }