Browse Source

Add scalar multiplication to BigInt

Sam Cappleman-Lynes 7 years ago
parent
commit
8b1288ea01
3 changed files with 20 additions and 12 deletions
  1. 2 0
      bigint/src/bigint.rs
  2. 7 3
      bigint/src/biguint.rs
  3. 11 9
      bigint/src/tests/bigint.rs

+ 2 - 0
bigint/src/bigint.rs

@@ -485,6 +485,8 @@ impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt {
     }
 }
 
+forward_all_scalar_binop_to_val_val_commutative!(impl Mul<BigDigit> for BigInt, mul);
+
 impl Mul<BigDigit> for BigInt {
     type Output = BigInt;
 

+ 7 - 3
bigint/src/biguint.rs

@@ -483,9 +483,13 @@ impl Mul<BigDigit> for BigUint {
 
     #[inline]
     fn mul(mut self, other: BigDigit) -> BigUint {
-        let carry = scalar_mul(&mut self.data[..], other);
-        if carry != 0 {
-            self.data.push(carry);
+        if other == 0 {
+            self.data.clear();
+        } else {
+            let carry = scalar_mul(&mut self.data[..], other);
+            if carry != 0 {
+                self.data.push(carry);
+            }
         }
         self
     }

+ 11 - 9
bigint/src/tests/bigint.rs

@@ -703,23 +703,25 @@ fn test_mul() {
 fn test_scalar_mul() {
     for elm in MUL_TRIPLES.iter() {
         let (a_vec, b_vec, c_vec) = *elm;
+        let a = BigInt::from_slice(Plus, a_vec);
+        let b = BigInt::from_slice(Plus, b_vec);
         let c = BigInt::from_slice(Plus, c_vec);
-        let nc = BigInt::from_slice(Minus, c_vec);
+        let (na, nb, nc) = (-&a, -&b, -&c);
 
         if a_vec.len() == 1 {
-            let b = BigInt::from_slice(Plus, b_vec);
-            let nb = BigInt::from_slice(Minus, b_vec);
             let a = a_vec[0];
-            assert!(b * a == c);
-            assert!(nb * a == nc);
+            assert_op!(b * a == c);
+            assert_op!(a * b == c);
+            assert_op!(nb * a == nc);
+            assert_op!(a * nb == nc);
         }
 
         if b_vec.len() == 1 {
-            let a = BigInt::from_slice(Plus, a_vec);
-            let na = BigInt::from_slice(Minus, a_vec);
             let b = b_vec[0];
-            assert!(a * b == c);
-            assert!(na * b == nc);
+            assert_op!(a * b == c);
+            assert_op!(b * a == c);
+            assert_op!(na * b == nc);
+            assert_op!(b * na == nc);
         }
     }
 }