Use better initial guesses for Roots
This commit is contained in:
parent
97b0b17145
commit
07a22d447b
159
src/biguint.rs
159
src/biguint.rs
@ -1397,6 +1397,35 @@ impl Integer for BigUint {
|
||||
}
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn fixpoint<F>(mut x: BigUint, max_bits: usize, f: F) -> BigUint
|
||||
where
|
||||
F: Fn(&BigUint) -> BigUint,
|
||||
{
|
||||
let mut xn = f(&x);
|
||||
|
||||
// If the value increased, then the initial guess must have been low.
|
||||
// Repeat until we reverse course.
|
||||
while x < xn {
|
||||
// Sometimes an increase will go way too far, especially with large
|
||||
// powers, and then take a long time to walk back. We know an upper
|
||||
// bound based on bit size, so saturate on that.
|
||||
x = if xn.bits() > max_bits {
|
||||
BigUint::one() << max_bits
|
||||
} else {
|
||||
xn
|
||||
};
|
||||
xn = f(&x);
|
||||
}
|
||||
|
||||
// Now keep repeating while the estimate is decreasing.
|
||||
while x > xn {
|
||||
x = xn;
|
||||
xn = f(&x);
|
||||
}
|
||||
x
|
||||
}
|
||||
|
||||
impl Roots for BigUint {
|
||||
// nth_root, sqrt and cbrt use Newton's method to compute
|
||||
// principal root of a given degree for a given integer.
|
||||
@ -1418,27 +1447,42 @@ impl Roots for BigUint {
|
||||
_ => (),
|
||||
}
|
||||
|
||||
let n = n as usize;
|
||||
let n_min_1 = n - 1;
|
||||
|
||||
let guess = BigUint::one() << (self.bits() / n + 1);
|
||||
|
||||
let mut u = guess;
|
||||
let mut s: BigUint;
|
||||
|
||||
loop {
|
||||
s = u;
|
||||
let q = self / s.pow(n_min_1);
|
||||
let t: BigUint = n_min_1 * &s + q;
|
||||
|
||||
u = t / n;
|
||||
|
||||
if u >= s {
|
||||
break;
|
||||
}
|
||||
// The root of non-zero values less than 2ⁿ can only be 1.
|
||||
let bits = self.bits();
|
||||
if bits <= n as usize {
|
||||
return BigUint::one()
|
||||
}
|
||||
|
||||
s
|
||||
// If we fit in `u64`, compute the root that way.
|
||||
if let Some(x) = self.to_u64() {
|
||||
return x.nth_root(n).into();
|
||||
}
|
||||
|
||||
let max_bits = bits / n as usize + 1;
|
||||
|
||||
let guess = if let Some(f) = self.to_f64() {
|
||||
// We fit in `f64` (lossy), so get a better initial guess from that.
|
||||
BigUint::from_f64((f.ln() / f64::from(n)).exp()).unwrap()
|
||||
} else {
|
||||
// Try to guess by scaling down such that it does fit in `f64`.
|
||||
// With some (x * 2ⁿᵏ), its nth root ≈ (ⁿ√x * 2ᵏ)
|
||||
let nsz = n as usize;
|
||||
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
|
||||
let root_scale = (extra_bits + (nsz - 1)) / nsz;
|
||||
let scale = root_scale * nsz;
|
||||
if scale < bits && bits - scale > nsz {
|
||||
(self >> scale).nth_root(n) << root_scale
|
||||
} else {
|
||||
BigUint::one() << max_bits
|
||||
}
|
||||
};
|
||||
|
||||
let n_min_1 = n - 1;
|
||||
fixpoint(guess, max_bits, move |s| {
|
||||
let q = self / s.pow(n_min_1);
|
||||
let t = n_min_1 * s + q;
|
||||
t / n
|
||||
})
|
||||
}
|
||||
|
||||
// Reference:
|
||||
@ -1448,23 +1492,31 @@ impl Roots for BigUint {
|
||||
return self.clone();
|
||||
}
|
||||
|
||||
let guess = BigUint::one() << (self.bits() / 2 + 1);
|
||||
|
||||
let mut u = guess;
|
||||
let mut s: BigUint;
|
||||
|
||||
loop {
|
||||
s = u;
|
||||
let q = self / &s;
|
||||
let t: BigUint = &s + q;
|
||||
u = t >> 1;
|
||||
|
||||
if u >= s {
|
||||
break;
|
||||
}
|
||||
// If we fit in `u64`, compute the root that way.
|
||||
if let Some(x) = self.to_u64() {
|
||||
return x.sqrt().into();
|
||||
}
|
||||
|
||||
s
|
||||
let bits = self.bits();
|
||||
let max_bits = bits / 2 as usize + 1;
|
||||
|
||||
let guess = if let Some(f) = self.to_f64() {
|
||||
// We fit in `f64` (lossy), so get a better initial guess from that.
|
||||
BigUint::from_f64(f.sqrt()).unwrap()
|
||||
} else {
|
||||
// Try to guess by scaling down such that it does fit in `f64`.
|
||||
// With some (x * 2²ᵏ), its sqrt ≈ (√x * 2ᵏ)
|
||||
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
|
||||
let root_scale = (extra_bits + 1) / 2;
|
||||
let scale = root_scale * 2;
|
||||
(self >> scale).sqrt() << root_scale
|
||||
};
|
||||
|
||||
fixpoint(guess, max_bits, move |s| {
|
||||
let q = self / s;
|
||||
let t = s + q;
|
||||
t >> 1
|
||||
})
|
||||
}
|
||||
|
||||
fn cbrt(&self) -> Self {
|
||||
@ -1472,23 +1524,32 @@ impl Roots for BigUint {
|
||||
return self.clone();
|
||||
}
|
||||
|
||||
let guess = BigUint::one() << (self.bits() / 3 + 1);
|
||||
|
||||
let mut u = guess;
|
||||
let mut s: BigUint;
|
||||
|
||||
loop {
|
||||
s = u;
|
||||
let q = self / (&s * &s);
|
||||
let t: BigUint = (&s << 1) + q;
|
||||
u = t / 3u32;
|
||||
|
||||
if u >= s {
|
||||
break;
|
||||
}
|
||||
// If we fit in `u64`, compute the root that way.
|
||||
if let Some(x) = self.to_u64() {
|
||||
return x.cbrt().into();
|
||||
}
|
||||
|
||||
s
|
||||
let bits = self.bits();
|
||||
let max_bits = bits / 3 as usize + 1;
|
||||
|
||||
let guess = if let Some(f) = self.to_f64() {
|
||||
// We fit in `f64` (lossy), so get a better initial guess from that.
|
||||
BigUint::from_f64(f.cbrt()).unwrap()
|
||||
} else {
|
||||
// Try to guess by scaling down such that it does fit in `f64`.
|
||||
// With some (x * 2³ᵏ), its cbrt ≈ (∛x * 2ᵏ)
|
||||
let extra_bits = bits - (f64::MAX_EXP as usize - 1);
|
||||
let root_scale = (extra_bits + 2) / 3;
|
||||
let scale = root_scale * 3;
|
||||
(self >> scale).cbrt() << root_scale
|
||||
};
|
||||
|
||||
|
||||
fixpoint(guess, max_bits, move |s| {
|
||||
let q = self / (s * s);
|
||||
let t = (s << 1) + q;
|
||||
t / 3u32
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
120
tests/roots.rs
120
tests/roots.rs
@ -2,57 +2,139 @@ extern crate num_bigint_dig as num_bigint;
|
||||
extern crate num_integer;
|
||||
extern crate num_traits;
|
||||
|
||||
#[cfg(feature = "rand")]
|
||||
extern crate rand;
|
||||
|
||||
mod biguint {
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::Pow;
|
||||
use std::str::FromStr;
|
||||
use num_traits::{One, Pow, Zero};
|
||||
use std::{i32, u32};
|
||||
|
||||
fn check(x: u64, n: u32) {
|
||||
let big_x = BigUint::from(x);
|
||||
let res = big_x.nth_root(n);
|
||||
fn check<T: Into<BigUint>>(x: T, n: u32) {
|
||||
let x: BigUint = x.into();
|
||||
let root = x.nth_root(n);
|
||||
println!("check {}.nth_root({}) = {}", x, n, root);
|
||||
|
||||
if n == 2 {
|
||||
assert_eq!(&res, &big_x.sqrt())
|
||||
assert_eq!(root, x.sqrt())
|
||||
} else if n == 3 {
|
||||
assert_eq!(&res, &big_x.cbrt())
|
||||
assert_eq!(root, x.cbrt())
|
||||
}
|
||||
|
||||
assert!(res.pow(n) <= big_x);
|
||||
assert!((res + 1u32).pow(n) > big_x);
|
||||
let lo = root.pow(n);
|
||||
assert!(lo <= x);
|
||||
assert_eq!(lo.nth_root(n), root);
|
||||
if !lo.is_zero() {
|
||||
assert_eq!((&lo - 1u32).nth_root(n), &root - 1u32);
|
||||
}
|
||||
|
||||
let hi = (&root + 1u32).pow(n);
|
||||
assert!(hi > x);
|
||||
assert_eq!(hi.nth_root(n), &root + 1u32);
|
||||
assert_eq!((&hi - 1u32).nth_root(n), root);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
check(99, 2);
|
||||
check(100, 2);
|
||||
check(120, 2);
|
||||
check(99u32, 2);
|
||||
check(100u32, 2);
|
||||
check(120u32, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cbrt() {
|
||||
check(8, 3);
|
||||
check(26, 3);
|
||||
check(8u32, 3);
|
||||
check(26u32, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root() {
|
||||
check(0, 1);
|
||||
check(10, 1);
|
||||
check(100, 4);
|
||||
check(0u32, 1);
|
||||
check(10u32, 1);
|
||||
check(100u32, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_nth_root_n_is_zero() {
|
||||
check(4, 0);
|
||||
check(4u32, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root_big() {
|
||||
let x = BigUint::from_str("123_456_789").unwrap();
|
||||
let x = BigUint::from(123_456_789_u32);
|
||||
let expected = BigUint::from(6u32);
|
||||
|
||||
assert_eq!(x.nth_root(10), expected);
|
||||
check(x, 10);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root_googol() {
|
||||
let googol = BigUint::from(10u32).pow(100u32);
|
||||
|
||||
// perfect divisors of 100
|
||||
for &n in &[2, 4, 5, 10, 20, 25, 50, 100] {
|
||||
let expected = BigUint::from(10u32).pow(100u32 / n);
|
||||
assert_eq!(googol.nth_root(n), expected);
|
||||
check(googol.clone(), n);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root_twos() {
|
||||
const EXP: u32 = 12;
|
||||
const LOG2: usize = 1 << EXP;
|
||||
let x = BigUint::one() << LOG2;
|
||||
|
||||
// the perfect divisors are just powers of two
|
||||
for exp in 1..EXP + 1 {
|
||||
let n = 2u32.pow(exp);
|
||||
let expected = BigUint::one() << (LOG2 / n as usize);
|
||||
assert_eq!(x.nth_root(n), expected);
|
||||
check(x.clone(), n);
|
||||
}
|
||||
|
||||
// degenerate cases should return quickly
|
||||
assert!(x.nth_root(x.bits() as u32).is_one());
|
||||
assert!(x.nth_root(i32::MAX as u32).is_one());
|
||||
assert!(x.nth_root(u32::MAX).is_one());
|
||||
}
|
||||
|
||||
#[cfg(feature = "rand")]
|
||||
#[test]
|
||||
fn test_roots_rand() {
|
||||
use num_bigint::RandBigInt;
|
||||
use rand::{thread_rng, Rng};
|
||||
use rand::distributions::Uniform;
|
||||
|
||||
let mut rng = thread_rng();
|
||||
let bit_range = Uniform::new(0, 2048);
|
||||
let sample_bits: Vec<_> = rng.sample_iter(&bit_range).take(100).collect();
|
||||
for bits in sample_bits {
|
||||
let x = rng.gen_biguint(bits);
|
||||
for n in 2..11 {
|
||||
check(x.clone(), n);
|
||||
}
|
||||
check(x.clone(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_roots_rand1() {
|
||||
// A random input that found regressions
|
||||
let s = "575981506858479247661989091587544744717244516135539456183849\
|
||||
986593934723426343633698413178771587697273822147578889823552\
|
||||
182702908597782734558103025298880194023243541613924361007059\
|
||||
353344183590348785832467726433749431093350684849462759540710\
|
||||
026019022227591412417064179299354183441181373862905039254106\
|
||||
4781867";
|
||||
let x: BigUint = s.parse().unwrap();
|
||||
|
||||
check(x.clone(), 2);
|
||||
check(x.clone(), 3);
|
||||
check(x.clone(), 10);
|
||||
check(x.clone(), 100);
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user