diff --git a/src/bigrand.rs b/src/bigrand.rs index a163bbf..f67c079 100644 --- a/src/bigrand.rs +++ b/src/bigrand.rs @@ -1,6 +1,7 @@ //! Randomization of big integers use rand::Rng; +use rand::distributions::uniform::{SampleUniform, UniformSampler}; use BigInt; use BigUint; @@ -34,8 +35,7 @@ pub trait RandBigInt { fn gen_bigint_range(&mut self, lbound: &BigInt, ubound: &BigInt) -> BigInt; } -#[cfg(any(feature = "rand", test))] -impl RandBigInt for R { +impl RandBigInt for R { fn gen_biguint(&mut self, bit_size: usize) -> BigUint { use super::big_digit::BITS; let (digits, rem) = bit_size.div_rem(&BITS); @@ -106,3 +106,40 @@ impl RandBigInt for R { } } } + + +/// The back-end implementing rand's `UniformSampler` for `BigUint`. +#[derive(Clone, Debug)] +pub struct UniformBigUint { + base: BigUint, + len: BigUint, +} + +impl UniformSampler for UniformBigUint { + type X = BigUint; + + fn new(low: Self::X, high: Self::X) -> Self { + assert!(low < high); + UniformBigUint { + len: high - &low, + base: low, + } + } + + fn new_inclusive(low: Self::X, high: Self::X) -> Self { + assert!(low <= high); + Self::new(low, high + 1u32) + } + + fn sample(&self, rng: &mut R) -> Self::X { + &self.base + rng.gen_biguint_below(&self.len) + } + + fn sample_single(low: Self::X, high: Self::X, rng: &mut R) -> Self::X { + rng.gen_biguint_range(&low, &high) + } +} + +impl SampleUniform for BigUint { + type Sampler = UniformBigUint; +} diff --git a/src/lib.rs b/src/lib.rs index 6bc15f6..922fcd9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -163,7 +163,7 @@ pub use bigint::BigInt; pub use bigint::ToBigInt; #[cfg(feature = "rand")] -pub use bigrand::RandBigInt; +pub use bigrand::{RandBigInt, UniformBigUint}; mod big_digit { /// A `BigDigit` is a `BigUint`'s composing element. diff --git a/tests/rand.rs b/tests/rand.rs index 553c263..284fa69 100644 --- a/tests/rand.rs +++ b/tests/rand.rs @@ -8,6 +8,8 @@ mod biguint { use num_bigint::{BigUint, RandBigInt}; use num_traits::Zero; use rand::thread_rng; + use rand::Rng; + use rand::distributions::Uniform; #[test] fn test_rand() { @@ -54,6 +56,29 @@ mod biguint { // Switching u and l should fail: let _n: BigUint = rng.gen_biguint_range(&u, &l); } + + #[test] + fn test_rand_uniform() { + let mut rng = thread_rng(); + + let tiny = Uniform::new(BigUint::from(236u32), BigUint::from(237u32)); + for _ in 0..10 { + assert_eq!(rng.sample(&tiny), BigUint::from(236u32)); + } + + let l = BigUint::from(403469000u32 + 2352); + let u = BigUint::from(403469000u32 + 3513); + let below = Uniform::new(BigUint::zero(), u.clone()); + let range = Uniform::new(l.clone(), u.clone()); + for _ in 0..1000 { + let n: BigUint = rng.sample(&below); + assert!(n < u); + + let n: BigUint = rng.sample(&range); + assert!(n >= l); + assert!(n < u); + } + } } mod bigint {