algorithms.rs 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603
  1. use std::borrow::Cow;
  2. use std::cmp;
  3. use std::cmp::Ordering::{self, Less, Greater, Equal};
  4. use std::iter::repeat;
  5. use std::mem;
  6. use traits;
  7. use traits::{Zero, One};
  8. use biguint::BigUint;
  9. use bigint::Sign;
  10. use bigint::Sign::{Minus, NoSign, Plus};
  11. #[allow(non_snake_case)]
  12. pub mod big_digit {
  13. /// A `BigDigit` is a `BigUint`'s composing element.
  14. pub type BigDigit = u32;
  15. /// A `DoubleBigDigit` is the internal type used to do the computations. Its
  16. /// size is the double of the size of `BigDigit`.
  17. pub type DoubleBigDigit = u64;
  18. pub const ZERO_BIG_DIGIT: BigDigit = 0;
  19. // `DoubleBigDigit` size dependent
  20. pub const BITS: usize = 32;
  21. pub const BASE: DoubleBigDigit = 1 << BITS;
  22. const LO_MASK: DoubleBigDigit = (-1i32 as DoubleBigDigit) >> BITS;
  23. #[inline]
  24. fn get_hi(n: DoubleBigDigit) -> BigDigit {
  25. (n >> BITS) as BigDigit
  26. }
  27. #[inline]
  28. fn get_lo(n: DoubleBigDigit) -> BigDigit {
  29. (n & LO_MASK) as BigDigit
  30. }
  31. /// Split one `DoubleBigDigit` into two `BigDigit`s.
  32. #[inline]
  33. pub fn from_doublebigdigit(n: DoubleBigDigit) -> (BigDigit, BigDigit) {
  34. (get_hi(n), get_lo(n))
  35. }
  36. /// Join two `BigDigit`s into one `DoubleBigDigit`
  37. #[inline]
  38. pub fn to_doublebigdigit(hi: BigDigit, lo: BigDigit) -> DoubleBigDigit {
  39. (lo as DoubleBigDigit) | ((hi as DoubleBigDigit) << BITS)
  40. }
  41. }
  42. use big_digit::{BigDigit, DoubleBigDigit};
  43. // Generic functions for add/subtract/multiply with carry/borrow:
  44. // Add with carry:
  45. #[inline]
  46. fn adc(a: BigDigit, b: BigDigit, carry: &mut BigDigit) -> BigDigit {
  47. let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) + (b as DoubleBigDigit) +
  48. (*carry as DoubleBigDigit));
  49. *carry = hi;
  50. lo
  51. }
  52. // Subtract with borrow:
  53. #[inline]
  54. fn sbb(a: BigDigit, b: BigDigit, borrow: &mut BigDigit) -> BigDigit {
  55. let (hi, lo) = big_digit::from_doublebigdigit(big_digit::BASE + (a as DoubleBigDigit) -
  56. (b as DoubleBigDigit) -
  57. (*borrow as DoubleBigDigit));
  58. // hi * (base) + lo == 1*(base) + ai - bi - borrow
  59. // => ai - bi - borrow < 0 <=> hi == 0
  60. *borrow = (hi == 0) as BigDigit;
  61. lo
  62. }
  63. #[inline]
  64. pub fn mac_with_carry(a: BigDigit, b: BigDigit, c: BigDigit, carry: &mut BigDigit) -> BigDigit {
  65. let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) +
  66. (b as DoubleBigDigit) * (c as DoubleBigDigit) +
  67. (*carry as DoubleBigDigit));
  68. *carry = hi;
  69. lo
  70. }
  71. #[inline]
  72. pub fn mul_with_carry(a: BigDigit, b: BigDigit, carry: &mut BigDigit) -> BigDigit {
  73. let (hi, lo) = big_digit::from_doublebigdigit((a as DoubleBigDigit) * (b as DoubleBigDigit) +
  74. (*carry as DoubleBigDigit));
  75. *carry = hi;
  76. lo
  77. }
  78. /// Divide a two digit numerator by a one digit divisor, returns quotient and remainder:
  79. ///
  80. /// Note: the caller must ensure that both the quotient and remainder will fit into a single digit.
  81. /// This is _not_ true for an arbitrary numerator/denominator.
  82. ///
  83. /// (This function also matches what the x86 divide instruction does).
  84. #[inline]
  85. fn div_wide(hi: BigDigit, lo: BigDigit, divisor: BigDigit) -> (BigDigit, BigDigit) {
  86. debug_assert!(hi < divisor);
  87. let lhs = big_digit::to_doublebigdigit(hi, lo);
  88. let rhs = divisor as DoubleBigDigit;
  89. ((lhs / rhs) as BigDigit, (lhs % rhs) as BigDigit)
  90. }
  91. pub fn div_rem_digit(mut a: BigUint, b: BigDigit) -> (BigUint, BigDigit) {
  92. let mut rem = 0;
  93. for d in a.data.iter_mut().rev() {
  94. let (q, r) = div_wide(rem, *d, b);
  95. *d = q;
  96. rem = r;
  97. }
  98. (a.normalize(), rem)
  99. }
  100. // Only for the Add impl:
  101. #[must_use]
  102. #[inline]
  103. pub fn __add2(a: &mut [BigDigit], b: &[BigDigit]) -> BigDigit {
  104. debug_assert!(a.len() >= b.len());
  105. let mut carry = 0;
  106. let (a_lo, a_hi) = a.split_at_mut(b.len());
  107. for (a, b) in a_lo.iter_mut().zip(b) {
  108. *a = adc(*a, *b, &mut carry);
  109. }
  110. if carry != 0 {
  111. for a in a_hi {
  112. *a = adc(*a, 0, &mut carry);
  113. if carry == 0 { break }
  114. }
  115. }
  116. carry
  117. }
  118. /// /Two argument addition of raw slices:
  119. /// a += b
  120. ///
  121. /// The caller _must_ ensure that a is big enough to store the result - typically this means
  122. /// resizing a to max(a.len(), b.len()) + 1, to fit a possible carry.
  123. pub fn add2(a: &mut [BigDigit], b: &[BigDigit]) {
  124. let carry = __add2(a, b);
  125. debug_assert!(carry == 0);
  126. }
  127. pub fn sub2(a: &mut [BigDigit], b: &[BigDigit]) {
  128. let mut borrow = 0;
  129. let len = cmp::min(a.len(), b.len());
  130. let (a_lo, a_hi) = a.split_at_mut(len);
  131. let (b_lo, b_hi) = b.split_at(len);
  132. for (a, b) in a_lo.iter_mut().zip(b_lo) {
  133. *a = sbb(*a, *b, &mut borrow);
  134. }
  135. if borrow != 0 {
  136. for a in a_hi {
  137. *a = sbb(*a, 0, &mut borrow);
  138. if borrow == 0 { break }
  139. }
  140. }
  141. // note: we're _required_ to fail on underflow
  142. assert!(borrow == 0 && b_hi.iter().all(|x| *x == 0),
  143. "Cannot subtract b from a because b is larger than a.");
  144. }
  145. pub fn sub2rev(a: &[BigDigit], b: &mut [BigDigit]) {
  146. debug_assert!(b.len() >= a.len());
  147. let mut borrow = 0;
  148. let len = cmp::min(a.len(), b.len());
  149. let (a_lo, a_hi) = a.split_at(len);
  150. let (b_lo, b_hi) = b.split_at_mut(len);
  151. for (a, b) in a_lo.iter().zip(b_lo) {
  152. *b = sbb(*a, *b, &mut borrow);
  153. }
  154. assert!(a_hi.is_empty());
  155. // note: we're _required_ to fail on underflow
  156. assert!(borrow == 0 && b_hi.iter().all(|x| *x == 0),
  157. "Cannot subtract b from a because b is larger than a.");
  158. }
  159. pub fn sub_sign(a: &[BigDigit], b: &[BigDigit]) -> (Sign, BigUint) {
  160. // Normalize:
  161. let a = &a[..a.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
  162. let b = &b[..b.iter().rposition(|&x| x != 0).map_or(0, |i| i + 1)];
  163. match cmp_slice(a, b) {
  164. Greater => {
  165. let mut a = a.to_vec();
  166. sub2(&mut a, b);
  167. (Plus, BigUint::new(a))
  168. }
  169. Less => {
  170. let mut b = b.to_vec();
  171. sub2(&mut b, a);
  172. (Minus, BigUint::new(b))
  173. }
  174. _ => (NoSign, Zero::zero()),
  175. }
  176. }
  177. /// Three argument multiply accumulate:
  178. /// acc += b * c
  179. fn mac_digit(acc: &mut [BigDigit], b: &[BigDigit], c: BigDigit) {
  180. if c == 0 {
  181. return;
  182. }
  183. let mut b_iter = b.iter();
  184. let mut carry = 0;
  185. for ai in acc.iter_mut() {
  186. if let Some(bi) = b_iter.next() {
  187. *ai = mac_with_carry(*ai, *bi, c, &mut carry);
  188. } else if carry != 0 {
  189. *ai = mac_with_carry(*ai, 0, c, &mut carry);
  190. } else {
  191. break;
  192. }
  193. }
  194. assert!(carry == 0);
  195. }
  196. /// Three argument multiply accumulate:
  197. /// acc += b * c
  198. fn mac3(acc: &mut [BigDigit], b: &[BigDigit], c: &[BigDigit]) {
  199. let (x, y) = if b.len() < c.len() {
  200. (b, c)
  201. } else {
  202. (c, b)
  203. };
  204. // Karatsuba multiplication is slower than long multiplication for small x and y:
  205. //
  206. if x.len() <= 4 {
  207. for (i, xi) in x.iter().enumerate() {
  208. mac_digit(&mut acc[i..], y, *xi);
  209. }
  210. } else {
  211. /*
  212. * Karatsuba multiplication:
  213. *
  214. * The idea is that we break x and y up into two smaller numbers that each have about half
  215. * as many digits, like so (note that multiplying by b is just a shift):
  216. *
  217. * x = x0 + x1 * b
  218. * y = y0 + y1 * b
  219. *
  220. * With some algebra, we can compute x * y with three smaller products, where the inputs to
  221. * each of the smaller products have only about half as many digits as x and y:
  222. *
  223. * x * y = (x0 + x1 * b) * (y0 + y1 * b)
  224. *
  225. * x * y = x0 * y0
  226. * + x0 * y1 * b
  227. * + x1 * y0 * b
  228. * + x1 * y1 * b^2
  229. *
  230. * Let p0 = x0 * y0 and p2 = x1 * y1:
  231. *
  232. * x * y = p0
  233. * + (x0 * y1 + x1 * y0) * b
  234. * + p2 * b^2
  235. *
  236. * The real trick is that middle term:
  237. *
  238. * x0 * y1 + x1 * y0
  239. *
  240. * = x0 * y1 + x1 * y0 - p0 + p0 - p2 + p2
  241. *
  242. * = x0 * y1 + x1 * y0 - x0 * y0 - x1 * y1 + p0 + p2
  243. *
  244. * Now we complete the square:
  245. *
  246. * = -(x0 * y0 - x0 * y1 - x1 * y0 + x1 * y1) + p0 + p2
  247. *
  248. * = -((x1 - x0) * (y1 - y0)) + p0 + p2
  249. *
  250. * Let p1 = (x1 - x0) * (y1 - y0), and substitute back into our original formula:
  251. *
  252. * x * y = p0
  253. * + (p0 + p2 - p1) * b
  254. * + p2 * b^2
  255. *
  256. * Where the three intermediate products are:
  257. *
  258. * p0 = x0 * y0
  259. * p1 = (x1 - x0) * (y1 - y0)
  260. * p2 = x1 * y1
  261. *
  262. * In doing the computation, we take great care to avoid unnecessary temporary variables
  263. * (since creating a BigUint requires a heap allocation): thus, we rearrange the formula a
  264. * bit so we can use the same temporary variable for all the intermediate products:
  265. *
  266. * x * y = p2 * b^2 + p2 * b
  267. * + p0 * b + p0
  268. * - p1 * b
  269. *
  270. * The other trick we use is instead of doing explicit shifts, we slice acc at the
  271. * appropriate offset when doing the add.
  272. */
  273. /*
  274. * When x is smaller than y, it's significantly faster to pick b such that x is split in
  275. * half, not y:
  276. */
  277. let b = x.len() / 2;
  278. let (x0, x1) = x.split_at(b);
  279. let (y0, y1) = y.split_at(b);
  280. /*
  281. * We reuse the same BigUint for all the intermediate multiplies and have to size p
  282. * appropriately here: x1.len() >= x0.len and y1.len() >= y0.len():
  283. */
  284. let len = x1.len() + y1.len() + 1;
  285. let mut p = BigUint { data: vec![0; len] };
  286. // p2 = x1 * y1
  287. mac3(&mut p.data[..], x1, y1);
  288. // Not required, but the adds go faster if we drop any unneeded 0s from the end:
  289. p = p.normalize();
  290. add2(&mut acc[b..], &p.data[..]);
  291. add2(&mut acc[b * 2..], &p.data[..]);
  292. // Zero out p before the next multiply:
  293. p.data.truncate(0);
  294. p.data.extend(repeat(0).take(len));
  295. // p0 = x0 * y0
  296. mac3(&mut p.data[..], x0, y0);
  297. p = p.normalize();
  298. add2(&mut acc[..], &p.data[..]);
  299. add2(&mut acc[b..], &p.data[..]);
  300. // p1 = (x1 - x0) * (y1 - y0)
  301. // We do this one last, since it may be negative and acc can't ever be negative:
  302. let (j0_sign, j0) = sub_sign(x1, x0);
  303. let (j1_sign, j1) = sub_sign(y1, y0);
  304. match j0_sign * j1_sign {
  305. Plus => {
  306. p.data.truncate(0);
  307. p.data.extend(repeat(0).take(len));
  308. mac3(&mut p.data[..], &j0.data[..], &j1.data[..]);
  309. p = p.normalize();
  310. sub2(&mut acc[b..], &p.data[..]);
  311. },
  312. Minus => {
  313. mac3(&mut acc[b..], &j0.data[..], &j1.data[..]);
  314. },
  315. NoSign => (),
  316. }
  317. }
  318. }
  319. pub fn mul3(x: &[BigDigit], y: &[BigDigit]) -> BigUint {
  320. let len = x.len() + y.len() + 1;
  321. let mut prod = BigUint { data: vec![0; len] };
  322. mac3(&mut prod.data[..], x, y);
  323. prod.normalize()
  324. }
  325. pub fn scalar_mul(a: &mut [BigDigit], b: BigDigit) -> BigDigit {
  326. let mut carry = 0;
  327. for a in a.iter_mut() {
  328. *a = mul_with_carry(*a, b, &mut carry);
  329. }
  330. carry
  331. }
  332. pub fn div_rem(u: &BigUint, d: &BigUint) -> (BigUint, BigUint) {
  333. if d.is_zero() {
  334. panic!()
  335. }
  336. if u.is_zero() {
  337. return (Zero::zero(), Zero::zero());
  338. }
  339. if *d == One::one() {
  340. return (u.clone(), Zero::zero());
  341. }
  342. // Required or the q_len calculation below can underflow:
  343. match u.cmp(d) {
  344. Less => return (Zero::zero(), u.clone()),
  345. Equal => return (One::one(), Zero::zero()),
  346. Greater => {} // Do nothing
  347. }
  348. // This algorithm is from Knuth, TAOCP vol 2 section 4.3, algorithm D:
  349. //
  350. // First, normalize the arguments so the highest bit in the highest digit of the divisor is
  351. // set: the main loop uses the highest digit of the divisor for generating guesses, so we
  352. // want it to be the largest number we can efficiently divide by.
  353. //
  354. let shift = d.data.last().unwrap().leading_zeros() as usize;
  355. let mut a = u << shift;
  356. let b = d << shift;
  357. // The algorithm works by incrementally calculating "guesses", q0, for part of the
  358. // remainder. Once we have any number q0 such that q0 * b <= a, we can set
  359. //
  360. // q += q0
  361. // a -= q0 * b
  362. //
  363. // and then iterate until a < b. Then, (q, a) will be our desired quotient and remainder.
  364. //
  365. // q0, our guess, is calculated by dividing the last few digits of a by the last digit of b
  366. // - this should give us a guess that is "close" to the actual quotient, but is possibly
  367. // greater than the actual quotient. If q0 * b > a, we simply use iterated subtraction
  368. // until we have a guess such that q0 * b <= a.
  369. //
  370. let bn = *b.data.last().unwrap();
  371. let q_len = a.data.len() - b.data.len() + 1;
  372. let mut q = BigUint { data: vec![0; q_len] };
  373. // We reuse the same temporary to avoid hitting the allocator in our inner loop - this is
  374. // sized to hold a0 (in the common case; if a particular digit of the quotient is zero a0
  375. // can be bigger).
  376. //
  377. let mut tmp = BigUint { data: Vec::with_capacity(2) };
  378. for j in (0..q_len).rev() {
  379. /*
  380. * When calculating our next guess q0, we don't need to consider the digits below j
  381. * + b.data.len() - 1: we're guessing digit j of the quotient (i.e. q0 << j) from
  382. * digit bn of the divisor (i.e. bn << (b.data.len() - 1) - so the product of those
  383. * two numbers will be zero in all digits up to (j + b.data.len() - 1).
  384. */
  385. let offset = j + b.data.len() - 1;
  386. if offset >= a.data.len() {
  387. continue;
  388. }
  389. /* just avoiding a heap allocation: */
  390. let mut a0 = tmp;
  391. a0.data.truncate(0);
  392. a0.data.extend(a.data[offset..].iter().cloned());
  393. /*
  394. * q0 << j * big_digit::BITS is our actual quotient estimate - we do the shifts
  395. * implicitly at the end, when adding and subtracting to a and q. Not only do we
  396. * save the cost of the shifts, the rest of the arithmetic gets to work with
  397. * smaller numbers.
  398. */
  399. let (mut q0, _) = div_rem_digit(a0, bn);
  400. let mut prod = &b * &q0;
  401. while cmp_slice(&prod.data[..], &a.data[j..]) == Greater {
  402. let one: BigUint = One::one();
  403. q0 = q0 - one;
  404. prod = prod - &b;
  405. }
  406. add2(&mut q.data[j..], &q0.data[..]);
  407. sub2(&mut a.data[j..], &prod.data[..]);
  408. a = a.normalize();
  409. tmp = q0;
  410. }
  411. debug_assert!(a < b);
  412. (q.normalize(), a >> shift)
  413. }
  414. /// Find last set bit
  415. /// fls(0) == 0, fls(u32::MAX) == 32
  416. pub fn fls<T: traits::PrimInt>(v: T) -> usize {
  417. mem::size_of::<T>() * 8 - v.leading_zeros() as usize
  418. }
  419. pub fn ilog2<T: traits::PrimInt>(v: T) -> usize {
  420. fls(v) - 1
  421. }
  422. #[inline]
  423. pub fn biguint_shl(n: Cow<BigUint>, bits: usize) -> BigUint {
  424. let n_unit = bits / big_digit::BITS;
  425. let mut data = match n_unit {
  426. 0 => n.into_owned().data,
  427. _ => {
  428. let len = n_unit + n.data.len() + 1;
  429. let mut data = Vec::with_capacity(len);
  430. data.extend(repeat(0).take(n_unit));
  431. data.extend(n.data.iter().cloned());
  432. data
  433. }
  434. };
  435. let n_bits = bits % big_digit::BITS;
  436. if n_bits > 0 {
  437. let mut carry = 0;
  438. for elem in data[n_unit..].iter_mut() {
  439. let new_carry = *elem >> (big_digit::BITS - n_bits);
  440. *elem = (*elem << n_bits) | carry;
  441. carry = new_carry;
  442. }
  443. if carry != 0 {
  444. data.push(carry);
  445. }
  446. }
  447. BigUint::new(data)
  448. }
  449. #[inline]
  450. pub fn biguint_shr(n: Cow<BigUint>, bits: usize) -> BigUint {
  451. let n_unit = bits / big_digit::BITS;
  452. if n_unit >= n.data.len() {
  453. return Zero::zero();
  454. }
  455. let mut data = match n_unit {
  456. 0 => n.into_owned().data,
  457. _ => n.data[n_unit..].to_vec(),
  458. };
  459. let n_bits = bits % big_digit::BITS;
  460. if n_bits > 0 {
  461. let mut borrow = 0;
  462. for elem in data.iter_mut().rev() {
  463. let new_borrow = *elem << (big_digit::BITS - n_bits);
  464. *elem = (*elem >> n_bits) | borrow;
  465. borrow = new_borrow;
  466. }
  467. }
  468. BigUint::new(data)
  469. }
  470. pub fn cmp_slice(a: &[BigDigit], b: &[BigDigit]) -> Ordering {
  471. debug_assert!(a.last() != Some(&0));
  472. debug_assert!(b.last() != Some(&0));
  473. let (a_len, b_len) = (a.len(), b.len());
  474. if a_len < b_len {
  475. return Less;
  476. }
  477. if a_len > b_len {
  478. return Greater;
  479. }
  480. for (&ai, &bi) in a.iter().rev().zip(b.iter().rev()) {
  481. if ai < bi {
  482. return Less;
  483. }
  484. if ai > bi {
  485. return Greater;
  486. }
  487. }
  488. return Equal;
  489. }
  490. #[cfg(test)]
  491. mod algorithm_tests {
  492. use {BigDigit, BigUint, BigInt};
  493. use Sign::Plus;
  494. use traits::Num;
  495. #[test]
  496. fn test_sub_sign() {
  497. use super::sub_sign;
  498. fn sub_sign_i(a: &[BigDigit], b: &[BigDigit]) -> BigInt {
  499. let (sign, val) = sub_sign(a, b);
  500. BigInt::from_biguint(sign, val)
  501. }
  502. let a = BigUint::from_str_radix("265252859812191058636308480000000", 10).unwrap();
  503. let b = BigUint::from_str_radix("26525285981219105863630848000000", 10).unwrap();
  504. let a_i = BigInt::from_biguint(Plus, a.clone());
  505. let b_i = BigInt::from_biguint(Plus, b.clone());
  506. assert_eq!(sub_sign_i(&a.data[..], &b.data[..]), &a_i - &b_i);
  507. assert_eq!(sub_sign_i(&b.data[..], &a.data[..]), &b_i - &a_i);
  508. }
  509. }