Poisson: split Knuth/Rejection methods (#1493)
This commit is contained in:
parent
ef052ec539
commit
f2638201ff
@ -11,95 +11,46 @@
|
||||
// Rustfmt splits macro invocations to shorten lines; in this case longer-lines are more readable
|
||||
#![rustfmt::skip]
|
||||
|
||||
const RAND_BENCH_N: u64 = 1000;
|
||||
|
||||
use criterion::{criterion_group, criterion_main, Criterion, Throughput};
|
||||
use criterion_cycles_per_byte::CyclesPerByte;
|
||||
|
||||
use core::mem::size_of;
|
||||
|
||||
use rand::prelude::*;
|
||||
use rand_distr::*;
|
||||
|
||||
// At this time, distributions are optimised for 64-bit platforms.
|
||||
use rand_pcg::Pcg64Mcg;
|
||||
|
||||
const ITER_ELTS: u64 = 100;
|
||||
|
||||
macro_rules! distr_int {
|
||||
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
|
||||
$group.throughput(Throughput::Bytes(
|
||||
size_of::<$ty>() as u64 * RAND_BENCH_N));
|
||||
$group.bench_function($fnn, |c| {
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let distr = $distr;
|
||||
|
||||
c.iter(|| {
|
||||
let mut accum: $ty = 0;
|
||||
for _ in 0..RAND_BENCH_N {
|
||||
let x: $ty = distr.sample(&mut rng);
|
||||
accum = accum.wrapping_add(x);
|
||||
}
|
||||
accum
|
||||
});
|
||||
c.iter(|| distr.sample(&mut rng));
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! distr_float {
|
||||
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
|
||||
$group.throughput(Throughput::Bytes(
|
||||
size_of::<$ty>() as u64 * RAND_BENCH_N));
|
||||
$group.bench_function($fnn, |c| {
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let distr = $distr;
|
||||
|
||||
c.iter(|| {
|
||||
let mut accum = 0.;
|
||||
for _ in 0..RAND_BENCH_N {
|
||||
let x: $ty = distr.sample(&mut rng);
|
||||
accum += x;
|
||||
}
|
||||
accum
|
||||
});
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! distr {
|
||||
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
|
||||
$group.throughput(Throughput::Bytes(
|
||||
size_of::<$ty>() as u64 * RAND_BENCH_N));
|
||||
$group.bench_function($fnn, |c| {
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let distr = $distr;
|
||||
|
||||
c.iter(|| {
|
||||
let mut accum: u32 = 0;
|
||||
for _ in 0..RAND_BENCH_N {
|
||||
let x: $ty = distr.sample(&mut rng);
|
||||
accum = accum.wrapping_add(x as u32);
|
||||
}
|
||||
accum
|
||||
});
|
||||
c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng));
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
macro_rules! distr_arr {
|
||||
($group:ident, $fnn:expr, $ty:ty, $distr:expr) => {
|
||||
$group.throughput(Throughput::Bytes(
|
||||
size_of::<$ty>() as u64 * RAND_BENCH_N));
|
||||
$group.bench_function($fnn, |c| {
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let distr = $distr;
|
||||
|
||||
c.iter(|| {
|
||||
let mut accum: u32 = 0;
|
||||
for _ in 0..RAND_BENCH_N {
|
||||
let x: $ty = distr.sample(&mut rng);
|
||||
accum = accum.wrapping_add(x[0] as u32);
|
||||
}
|
||||
accum
|
||||
});
|
||||
c.iter(|| Distribution::<$ty>::sample(&distr, &mut rng));
|
||||
});
|
||||
};
|
||||
}
|
||||
@ -111,122 +62,126 @@ macro_rules! sample_binomial {
|
||||
}
|
||||
|
||||
fn bench(c: &mut Criterion<CyclesPerByte>) {
|
||||
{
|
||||
let mut g = c.benchmark_group("exp");
|
||||
distr_float!(g, "exp", f64, Exp::new(1.23 * 4.56).unwrap());
|
||||
distr_float!(g, "exp1_specialized", f64, Exp1);
|
||||
distr_float!(g, "exp1_general", f64, Exp::new(1.).unwrap());
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("normal");
|
||||
distr_float!(g, "normal", f64, Normal::new(-1.23, 4.56).unwrap());
|
||||
distr_float!(g, "standardnormal_specialized", f64, StandardNormal);
|
||||
distr_float!(g, "standardnormal_general", f64, Normal::new(0., 1.).unwrap());
|
||||
distr_float!(g, "log_normal", f64, LogNormal::new(-1.23, 4.56).unwrap());
|
||||
g.throughput(Throughput::Bytes(size_of::<f64>() as u64 * RAND_BENCH_N));
|
||||
g.throughput(Throughput::Elements(ITER_ELTS));
|
||||
g.bench_function("iter", |c| {
|
||||
use core::f64::consts::{E, PI};
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let distr = Normal::new(-E, PI).unwrap();
|
||||
let mut iter = distr.sample_iter(&mut rng);
|
||||
|
||||
c.iter(|| {
|
||||
let mut accum = 0.0;
|
||||
for _ in 0..RAND_BENCH_N {
|
||||
accum += iter.next().unwrap();
|
||||
}
|
||||
accum
|
||||
distr.sample_iter(&mut rng)
|
||||
.take(ITER_ELTS as usize)
|
||||
.fold(0.0, |a, r| a + r)
|
||||
});
|
||||
});
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("skew_normal");
|
||||
distr_float!(g, "shape_zero", f64, SkewNormal::new(0.0, 1.0, 0.0).unwrap());
|
||||
distr_float!(g, "shape_positive", f64, SkewNormal::new(0.0, 1.0, 100.0).unwrap());
|
||||
distr_float!(g, "shape_negative", f64, SkewNormal::new(0.0, 1.0, -100.0).unwrap());
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("gamma");
|
||||
distr_float!(g, "gamma_large_shape", f64, Gamma::new(10., 1.0).unwrap());
|
||||
distr_float!(g, "gamma_small_shape", f64, Gamma::new(0.1, 1.0).unwrap());
|
||||
distr_float!(g, "beta_small_param", f64, Beta::new(0.1, 0.1).unwrap());
|
||||
distr_float!(g, "beta_large_param_similar", f64, Beta::new(101., 95.).unwrap());
|
||||
distr_float!(g, "beta_large_param_different", f64, Beta::new(10., 1000.).unwrap());
|
||||
distr_float!(g, "beta_mixed_param", f64, Beta::new(0.5, 100.).unwrap());
|
||||
}
|
||||
distr_float!(g, "large_shape", f64, Gamma::new(10., 1.0).unwrap());
|
||||
distr_float!(g, "small_shape", f64, Gamma::new(0.1, 1.0).unwrap());
|
||||
g.finish();
|
||||
|
||||
let mut g = c.benchmark_group("beta");
|
||||
distr_float!(g, "small_param", f64, Beta::new(0.1, 0.1).unwrap());
|
||||
distr_float!(g, "large_param_similar", f64, Beta::new(101., 95.).unwrap());
|
||||
distr_float!(g, "large_param_different", f64, Beta::new(10., 1000.).unwrap());
|
||||
distr_float!(g, "mixed_param", f64, Beta::new(0.5, 100.).unwrap());
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("cauchy");
|
||||
distr_float!(g, "cauchy", f64, Cauchy::new(4.2, 6.9).unwrap());
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("triangular");
|
||||
distr_float!(g, "triangular", f64, Triangular::new(0., 1., 0.9).unwrap());
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("geometric");
|
||||
distr_int!(g, "geometric", u64, Geometric::new(0.5).unwrap());
|
||||
distr_int!(g, "standard_geometric", u64, StandardGeometric);
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("weighted");
|
||||
distr_int!(g, "weighted_i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "weighted_u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "weighted_f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
|
||||
distr_int!(g, "weighted_large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());
|
||||
distr_int!(g, "weighted_alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "weighted_alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "weighted_alias_method_f64", usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
|
||||
distr_int!(g, "weighted_alias_method_large_set", usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());
|
||||
}
|
||||
distr_int!(g, "i8", usize, WeightedIndex::new([1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "u32", usize, WeightedIndex::new([1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "f64", usize, WeightedIndex::new([1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
|
||||
distr_int!(g, "large_set", usize, WeightedIndex::new((0..10000).rev().chain(1..10001)).unwrap());
|
||||
distr_int!(g, "alias_method_i8", usize, WeightedAliasIndex::new(vec![1i8, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "alias_method_u32", usize, WeightedAliasIndex::new(vec![1u32, 2, 3, 4, 12, 0, 2, 1]).unwrap());
|
||||
distr_int!(g, "alias_method_f64", usize, WeightedAliasIndex::new(vec![1.0f64, 0.001, 1.0/3.0, 4.01, 0.0, 3.3, 22.0, 0.001]).unwrap());
|
||||
distr_int!(g, "alias_method_large_set", usize, WeightedAliasIndex::new((0..10000).rev().chain(1..10001).collect()).unwrap());
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("binomial");
|
||||
sample_binomial!(g, "binomial", 20, 0.7);
|
||||
sample_binomial!(g, "binomial_small", 1_000_000, 1e-30);
|
||||
sample_binomial!(g, "binomial_1", 1, 0.9);
|
||||
sample_binomial!(g, "binomial_10", 10, 0.9);
|
||||
sample_binomial!(g, "binomial_100", 100, 0.99);
|
||||
sample_binomial!(g, "binomial_1000", 1000, 0.01);
|
||||
sample_binomial!(g, "binomial_1e12", 1_000_000_000_000, 0.2);
|
||||
}
|
||||
sample_binomial!(g, "small", 1_000_000, 1e-30);
|
||||
sample_binomial!(g, "1", 1, 0.9);
|
||||
sample_binomial!(g, "10", 10, 0.9);
|
||||
sample_binomial!(g, "100", 100, 0.99);
|
||||
sample_binomial!(g, "1000", 1000, 0.01);
|
||||
sample_binomial!(g, "1e12", 1_000_000_000_000, 0.2);
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("poisson");
|
||||
distr_float!(g, "poisson", f64, Poisson::new(4.0).unwrap());
|
||||
for lambda in [1f64, 4.0, 10.0, 100.0].into_iter() {
|
||||
let name = format!("{lambda}");
|
||||
distr_float!(g, name, f64, Poisson::new(lambda).unwrap());
|
||||
}
|
||||
g.throughput(Throughput::Elements(ITER_ELTS));
|
||||
g.bench_function("variable", |c| {
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let ldistr = Uniform::new(0.1, 10.0).unwrap();
|
||||
|
||||
c.iter(|| {
|
||||
let l = rng.sample(ldistr);
|
||||
let distr = Poisson::new(l * l).unwrap();
|
||||
Distribution::<f64>::sample_iter(&distr, &mut rng)
|
||||
.take(ITER_ELTS as usize)
|
||||
.fold(0.0, |a, r| a + r)
|
||||
})
|
||||
});
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("zipf");
|
||||
distr_float!(g, "zipf", f64, Zipf::new(10, 1.5).unwrap());
|
||||
distr_float!(g, "zeta", f64, Zeta::new(1.5).unwrap());
|
||||
}
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("bernoulli");
|
||||
distr!(g, "bernoulli", bool, Bernoulli::new(0.18).unwrap());
|
||||
}
|
||||
g.bench_function("bernoulli", |c| {
|
||||
let mut rng = Pcg64Mcg::from_os_rng();
|
||||
let distr = Bernoulli::new(0.18).unwrap();
|
||||
c.iter(|| distr.sample(&mut rng))
|
||||
});
|
||||
g.finish();
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("circle");
|
||||
let mut g = c.benchmark_group("unit");
|
||||
distr_arr!(g, "circle", [f64; 2], UnitCircle);
|
||||
}
|
||||
|
||||
{
|
||||
let mut g = c.benchmark_group("sphere");
|
||||
distr_arr!(g, "sphere", [f64; 3], UnitSphere);
|
||||
}
|
||||
g.finish();
|
||||
}
|
||||
|
||||
criterion_group!(
|
||||
name = benches;
|
||||
config = Criterion::default().with_measurement(CyclesPerByte);
|
||||
config = Criterion::default().with_measurement(CyclesPerByte)
|
||||
.warm_up_time(core::time::Duration::from_secs(1))
|
||||
.measurement_time(core::time::Duration::from_secs(2));
|
||||
targets = bench
|
||||
);
|
||||
criterion_main!(benches);
|
||||
|
@ -211,7 +211,7 @@ mod normal;
|
||||
mod normal_inverse_gaussian;
|
||||
mod pareto;
|
||||
mod pert;
|
||||
mod poisson;
|
||||
pub(crate) mod poisson;
|
||||
mod skew_normal;
|
||||
mod student_t;
|
||||
mod triangular;
|
||||
|
@ -45,18 +45,10 @@ use rand::Rng;
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct Poisson<F>
|
||||
pub struct Poisson<F>(Method<F>)
|
||||
where
|
||||
F: Float + FloatConst,
|
||||
Standard: Distribution<F>,
|
||||
{
|
||||
lambda: F,
|
||||
// precalculated values
|
||||
exp_lambda: F,
|
||||
log_lambda: F,
|
||||
sqrt_2lambda: F,
|
||||
magic_val: F,
|
||||
}
|
||||
Standard: Distribution<F>;
|
||||
|
||||
/// Error type returned from [`Poisson::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
@ -81,6 +73,50 @@ impl fmt::Display for Error {
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub(crate) struct KnuthMethod<F> {
|
||||
exp_lambda: F,
|
||||
}
|
||||
|
||||
impl<F: Float> KnuthMethod<F> {
|
||||
pub(crate) fn new(lambda: F) -> Self {
|
||||
KnuthMethod {
|
||||
exp_lambda: (-lambda).exp(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
struct RejectionMethod<F> {
|
||||
lambda: F,
|
||||
log_lambda: F,
|
||||
sqrt_2lambda: F,
|
||||
magic_val: F,
|
||||
}
|
||||
|
||||
impl<F: Float> RejectionMethod<F> {
|
||||
pub(crate) fn new(lambda: F) -> Self {
|
||||
let log_lambda = lambda.ln();
|
||||
let sqrt_2lambda = (F::from(2.0).unwrap() * lambda).sqrt();
|
||||
let magic_val = lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda);
|
||||
RejectionMethod {
|
||||
lambda,
|
||||
log_lambda,
|
||||
sqrt_2lambda,
|
||||
magic_val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
|
||||
enum Method<F> {
|
||||
Knuth(KnuthMethod<F>),
|
||||
Rejection(RejectionMethod<F>),
|
||||
}
|
||||
|
||||
impl<F> Poisson<F>
|
||||
where
|
||||
F: Float + FloatConst,
|
||||
@ -104,14 +140,81 @@ where
|
||||
if !(lambda > F::zero()) {
|
||||
return Err(Error::ShapeTooSmall);
|
||||
}
|
||||
let log_lambda = lambda.ln();
|
||||
Ok(Poisson {
|
||||
lambda,
|
||||
exp_lambda: (-lambda).exp(),
|
||||
log_lambda,
|
||||
sqrt_2lambda: (F::from(2.0).unwrap() * lambda).sqrt(),
|
||||
magic_val: lambda * log_lambda - crate::utils::log_gamma(F::one() + lambda),
|
||||
})
|
||||
|
||||
// Use the Knuth method only for low expected values
|
||||
let method = if lambda < F::from(12.0).unwrap() {
|
||||
Method::Knuth(KnuthMethod::new(lambda))
|
||||
} else {
|
||||
Method::Rejection(RejectionMethod::new(lambda))
|
||||
};
|
||||
|
||||
Ok(Poisson(method))
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Distribution<F> for KnuthMethod<F>
|
||||
where
|
||||
F: Float + FloatConst,
|
||||
Standard: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
let mut result = F::one();
|
||||
let mut p = rng.random::<F>();
|
||||
while p > self.exp_lambda {
|
||||
p = p * rng.random::<F>();
|
||||
result = result + F::one();
|
||||
}
|
||||
result - F::one()
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Distribution<F> for RejectionMethod<F>
|
||||
where
|
||||
F: Float + FloatConst,
|
||||
Standard: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
// The algorithm from Numerical Recipes in C
|
||||
|
||||
// we use the Cauchy distribution as the comparison distribution
|
||||
// f(x) ~ 1/(1+x^2)
|
||||
let cauchy = Cauchy::new(F::zero(), F::one()).unwrap();
|
||||
let mut result;
|
||||
|
||||
loop {
|
||||
let mut comp_dev;
|
||||
|
||||
loop {
|
||||
// draw from the Cauchy distribution
|
||||
comp_dev = rng.sample(cauchy);
|
||||
// shift the peak of the comparison distribution
|
||||
result = self.sqrt_2lambda * comp_dev + self.lambda;
|
||||
// repeat the drawing until we are in the range of possible values
|
||||
if result >= F::zero() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// now the result is a random variable greater than 0 with Cauchy distribution
|
||||
// the result should be an integer value
|
||||
result = result.floor();
|
||||
|
||||
// this is the ratio of the Poisson distribution to the comparison distribution
|
||||
// the magic value scales the distribution function to a range of approximately 0-1
|
||||
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
|
||||
// this doesn't change the resulting distribution, only increases the rate of failed drawings
|
||||
let check = F::from(0.9).unwrap()
|
||||
* (F::one() + comp_dev * comp_dev)
|
||||
* (result * self.log_lambda
|
||||
- crate::utils::log_gamma(F::one() + result)
|
||||
- self.magic_val)
|
||||
.exp();
|
||||
|
||||
// check with uniform random value - if below the threshold, we are within the target distribution
|
||||
if rng.random::<F>() <= check {
|
||||
break;
|
||||
}
|
||||
}
|
||||
result
|
||||
}
|
||||
}
|
||||
|
||||
@ -122,59 +225,9 @@ where
|
||||
{
|
||||
#[inline]
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
// using the algorithm from Numerical Recipes in C
|
||||
|
||||
// for low expected values use the Knuth method
|
||||
if self.lambda < F::from(12.0).unwrap() {
|
||||
let mut result = F::one();
|
||||
let mut p = rng.random::<F>();
|
||||
while p > self.exp_lambda {
|
||||
p = p * rng.random::<F>();
|
||||
result = result + F::one();
|
||||
}
|
||||
result - F::one()
|
||||
}
|
||||
// high expected values - rejection method
|
||||
else {
|
||||
// we use the Cauchy distribution as the comparison distribution
|
||||
// f(x) ~ 1/(1+x^2)
|
||||
let cauchy = Cauchy::new(F::zero(), F::one()).unwrap();
|
||||
let mut result;
|
||||
|
||||
loop {
|
||||
let mut comp_dev;
|
||||
|
||||
loop {
|
||||
// draw from the Cauchy distribution
|
||||
comp_dev = rng.sample(cauchy);
|
||||
// shift the peak of the comparison distribution
|
||||
result = self.sqrt_2lambda * comp_dev + self.lambda;
|
||||
// repeat the drawing until we are in the range of possible values
|
||||
if result >= F::zero() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
// now the result is a random variable greater than 0 with Cauchy distribution
|
||||
// the result should be an integer value
|
||||
result = result.floor();
|
||||
|
||||
// this is the ratio of the Poisson distribution to the comparison distribution
|
||||
// the magic value scales the distribution function to a range of approximately 0-1
|
||||
// since it is not exact, we multiply the ratio by 0.9 to avoid ratios greater than 1
|
||||
// this doesn't change the resulting distribution, only increases the rate of failed drawings
|
||||
let check = F::from(0.9).unwrap()
|
||||
* (F::one() + comp_dev * comp_dev)
|
||||
* (result * self.log_lambda
|
||||
- crate::utils::log_gamma(F::one() + result)
|
||||
- self.magic_val)
|
||||
.exp();
|
||||
|
||||
// check with uniform random value - if below the threshold, we are within the target distribution
|
||||
if rng.random::<F>() <= check {
|
||||
break;
|
||||
}
|
||||
}
|
||||
result
|
||||
match &self.0 {
|
||||
Method::Knuth(method) => method.sample(rng),
|
||||
Method::Rejection(method) => method.sample(rng),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user