瀏覽代碼

Drop some dependencies on BigDigit's size

Before:
test divide_0          ... bench:       1,011 ns/iter (+/- 184)
test divide_1          ... bench:      18,535 ns/iter (+/- 770)
test divide_2          ... bench:     990,467 ns/iter (+/- 91,980)
test fac_to_string     ... bench:       1,275 ns/iter (+/- 60)
test factorial_100     ... bench:       6,453 ns/iter (+/- 101)
test fib_100           ... bench:       1,142 ns/iter (+/- 99)
test fib_1000          ... bench:      18,713 ns/iter (+/- 2,172)
test fib_10000         ... bench:   1,197,965 ns/iter (+/- 21,178)
test fib_to_string     ... bench:         225 ns/iter (+/- 13)
test from_str_radix_02 ... bench:       3,460 ns/iter (+/- 626)
test from_str_radix_08 ... bench:       1,324 ns/iter (+/- 24)
test from_str_radix_10 ... bench:       1,488 ns/iter (+/- 19)
test from_str_radix_16 ... bench:         969 ns/iter (+/- 22)
test from_str_radix_36 ... bench:       1,135 ns/iter (+/- 23)
test hash              ... bench:     102,126 ns/iter (+/- 1,016)
test multiply_0        ... bench:         353 ns/iter (+/- 74)
test multiply_1        ... bench:      31,006 ns/iter (+/- 679)
test multiply_2        ... bench:   3,438,143 ns/iter (+/- 47,640)
test pow_bench         ... bench:   7,457,045 ns/iter (+/- 96,175)
test shl               ... bench:       5,627 ns/iter (+/- 121)
test shr               ... bench:       5,054 ns/iter (+/- 112)
test to_str_radix_02   ... bench:       2,774 ns/iter (+/- 88)
test to_str_radix_08   ... bench:         980 ns/iter (+/- 425)
test to_str_radix_10   ... bench:       3,029 ns/iter (+/- 115)
test to_str_radix_16   ... bench:         788 ns/iter (+/- 14)
test to_str_radix_36   ... bench:       8,285 ns/iter (+/- 175)

After:
test divide_0          ... bench:         925 ns/iter (+/- 30)
test divide_1          ... bench:      17,660 ns/iter (+/- 379)
test divide_2          ... bench:     972,427 ns/iter (+/- 7,560)
test fac_to_string     ... bench:       1,260 ns/iter (+/- 36)
test factorial_100     ... bench:       7,077 ns/iter (+/- 204)
test fib_100           ... bench:       1,124 ns/iter (+/- 32)
test fib_1000          ... bench:      18,475 ns/iter (+/- 166)
test fib_10000         ... bench:   1,192,748 ns/iter (+/- 27,128)
test fib_to_string     ... bench:         228 ns/iter (+/- 10)
test from_str_radix_02 ... bench:       3,379 ns/iter (+/- 74)
test from_str_radix_08 ... bench:       1,355 ns/iter (+/- 24)
test from_str_radix_10 ... bench:       1,470 ns/iter (+/- 20)
test from_str_radix_16 ... bench:         958 ns/iter (+/- 239)
test from_str_radix_36 ... bench:       1,137 ns/iter (+/- 19)
test hash              ... bench:     102,730 ns/iter (+/- 39,897)
test multiply_0        ... bench:         351 ns/iter (+/- 15)
test multiply_1        ... bench:      31,139 ns/iter (+/- 1,053)
test multiply_2        ... bench:   3,464,509 ns/iter (+/- 124,235)
test pow_bench         ... bench:   7,448,428 ns/iter (+/- 326,903)
test shl               ... bench:       5,784 ns/iter (+/- 190)
test shr               ... bench:       4,820 ns/iter (+/- 63)
test to_str_radix_02   ... bench:       2,757 ns/iter (+/- 33)
test to_str_radix_08   ... bench:         989 ns/iter (+/- 67)
test to_str_radix_10   ... bench:       3,045 ns/iter (+/- 70)
test to_str_radix_16   ... bench:         787 ns/iter (+/- 24)
test to_str_radix_36   ... bench:       8,257 ns/iter (+/- 117)/
Kent Overstreet 8 年之前
父節點
當前提交
8e0baecf5c
共有 1 個文件被更改,包括 227 次插入133 次删除
  1. 227 133
      bigint/src/lib.rs

