rand/rand_distr/tests/value_stability.rs
Diggory Hardy 2584f48ace
Fix pert for mode approx eq mean; use builder pattern (#1452)
- Fix #1311 (mode close to mean)
- Use a builder pattern, allowing specification via mode OR mean
2024-07-16 09:04:49 +01:00

554 lines
13 KiB
Rust

// Copyright 2018 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
use average::assert_almost_eq;
use core::fmt::Debug;
use rand::Rng;
use rand_distr::*;
fn get_rng(seed: u64) -> impl Rng {
// For tests, we want a statistically good, fast, reproducible RNG.
// PCG32 will do fine, and will be easy to embed if we ever need to.
const INC: u64 = 11634580027462260723;
rand_pcg::Pcg32::new(seed, INC)
}
/// We only assert approximate equality since some platforms do not perform
/// identically (i686-unknown-linux-gnu and most notably x86_64-pc-windows-gnu).
trait ApproxEq {
fn assert_almost_eq(&self, rhs: &Self);
}
impl ApproxEq for f32 {
fn assert_almost_eq(&self, rhs: &Self) {
assert_almost_eq!(self, rhs, 1e-6);
}
}
impl ApproxEq for f64 {
fn assert_almost_eq(&self, rhs: &Self) {
assert_almost_eq!(self, rhs, 1e-14);
}
}
impl ApproxEq for u64 {
fn assert_almost_eq(&self, rhs: &Self) {
assert_eq!(self, rhs);
}
}
impl<T: ApproxEq> ApproxEq for [T; 2] {
fn assert_almost_eq(&self, rhs: &Self) {
self[0].assert_almost_eq(&rhs[0]);
self[1].assert_almost_eq(&rhs[1]);
}
}
impl<T: ApproxEq> ApproxEq for [T; 3] {
fn assert_almost_eq(&self, rhs: &Self) {
self[0].assert_almost_eq(&rhs[0]);
self[1].assert_almost_eq(&rhs[1]);
self[2].assert_almost_eq(&rhs[2]);
}
}
fn test_samples<F: Debug + ApproxEq, D: Distribution<F>>(seed: u64, distr: D, expected: &[F]) {
let mut rng = get_rng(seed);
for val in expected {
let x = rng.sample(&distr);
x.assert_almost_eq(val);
}
}
#[test]
fn binomial_stability() {
// We have multiple code paths: np < 10, p > 0.5
test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]);
test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]);
test_samples(
353,
Binomial::new(2000, 0.6).unwrap(),
&[1194, 1208, 1192, 1210],
);
}
#[test]
fn geometric_stability() {
test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]);
test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]);
test_samples(
464,
Geometric::new(0.05).unwrap(),
&[24, 51, 81, 67, 27, 11, 7, 6],
);
test_samples(
464,
Geometric::new(0.95).unwrap(),
&[0, 0, 0, 0, 1, 0, 0, 0],
);
// expect non-random behaviour for series of pre-determined trials
test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]);
test_samples(464, Geometric::new(1.0).unwrap(), &[0; 100][..]);
}
#[test]
fn hypergeometric_stability() {
// We have multiple code paths based on the distribution's mode and sample_size
test_samples(
7221,
Hypergeometric::new(99, 33, 8).unwrap(),
&[4, 3, 2, 2, 3, 2, 3, 1],
); // Algorithm HIN
test_samples(
7221,
Hypergeometric::new(100, 50, 50).unwrap(),
&[23, 27, 26, 27, 22, 24, 31, 22],
); // Algorithm H2PE
}
#[test]
fn unit_ball_stability() {
test_samples(
2,
UnitBall,
&[
[
0.018035709265959987f64,
-0.4348771383120438,
-0.07982762085055706,
],
[
0.10588569388223945,
-0.4734350111375454,
-0.7392104908825501,
],
[
0.11060237642041049,
-0.16065642822852677,
-0.8444043930440075,
],
],
);
}
#[test]
fn unit_circle_stability() {
test_samples(
2,
UnitCircle,
&[
[-0.9965658683520504f64, -0.08280380447614634],
[-0.9790853270389644, -0.20345004884984505],
[-0.8449189758898707, 0.5348943112253227],
],
);
}
#[test]
fn unit_sphere_stability() {
test_samples(
2,
UnitSphere,
&[
[
0.03247542860231647f64,
-0.7830477442152738,
0.6211131755296027,
],
[
-0.09978440840914075,
0.9706650829833128,
-0.21875184231323952,
],
[0.2735582468624679, 0.9435374242279655, -0.1868234852870203],
],
);
}
#[test]
fn unit_disc_stability() {
test_samples(
2,
UnitDisc,
&[
[0.018035709265959987f64, -0.4348771383120438],
[-0.07982762085055706, 0.7765329819820659],
[0.21450745997299503, 0.7398636984333291],
],
);
}
#[test]
fn pareto_stability() {
test_samples(
213,
Pareto::new(1.0, 1.0).unwrap(),
&[1.0423688f32, 2.1235929, 4.132709, 1.4679428],
);
test_samples(
213,
Pareto::new(2.0, 0.5).unwrap(),
&[
9.019295276219136f64,
4.3097126018270595,
6.837815045397157,
105.8826669383772,
],
);
}
#[test]
fn poisson_stability() {
test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]);
test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]);
test_samples(
223,
Poisson::new(27.0).unwrap(),
&[28.0f32, 32.0, 36.0, 36.0],
);
}
#[test]
fn triangular_stability() {
test_samples(
860,
Triangular::new(2., 10., 3.).unwrap(),
&[
5.74373257511361f64,
7.890059162791258f64,
4.7256280652553455f64,
2.9474808121184077f64,
3.058301946314053f64,
],
);
}
#[test]
fn normal_inverse_gaussian_stability() {
test_samples(
213,
NormalInverseGaussian::new(2.0, 1.0).unwrap(),
&[0.6568966f32, 1.3744819, 2.216063, 0.11488572],
);
test_samples(
213,
NormalInverseGaussian::new(2.0, 1.0).unwrap(),
&[
0.6838707059642927f64,
2.4447306460569784,
0.2361045023235968,
1.7774534624785319,
],
);
}
#[test]
fn pert_stability() {
// mean = 4, var = 12/7
test_samples(
860,
Pert::new(2., 10.).with_mode(3.).unwrap(),
&[
4.908681667460367,
4.014196196158352,
2.6489397149197234,
3.4569780580044727,
4.242864311947118,
],
);
}
#[test]
fn inverse_gaussian_stability() {
test_samples(
213,
InverseGaussian::new(1.0, 3.0).unwrap(),
&[0.9339157f32, 1.108113, 0.50864697, 0.39849377],
);
test_samples(
213,
InverseGaussian::new(1.0, 3.0).unwrap(),
&[
1.0707604954722476f64,
0.9628140605340697,
0.4069687656468226,
0.660283852985818,
],
);
}
#[test]
fn gamma_stability() {
// Gamma has 3 cases: shape == 1, shape < 1, shape > 1
test_samples(
223,
Gamma::new(1.0, 5.0).unwrap(),
&[5.398085f32, 9.162783, 0.2300583, 1.7235851],
);
test_samples(
223,
Gamma::new(0.8, 5.0).unwrap(),
&[0.5051203f32, 0.9048302, 3.095812, 1.8566116],
);
test_samples(
223,
Gamma::new(1.1, 5.0).unwrap(),
&[
7.783878094584059f64,
1.4939528171618057,
8.638017638857592,
3.0949337228829004,
],
);
// ChiSquared has 2 cases: k == 1, k != 1
test_samples(
223,
ChiSquared::new(1.0).unwrap(),
&[
0.4893526200348249f64,
1.635249736808788,
0.5013580219361969,
0.1457735613733489,
],
);
test_samples(
223,
ChiSquared::new(0.1).unwrap(),
&[
0.014824404726978617f64,
0.021602123937134326,
0.0000003431429746851693,
0.00000002291755769542258,
],
);
test_samples(
223,
ChiSquared::new(10.0).unwrap(),
&[12.693656f32, 6.812016, 11.082001, 12.436167],
);
// FisherF has same special cases as ChiSquared on each param
test_samples(
223,
FisherF::new(1.0, 13.5).unwrap(),
&[0.32283646f32, 0.048049655, 0.0788893, 1.817178],
);
test_samples(
223,
FisherF::new(1.0, 1.0).unwrap(),
&[0.29925257f32, 3.4392934, 9.567652, 0.020074],
);
test_samples(
223,
FisherF::new(0.7, 13.5).unwrap(),
&[
3.3196593155045124f64,
0.3409169916262829,
0.03377989856426519,
0.00004041672861036937,
],
);
// StudentT has same special cases as ChiSquared
test_samples(
223,
StudentT::new(1.0).unwrap(),
&[0.54703987f32, -1.8545331, 3.093162, -0.14168274],
);
test_samples(
223,
StudentT::new(1.1).unwrap(),
&[
0.7729195887949754f64,
1.2606210611616204,
-1.7553606501113175,
-2.377641221169782,
],
);
// Beta has two special cases:
//
// 1. min(alpha, beta) <= 1
// 2. min(alpha, beta) > 1
test_samples(
223,
Beta::new(1.0, 0.8).unwrap(),
&[
0.8300703726659456,
0.8134131062097899,
0.47912589330631555,
0.25323238071138526,
],
);
test_samples(
223,
Beta::new(3.0, 1.2).unwrap(),
&[
0.49563509121756827,
0.9551305482256759,
0.5151181353461637,
0.7551732971235077,
],
);
}
#[test]
fn exponential_stability() {
test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]);
test_samples(
223,
Exp1,
&[
1.0796170642388276f64,
1.8325565304274,
0.04601166186842716,
0.3447170217100157,
],
);
test_samples(
223,
Exp::new(2.0).unwrap(),
&[0.5398085f32, 0.91627824, 0.02300583, 0.17235851],
);
test_samples(
223,
Exp::new(1.0).unwrap(),
&[
1.0796170642388276f64,
1.8325565304274,
0.04601166186842716,
0.3447170217100157,
],
);
}
#[test]
fn normal_stability() {
test_samples(
213,
StandardNormal,
&[-0.11844189f32, 0.781378, 0.06563994, -1.1932899],
);
test_samples(
213,
StandardNormal,
&[
-0.11844188827977231f64,
0.7813779637772346,
0.06563993969580051,
-1.1932899004186373,
],
);
test_samples(
213,
Normal::new(0.0, 1.0).unwrap(),
&[-0.11844189f32, 0.781378, 0.06563994, -1.1932899],
);
test_samples(
213,
Normal::new(2.0, 0.5).unwrap(),
&[
1.940779055860114f64,
2.3906889818886174,
2.0328199698479,
1.4033550497906813,
],
);
test_samples(
213,
LogNormal::new(0.0, 1.0).unwrap(),
&[0.88830346f32, 2.1844804, 1.0678421, 0.30322206],
);
test_samples(
213,
LogNormal::new(2.0, 0.5).unwrap(),
&[
6.964174338639032f64,
10.921015733601452,
7.6355881556915906,
4.068828213584092,
],
);
}
#[test]
fn weibull_stability() {
test_samples(
213,
Weibull::new(1.0, 1.0).unwrap(),
&[0.041495778f32, 0.7531094, 1.4189332, 0.38386202],
);
test_samples(
213,
Weibull::new(2.0, 0.5).unwrap(),
&[
1.1343478702739669f64,
0.29470010050655226,
0.7556151370284702,
7.877212340241561,
],
);
}
#[cfg(feature = "alloc")]
#[test]
fn dirichlet_stability() {
let mut rng = get_rng(223);
assert_eq!(
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
);
assert_eq!(
rng.sample(Dirichlet::new([8.0; 5]).unwrap()),
[
0.17684200044809556,
0.29915953935953055,
0.1832858056608014,
0.1425623503573967,
0.19815030417417595
]
);
// Test stability for the case where all alphas are less than 0.1.
assert_eq!(
rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()),
[
0.00027580456855692104,
2.296135759821706e-20,
3.004118281150937e-9,
0.9997241924273248
]
);
}
#[test]
fn cauchy_stability() {
test_samples(
353,
Cauchy::new(100f64, 10.0).unwrap(),
&[
77.93369152808678f64,
90.1606912098641,
125.31516221323625,
86.10217834773925,
],
);
// Unfortunately this test is not fully portable due to reliance on the
// system's implementation of tanf (see doc on Cauchy struct).
// We use a lower threshold of 1e-5 here.
let distr = Cauchy::new(10f32, 7.0).unwrap();
let mut rng = get_rng(353);
let expected = [15.023088, -5.446413, 3.7092876, 3.112482];
for &a in expected.iter() {
let b = rng.sample(distr);
assert_almost_eq!(a, b, 1e-5);
}
}