Quellcode durchsuchen

Add scalar subtraction to BigInt

Sam Cappleman-Lynes vor 7 Jahren
Ursprung
Commit
79448cbdf9
2 geänderte Dateien mit 74 neuen und 1 gelöschten Zeilen
  1. 29 0
      bigint/src/bigint.rs
  2. 45 1
      bigint/src/tests/bigint.rs

+ 29 - 0
bigint/src/bigint.rs

@@ -445,6 +445,35 @@ impl Sub<BigInt> for BigInt {
     }
 }
 
+forward_all_scalar_binop_to_val_val!(impl Sub<BigDigit> for BigInt, sub);
+
+impl Sub<BigDigit> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn sub(self, other: BigDigit) -> BigInt {
+        match self.sign {
+            NoSign => BigInt::from_biguint(Minus, From::from(other)),
+            Minus => BigInt::from_biguint(Minus, self.data + other),
+            Plus =>
+                match self.data.cmp(&From::from(other)) {
+                    Equal => Zero::zero(),
+                    Greater => BigInt::from_biguint(Plus, self.data - other),
+                    Less => BigInt::from_biguint(Minus, other - self.data),
+                }
+        }
+    }
+}
+
+impl Sub<BigInt> for BigDigit {
+    type Output = BigInt;
+
+    #[inline]
+    fn sub(self, other: BigInt) -> BigInt {
+        -(other - self)
+    }
+}
+
 forward_all_binop_to_ref_ref!(impl Mul for BigInt, mul);
 
 impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt {

+ 45 - 1
bigint/src/tests/bigint.rs

@@ -556,9 +556,10 @@ fn test_add() {
 fn test_scalar_add() {
     for elm in SUM_TRIPLES.iter() {
         let (a_vec, b_vec, c_vec) = *elm;
+        let a = BigInt::from_slice(Plus, a_vec);
         let b = BigInt::from_slice(Plus, b_vec);
         let c = BigInt::from_slice(Plus, c_vec);
-        let (nb, nc) = (-&b, -&c);
+        let (na, nb, nc) = (-&a, -&b, -&c);
 
         if a_vec.len() == 1 {
             let a = a_vec[0];
@@ -567,6 +568,14 @@ fn test_scalar_add() {
             assert_op!(a + nc == nb);
             assert_op!(nc + a == nb);
         }
+
+        if b_vec.len() == 1 {
+            let b = b_vec[0];
+            assert_op!(a + b == c);
+            assert_op!(b + a == c);
+            assert_op!(b + nc == na);
+            assert_op!(nc + b == na);
+        }
     }
 }
 
@@ -590,6 +599,41 @@ fn test_sub() {
     }
 }
 
+#[test]
+fn test_scalar_sub() {
+    for elm in SUM_TRIPLES.iter() {
+        let (a_vec, b_vec, c_vec) = *elm;
+        let a = BigInt::from_slice(Plus, a_vec);
+        let b = BigInt::from_slice(Plus, b_vec);
+        let c = BigInt::from_slice(Plus, c_vec);
+        let (na, nb, nc) = (-&a, -&b, -&c);
+
+        if a_vec.len() == 1 {
+            let a = a_vec[0];
+            assert_op!(c - a == b);
+            assert_op!(a - c == nb);
+            assert_op!(a - nb == c);
+            assert_op!(nb - a == nc);
+        }
+
+        if b_vec.len() == 1 {
+            let b = b_vec[0];
+            assert_op!(c - b == a);
+            assert_op!(b - c == na);
+            assert_op!(b - na == c);
+            assert_op!(na - b == nc);
+        }
+
+        if c_vec.len() == 1 {
+            let c = c_vec[0];
+            assert_op!(c - a == b);
+            assert_op!(a - c == nb);
+            assert_op!(c - b == a);
+            assert_op!(b - c == na);
+        }
+    }
+}
+
 const M: u32 = ::std::u32::MAX;
 static MUL_TRIPLES: &'static [(&'static [BigDigit],
            &'static [BigDigit],