Use better initial guesses for Roots

This commit is contained in:
Josh Stone 2018-12-05 14:32:10 -08:00 committed by dignifiedquire
parent 97b0b17145
commit 07a22d447b
2 changed files with 211 additions and 68 deletions

View File

@ -1397,6 +1397,35 @@ impl Integer for BigUint {
}
}
#[inline]
fn fixpoint<F>(mut x: BigUint, max_bits: usize, f: F) -> BigUint
where
F: Fn(&BigUint) -> BigUint,
{
let mut xn = f(&x);
// If the value increased, then the initial guess must have been low.
// Repeat until we reverse course.
while x < xn {
// Sometimes an increase will go way too far, especially with large
// powers, and then take a long time to walk back. We know an upper
// bound based on bit size, so saturate on that.
x = if xn.bits() > max_bits {
BigUint::one() << max_bits
} else {
xn
};
xn = f(&x);
}
// Now keep repeating while the estimate is decreasing.
while x > xn {
x = xn;
xn = f(&x);
}
x
}
impl Roots for BigUint {
// nth_root, sqrt and cbrt use Newton's method to compute
// principal root of a given degree for a given integer.
@ -1418,27 +1447,42 @@ impl Roots for BigUint {
_ => (),
}
let n = n as usize;
let n_min_1 = n - 1;
let guess = BigUint::one() << (self.bits() / n + 1);
let mut u = guess;
let mut s: BigUint;
loop {
s = u;
let q = self / s.pow(n_min_1);
let t: BigUint = n_min_1 * &s + q;
u = t / n;
if u >= s {
break;
}
// The root of non-zero values less than 2ⁿ can only be 1.
let bits = self.bits();
if bits <= n as usize {
return BigUint::one()
}
s
// If we fit in `u64`, compute the root that way.
if let Some(x) = self.to_u64() {
return x.nth_root(n).into();
}
let max_bits = bits / n as usize + 1;
let guess = if let Some(f) = self.to_f64() {
// We fit in `f64` (lossy), so get a better initial guess from that.
BigUint::from_f64((f.ln() / f64::from(n)).exp()).unwrap()
} else {
// Try to guess by scaling down such that it does fit in `f64`.
// With some (x * 2ⁿᵏ), its nth root ≈ (ⁿ√x * 2ᵏ)
let nsz = n as usize;
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
let root_scale = (extra_bits + (nsz - 1)) / nsz;
let scale = root_scale * nsz;
if scale < bits && bits - scale > nsz {
(self >> scale).nth_root(n) << root_scale
} else {
BigUint::one() << max_bits
}
};
let n_min_1 = n - 1;
fixpoint(guess, max_bits, move |s| {
let q = self / s.pow(n_min_1);
let t = n_min_1 * s + q;
t / n
})
}
// Reference:
@ -1448,23 +1492,31 @@ impl Roots for BigUint {
return self.clone();
}
let guess = BigUint::one() << (self.bits() / 2 + 1);
let mut u = guess;
let mut s: BigUint;
loop {
s = u;
let q = self / &s;
let t: BigUint = &s + q;
u = t >> 1;
if u >= s {
break;
}
// If we fit in `u64`, compute the root that way.
if let Some(x) = self.to_u64() {
return x.sqrt().into();
}
s
let bits = self.bits();
let max_bits = bits / 2 as usize + 1;
let guess = if let Some(f) = self.to_f64() {
// We fit in `f64` (lossy), so get a better initial guess from that.
BigUint::from_f64(f.sqrt()).unwrap()
} else {
// Try to guess by scaling down such that it does fit in `f64`.
// With some (x * 2²ᵏ), its sqrt ≈ (√x * 2ᵏ)
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
let root_scale = (extra_bits + 1) / 2;
let scale = root_scale * 2;
(self >> scale).sqrt() << root_scale
};
fixpoint(guess, max_bits, move |s| {
let q = self / s;
let t = s + q;
t >> 1
})
}
fn cbrt(&self) -> Self {
@ -1472,23 +1524,32 @@ impl Roots for BigUint {
return self.clone();
}
let guess = BigUint::one() << (self.bits() / 3 + 1);
let mut u = guess;
let mut s: BigUint;
loop {
s = u;
let q = self / (&s * &s);
let t: BigUint = (&s << 1) + q;
u = t / 3u32;
if u >= s {
break;
}
// If we fit in `u64`, compute the root that way.
if let Some(x) = self.to_u64() {
return x.cbrt().into();
}
s
let bits = self.bits();
let max_bits = bits / 3 as usize + 1;
let guess = if let Some(f) = self.to_f64() {
// We fit in `f64` (lossy), so get a better initial guess from that.
BigUint::from_f64(f.cbrt()).unwrap()
} else {
// Try to guess by scaling down such that it does fit in `f64`.
// With some (x * 2³ᵏ), its cbrt ≈ (∛x * 2ᵏ)
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
let root_scale = (extra_bits + 2) / 3;
let scale = root_scale * 3;
(self >> scale).cbrt() << root_scale
};
fixpoint(guess, max_bits, move |s| {
let q = self / (s * s);
let t = (s << 1) + q;
t / 3u32
})
}
}

View File

@ -2,57 +2,139 @@ extern crate num_bigint_dig as num_bigint;
extern crate num_integer;
extern crate num_traits;
#[cfg(feature = "rand")]
extern crate rand;
mod biguint {
use num_bigint::BigUint;
use num_traits::Pow;
use std::str::FromStr;
use num_traits::{One, Pow, Zero};
use std::{i32, u32};
fn check(x: u64, n: u32) {
let big_x = BigUint::from(x);
let res = big_x.nth_root(n);
fn check<T: Into<BigUint>>(x: T, n: u32) {
let x: BigUint = x.into();
let root = x.nth_root(n);
println!("check {}.nth_root({}) = {}", x, n, root);
if n == 2 {
assert_eq!(&res, &big_x.sqrt())
assert_eq!(root, x.sqrt())
} else if n == 3 {
assert_eq!(&res, &big_x.cbrt())
assert_eq!(root, x.cbrt())
}
assert!(res.pow(n) <= big_x);
assert!((res + 1u32).pow(n) > big_x);
let lo = root.pow(n);
assert!(lo <= x);
assert_eq!(lo.nth_root(n), root);
if !lo.is_zero() {
assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32);
}
let hi = (&root + 1u32).pow(n);
assert!(hi > x);
assert_eq!(hi.nth_root(n), &root + 1u32);
assert_eq!((&hi - 1u32).nth_root(n), root);
}
#[test]
fn test_sqrt() {
check(99, 2);
check(100, 2);
check(120, 2);
check(99u32, 2);
check(100u32, 2);
check(120u32, 2);
}
#[test]
fn test_cbrt() {
check(8, 3);
check(26, 3);
check(8u32, 3);
check(26u32, 3);
}
#[test]
fn test_nth_root() {
check(0, 1);
check(10, 1);
check(100, 4);
check(0u32, 1);
check(10u32, 1);
check(100u32, 4);
}
#[test]
#[should_panic]
fn test_nth_root_n_is_zero() {
check(4, 0);
check(4u32, 0);
}
#[test]
fn test_nth_root_big() {
let x = BigUint::from_str("123_456_789").unwrap();
let x = BigUint::from(123_456_789_u32);
let expected = BigUint::from(6u32);
assert_eq!(x.nth_root(10), expected);
check(x, 10);
}
#[test]
fn test_nth_root_googol() {
let googol = BigUint::from(10u32).pow(100u32);
// perfect divisors of 100
for &n in &[2, 4, 5, 10, 20, 25, 50, 100] {
let expected = BigUint::from(10u32).pow(100u32 / n);
assert_eq!(googol.nth_root(n), expected);
check(googol.clone(), n);
}
}
#[test]
fn test_nth_root_twos() {
const EXP: u32 = 12;
const LOG2: usize = 1 << EXP;
let x = BigUint::one() << LOG2;
// the perfect divisors are just powers of two
for exp in 1..EXP + 1 {
let n = 2u32.pow(exp);
let expected = BigUint::one() << (LOG2 / n as usize);
assert_eq!(x.nth_root(n), expected);
check(x.clone(), n);
}
// degenerate cases should return quickly
assert!(x.nth_root(x.bits() as u32).is_one());
assert!(x.nth_root(i32::MAX as u32).is_one());
assert!(x.nth_root(u32::MAX).is_one());
}
#[cfg(feature = "rand")]
#[test]
fn test_roots_rand() {
use num_bigint::RandBigInt;
use rand::{thread_rng, Rng};
use rand::distributions::Uniform;
let mut rng = thread_rng();
let bit_range = Uniform::new(0, 2048);
let sample_bits: Vec<_> = rng.sample_iter(&bit_range).take(100).collect();
for bits in sample_bits {
let x = rng.gen_biguint(bits);
for n in 2..11 {
check(x.clone(), n);
}
check(x.clone(), 100);
}
}
#[test]
fn test_roots_rand1() {
// A random input that found regressions
let s = "575981506858479247661989091587544744717244516135539456183849\
986593934723426343633698413178771587697273822147578889823552\
182702908597782734558103025298880194023243541613924361007059\
353344183590348785832467726433749431093350684849462759540710\
026019022227591412417064179299354183441181373862905039254106\
4781867";
let x: BigUint = s.parse().unwrap();
check(x.clone(), 2);
check(x.clone(), 3);
check(x.clone(), 10);
check(x.clone(), 100);
}
}