Parcourir la source

Add operations on i32 to BigInt

Sam Cappleman-Lynes il y a 7 ans
Parent
commit
94d570697c
2 fichiers modifiés avec 185 ajouts et 0 suppressions
  1. 122 0
      bigint/src/bigint.rs
  2. 63 0
      bigint/src/tests/bigint.rs

+ 122 - 0
bigint/src/bigint.rs

@@ -299,6 +299,14 @@ impl Signed for BigInt {
     }
 }
 
+// A convenience method for getting the absolute value of an i32 in a u32.
+fn i32_abs_as_u32(a: i32) -> u32 {
+    match a.checked_abs() {
+        Some(x) => x as u32,
+        None => a as u32
+    }
+}
+
 // We want to forward to BigUint::add, but it's not clear how that will go until
 // we compare both sign and magnitude.  So we duplicate this body for every
 // val/ref combination, deferring that decision to BigUint's own forwarding.
@@ -382,6 +390,21 @@ impl Add<BigDigit> for BigInt {
     }
 }
 
+forward_all_scalar_binop_to_val_val_commutative!(impl Add<i32> for BigInt, add);
+
+impl Add<i32> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn add(self, other: i32) -> BigInt {
+        if other >= 0 {
+            self + other as u32
+        } else {
+            self - i32_abs_as_u32(other)
+        }
+    }
+}
+
 // We want to forward to BigUint::sub, but it's not clear how that will go until
 // we compare both sign and magnitude.  So we duplicate this body for every
 // val/ref combination, deferring that decision to BigUint's own forwarding.
@@ -474,6 +497,34 @@ impl Sub<BigInt> for BigDigit {
     }
 }
 
+forward_all_scalar_binop_to_val_val!(impl Sub<i32> for BigInt, sub);
+
+impl Sub<i32> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn sub(self, other: i32) -> BigInt {
+        if other >= 0 {
+            self - other as u32
+        } else {
+            self + i32_abs_as_u32(other)
+        }
+    }
+}
+
+impl Sub<BigInt> for i32 {
+    type Output = BigInt;
+
+    #[inline]
+    fn sub(self, other: BigInt) -> BigInt {
+        if self >= 0 {
+            self as u32 - other
+        } else {
+            -other - i32_abs_as_u32(self)
+        }
+    }
+}
+
 forward_all_binop_to_ref_ref!(impl Mul for BigInt, mul);
 
 impl<'a, 'b> Mul<&'b BigInt> for &'a BigInt {
@@ -496,6 +547,21 @@ impl Mul<BigDigit> for BigInt {
     }
 }
 
+forward_all_scalar_binop_to_val_val_commutative!(impl Mul<i32> for BigInt, mul);
+
+impl Mul<i32> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn mul(self, other: i32) -> BigInt {
+        if other >= 0 {
+            self * other as u32
+        } else {
+            -(self * i32_abs_as_u32(other))
+        }
+    }
+}
+
 forward_all_binop_to_ref_ref!(impl Div for BigInt, div);
 
 impl<'a, 'b> Div<&'b BigInt> for &'a BigInt {
@@ -528,6 +594,34 @@ impl Div<BigInt> for BigDigit {
     }
 }
 
+forward_all_scalar_binop_to_val_val!(impl Div<i32> for BigInt, div);
+
+impl Div<i32> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn div(self, other: i32) -> BigInt {
+        if other >= 0 {
+            self / other as u32
+        } else {
+            -(self / i32_abs_as_u32(other))
+        }
+    }
+}
+
+impl Div<BigInt> for i32 {
+    type Output = BigInt;
+
+    #[inline]
+    fn div(self, other: BigInt) -> BigInt {
+        if self >= 0 {
+            self as u32 / other
+        } else {
+            -(i32_abs_as_u32(self) / other)
+        }
+    }
+}
+
 forward_all_binop_to_ref_ref!(impl Rem for BigInt, rem);
 
 impl<'a, 'b> Rem<&'b BigInt> for &'a BigInt {
@@ -560,6 +654,34 @@ impl Rem<BigInt> for BigDigit {
     }
 }
 
