Pārlūkot izejas kodu

bigint: greatly improve to_str_radix performance

Before:
     test fac_to_string     ... bench:      18,183 ns/iter (+/- 310)
     test fib_to_string     ... bench:         819 ns/iter (+/- 8)
     test to_str_radix_02   ... bench:     204,479 ns/iter (+/- 2,826)
     test to_str_radix_08   ... bench:      68,275 ns/iter (+/- 769)
     test to_str_radix_10   ... bench:      61,809 ns/iter (+/- 907)
     test to_str_radix_16   ... bench:      51,438 ns/iter (+/- 539)
     test to_str_radix_36   ... bench:      39,939 ns/iter (+/- 976)

After:
     test fac_to_string     ... bench:       1,204 ns/iter (+/- 16)
     test fib_to_string     ... bench:         269 ns/iter (+/- 3)
     test to_str_radix_02   ... bench:       2,428 ns/iter (+/- 80)
     test to_str_radix_08   ... bench:         820 ns/iter (+/- 38)
     test to_str_radix_10   ... bench:       2,984 ns/iter (+/- 303)
     test to_str_radix_16   ... bench:         689 ns/iter (+/- 25)
     test to_str_radix_36   ... bench:       7,995 ns/iter (+/- 100)
Josh Stone 9 gadi atpakaļ
vecāks
revīzija
49529895a2
1 mainītis faili ar 173 papildinājumiem un 53 dzēšanām
  1. 173 53
      src/bigint.rs

+ 173 - 53
src/bigint.rs

@@ -66,7 +66,7 @@ use std::iter::repeat;
 use std::num::ParseIntError;
 use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Neg, Rem, Shl, Shr, Sub};
 use std::str::{self, FromStr};
-use std::{cmp, fmt, hash, mem};
+use std::{cmp, fmt, hash};
 use std::cmp::Ordering::{self, Less, Greater, Equal};
 use std::{i64, u64};
 
@@ -1194,33 +1194,121 @@ impl_to_biguint!(u16,  FromPrimitive::from_u16);
 impl_to_biguint!(u32,  FromPrimitive::from_u32);
 impl_to_biguint!(u64,  FromPrimitive::from_u64);
 
