Implement optimized sqrt, cbrt methods
This commit overrides default implementations of Roots::sqrt and Roots::cbrt for BigInt and BigUint with optimized ones. It also improves tests and resolves minor inconsistencies. Signed-off-by: Manca Bizjak <manca.bizjak@xlab.si>
This commit is contained in:
parent
1f2590656b
commit
2b473e9403
@ -11,7 +11,7 @@ use std::mem::replace;
|
||||
use test::Bencher;
|
||||
use num_bigint::{BigInt, BigUint, RandBigInt};
|
||||
use num_traits::{Zero, One, FromPrimitive, Num};
|
||||
use rand::{SeedableRng, StdRng, Rng};
|
||||
use rand::{SeedableRng, StdRng};
|
||||
|
||||
fn get_rng() -> StdRng {
|
||||
let mut seed = [0; 32];
|
||||
@ -361,14 +361,9 @@ fn roots_cbrt(b: &mut Bencher) {
|
||||
}
|
||||
|
||||
#[bench]
|
||||
fn roots_nth(b: &mut Bencher) {
|
||||
fn roots_nth_100(b: &mut Bencher) {
|
||||
let mut rng = get_rng();
|
||||
let x = rng.gen_biguint(2048);
|
||||
// Although n is u32, here we limit it to the set of u8 values since it
|
||||
// hugely impacts the performance of nth_root due to exponentiation to
|
||||
// the power of n-1. Using very large values for n is also not very realistic,
|
||||
// and any n > x's bit size produces 1 as a result anyway.
|
||||
let n: u8 = rng.gen();
|
||||
|
||||
b.iter(|| { x.nth_root(n as u32) });
|
||||
b.iter(|| x.nth_root(100));
|
||||
}
|
||||
|
@ -1805,10 +1805,20 @@ impl Integer for BigInt {
|
||||
impl Roots for BigInt {
|
||||
fn nth_root(&self, n: u32) -> Self {
|
||||
assert!(!(self.is_negative() && n.is_even()),
|
||||
"n-th root is undefined for number (n={})", n);
|
||||
"root of degree {} is imaginary", n);
|
||||
|
||||
BigInt::from_biguint(self.sign, self.data.nth_root(n))
|
||||
}
|
||||
|
||||
fn sqrt(&self) -> Self {
|
||||
assert!(!self.is_negative(), "square root is imaginary");
|
||||
|
||||
BigInt::from_biguint(self.sign, self.data.sqrt())
|
||||
}
|
||||
|
||||
fn cbrt(&self) -> Self {
|
||||
BigInt::from_biguint(self.sign, self.data.cbrt())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToPrimitive for BigInt {
|
||||
|
@ -1027,32 +1027,30 @@ impl Integer for BigUint {
|
||||
}
|
||||
|
||||
impl Roots for BigUint {
|
||||
// nth_root, sqrt and cbrt use Newton's method to compute
|
||||
// principal root of a given degree for a given integer.
|
||||
|
||||
// Reference:
|
||||
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14
|
||||
fn nth_root(&self, n: u32) -> Self {
|
||||
assert!(n > 0, "n must be at least 1");
|
||||
assert!(n > 0, "root degree n must be at least 1");
|
||||
|
||||
let one = BigUint::one();
|
||||
|
||||
// Trivial cases
|
||||
if self.is_zero() {
|
||||
return BigUint::zero();
|
||||
if self.is_zero() || self.is_one() {
|
||||
return self.clone()
|
||||
}
|
||||
|
||||
if self.is_one() {
|
||||
return one;
|
||||
match n { // Optimize for small n
|
||||
1 => return self.clone(),
|
||||
2 => return self.sqrt(),
|
||||
3 => return self.cbrt(),
|
||||
_ => (),
|
||||
}
|
||||
|
||||
let n = n as usize;
|
||||
let n_min_1 = (n as usize) - 1;
|
||||
let n_min_1 = n - 1;
|
||||
|
||||
// Newton's method to compute the nth root of an integer.
|
||||
//
|
||||
// Reference:
|
||||
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.14
|
||||
//
|
||||
// Set initial guess to something definitely >= floor(nth_root of self)
|
||||
// but as low as possible to speed up convergence.
|
||||
let bit_len = self.len() * big_digit::BITS;
|
||||
let guess = one << (bit_len/n + 1);
|
||||
let guess = BigUint::one() << (bit_len/n + 1);
|
||||
|
||||
let mut u = guess;
|
||||
let mut s: BigUint;
|
||||
@ -1062,7 +1060,6 @@ impl Roots for BigUint {
|
||||
let q = self / pow(s.clone(), n_min_1);
|
||||
let t: BigUint = n_min_1 * &s + q;
|
||||
|
||||
// Compute the candidate value for next iteration
|
||||
u = t / n;
|
||||
|
||||
if u >= s { break; }
|
||||
@ -1070,6 +1067,54 @@ impl Roots for BigUint {
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
// Reference:
|
||||
// Brent & Zimmermann, Modern Computer Arithmetic, v0.5.9, Algorithm 1.13
|
||||
fn sqrt(&self) -> Self {
|
||||
if self.is_zero() || self.is_one() {
|
||||
return self.clone()
|
||||
}
|
||||
|
||||
let bit_len = self.len() * big_digit::BITS;
|
||||
let guess = BigUint::one() << (bit_len/2 + 1);
|
||||
|
||||
let mut u = guess;
|
||||
let mut s: BigUint;
|
||||
|
||||
loop {
|
||||
s = u;
|
||||
let q = self / &s;
|
||||
let t: BigUint = &s + q;
|
||||
u = t >> 1;
|
||||
|
||||
if u >= s { break; }
|
||||
}
|
||||
|
||||
s
|
||||
}
|
||||
|
||||
fn cbrt(&self) -> Self {
|
||||
if self.is_zero() || self.is_one() {
|
||||
return self.clone()
|
||||
}
|
||||
|
||||
let bit_len = self.len() * big_digit::BITS;
|
||||
let guess = BigUint::one() << (bit_len/3 + 1);
|
||||
|
||||
let mut u = guess;
|
||||
let mut s: BigUint;
|
||||
|
||||
loop {
|
||||
s = u;
|
||||
let q = self / (&s * &s);
|
||||
let t: BigUint = (&s << 1) + q;
|
||||
u = t / 3u32;
|
||||
|
||||
if u >= s { break; }
|
||||
}
|
||||
|
||||
s
|
||||
}
|
||||
}
|
||||
|
||||
fn high_bits_to_u64(v: &BigUint) -> u64 {
|
||||
@ -1797,8 +1842,7 @@ impl BigUint {
|
||||
}
|
||||
|
||||
/// Returns the truncated principal square root of `self` --
|
||||
/// see [Roots::sqrt](Roots::sqrt).
|
||||
// struct.BigInt.html#trait.Roots
|
||||
/// see [Roots::sqrt](Roots::sqrt)
|
||||
pub fn sqrt(&self) -> Self {
|
||||
Roots::sqrt(self)
|
||||
}
|
||||
@ -1810,7 +1854,7 @@ impl BigUint {
|
||||
}
|
||||
|
||||
/// Returns the truncated principal `n`th root of `self` --
|
||||
/// See [Roots::nth_root](Roots::nth_root).
|
||||
/// see [Roots::nth_root](Roots::nth_root).
|
||||
pub fn nth_root(&self, n: u32) -> Self {
|
||||
Roots::nth_root(self, n)
|
||||
}
|
||||
|
@ -4,46 +4,53 @@ extern crate num_traits;
|
||||
|
||||
mod biguint {
|
||||
use num_bigint::BigUint;
|
||||
use num_traits::FromPrimitive;
|
||||
use num_traits::pow;
|
||||
use std::str::FromStr;
|
||||
|
||||
fn check(x: i32, n: u32, expected: i32) {
|
||||
let big_x: BigUint = FromPrimitive::from_i32(x).unwrap();
|
||||
let big_expected: BigUint = FromPrimitive::from_i32(expected).unwrap();
|
||||
fn check(x: u64, n: u32) {
|
||||
let big_x = BigUint::from(x);
|
||||
let res = big_x.nth_root(n);
|
||||
|
||||
assert_eq!(big_x.nth_root(n), big_expected);
|
||||
if n == 2 {
|
||||
assert_eq!(&res, &big_x.sqrt())
|
||||
} else if n == 3 {
|
||||
assert_eq!(&res, &big_x.cbrt())
|
||||
}
|
||||
|
||||
assert!(pow(res.clone(), n as usize) <= big_x);
|
||||
assert!(pow(res.clone() + 1u32, n as usize) > big_x);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sqrt() {
|
||||
check(99, 2, 9);
|
||||
check(100, 2, 10);
|
||||
check(120, 2, 10);
|
||||
check(99, 2);
|
||||
check(100, 2);
|
||||
check(120, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cbrt() {
|
||||
check(8, 3, 2);
|
||||
check(26, 3, 2);
|
||||
check(8, 3);
|
||||
check(26, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root() {
|
||||
check(0, 1, 0);
|
||||
check(10, 1, 10);
|
||||
check(100, 4, 3);
|
||||
check(0, 1);
|
||||
check(10, 1);
|
||||
check(100, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_nth_root_n_is_zero() {
|
||||
check(4, 0, 0);
|
||||
check(4, 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root_big() {
|
||||
let x: BigUint = FromStr::from_str("123_456_789").unwrap();
|
||||
let expected : BigUint = FromPrimitive::from_i32(6).unwrap();
|
||||
let x = BigUint::from_str("123_456_789").unwrap();
|
||||
let expected = BigUint::from(6u32);
|
||||
|
||||
assert_eq!(x.nth_root(10), expected);
|
||||
}
|
||||
@ -51,34 +58,47 @@ mod biguint {
|
||||
|
||||
mod bigint {
|
||||
use num_bigint::BigInt;
|
||||
use num_traits::FromPrimitive;
|
||||
use num_traits::{Signed, pow};
|
||||
|
||||
fn check(x: i32, n: u32, expected: i32) {
|
||||
let big_x: BigInt = FromPrimitive::from_i32(x).unwrap();
|
||||
let big_expected: BigInt = FromPrimitive::from_i32(expected).unwrap();
|
||||
fn check(x: i64, n: u32) {
|
||||
let big_x = BigInt::from(x);
|
||||
let res = big_x.nth_root(n);
|
||||
|
||||
assert_eq!(big_x.nth_root(n), big_expected);
|
||||
if n == 2 {
|
||||
assert_eq!(&res, &big_x.sqrt())
|
||||
} else if n == 3 {
|
||||
assert_eq!(&res, &big_x.cbrt())
|
||||
}
|
||||
|
||||
if big_x.is_negative() {
|
||||
assert!(pow(res.clone() - 1u32, n as usize) < big_x);
|
||||
assert!(pow(res.clone(), n as usize) >= big_x);
|
||||
} else {
|
||||
assert!(pow(res.clone(), n as usize) <= big_x);
|
||||
assert!(pow(res.clone() + 1u32, n as usize) > big_x);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nth_root() {
|
||||
check(-100, 3, -4);
|
||||
check(-100, 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_nth_root_x_neg_n_even() {
|
||||
check(-100, 4, 0);
|
||||
check(-100, 4);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_sqrt_x_neg() {
|
||||
check(-4, 2, -2);
|
||||
check(-4, 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_cbrt() {
|
||||
check(-8, 3, -2);
|
||||
check(8, 3);
|
||||
check(-8, 3);
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user