Parcourir la source

Add scalar multiplication to BigUint, BigInt

BigUint and BigInt can now be multiplied by a BigDigit, re-using the same buffer for the output, thereby reducing allocations and copying.
Sam Cappleman-Lynes il y a 8 ans
Parent
commit
e520bdad0d

+ 17 - 0
bigint/src/algorithms.rs

@@ -85,6 +85,15 @@ pub fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigi
     lo
 }
 
+#[inline]
+pub fn mul_with_carry(a: BigDigit, b: BigDigit, carry: &mut BigDigit) -> BigDigit {
+    let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) * (b as DoubleBigDigit) +
+                                                  (*carry as DoubleBigDigit));
+
+    *carry = hi;
+    lo
+}
+
 /// Divide a two digit numerator by a one digit divisor, returns quotient and remainder:
 ///
 /// Note: the caller must ensure that both the quotient and remainder will fit into a single digit.
@@ -377,6 +386,14 @@ pub fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
     prod.normalize()
 }
 
+pub fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit {
+    let mut carry = 0;
+    for a in a.iter_mut() {
+        *a = mul_with_carry(*a, b, &mut carry);
+    }
+    carry
+}
+
 pub fn div_rem(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) {
     if d.is_zero() {
         panic!()

+ 9 - 0
bigint/src/bigint.rs

@@ -436,6 +436,15 @@ impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt {
     }
 }
 
+impl Mul<BigDigit> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn mul(self, other: BigDigit) -> BigInt {
+        BigInt::from_biguint(self.sign, self.data * other)
+    }
+}
+
 forward_all_binop_to_ref_ref!(impl Div for BigInt, div);
 
 impl<'a, 'b> Div<&'b BigInt> for &'a BigInt {

+ 14 - 1
bigint/src/biguint.rs

@@ -22,7 +22,7 @@ mod algorithms;
 pub use self::algorithms::big_digit;
 pub use self::big_digit::{BigDigit, DoubleBigDigit, ZERO_BIG_DIGIT};
 
-use self::algorithms::{mac_with_carry, mul3, div_rem, div_rem_digit};
+use self::algorithms::{mac_with_carry, mul3, scalar_mul, div_rem, div_rem_digit};
 use self::algorithms::{__add2, add2, sub2, sub2rev};
 use self::algorithms::{biguint_shl, biguint_shr};
 use self::algorithms::{cmp_slice, fls, ilog2};
@@ -431,6 +431,19 @@ impl<'a, 'b> Mul<&'b BigUint> for &'a BigUint {
     }
 }
 
+impl Mul<BigDigit> for BigUint {
+    type Output = BigUint;
+
+    #[inline]
+    fn mul(mut self, other: BigDigit) -> BigUint {
+        let carry = scalar_mul(&mut self.data[..], other);
+        if carry != 0 {
+            self.data.push(carry);
+        }
+        self
+    }
+}
+
 forward_all_binop_to_ref_ref!(impl Div for BigUint, div);
 
 impl<'a, 'b> Div<&'b BigUint> for &'a BigUint {

+ 25 - 0
bigint/src/tests/bigint.rs

@@ -637,6 +637,31 @@ fn test_mul() {
     }
 }
 
+#[test]
+fn test_scalar_mul() {
+    for elm in MUL_TRIPLES.iter() {
+        let (a_vec, b_vec, c_vec) = *elm;
+        let c = BigInt::from_slice(Plus, c_vec);
+        let nc = BigInt::from_slice(Minus, c_vec);
+
+        if a_vec.len() == 1 {
+            let b = BigInt::from_slice(Plus, b_vec);
+            let nb = BigInt::from_slice(Minus, b_vec);
+            let a = a_vec[0];
+            assert!(b * a == c);
+            assert!(nb * a == nc);
+        }
+
+        if b_vec.len() == 1 {
+            let a = BigInt::from_slice(Plus, a_vec);
+            let na = BigInt::from_slice(Minus, a_vec);
+            let b = b_vec[0];
+            assert!(a * b == c);
+            assert!(na * b == nc);
+        }
+    }
+}
+
 #[test]
 fn test_div_mod_floor() {
     fn check_sub(a: &BigInt, b: &BigInt, ans_d: &BigInt, ans_m: &BigInt) {

+ 20 - 0
bigint/src/tests/biguint.rs

@@ -770,6 +770,26 @@ fn test_mul() {
     }
 }
 
+#[test]
+fn test_scalar_mul() {
+    for elm in MUL_TRIPLES.iter() {
+        let (a_vec, b_vec, c_vec) = *elm;
+        let c = BigUint::from_slice(c_vec);
+
+        if a_vec.len() == 1 {
+            let b = BigUint::from_slice(b_vec);
+            let a = a_vec[0];
+            assert!(b * a == c);
+        }
+
+        if b_vec.len() == 1 {
+            let a = BigUint::from_slice(a_vec);
+            let b = b_vec[0];
+            assert!(a * b == c);
+        }
+    }
+}
+
 #[test]
 fn test_div_rem() {
     for elm in MUL_TRIPLES.iter() {