+ 227 - 133
bigint/src/lib.rs

@@ -306,7 +306,8 @@ impl FromStr for BigUint {
     }
 }
 
-// Read bitwise digits that evenly divide BigDigit
+// Convert from a power of two radix (bits == ilog2(radix)) where bits evenly divides
+// BigDigit::BITS
 fn from_bitwise_digits_le(v: &[u8], bits: usize) -> BigUint {
     debug_assert!(!v.is_empty() && bits <= 8 && big_digit::BITS % bits == 0);
     debug_assert!(v.iter().all(|&c| (c as BigDigit) < (1 << bits)));
@@ -315,14 +316,15 @@ fn from_bitwise_digits_le(v: &[u8], bits: usize) -> BigUint {
 
     let data = v.chunks(digits_per_big_digit)
                 .map(|chunk| {
-                    chunk.iter().rev().fold(0u32, |acc, &c| (acc << bits) | c as BigDigit)
+                    chunk.iter().rev().fold(0, |acc, &c| (acc << bits) | c as BigDigit)
                 })
                 .collect();
 
     BigUint::new(data)
 }
 
-// Read bitwise digits that don't evenly divide BigDigit
+// Convert from a power of two radix (bits == ilog2(radix)) where bits doesn't evenly divide
+// BigDigit::BITS
 fn from_inexact_bitwise_digits_le(v: &[u8], bits: usize) -> BigUint {
     debug_assert!(!v.is_empty() && bits <= 8 && big_digit::BITS % bits != 0);
     debug_assert!(v.iter().all(|&c| (c as BigDigit) < (1 << bits)));
@@ -331,15 +333,20 @@ fn from_inexact_bitwise_digits_le(v: &[u8], bits: usize) -> BigUint {
     let mut data = Vec::with_capacity(big_digits);
 
     let mut d = 0;
-    let mut dbits = 0;
+    let mut dbits = 0; // number of bits we currently have in d
+
+    // walk v accumululating bits in d; whenever we accumulate big_digit::BITS in d, spit out a
+    // big_digit:
     for &c in v {
-        d |= (c as DoubleBigDigit) << dbits;
+        d |= (c as BigDigit) << dbits;
         dbits += bits;
+
         if dbits >= big_digit::BITS {
-            let (hi, lo) = big_digit::from_doublebigdigit(d);
-            data.push(lo);
-            d = hi as DoubleBigDigit;
+            data.push(d);
             dbits -= big_digit::BITS;
+            // if dbits was > big_digit::BITS, we dropped some of the bits in c (they couldn't fit
+            // in d) - grab the bits we lost here:
+            d = (c as BigDigit) >> (bits - dbits);
         }
     }
 
@@ -362,8 +369,7 @@ fn from_radix_digits_be(v: &[u8], radix: u32) -> BigUint {
     let mut data = Vec::with_capacity(big_digits as usize);
 
     let (base, power) = get_radix_base(radix);
-    debug_assert!(base < (1 << 32));
-    let base = base as BigDigit;
+    let radix = radix as BigDigit;
 
     let r = v.len() % power;
     let i = if r == 0 {
@@ -435,7 +441,7 @@ impl Num for BigUint {
 
         let res = if radix.is_power_of_two() {
             // Powers of two can use bitwise masks and shifting instead of multiplication
-            let bits = radix.trailing_zeros() as usize;
+            let bits = ilog2(radix);
             v.reverse();
             if big_digit::BITS % bits == 0 {
                 from_bitwise_digits_le(&v, bits)
@@ -1349,6 +1355,46 @@ impl Integer for BigUint {
     }
 }
 
+fn high_bits_to_u64(v: &BigUint) -> u64 {
+    match v.data.len() {
+        0   => 0,
+        1   => v.data[0] as u64,
+        _   => {
+            let mut bits = v.bits();
+            let mut ret = 0u64;
+            let mut ret_bits = 0;
+
+            for d in v.data.iter().rev() {
+                let digit_bits = (bits - 1) % big_digit::BITS + 1;
+                let bits_want = cmp::min(64 - ret_bits, digit_bits);
+
+                if bits_want != 64 {
+                    ret <<= bits_want;
+                }
+                ret      |= *d as u64 >> (digit_bits - bits_want);
+                ret_bits += bits_want;
+                bits     -= bits_want;
+
+                if ret_bits == 64 {
+                    break;
+                }
+            }
+
+            ret
+        }
+    }
+}
+
+/// Find last set bit
+/// fls(0) == 0, fls(u32::MAX) == 32
+fn fls<T: traits::PrimInt>(v: T) -> usize {
+    std::mem::size_of::<T>() * 8 - v.leading_zeros() as usize
+}
+
+fn ilog2<T: traits::PrimInt>(v: T) -> usize {
+    fls(v) - 1
+}
+
 impl ToPrimitive for BigUint {
     #[inline]
     fn to_i64(&self) -> Option<i64> {
@@ -1362,76 +1408,53 @@ impl ToPrimitive for BigUint {
         })
     }
 
-    // `DoubleBigDigit` size dependent
     #[inline]
     fn to_u64(&self) -> Option<u64> {
-        match self.data.len() {
-            0 => Some(0),
-            1 => Some(self.data[0] as u64),
-            2 => Some(big_digit::to_doublebigdigit(self.data[1], self.data[0]) as u64),
-            _ => None,
+        let mut ret: u64 = 0;
+        let mut bits = 0;
+
+        for i in self.data.iter() {
+            if bits >= 64 {
+                return None;
+            }
+
+            ret += (*i as u64) << bits;
+            bits += big_digit::BITS;
         }
+
+        Some(ret)
     }
 
-    // `DoubleBigDigit` size dependent
     #[inline]
     fn to_f32(&self) -> Option<f32> {
-        match self.data.len() {
-            0 => Some(f32::zero()),
-            1 => Some(self.data[0] as f32),
-            len => {
-                // this will prevent any overflow of exponent
-                if len > (f32::MAX_EXP as usize) / big_digit::BITS {
-                    None
-                } else {
-                    let exponent = (len - 2) * big_digit::BITS;
-                    // we need 25 significant digits, 24 to be stored and 1 for rounding
-                    // this gives at least 33 significant digits
-                    let mantissa = big_digit::to_doublebigdigit(self.data[len - 1],
-                                                                self.data[len - 2]);
-                    // this cast handles rounding
-                    let ret = (mantissa as f32) * 2.0.powi(exponent as i32);
-                    if ret.is_infinite() {
-                        None
-                    } else {
-                        Some(ret)
-                    }
-                }
+        let mantissa = high_bits_to_u64(self);
+        let exponent = self.bits() - fls(mantissa);
+
+        if exponent > f32::MAX_EXP as usize {
+            None
+        } else {
+            let ret = (mantissa as f32) * 2.0f32.powi(exponent as i32);
+            if ret.is_infinite() {
+                None
+            } else {
+                Some(ret)
             }
         }
     }
 
-    // `DoubleBigDigit` size dependent
     #[inline]
     fn to_f64(&self) -> Option<f64> {
-        match self.data.len() {
-            0 => Some(f64::zero()),
-            1 => Some(self.data[0] as f64),
-            2 => Some(big_digit::to_doublebigdigit(self.data[1], self.data[0]) as f64),
-            len => {
-                // this will prevent any overflow of exponent
-                if len > (f64::MAX_EXP as usize) / big_digit::BITS {
-                    None
-                } else {
-                    let mut exponent = (len - 2) * big_digit::BITS;
-                    let mut mantissa = big_digit::to_doublebigdigit(self.data[len - 1],
-                                                                    self.data[len - 2]);
-                    // we need at least 54 significant bit digits, 53 to be stored and 1 for rounding
-                    // so we take enough from the next BigDigit to make it up to 64
-                    let shift = mantissa.leading_zeros() as usize;
-                    if shift > 0 {
-                        mantissa <<= shift;
-                        mantissa |= self.data[len - 3] as u64 >> (big_digit::BITS - shift);
-                        exponent -= shift;
-                    }
-                    // this cast handles rounding
-                    let ret = (mantissa as f64) * 2.0.powi(exponent as i32);
-                    if ret.is_infinite() {
-                        None
-                    } else {
-                        Some(ret)
-                    }
-                }
+        let mantissa = high_bits_to_u64(self);
+        let exponent = self.bits() - fls(mantissa);
+
+        if exponent > f64::MAX_EXP as usize {
+            None
+        } else {
+            let ret = (mantissa as f64) * 2.0f64.powi(exponent as i32);
+            if ret.is_infinite() {
+                None
+            } else {
+                Some(ret)
             }
         }
     }
@@ -1484,14 +1507,17 @@ impl FromPrimitive for BigUint {
 }
 
 impl From<u64> for BigUint {
-    // `DoubleBigDigit` size dependent
     #[inline]
-    fn from(n: u64) -> Self {
-        match big_digit::from_doublebigdigit(n) {
-            (0, 0) => BigUint::zero(),
-            (0, n0) => BigUint { data: vec![n0] },
-            (n1, n0) => BigUint { data: vec![n0, n1] },
+    fn from(mut n: u64) -> Self {
+        let mut ret: BigUint = Zero::zero();
+
+        while n != 0 {
+            ret.data.push(n as BigDigit);
+            // don't overflow if BITS is 64:
+            n = (n >> 1) >> (big_digit::BITS - 1);
         }
+
+        ret
     }
 }
 
@@ -1591,28 +1617,36 @@ fn to_bitwise_digits_le(u: &BigUint, bits: usize) -> Vec<u8> {
 fn to_inexact_bitwise_digits_le(u: &BigUint, bits: usize) -> Vec<u8> {
     debug_assert!(!u.is_zero() && bits <= 8 && big_digit::BITS % bits != 0);
 
-    let last_i = u.data.len() - 1;
-    let mask: DoubleBigDigit = (1 << bits) - 1;
+    let mask: BigDigit = (1 << bits) - 1;
     let digits = (u.bits() + bits - 1) / bits;
     let mut res = Vec::with_capacity(digits);
 
     let mut r = 0;
     let mut rbits = 0;
-    for hi in u.data[..last_i].iter().cloned() {
-        r |= (hi as DoubleBigDigit) << rbits;
+
+    for c in &u.data {
+        r |= *c << rbits;
         rbits += big_digit::BITS;
 
         while rbits >= bits {
             res.push((r & mask) as u8);
             r >>= bits;
+
+            // r had more bits than it could fit - grab the bits we lost
+            if rbits > big_digit::BITS {
+                r = *c >> (big_digit::BITS - (rbits - bits));
+            }
+
             rbits -= bits;
         }
     }
 
-    r |= (u.data[last_i] as DoubleBigDigit) << rbits;
-    while r != 0 {
-        res.push((r & mask) as u8);
-        r >>= bits;
+    if rbits != 0 {
+        res.push(r as u8);
+    }
+
+    while let Some(&0) = res.last() {
+        res.pop();
     }
 
     res
@@ -1629,8 +1663,7 @@ fn to_radix_digits_le(u: &BigUint, radix: u32) -> Vec<u8> {
     let mut digits = u.clone();
 
     let (base, power) = get_radix_base(radix);
-    debug_assert!(base < (1 << 32));
-    let base = base as BigDigit;
+    let radix = radix as BigDigit;
 
     while digits.data.len() > 1 {
         let (q, mut r) = div_rem_digit(digits, base);
@@ -1659,7 +1692,7 @@ fn to_str_radix_reversed(u: &BigUint, radix: u32) -> Vec<u8> {
 
     let mut res = if radix.is_power_of_two() {
         // Powers of two can use bitwise masks and shifting instead of division
-        let bits = radix.trailing_zeros() as usize;
+        let bits = ilog2(radix);
         if big_digit::BITS % bits == 0 {
             to_bitwise_digits_le(u, bits)
         } else {
@@ -1852,57 +1885,115 @@ impl serde::Deserialize for BigUint {
     }
 }
 
-// `DoubleBigDigit` size dependent
 /// Returns the greatest power of the radix <= big_digit::BASE
 #[inline]
-fn get_radix_base(radix: u32) -> (DoubleBigDigit, usize) {
+fn get_radix_base(radix: u32) -> (BigDigit, usize) {
+    debug_assert!(2 <= radix && radix <= 36, "The radix must be within 2...36");
+    debug_assert!(!radix.is_power_of_two());
+
     // To generate this table:
-    //    let target = std::u32::max as u64 + 1;
     //    for radix in 2u64..37 {
-    //        let power = (target as f64).log(radix as f64) as u32;
-    //        let base = radix.pow(power);
+    //        let mut power = big_digit::BITS / fls(radix as u64);
+    //        let mut base = radix.pow(power as u32);
+    //
+    //        while let Some(b) = base.checked_mul(radix) {
+    //            if b > big_digit::MAX {
+    //                break;
+    //            }
+    //            base = b;
+    //            power += 1;
+    //        }
+    //
     //        println!("({:10}, {:2}), // {:2}", base, power, radix);
     //    }
-    const BASES: [(DoubleBigDigit, usize); 37] = [(0, 0),
-                                                  (0, 0),
-                                                  (4294967296, 32), // 2
-                                                  (3486784401, 20), // 3
-                                                  (4294967296, 16), // 4
-                                                  (1220703125, 13), // 5
-                                                  (2176782336, 12), // 6
-                                                  (1977326743, 11), // 7
-                                                  (1073741824, 10), // 8
-                                                  (3486784401, 10), // 9
-                                                  (1000000000, 9), // 10
-                                                  (2357947691, 9), // 11
-                                                  (429981696, 8), // 12
-                                                  (815730721, 8), // 13
-                                                  (1475789056, 8), // 14
-                                                  (2562890625, 8), // 15
-                                                  (4294967296, 8), // 16
-                                                  (410338673, 7), // 17
-                                                  (612220032, 7), // 18
-                                                  (893871739, 7), // 19
-                                                  (1280000000, 7), // 20
-                                                  (1801088541, 7), // 21
-                                                  (2494357888, 7), // 22
-                                                  (3404825447, 7), // 23
-                                                  (191102976, 6), // 24
-                                                  (244140625, 6), // 25
-                                                  (308915776, 6), // 26
-                                                  (387420489, 6), // 27
-                                                  (481890304, 6), // 28
-                                                  (594823321, 6), // 29
-                                                  (729000000, 6), // 30
-                                                  (887503681, 6), // 31
-                                                  (1073741824, 6), // 32
-                                                  (1291467969, 6), // 33
-                                                  (1544804416, 6), // 34
-                                                  (1838265625, 6), // 35
-                                                  (2176782336, 6) /* 36 */];
 
-    assert!(2 <= radix && radix <= 36, "The radix must be within 2...36");
-    BASES[radix as usize]
+    match big_digit::BITS {
+        32  => {
+            const BASES: [(u32, usize); 37] = [(0, 0), (0, 0),
+                (0,                     0), // 2
+                (3486784401,            20),// 3
+                (0,                     0), // 4
+                (1220703125,            13),// 5
+                (2176782336,            12),// 6
+                (1977326743,            11),// 7
+                (0,                     0), // 8
+                (3486784401,            10),// 9
+                (1000000000,            9), // 10
+                (2357947691,            9), // 11
+                (429981696,             8), // 12
+                (815730721,             8), // 13
+                (1475789056,            8), // 14
+                (2562890625,            8), // 15
+                (0,                     0), // 16
+                (410338673,             7), // 17
+                (612220032,             7), // 18
+                (893871739,             7), // 19
+                (1280000000,            7), // 20
+                (1801088541,            7), // 21
+                (2494357888,            7), // 22
+                (3404825447,            7), // 23
+                (191102976,             6), // 24
+                (244140625,             6), // 25
+                (308915776,             6), // 26
+                (387420489,             6), // 27
+                (481890304,             6), // 28
+                (594823321,             6), // 29
+                (729000000,             6), // 30
+                (887503681,             6), // 31
+                (0,                     0), // 32
+                (1291467969,            6), // 33
+                (1544804416,            6), // 34
+                (1838265625,            6), // 35
+                (2176782336,            6)  // 36
+            ];
+
+            let (base, power) = BASES[radix as usize];
+            (base as BigDigit, power)
+        }
+        64  => {
+            const BASES: [(u64, usize); 37] = [(0, 0), (0, 0),
+                (9223372036854775808,	63), //  2
+                (12157665459056928801,	40), //  3
+                (4611686018427387904,	31), //  4
+                (7450580596923828125,	27), //  5
+                (4738381338321616896,	24), //  6
+                (3909821048582988049,	22), //  7
+                (9223372036854775808,	21), //  8
+                (12157665459056928801,	20), //  9
+                (10000000000000000000,	19), // 10
+                (5559917313492231481,	18), // 11
+                (2218611106740436992,	17), // 12
+                (8650415919381337933,	17), // 13
+                (2177953337809371136,	16), // 14
+                (6568408355712890625,	16), // 15
+                (1152921504606846976,	15), // 16
+                (2862423051509815793,	15), // 17
+                (6746640616477458432,	15), // 18
+                (15181127029874798299,	15), // 19
+                (1638400000000000000,	14), // 20
+                (3243919932521508681,	14), // 21
+                (6221821273427820544,	14), // 22
+                (11592836324538749809,	14), // 23
+                (876488338465357824,	13), // 24
+                (1490116119384765625,	13), // 25
+                (2481152873203736576,	13), // 26
+                (4052555153018976267,	13), // 27
+                (6502111422497947648,	13), // 28
+                (10260628712958602189,	13), // 29
+                (15943230000000000000,	13), // 30
+                (787662783788549761,	12), // 31
+                (1152921504606846976,	12), // 32
+                (1667889514952984961,	12), // 33
+                (2386420683693101056,	12), // 34
+                (3379220508056640625,	12), // 35
+                (4738381338321616896,	12), // 36
+            ];
+
+            let (base, power) = BASES[radix as usize];
+            (base as BigDigit, power)
+        }
+        _   => panic!("Invalid bigdigit size")
+    }
 }
 
 /// A Sign is a `BigInt`'s composing element.
@@ -3459,8 +3550,8 @@ mod biguint_tests {
     fn test_convert_i64() {
         fn check(b1: BigUint, i: i64) {
             let b2: BigUint = FromPrimitive::from_i64(i).unwrap();
-            assert!(b1 == b2);
-            assert!(b1.to_i64().unwrap() == i);
+            assert_eq!(b1, b2);
+            assert_eq!(b1.to_i64().unwrap(), i);
         }
 
         check(Zero::zero(), 0);
@@ -3484,8 +3575,8 @@ mod biguint_tests {
     fn test_convert_u64() {
         fn check(b1: BigUint, u: u64) {
             let b2: BigUint = FromPrimitive::from_u64(u).unwrap();
-            assert!(b1 == b2);
-            assert!(b1.to_u64().unwrap() == u);
+            assert_eq!(b1, b2);
+            assert_eq!(b1.to_u64().unwrap(), u);
         }
 
         check(Zero::zero(), 0);
@@ -3976,6 +4067,7 @@ mod biguint_tests {
                     format!("2{}1", repeat("0").take(bits / 2 - 1).collect::<String>())),
                    (10,
                     match bits {
+                       64 => "36893488147419103233".to_string(),
                        32 => "8589934593".to_string(),
                        16 => "131073".to_string(),
                        _ => panic!(),
@@ -3993,12 +4085,14 @@ mod biguint_tests {
                             repeat("0").take(bits / 2 - 1).collect::<String>())),
                    (8,
                     match bits {
+                       64 => "14000000000000000000004000000000000000000001".to_string(),
                        32 => "6000000000100000000001".to_string(),
                        16 => "140000400001".to_string(),
                        _ => panic!(),
                    }),
                    (10,
                     match bits {
+                       64 => "1020847100762815390427017310442723737601".to_string(),
                        32 => "55340232229718589441".to_string(),
                        16 => "12885032961".to_string(),
                        _ => panic!(),