Make sure BTPE is not entered when np < 10 (#1484)
This commit is contained in:
parent
66b11eb17b
commit
bc3341185e
@ -6,6 +6,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
|
||||
|
||||
## Unreleased
|
||||
- The `serde1` feature has been renamed `serde` (#1477)
|
||||
- Fix panic in Binomial (#1484)
|
||||
- Move some of the computations in Binomial from `sample` to `new` (#1484)
|
||||
|
||||
### Added
|
||||
- Add plots for `rand_distr` distributions to documentation (#1434)
|
||||
|
@ -26,10 +26,6 @@ use rand::Rng;
|
||||
///
|
||||
/// `f(k) = n!/(k! (n-k)!) p^k (1-p)^(n-k)` for `k >= 0`.
|
||||
///
|
||||
/// # Known issues
|
||||
///
|
||||
/// See documentation of [`Binomial::new`].
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The following plot of the binomial distribution illustrates the
|
||||
@ -50,10 +46,34 @@ use rand::Rng;
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Binomial {
|
||||
/// Number of trials.
|
||||
method: Method,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
enum Method {
|
||||
Binv(Binv, bool),
|
||||
Btpe(Btpe, bool),
|
||||
Poisson(crate::poisson::KnuthMethod<f64>),
|
||||
Constant(u64),
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
struct Binv {
|
||||
r: f64,
|
||||
s: f64,
|
||||
a: f64,
|
||||
n: u64,
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
struct Btpe {
|
||||
n: u64,
|
||||
/// Probability of success.
|
||||
p: f64,
|
||||
m: i64,
|
||||
p1: f64,
|
||||
}
|
||||
|
||||
/// Error type returned from [`Binomial::new`].
|
||||
@ -82,13 +102,6 @@ impl std::error::Error for Error {}
|
||||
impl Binomial {
|
||||
/// Construct a new `Binomial` with the given shape parameters `n` (number
|
||||
/// of trials) and `p` (probability of success).
|
||||
///
|
||||
/// # Known issues
|
||||
///
|
||||
/// Although this method should return an [`Error`] on invalid parameters,
|
||||
/// some (extreme) parameter combinations are known to return a [`Binomial`]
|
||||
/// object which panics when [sampled](Distribution::sample).
|
||||
/// See [#1378](https://github.com/rust-random/rand/issues/1378).
|
||||
pub fn new(n: u64, p: f64) -> Result<Binomial, Error> {
|
||||
if !(p >= 0.0) {
|
||||
return Err(Error::ProbabilityTooSmall);
|
||||
@ -96,33 +109,22 @@ impl Binomial {
|
||||
if !(p <= 1.0) {
|
||||
return Err(Error::ProbabilityTooLarge);
|
||||
}
|
||||
Ok(Binomial { n, p })
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert a `f64` to an `i64`, panicking on overflow.
|
||||
fn f64_to_i64(x: f64) -> i64 {
|
||||
assert!(x < (i64::MAX as f64));
|
||||
x as i64
|
||||
}
|
||||
|
||||
impl Distribution<u64> for Binomial {
|
||||
#[allow(clippy::many_single_char_names)] // Same names as in the reference.
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
||||
// Handle these values directly.
|
||||
if self.p == 0.0 {
|
||||
return 0;
|
||||
} else if self.p == 1.0 {
|
||||
return self.n;
|
||||
if p == 0.0 {
|
||||
return Ok(Binomial {
|
||||
method: Method::Constant(0),
|
||||
});
|
||||
}
|
||||
|
||||
// The binomial distribution is symmetrical with respect to p -> 1-p,
|
||||
// k -> n-k switch p so that it is less than 0.5 - this allows for lower
|
||||
// expected values we will just invert the result at the end
|
||||
let p = if self.p <= 0.5 { self.p } else { 1.0 - self.p };
|
||||
if p == 1.0 {
|
||||
return Ok(Binomial {
|
||||
method: Method::Constant(n),
|
||||
});
|
||||
}
|
||||
|
||||
let result;
|
||||
let q = 1. - p;
|
||||
// The binomial distribution is symmetrical with respect to p -> 1-p
|
||||
let flipped = p > 0.5;
|
||||
let p = if flipped { 1.0 - p } else { p };
|
||||
|
||||
// For small n * min(p, 1 - p), the BINV algorithm based on the inverse
|
||||
// transformation of the binomial distribution is efficient. Otherwise,
|
||||
@ -136,204 +138,253 @@ impl Distribution<u64> for Binomial {
|
||||
// Ranlib uses 30, and GSL uses 14.
|
||||
const BINV_THRESHOLD: f64 = 10.;
|
||||
|
||||
// Same value as in GSL.
|
||||
// It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
|
||||
// It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
|
||||
// When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
|
||||
const BINV_MAX_X: u64 = 110;
|
||||
|
||||
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) {
|
||||
// Use the BINV algorithm.
|
||||
let s = p / q;
|
||||
let a = ((self.n + 1) as f64) * s;
|
||||
|
||||
result = 'outer: loop {
|
||||
let mut r = q.powi(self.n as i32);
|
||||
let mut u: f64 = rng.random();
|
||||
let mut x = 0;
|
||||
|
||||
while u > r {
|
||||
u -= r;
|
||||
x += 1;
|
||||
if x > BINV_MAX_X {
|
||||
continue 'outer;
|
||||
}
|
||||
r *= a / (x as f64) - s;
|
||||
}
|
||||
break x;
|
||||
let np = n as f64 * p;
|
||||
let method = if np < BINV_THRESHOLD {
|
||||
let q = 1.0 - p;
|
||||
if q == 1.0 {
|
||||
// p is so small that this is extremely close to a Poisson distribution.
|
||||
// The flipped case cannot occur here.
|
||||
Method::Poisson(crate::poisson::KnuthMethod::new(np))
|
||||
} else {
|
||||
let s = p / q;
|
||||
Method::Binv(
|
||||
Binv {
|
||||
r: q.powf(n as f64),
|
||||
s,
|
||||
a: (n as f64 + 1.0) * s,
|
||||
n,
|
||||
},
|
||||
flipped,
|
||||
)
|
||||
}
|
||||
} else {
|
||||
// Use the BTPE algorithm.
|
||||
|
||||
// Threshold for using the squeeze algorithm. This can be freely
|
||||
// chosen based on performance. Ranlib and GSL use 20.
|
||||
const SQUEEZE_THRESHOLD: i64 = 20;
|
||||
|
||||
// Step 0: Calculate constants as functions of `n` and `p`.
|
||||
let n = self.n as f64;
|
||||
let np = n * p;
|
||||
let q = 1.0 - p;
|
||||
let npq = np * q;
|
||||
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
|
||||
let f_m = np + p;
|
||||
let m = f64_to_i64(f_m);
|
||||
// radius of triangle region, since height=1 also area of region
|
||||
let p1 = (2.195 * npq.sqrt() - 4.6 * q).floor() + 0.5;
|
||||
// tip of triangle
|
||||
let x_m = (m as f64) + 0.5;
|
||||
// left edge of triangle
|
||||
let x_l = x_m - p1;
|
||||
// right edge of triangle
|
||||
let x_r = x_m + p1;
|
||||
let c = 0.134 + 20.5 / (15.3 + (m as f64));
|
||||
// p1 + area of parallelogram region
|
||||
let p2 = p1 * (1. + 2. * c);
|
||||
Method::Btpe(Btpe { n, p, m, p1 }, flipped)
|
||||
};
|
||||
Ok(Binomial { method })
|
||||
}
|
||||
}
|
||||
|
||||
fn lambda(a: f64) -> f64 {
|
||||
a * (1. + 0.5 * a)
|
||||
/// Convert a `f64` to an `i64`, panicking on overflow.
|
||||
fn f64_to_i64(x: f64) -> i64 {
|
||||
assert!(x < (i64::MAX as f64));
|
||||
x as i64
|
||||
}
|
||||
|
||||
fn binv<R: Rng + ?Sized>(binv: Binv, flipped: bool, rng: &mut R) -> u64 {
|
||||
// Same value as in GSL.
|
||||
// It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
|
||||
// It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
|
||||
// When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
|
||||
const BINV_MAX_X: u64 = 110;
|
||||
|
||||
let sample = 'outer: loop {
|
||||
let mut r = binv.r;
|
||||
let mut u: f64 = rng.random();
|
||||
let mut x = 0;
|
||||
|
||||
while u > r {
|
||||
u -= r;
|
||||
x += 1;
|
||||
if x > BINV_MAX_X {
|
||||
continue 'outer;
|
||||
}
|
||||
r *= binv.a / (x as f64) - binv.s;
|
||||
}
|
||||
break x;
|
||||
};
|
||||
|
||||
let lambda_l = lambda((f_m - x_l) / (f_m - x_l * p));
|
||||
let lambda_r = lambda((x_r - f_m) / (x_r * q));
|
||||
// p1 + area of left tail
|
||||
let p3 = p2 + c / lambda_l;
|
||||
// p1 + area of right tail
|
||||
let p4 = p3 + c / lambda_r;
|
||||
if flipped {
|
||||
binv.n - sample
|
||||
} else {
|
||||
sample
|
||||
}
|
||||
}
|
||||
|
||||
// return value
|
||||
let mut y: i64;
|
||||
#[allow(clippy::many_single_char_names)] // Same names as in the reference.
|
||||
fn btpe<R: Rng + ?Sized>(btpe: Btpe, flipped: bool, rng: &mut R) -> u64 {
|
||||
// Threshold for using the squeeze algorithm. This can be freely
|
||||
// chosen based on performance. Ranlib and GSL use 20.
|
||||
const SQUEEZE_THRESHOLD: i64 = 20;
|
||||
|
||||
let gen_u = Uniform::new(0., p4).unwrap();
|
||||
let gen_v = Uniform::new(0., 1.).unwrap();
|
||||
// Step 0: Calculate constants as functions of `n` and `p`.
|
||||
let n = btpe.n as f64;
|
||||
let np = n * btpe.p;
|
||||
let q = 1. - btpe.p;
|
||||
let npq = np * q;
|
||||
let f_m = np + btpe.p;
|
||||
let m = btpe.m;
|
||||
// radius of triangle region, since height=1 also area of region
|
||||
let p1 = btpe.p1;
|
||||
// tip of triangle
|
||||
let x_m = (m as f64) + 0.5;
|
||||
// left edge of triangle
|
||||
let x_l = x_m - p1;
|
||||
// right edge of triangle
|
||||
let x_r = x_m + p1;
|
||||
let c = 0.134 + 20.5 / (15.3 + (m as f64));
|
||||
// p1 + area of parallelogram region
|
||||
let p2 = p1 * (1. + 2. * c);
|
||||
|
||||
loop {
|
||||
// Step 1: Generate `u` for selecting the region. If region 1 is
|
||||
// selected, generate a triangularly distributed variate.
|
||||
let u = gen_u.sample(rng);
|
||||
let mut v = gen_v.sample(rng);
|
||||
if !(u > p1) {
|
||||
y = f64_to_i64(x_m - p1 * v + u);
|
||||
break;
|
||||
}
|
||||
fn lambda(a: f64) -> f64 {
|
||||
a * (1. + 0.5 * a)
|
||||
}
|
||||
|
||||
if !(u > p2) {
|
||||
// Step 2: Region 2, parallelograms. Check if region 2 is
|
||||
// used. If so, generate `y`.
|
||||
let x = x_l + (u - p1) / c;
|
||||
v = v * c + 1.0 - (x - x_m).abs() / p1;
|
||||
if v > 1. {
|
||||
continue;
|
||||
} else {
|
||||
y = f64_to_i64(x);
|
||||
}
|
||||
} else if !(u > p3) {
|
||||
// Step 3: Region 3, left exponential tail.
|
||||
y = f64_to_i64(x_l + v.ln() / lambda_l);
|
||||
if y < 0 {
|
||||
continue;
|
||||
} else {
|
||||
v *= (u - p2) * lambda_l;
|
||||
}
|
||||
} else {
|
||||
// Step 4: Region 4, right exponential tail.
|
||||
y = f64_to_i64(x_r - v.ln() / lambda_r);
|
||||
if y > 0 && (y as u64) > self.n {
|
||||
continue;
|
||||
} else {
|
||||
v *= (u - p3) * lambda_r;
|
||||
}
|
||||
}
|
||||
let lambda_l = lambda((f_m - x_l) / (f_m - x_l * btpe.p));
|
||||
let lambda_r = lambda((x_r - f_m) / (x_r * q));
|
||||
|
||||
// Step 5: Acceptance/rejection comparison.
|
||||
let p3 = p2 + c / lambda_l;
|
||||
|
||||
// Step 5.0: Test for appropriate method of evaluating f(y).
|
||||
let k = (y - m).abs();
|
||||
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
|
||||
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
|
||||
// search from the mode.
|
||||
let s = p / q;
|
||||
let a = s * (n + 1.);
|
||||
let mut f = 1.0;
|
||||
match m.cmp(&y) {
|
||||
Ordering::Less => {
|
||||
let mut i = m;
|
||||
loop {
|
||||
i += 1;
|
||||
f *= a / (i as f64) - s;
|
||||
if i == y {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Greater => {
|
||||
let mut i = y;
|
||||
loop {
|
||||
i += 1;
|
||||
f /= a / (i as f64) - s;
|
||||
if i == m {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Equal => {}
|
||||
}
|
||||
if v > f {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
let p4 = p3 + c / lambda_r;
|
||||
|
||||
// Step 5.2: Squeezing. Check the value of ln(v) against upper and
|
||||
// lower bound of ln(f(y)).
|
||||
let k = k as f64;
|
||||
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
|
||||
let t = -0.5 * k * k / npq;
|
||||
let alpha = v.ln();
|
||||
if alpha < t - rho {
|
||||
break;
|
||||
}
|
||||
if alpha > t + rho {
|
||||
continue;
|
||||
}
|
||||
// return value
|
||||
let mut y: i64;
|
||||
|
||||
// Step 5.3: Final acceptance/rejection test.
|
||||
let x1 = (y + 1) as f64;
|
||||
let f1 = (m + 1) as f64;
|
||||
let z = (f64_to_i64(n) + 1 - m) as f64;
|
||||
let w = (f64_to_i64(n) - y + 1) as f64;
|
||||
let gen_u = Uniform::new(0., p4).unwrap();
|
||||
let gen_v = Uniform::new(0., 1.).unwrap();
|
||||
|
||||
fn stirling(a: f64) -> f64 {
|
||||
let a2 = a * a;
|
||||
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
|
||||
}
|
||||
|
||||
if alpha
|
||||
> x_m * (f1 / x1).ln()
|
||||
+ (n - (m as f64) + 0.5) * (z / w).ln()
|
||||
+ ((y - m) as f64) * (w * p / (x1 * q)).ln()
|
||||
// We use the signs from the GSL implementation, which are
|
||||
// different than the ones in the reference. According to
|
||||
// the GSL authors, the new signs were verified to be
|
||||
// correct by one of the original designers of the
|
||||
// algorithm.
|
||||
+ stirling(f1)
|
||||
+ stirling(z)
|
||||
- stirling(x1)
|
||||
- stirling(w)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
assert!(y >= 0);
|
||||
result = y as u64;
|
||||
loop {
|
||||
// Step 1: Generate `u` for selecting the region. If region 1 is
|
||||
// selected, generate a triangularly distributed variate.
|
||||
let u = gen_u.sample(rng);
|
||||
let mut v = gen_v.sample(rng);
|
||||
if !(u > p1) {
|
||||
y = f64_to_i64(x_m - p1 * v + u);
|
||||
break;
|
||||
}
|
||||
|
||||
// Invert the result for p < 0.5.
|
||||
if p != self.p {
|
||||
self.n - result
|
||||
if !(u > p2) {
|
||||
// Step 2: Region 2, parallelograms. Check if region 2 is
|
||||
// used. If so, generate `y`.
|
||||
let x = x_l + (u - p1) / c;
|
||||
v = v * c + 1.0 - (x - x_m).abs() / p1;
|
||||
if v > 1. {
|
||||
continue;
|
||||
} else {
|
||||
y = f64_to_i64(x);
|
||||
}
|
||||
} else if !(u > p3) {
|
||||
// Step 3: Region 3, left exponential tail.
|
||||
y = f64_to_i64(x_l + v.ln() / lambda_l);
|
||||
if y < 0 {
|
||||
continue;
|
||||
} else {
|
||||
v *= (u - p2) * lambda_l;
|
||||
}
|
||||
} else {
|
||||
result
|
||||
// Step 4: Region 4, right exponential tail.
|
||||
y = f64_to_i64(x_r - v.ln() / lambda_r);
|
||||
if y > 0 && (y as u64) > btpe.n {
|
||||
continue;
|
||||
} else {
|
||||
v *= (u - p3) * lambda_r;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5: Acceptance/rejection comparison.
|
||||
|
||||
// Step 5.0: Test for appropriate method of evaluating f(y).
|
||||
let k = (y - m).abs();
|
||||
if !(k > SQUEEZE_THRESHOLD && (k as f64) < 0.5 * npq - 1.) {
|
||||
// Step 5.1: Evaluate f(y) via the recursive relationship. Start the
|
||||
// search from the mode.
|
||||
let s = btpe.p / q;
|
||||
let a = s * (n + 1.);
|
||||
let mut f = 1.0;
|
||||
match m.cmp(&y) {
|
||||
Ordering::Less => {
|
||||
let mut i = m;
|
||||
loop {
|
||||
i += 1;
|
||||
f *= a / (i as f64) - s;
|
||||
if i == y {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Greater => {
|
||||
let mut i = y;
|
||||
loop {
|
||||
i += 1;
|
||||
f /= a / (i as f64) - s;
|
||||
if i == m {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ordering::Equal => {}
|
||||
}
|
||||
if v > f {
|
||||
continue;
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Step 5.2: Squeezing. Check the value of ln(v) against upper and
|
||||
// lower bound of ln(f(y)).
|
||||
let k = k as f64;
|
||||
let rho = (k / npq) * ((k * (k / 3. + 0.625) + 1. / 6.) / npq + 0.5);
|
||||
let t = -0.5 * k * k / npq;
|
||||
let alpha = v.ln();
|
||||
if alpha < t - rho {
|
||||
break;
|
||||
}
|
||||
if alpha > t + rho {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Step 5.3: Final acceptance/rejection test.
|
||||
let x1 = (y + 1) as f64;
|
||||
let f1 = (m + 1) as f64;
|
||||
let z = (f64_to_i64(n) + 1 - m) as f64;
|
||||
let w = (f64_to_i64(n) - y + 1) as f64;
|
||||
|
||||
fn stirling(a: f64) -> f64 {
|
||||
let a2 = a * a;
|
||||
(13860. - (462. - (132. - (99. - 140. / a2) / a2) / a2) / a2) / a / 166320.
|
||||
}
|
||||
|
||||
if alpha
|
||||
> x_m * (f1 / x1).ln()
|
||||
+ (n - (m as f64) + 0.5) * (z / w).ln()
|
||||
+ ((y - m) as f64) * (w * btpe.p / (x1 * q)).ln()
|
||||
// We use the signs from the GSL implementation, which are
|
||||
// different than the ones in the reference. According to
|
||||
// the GSL authors, the new signs were verified to be
|
||||
// correct by one of the original designers of the
|
||||
// algorithm.
|
||||
+ stirling(f1)
|
||||
+ stirling(z)
|
||||
- stirling(x1)
|
||||
- stirling(w)
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
assert!(y >= 0);
|
||||
let y = y as u64;
|
||||
|
||||
if flipped {
|
||||
btpe.n - y
|
||||
} else {
|
||||
y
|
||||
}
|
||||
}
|
||||
|
||||
impl Distribution<u64> for Binomial {
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
||||
match self.method {
|
||||
Method::Binv(binv_para, flipped) => binv(binv_para, flipped, rng),
|
||||
Method::Btpe(btpe_para, flipped) => btpe(btpe_para, flipped, rng),
|
||||
Method::Poisson(poisson) => poisson.sample(rng) as u64,
|
||||
Method::Constant(c) => c,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -371,6 +422,8 @@ mod test {
|
||||
test_binomial_mean_and_variance(40, 0.5, &mut rng);
|
||||
test_binomial_mean_and_variance(20, 0.7, &mut rng);
|
||||
test_binomial_mean_and_variance(20, 0.5, &mut rng);
|
||||
test_binomial_mean_and_variance(1 << 61, 1e-17, &mut rng);
|
||||
test_binomial_mean_and_variance(u64::MAX, 1e-19, &mut rng);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
Loading…
x
Reference in New Issue
Block a user