Browse Source

bigint: avoid new allocations for small shifts

Before:
    test shl               ... bench:       7,312 ns/iter (+/- 218)
    test shr               ... bench:       5,282 ns/iter (+/- 243)

After:
    test shl               ... bench:       4,946 ns/iter (+/- 88)
    test shr               ... bench:       4,121 ns/iter (+/- 52)
Josh Stone 9 years ago
parent
commit
b454a14bc4
1 changed files with 64 additions and 60 deletions
  1. 64 60
      src/bigint.rs

+ 64 - 60
src/bigint.rs

@@ -66,6 +66,7 @@
 
 use Integer;
 
+use std::borrow::Cow;
 use std::default::Default;
 use std::error::Error;
 use std::iter::repeat;
@@ -590,11 +591,43 @@ impl<'a> BitXor<&'a BigUint> for BigUint {
     }
 }
 
+#[inline]
+fn biguint_shl(n: Cow<BigUint>, bits: usize) -> BigUint {
+    let n_unit = bits / big_digit::BITS;
+    let mut data = match n_unit {
+        0 => n.into_owned().data,
+        _ => {
+            let len = n_unit + n.data.len() + 1;
+            let mut data = Vec::with_capacity(len);
+            data.extend(repeat(0).take(n_unit));
+            data.extend(n.data.iter().cloned());
+            data
+        },
+    };
+
+    let n_bits = bits % big_digit::BITS;
+    if n_bits > 0 {
+        let mut carry = 0;
+        for elem in data[n_unit..].iter_mut() {
+            let new_carry = *elem >> (big_digit::BITS - n_bits);
+            *elem = (*elem << n_bits) | carry;
+            carry = new_carry;
+        }
+        if carry != 0 {
+            data.push(carry);
+        }
+    }
+
+    BigUint::new(data)
+}
+
 impl Shl<usize> for BigUint {
     type Output = BigUint;
 
     #[inline]
-    fn shl(self, rhs: usize) -> BigUint { (&self) << rhs }
+    fn shl(self, rhs: usize) -> BigUint {
+        biguint_shl(Cow::Owned(self), rhs)
+    }
 }
 
 impl<'a> Shl<usize> for &'a BigUint {
@@ -602,17 +635,39 @@ impl<'a> Shl<usize> for &'a BigUint {
 
     #[inline]
     fn shl(self, rhs: usize) -> BigUint {
-        let n_unit = rhs / big_digit::BITS;
-        let n_bits = rhs % big_digit::BITS;
-        self.shl_unit(n_unit).shl_bits(n_bits)
+        biguint_shl(Cow::Borrowed(self), rhs)
+    }
+}
+
+#[inline]
+fn biguint_shr(n: Cow<BigUint>, bits: usize) -> BigUint {
+    let n_unit = bits / big_digit::BITS;
+    if n_unit >= n.data.len() { return Zero::zero(); }
+    let mut data = match n_unit {
+        0 => n.into_owned().data,
+        _ => n.data[n_unit..].to_vec(),
+    };
+
+    let n_bits = bits % big_digit::BITS;
+    if n_bits > 0 {
+        let mut borrow = 0;
+        for elem in data.iter_mut().rev() {
+            let new_borrow = *elem << (big_digit::BITS - n_bits);
+            *elem = (*elem >> n_bits) | borrow;
+            borrow = new_borrow;
+        }
     }
+
+    BigUint::new(data)
 }
 
 impl Shr<usize> for BigUint {
     type Output = BigUint;
 
     #[inline]
-    fn shr(self, rhs: usize) -> BigUint { (&self) >> rhs }
+    fn shr(self, rhs: usize) -> BigUint {
+        biguint_shr(Cow::Owned(self), rhs)
+    }
 }
 
 impl<'a> Shr<usize> for &'a BigUint {
@@ -620,9 +675,7 @@ impl<'a> Shr<usize> for &'a BigUint {
 
     #[inline]
     fn shr(self, rhs: usize) -> BigUint {
-        let n_unit = rhs / big_digit::BITS;
-        let n_bits = rhs % big_digit::BITS;
-        self.shr_unit(n_unit).shr_bits(n_bits)
+        biguint_shr(Cow::Borrowed(self), rhs)
     }
 }
 
@@ -1646,57 +1699,6 @@ impl BigUint {
         str::from_utf8(buf).ok().and_then(|s| BigUint::from_str_radix(s, radix).ok())
     }
 
-    #[inline]
-    fn shl_unit(&self, n_unit: usize) -> BigUint {
-        if n_unit == 0 || self.is_zero() { return self.clone(); }
-
-        let mut v = vec![0; n_unit];
-        v.extend(self.data.iter().cloned());
-        BigUint::new(v)
-    }
-
-    #[inline]
-    fn shl_bits(self, n_bits: usize) -> BigUint {
-        if n_bits == 0 || self.is_zero() { return self; }
-
-        assert!(n_bits < big_digit::BITS);
-
-        let mut carry = 0;
-        let mut shifted = self.data;
-        for elem in shifted.iter_mut() {
-            let new_carry = *elem >> (big_digit::BITS - n_bits);
-            *elem = (*elem << n_bits) | carry;
-            carry = new_carry;
-        }
-        if carry != 0 {
-            shifted.push(carry);
-        }
-        BigUint::new(shifted)
-    }
-
-    #[inline]
-    fn shr_unit(&self, n_unit: usize) -> BigUint {
-        if n_unit == 0 { return self.clone(); }
-        if self.data.len() < n_unit { return Zero::zero(); }
-        BigUint::from_slice(&self.data[n_unit ..])
-    }
-
-    #[inline]
-    fn shr_bits(self, n_bits: usize) -> BigUint {
-        if n_bits == 0 || self.data.is_empty() { return self; }
-
-        assert!(n_bits < big_digit::BITS);
-
-        let mut borrow = 0;
-        let mut shifted = self.data;
-        for elem in shifted.iter_mut().rev() {
-            let new_borrow = *elem << (big_digit::BITS - n_bits);
-            *elem = (*elem >> n_bits) | borrow;
-            borrow = new_borrow;
-        }
-        BigUint::new(shifted)
-    }
-
     /// Determines the fewest bits necessary to express the `BigUint`.
     pub fn bits(&self) -> usize {
         if self.is_zero() { return 0; }
@@ -1915,7 +1917,9 @@ impl Shr<usize> for BigInt {
     type Output = BigInt;
 
     #[inline]
-    fn shr(self, rhs: usize) -> BigInt { (&self) >> rhs }
+    fn shr(self, rhs: usize) -> BigInt {
+        BigInt::from_biguint(self.sign, self.data >> rhs)
+    }
 }
 
 impl<'a> Shr<usize> for &'a BigInt {