+forward_all_scalar_binop_to_val_val!(impl Rem<i32> for BigInt, rem);
+
+impl Rem<i32> for BigInt {
+    type Output = BigInt;
+
+    #[inline]
+    fn rem(self, other: i32) -> BigInt {
+        if other >= 0 {
+            self % other as u32
+        } else {
+            self % i32_abs_as_u32(other)
+        }
+    }
+}
+
+impl Rem<BigInt> for i32 {
+    type Output = BigInt;
+
+    #[inline]
+    fn rem(self, other: BigInt) -> BigInt {
+        if self >= 0 {
+            self as u32 % other
+        } else {
+            -(i32_abs_as_u32(self) % other)
+        }
+    }
+}
+
 impl Neg for BigInt {
     type Output = BigInt;
 

+ 63 - 0
bigint/src/tests/bigint.rs

@@ -567,6 +567,14 @@ fn test_scalar_add() {
             assert_op!(b + a == c);
             assert_op!(a + nc == nb);
             assert_op!(nc + a == nb);
+
+            if a <= i32::max_value() as u32 {
+                let na = -(a as i32);
+                assert_op!(na + nb == nc);
+                assert_op!(nb + na == nc);
+                assert_op!(na + c == b);
+                assert_op!(c + na == b);
+            }
         }
 
         if b_vec.len() == 1 {
@@ -575,6 +583,14 @@ fn test_scalar_add() {
             assert_op!(b + a == c);
             assert_op!(b + nc == na);
             assert_op!(nc + b == na);
+
+            if b <= i32::max_value() as u32 {
+                let nb = -(b as i32);
+                assert_op!(na + nb == nc);
+                assert_op!(nb + na == nc);
+                assert_op!(nb + c == a);
+                assert_op!(c + nb == a);
+            }
         }
     }
 }
@@ -614,6 +630,14 @@ fn test_scalar_sub() {
             assert_op!(a - c == nb);
             assert_op!(a - nb == c);
             assert_op!(nb - a == nc);
+
+            if a <= i32::max_value() as u32 {
+                let na = -(a as i32);
+                assert_op!(nc - na == nb);
+                assert_op!(na - nc == b);
+                assert_op!(na - b == nc);
+                assert_op!(b - na == c);
+            }
         }
 
         if b_vec.len() == 1 {
@@ -622,6 +646,14 @@ fn test_scalar_sub() {
             assert_op!(b - c == na);
             assert_op!(b - na == c);
             assert_op!(na - b == nc);
+
+            if b <= i32::max_value() as u32 {
+                let nb = -(b as i32);
+                assert_op!(nc - nb == na);
+                assert_op!(nb - nc == a);
+                assert_op!(nb - a == nc);
+                assert_op!(a - nb == c);
+            }
         }
 
         if c_vec.len() == 1 {
@@ -630,6 +662,14 @@ fn test_scalar_sub() {
             assert_op!(a - c == nb);
             assert_op!(c - b == a);
             assert_op!(b - c == na);
+
+            if c <= i32::max_value() as u32 {
+                let nc = -(c as i32);
+                assert_op!(nc - na == nb);
+                assert_op!(na - nc == b);
+                assert_op!(nc - nb == na);
+                assert_op!(nb - nc == a);
+            }
         }
     }
 }
@@ -665,6 +705,7 @@ static DIV_REM_QUADRUPLES: &'static [(&'static [BigDigit],
            &'static [BigDigit],
            &'static [BigDigit],
            &'static [BigDigit])] = &[(&[1], &[2], &[], &[1]),
+                                     (&[3], &[2], &[1], &[1]),
                                      (&[1, 1], &[2], &[M / 2 + 1], &[1]),
                                      (&[1, 1, 1], &[2], &[M / 2 + 1, M / 2 + 1], &[1]),
                                      (&[0, 1], &[N1], &[1], &[1]),
@@ -714,6 +755,14 @@ fn test_scalar_mul() {
             assert_op!(a * b == c);
             assert_op!(nb * a == nc);
             assert_op!(a * nb == nc);
+
+            if a <= i32::max_value() as u32 {
+                let na = -(a as i32);
+                assert_op!(nb * na == c);
+                assert_op!(na * nb == c);
+                assert_op!(b * na == nc);
+                assert_op!(na * b == nc);
+            }
         }
 
         if b_vec.len() == 1 {
@@ -722,6 +771,14 @@ fn test_scalar_mul() {
             assert_op!(b * a == c);
             assert_op!(na * b == nc);
             assert_op!(b * na == nc);
+
+            if b <= i32::max_value() as u32 {
+                let nb = -(b as i32);
+                assert_op!(na * nb == c);
+                assert_op!(nb * na == c);
+                assert_op!(a * nb == nc);
+                assert_op!(nb * a == nc);
+            }
         }
     }
 }
@@ -847,6 +904,12 @@ fn test_scalar_div_rem() {
         let (a, b, ans_q, ans_r) = (a.clone(), b.clone(), ans_q.clone(), ans_r.clone());
         assert_op!(a / b == ans_q);
         assert_op!(a % b == ans_r);
+
+        if b <= i32::max_value() as u32 {
+            let nb = -(b as i32);
+            assert_op!(a / nb == -ans_q.clone());
+            assert_op!(a % nb == ans_r);
+        }
     }
 
     fn check(a: &BigInt, b: BigDigit, q: &BigInt, r: &BigInt) {