-fn to_str_radix_reversed(u: &BigUint, radix: u32) -> Vec<u8> {
-    if radix < 2 || radix > 36 {
-        panic!("invalid radix: {}", radix);
+// Extract bitwise digits that evenly divide BigDigit
+fn to_bitwise_digits_le(u: &BigUint, bits: usize) -> Vec<u8> {
+    debug_assert!(!u.is_zero() && bits <= 8 && big_digit::BITS % bits == 0);
+
+    let last_i = u.data.len() - 1;
+    let mask: BigDigit = (1 << bits) - 1;
+    let digits_per_big_digit = big_digit::BITS / bits;
+    let digits = (u.bits() + bits - 1) / bits;
+    let mut res = Vec::with_capacity(digits);
+
+    for mut r in u.data[..last_i].iter().cloned() {
+        for _ in 0..digits_per_big_digit {
+            res.push((r & mask) as u8);
+            r >>= bits;
+        }
     }
 
-    if u.is_zero() {
-        vec![b'0']
-    } else {
-        let mut res = Vec::new();
-        let mut digits = u.clone();
+    let mut r = u.data[last_i];
+    while r != 0 {
+        res.push((r & mask) as u8);
+        r >>= bits;
+    }
+
+    res
+}
+
+// Extract bitwise digits that don't evenly divide BigDigit
+fn to_inexact_bitwise_digits_le(u: &BigUint, bits: usize) -> Vec<u8> {
+    debug_assert!(!u.is_zero() && bits <= 8 && big_digit::BITS % bits != 0);
+
+    let last_i = u.data.len() - 1;
+    let mask: DoubleBigDigit = (1 << bits) - 1;
+    let digits = (u.bits() + bits - 1) / bits;
+    let mut res = Vec::with_capacity(digits);
+
+    let mut r = 0;
+    let mut rbits = 0;
+    for hi in u.data[..last_i].iter().cloned() {
+        r |= (hi as DoubleBigDigit) << rbits;
+        rbits += big_digit::BITS;
 
-        while digits != Zero::zero() {
-            let (q, r) = div_rem_digit(digits, radix as BigDigit);
-            res.push(to_digit(r as u8));
-            digits = q;
+        while rbits >= bits {
+            res.push((r & mask) as u8);
+            r >>= bits;
+            rbits -= bits;
         }
+    }
+
+    r |= (u.data[last_i] as DoubleBigDigit) << rbits;
+    while r != 0 {
+        res.push((r & mask) as u8);
+        r >>= bits;
+    }
+
+    res
+}
+
+// Extract little-endian radix digits
+#[inline(always)] // forced inline to get const-prop for radix=10
+fn to_radix_digits_le(u: &BigUint, radix: u32) -> Vec<u8> {
+    debug_assert!(!u.is_zero() && !radix.is_power_of_two());
 
-        res
+    let mut res = Vec::new();
+    let mut digits = u.clone();
+    let (base, power) = get_radix_base(radix);
+    debug_assert!(base < (1 << 32));
+    let base = base as BigDigit;
+
+    while digits.data.len() > 1 {
+        let (q, mut r) = div_rem_digit(digits, base);
+        for _ in 0..power {
+            res.push((r % radix) as u8);
+            r /= radix;
+        }
+        digits = q;
+    }
+
+    let mut r = digits.data[0];
+    while r != 0 {
+        res.push((r % radix) as u8);
+        r /= radix;
     }
+
+    res
 }
 
-fn to_digit(b: u8) -> u8 {
-    match b {
-        0 ... 9 => b'0' + b,
-        10 ... 35 => b'a' - 10 + b,
-        _ => panic!("invalid digit: {}", b)
+fn to_str_radix_reversed(u: &BigUint, radix: u32) -> Vec<u8> {
+    assert!(2 <= radix && radix <= 36, "The radix must be within 2...36");
+
+    if u.is_zero() {
+        return vec![b'0']
+    }
+
+    let mut res = if radix.is_power_of_two() {
+        // Powers of two can use bitwise masks and shifting instead of division
+        let bits = radix.trailing_zeros() as usize;
+        if big_digit::BITS % bits == 0 {
+            to_bitwise_digits_le(u, bits)
+        } else {
+            to_inexact_bitwise_digits_le(u, bits)
+        }
+    } else if radix == 10 {
+        // 10 is so common that it's worth separating out for const-propagation.
+        // Optimizers can often turn constant division into a faster multiplication.
+        to_radix_digits_le(u, 10)
+    } else {
+        to_radix_digits_le(u, radix)
+    };
+
+    // Now convert everything to ASCII digits.
+    for r in &mut res {
+        const DIGITS: &'static [u8; 36] = b"0123456789abcdefghijklmnopqrstuvwxyz";
+        *r = DIGITS[*r as usize];
     }
+    res
 }
 
 impl BigUint {
@@ -1289,24 +1377,10 @@ impl BigUint {
     /// ```
     #[inline]
     pub fn to_bytes_le(&self) -> Vec<u8> {
-        let mut result = Vec::new();
-        for word in self.data.iter() {
-            let mut w = *word;
-            for _ in 0..mem::size_of::<BigDigit>() {
-                let b = (w & 0xFF) as u8;
-                w = w >> 8;
-                result.push(b);
-            }
-        }
-
-        while let Some(&0) = result.last() {
-            result.pop();
-        }
-
-        if result.is_empty() {
+        if self.is_zero() {
             vec![0]
         } else {
-            result
+            to_bitwise_digits_le(self, 8)
         }
     }
 
@@ -1431,26 +1505,57 @@ impl BigUint {
 }
 
 // `DoubleBigDigit` size dependent
+/// Returns the greatest power of the radix <= BigDigit::MAX + 1
 #[inline]
 fn get_radix_base(radix: u32) -> (DoubleBigDigit, usize) {
-    match radix {
-        2  => (4294967296, 32),
-        3  => (3486784401, 20),
-        4  => (4294967296, 16),
-        5  => (1220703125, 13),
-        6  => (2176782336, 12),
-        7  => (1977326743, 11),
-        8  => (1073741824, 10),
-        9  => (3486784401, 10),
-        10 => (1000000000, 9),
-        11 => (2357947691, 9),
-        12 => (429981696,  8),
-        13 => (815730721,  8),
-        14 => (1475789056, 8),
-        15 => (2562890625, 8),
-        16 => (4294967296, 8),
-        _  => panic!("The radix must be within (1, 16]")
-    }
+    // To generate this table:
+    //    let target = std::u32::max as u64 + 1;
+    //    for radix in 2u64..37 {
+    //        let power = (target as f64).log(radix as f64) as u32;
+    //        let base = radix.pow(power);
+    //        println!("({:10}, {:2}), // {:2}", base, power, radix);
+    //    }
+    const BASES: [(DoubleBigDigit, usize); 37] = [
+        (0, 0), (0, 0),
+        (4294967296, 32), //  2
+        (3486784401, 20), //  3
+        (4294967296, 16), //  4
+        (1220703125, 13), //  5
+        (2176782336, 12), //  6
+        (1977326743, 11), //  7
+        (1073741824, 10), //  8
+        (3486784401, 10), //  9
+        (1000000000,  9), // 10
+        (2357947691,  9), // 11
+        ( 429981696,  8), // 12
+        ( 815730721,  8), // 13
+        (1475789056,  8), // 14
+        (2562890625,  8), // 15
+        (4294967296,  8), // 16
+        ( 410338673,  7), // 17
+        ( 612220032,  7), // 18
+        ( 893871739,  7), // 19
+        (1280000000,  7), // 20
+        (1801088541,  7), // 21
+        (2494357888,  7), // 22
+        (3404825447,  7), // 23
+        ( 191102976,  6), // 24
+        ( 244140625,  6), // 25
+        ( 308915776,  6), // 26
+        ( 387420489,  6), // 27
+        ( 481890304,  6), // 28
+        ( 594823321,  6), // 29
+        ( 729000000,  6), // 30
+        ( 887503681,  6), // 31
+        (1073741824,  6), // 32
+        (1291467969,  6), // 33
+        (1544804416,  6), // 34
+        (1838265625,  6), // 35
+        (2176782336,  6), // 36
+    ];
+
+    assert!(2 <= radix && radix <= 36, "The radix must be within 2...36");
+    BASES[radix as usize]
 }
 
 /// A Sign is a `BigInt`'s composing element.
@@ -3242,6 +3347,11 @@ mod biguint_tests {
              format!("3{}2{}1",
                      repeat("0").take(bits / 2 - 1).collect::<String>(),
                      repeat("0").take(bits / 2 - 1).collect::<String>())),
+            (8, match bits {
+                32 => "6000000000100000000001".to_string(),
+                16 => "140000400001".to_string(),
+                _ => panic!()
+            }),
             (10, match bits {
                 32 => "55340232229718589441".to_string(),
                 16 => "12885032961".to_string(),
@@ -3286,6 +3396,16 @@ mod biguint_tests {
         assert_eq!(minus_one, None);
     }
 
+    #[test]
+    fn test_all_str_radix() {
+        let n = BigUint::new((0..10).collect());
+        for radix in 2..37 {
+            let s = n.to_str_radix(radix);
+            let x = BigUint::from_str_radix(&s, radix);
+            assert_eq!(x.unwrap(), n);
+        }
+    }
+
     #[test]
     fn test_factor() {
         fn factor(n: usize) -> BigUint {