rand_distr: split gamma module (#1464)
Move Beta, Student's t, Fisher-F, Chi-squared and Zeta distributions to their own modules.
This commit is contained in:
parent
763dbc5bbb
commit
d17ce4e0a1
298
rand_distr/src/beta.rs
Normal file
298
rand_distr/src/beta.rs
Normal file
@ -0,0 +1,298 @@
|
||||
// Copyright 2018 Developers of the Rand project.
|
||||
// Copyright 2013 The Rust Project Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Beta distribution.
|
||||
|
||||
use crate::{Distribution, Open01};
|
||||
use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde1")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The algorithm used for sampling the Beta distribution.
|
||||
///
|
||||
/// Reference:
|
||||
///
|
||||
/// R. C. H. Cheng (1978).
|
||||
/// Generating beta variates with nonintegral shape parameters.
|
||||
/// Communications of the ACM 21, 317-322.
|
||||
/// https://doi.org/10.1145/359460.359482
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
enum BetaAlgorithm<N> {
|
||||
BB(BB<N>),
|
||||
BC(BC<N>),
|
||||
}
|
||||
|
||||
/// Algorithm BB for `min(alpha, beta) > 1`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
struct BB<N> {
|
||||
alpha: N,
|
||||
beta: N,
|
||||
gamma: N,
|
||||
}
|
||||
|
||||
/// Algorithm BC for `min(alpha, beta) <= 1`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
struct BC<N> {
|
||||
alpha: N,
|
||||
beta: N,
|
||||
kappa1: N,
|
||||
kappa2: N,
|
||||
}
|
||||
|
||||
/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`.
|
||||
///
|
||||
/// The Beta distribution is a continuous probability distribution
|
||||
/// defined on the interval `[0, 1]`. It is the conjugate prior for the
|
||||
/// parameter `p` of the [`Binomial`][crate::Binomial] distribution.
|
||||
///
|
||||
/// It has two shape parameters `α` (alpha) and `β` (beta) which control
|
||||
/// the shape of the distribution. Both `a` and `β` must be greater than zero.
|
||||
/// The distribution is symmetric when `α = β`.
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the Beta distribution with various combinations
|
||||
/// of `α` and `β`.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{Distribution, Beta};
|
||||
///
|
||||
/// let beta = Beta::new(2.0, 5.0).unwrap();
|
||||
/// let v = beta.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a Beta(2, 5) distribution", v);
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct Beta<F>
|
||||
where
|
||||
F: Float,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
a: F,
|
||||
b: F,
|
||||
switched_params: bool,
|
||||
algorithm: BetaAlgorithm<F>,
|
||||
}
|
||||
|
||||
/// Error type returned from [`Beta::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub enum Error {
|
||||
/// `alpha <= 0` or `nan`.
|
||||
AlphaTooSmall,
|
||||
/// `beta <= 0` or `nan`.
|
||||
BetaTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Error::AlphaTooSmall => "alpha is not positive in beta distribution",
|
||||
Error::BetaTooSmall => "beta is not positive in beta distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl<F> Beta<F>
|
||||
where
|
||||
F: Float,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Construct an object representing the `Beta(alpha, beta)`
|
||||
/// distribution.
|
||||
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, Error> {
|
||||
if !(alpha > F::zero()) {
|
||||
return Err(Error::AlphaTooSmall);
|
||||
}
|
||||
if !(beta > F::zero()) {
|
||||
return Err(Error::BetaTooSmall);
|
||||
}
|
||||
// From now on, we use the notation from the reference,
|
||||
// i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
|
||||
let (a0, b0) = (alpha, beta);
|
||||
let (a, b, switched_params) = if a0 < b0 {
|
||||
(a0, b0, false)
|
||||
} else {
|
||||
(b0, a0, true)
|
||||
};
|
||||
if a > F::one() {
|
||||
// Algorithm BB
|
||||
let alpha = a + b;
|
||||
|
||||
let two = F::from(2.).unwrap();
|
||||
let beta_numer = alpha - two;
|
||||
let beta_denom = two * a * b - alpha;
|
||||
let beta = (beta_numer / beta_denom).sqrt();
|
||||
|
||||
let gamma = a + F::one() / beta;
|
||||
|
||||
Ok(Beta {
|
||||
a,
|
||||
b,
|
||||
switched_params,
|
||||
algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
|
||||
})
|
||||
} else {
|
||||
// Algorithm BC
|
||||
//
|
||||
// Here `a` is the maximum instead of the minimum.
|
||||
let (a, b, switched_params) = (b, a, !switched_params);
|
||||
let alpha = a + b;
|
||||
let beta = F::one() / b;
|
||||
let delta = F::one() + a - b;
|
||||
let kappa1 = delta
|
||||
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
|
||||
/ (a * beta - F::from(14. / 18.).unwrap());
|
||||
let kappa2 = F::from(0.25).unwrap()
|
||||
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
|
||||
|
||||
Ok(Beta {
|
||||
a,
|
||||
b,
|
||||
switched_params,
|
||||
algorithm: BetaAlgorithm::BC(BC {
|
||||
alpha,
|
||||
beta,
|
||||
kappa1,
|
||||
kappa2,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Distribution<F> for Beta<F>
|
||||
where
|
||||
F: Float,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
let mut w;
|
||||
match self.algorithm {
|
||||
BetaAlgorithm::BB(algo) => {
|
||||
loop {
|
||||
// 1.
|
||||
let u1 = rng.sample(Open01);
|
||||
let u2 = rng.sample(Open01);
|
||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||
w = self.a * v.exp();
|
||||
let z = u1 * u1 * u2;
|
||||
let r = algo.gamma * v - F::from(4.).unwrap().ln();
|
||||
let s = self.a + r - w;
|
||||
// 2.
|
||||
if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
|
||||
break;
|
||||
}
|
||||
// 3.
|
||||
let t = z.ln();
|
||||
if s >= t {
|
||||
break;
|
||||
}
|
||||
// 4.
|
||||
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
BetaAlgorithm::BC(algo) => {
|
||||
loop {
|
||||
let z;
|
||||
// 1.
|
||||
let u1 = rng.sample(Open01);
|
||||
let u2 = rng.sample(Open01);
|
||||
if u1 < F::from(0.5).unwrap() {
|
||||
// 2.
|
||||
let y = u1 * u2;
|
||||
z = u1 * y;
|
||||
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// 3.
|
||||
z = u1 * u1 * u2;
|
||||
if z <= F::from(0.25).unwrap() {
|
||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||
w = self.a * v.exp();
|
||||
break;
|
||||
}
|
||||
// 4.
|
||||
if z >= algo.kappa2 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// 5.
|
||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||
w = self.a * v.exp();
|
||||
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
|
||||
- F::from(4.).unwrap().ln()
|
||||
< z.ln())
|
||||
{
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
// 5. for BB, 6. for BC
|
||||
if !self.switched_params {
|
||||
if w == F::infinity() {
|
||||
// Assuming `b` is finite, for large `w`:
|
||||
return F::one();
|
||||
}
|
||||
w / (self.b + w)
|
||||
} else {
|
||||
self.b / (self.b + w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_beta() {
|
||||
let beta = Beta::new(1.0, 2.0).unwrap();
|
||||
let mut rng = crate::test::rng(201);
|
||||
for _ in 0..1000 {
|
||||
beta.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_beta_invalid_dof() {
|
||||
Beta::new(0., 0.).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta_small_param() {
|
||||
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
|
||||
let mut rng = crate::test::rng(206);
|
||||
for i in 0..1000 {
|
||||
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn beta_distributions_can_be_compared() {
|
||||
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
|
||||
}
|
||||
}
|
@ -52,7 +52,7 @@ pub struct Binomial {
|
||||
p: f64,
|
||||
}
|
||||
|
||||
/// Error type returned from `Binomial::new`.
|
||||
/// Error type returned from [`Binomial::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `p < 0` or `nan`.
|
||||
|
@ -64,7 +64,7 @@ where
|
||||
scale: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Cauchy::new`.
|
||||
/// Error type returned from [`Cauchy::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `scale <= 0` or `nan`.
|
||||
|
179
rand_distr/src/chi_squared.rs
Normal file
179
rand_distr/src/chi_squared.rs
Normal file
@ -0,0 +1,179 @@
|
||||
// Copyright 2018 Developers of the Rand project.
|
||||
// Copyright 2013 The Rust Project Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Chi-squared distribution.
|
||||
|
||||
use self::ChiSquaredRepr::*;
|
||||
|
||||
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
|
||||
use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde1")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`.
|
||||
///
|
||||
/// The chi-squared distribution is a continuous probability
|
||||
/// distribution with parameter `k > 0` degrees of freedom.
|
||||
///
|
||||
/// For `k > 0` integral, this distribution is the sum of the squares
|
||||
/// of `k` independent standard normal random variables. For other
|
||||
/// `k`, this uses the equivalent characterisation
|
||||
/// `χ²(k) = Gamma(k/2, 2)`.
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the chi-squared distribution with various degrees
|
||||
/// of freedom.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{ChiSquared, Distribution};
|
||||
///
|
||||
/// let chi = ChiSquared::new(11.0).unwrap();
|
||||
/// let v = chi.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a χ²(11) distribution", v)
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct ChiSquared<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
repr: ChiSquaredRepr<F>,
|
||||
}
|
||||
|
||||
/// Error type returned from [`ChiSquared::new`] and [`StudentT::new`](crate::StudentT::new).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub enum Error {
|
||||
/// `0.5 * k <= 0` or `nan`.
|
||||
DoFTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Error::DoFTooSmall => {
|
||||
"degrees-of-freedom k is not positive in chi-squared distribution"
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
enum ChiSquaredRepr<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
// k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1,
|
||||
// e.g. when alpha = 1/2 as it would be for this case, so special-
|
||||
// casing and using the definition of N(0,1)^2 is faster.
|
||||
DoFExactlyOne,
|
||||
DoFAnythingElse(Gamma<F>),
|
||||
}
|
||||
|
||||
impl<F> ChiSquared<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Create a new chi-squared distribution with degrees-of-freedom
|
||||
/// `k`.
|
||||
pub fn new(k: F) -> Result<ChiSquared<F>, Error> {
|
||||
let repr = if k == F::one() {
|
||||
DoFExactlyOne
|
||||
} else {
|
||||
if !(F::from(0.5).unwrap() * k > F::zero()) {
|
||||
return Err(Error::DoFTooSmall);
|
||||
}
|
||||
DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap())
|
||||
};
|
||||
Ok(ChiSquared { repr })
|
||||
}
|
||||
}
|
||||
impl<F> Distribution<F> for ChiSquared<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
match self.repr {
|
||||
DoFExactlyOne => {
|
||||
// k == 1 => N(0,1)^2
|
||||
let norm: F = rng.sample(StandardNormal);
|
||||
norm * norm
|
||||
}
|
||||
DoFAnythingElse(ref g) => g.sample(rng),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chi_squared_one() {
|
||||
let chi = ChiSquared::new(1.0).unwrap();
|
||||
let mut rng = crate::test::rng(201);
|
||||
for _ in 0..1000 {
|
||||
chi.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_chi_squared_small() {
|
||||
let chi = ChiSquared::new(0.5).unwrap();
|
||||
let mut rng = crate::test::rng(202);
|
||||
for _ in 0..1000 {
|
||||
chi.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_chi_squared_large() {
|
||||
let chi = ChiSquared::new(30.0).unwrap();
|
||||
let mut rng = crate::test::rng(203);
|
||||
for _ in 0..1000 {
|
||||
chi.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_chi_squared_invalid_dof() {
|
||||
ChiSquared::new(-1.0).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gamma_distributions_can_be_compared() {
|
||||
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chi_squared_distributions_can_be_compared() {
|
||||
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
|
||||
}
|
||||
}
|
@ -31,7 +31,7 @@ where
|
||||
samplers: [Gamma<F>; N],
|
||||
}
|
||||
|
||||
/// Error type returned from `DirchletFromGamma::new`.
|
||||
/// Error type returned from [`DirchletFromGamma::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum DirichletFromGammaError {
|
||||
/// Gamma::new(a, 1) failed.
|
||||
@ -103,7 +103,7 @@ where
|
||||
samplers: Box<[Beta<F>]>,
|
||||
}
|
||||
|
||||
/// Error type returned from `DirchletFromBeta::new`.
|
||||
/// Error type returned from [`DirchletFromBeta::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
enum DirichletFromBetaError {
|
||||
/// Beta::new(a, b) failed.
|
||||
@ -226,7 +226,7 @@ where
|
||||
repr: DirichletRepr<F, N>,
|
||||
}
|
||||
|
||||
/// Error type returned from `Dirchlet::new`.
|
||||
/// Error type returned from [`Dirichlet::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `alpha.len() < 2`.
|
||||
|
@ -130,7 +130,7 @@ where
|
||||
lambda_inverse: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Exp::new`.
|
||||
/// Error type returned from [`Exp::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `lambda < 0` or `nan`.
|
||||
|
131
rand_distr/src/fisher_f.rs
Normal file
131
rand_distr/src/fisher_f.rs
Normal file
@ -0,0 +1,131 @@
|
||||
// Copyright 2018 Developers of the Rand project.
|
||||
// Copyright 2013 The Rust Project Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Fisher F-distribution.
|
||||
|
||||
use crate::{ChiSquared, Distribution, Exp1, Open01, StandardNormal};
|
||||
use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde1")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`.
|
||||
///
|
||||
/// This distribution is equivalent to the ratio of two normalised
|
||||
/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) /
|
||||
/// (χ²(n)/n)`.
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the F-distribution with various values of `m` and `n`.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{FisherF, Distribution};
|
||||
///
|
||||
/// let f = FisherF::new(2.0, 32.0).unwrap();
|
||||
/// let v = f.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from an F(2, 32) distribution", v)
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct FisherF<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
numer: ChiSquared<F>,
|
||||
denom: ChiSquared<F>,
|
||||
// denom_dof / numer_dof so that this can just be a straight
|
||||
// multiplication, rather than a division.
|
||||
dof_ratio: F,
|
||||
}
|
||||
|
||||
/// Error type returned from [`FisherF::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub enum Error {
|
||||
/// `m <= 0` or `nan`.
|
||||
MTooSmall,
|
||||
/// `n <= 0` or `nan`.
|
||||
NTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Error::MTooSmall => "m is not positive in Fisher F distribution",
|
||||
Error::NTooSmall => "n is not positive in Fisher F distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl<F> FisherF<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Create a new `FisherF` distribution, with the given parameter.
|
||||
pub fn new(m: F, n: F) -> Result<FisherF<F>, Error> {
|
||||
let zero = F::zero();
|
||||
if !(m > zero) {
|
||||
return Err(Error::MTooSmall);
|
||||
}
|
||||
if !(n > zero) {
|
||||
return Err(Error::NTooSmall);
|
||||
}
|
||||
|
||||
Ok(FisherF {
|
||||
numer: ChiSquared::new(m).unwrap(),
|
||||
denom: ChiSquared::new(n).unwrap(),
|
||||
dof_ratio: n / m,
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<F> Distribution<F> for FisherF<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_f() {
|
||||
let f = FisherF::new(2.0, 32.0).unwrap();
|
||||
let mut rng = crate::test::rng(204);
|
||||
for _ in 0..1000 {
|
||||
f.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fisher_f_distributions_can_be_compared() {
|
||||
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
|
||||
}
|
||||
}
|
@ -55,7 +55,7 @@ where
|
||||
shape: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Frechet::new`.
|
||||
/// Error type returned from [`Frechet::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// location is infinite or NaN
|
||||
|
@ -7,17 +7,11 @@
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Gamma and derived distributions.
|
||||
//! The Gamma distribution.
|
||||
|
||||
// We use the variable names from the published reference, therefore this
|
||||
// warning is not helpful.
|
||||
#![allow(clippy::many_single_char_names)]
|
||||
|
||||
use self::ChiSquaredRepr::*;
|
||||
use self::GammaRepr::*;
|
||||
|
||||
use crate::normal::StandardNormal;
|
||||
use crate::{Distribution, Exp, Exp1, Open01};
|
||||
use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
|
||||
use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
@ -80,7 +74,7 @@ where
|
||||
repr: GammaRepr<F>,
|
||||
}
|
||||
|
||||
/// Error type returned from `Gamma::new`.
|
||||
/// Error type returned from [`Gamma::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `shape <= 0` or `nan`.
|
||||
@ -276,632 +270,12 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`.
|
||||
///
|
||||
/// The chi-squared distribution is a continuous probability
|
||||
/// distribution with parameter `k > 0` degrees of freedom.
|
||||
///
|
||||
/// For `k > 0` integral, this distribution is the sum of the squares
|
||||
/// of `k` independent standard normal random variables. For other
|
||||
/// `k`, this uses the equivalent characterisation
|
||||
/// `χ²(k) = Gamma(k/2, 2)`.
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the chi-squared distribution with various degrees
|
||||
/// of freedom.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{ChiSquared, Distribution};
|
||||
///
|
||||
/// let chi = ChiSquared::new(11.0).unwrap();
|
||||
/// let v = chi.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a χ²(11) distribution", v)
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct ChiSquared<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
repr: ChiSquaredRepr<F>,
|
||||
}
|
||||
|
||||
/// Error type returned from `ChiSquared::new` and `StudentT::new`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub enum ChiSquaredError {
|
||||
/// `0.5 * k <= 0` or `nan`.
|
||||
DoFTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for ChiSquaredError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
ChiSquaredError::DoFTooSmall => {
|
||||
"degrees-of-freedom k is not positive in chi-squared distribution"
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for ChiSquaredError {}
|
||||
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
enum ChiSquaredRepr<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
// k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1,
|
||||
// e.g. when alpha = 1/2 as it would be for this case, so special-
|
||||
// casing and using the definition of N(0,1)^2 is faster.
|
||||
DoFExactlyOne,
|
||||
DoFAnythingElse(Gamma<F>),
|
||||
}
|
||||
|
||||
impl<F> ChiSquared<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Create a new chi-squared distribution with degrees-of-freedom
|
||||
/// `k`.
|
||||
pub fn new(k: F) -> Result<ChiSquared<F>, ChiSquaredError> {
|
||||
let repr = if k == F::one() {
|
||||
DoFExactlyOne
|
||||
} else {
|
||||
if !(F::from(0.5).unwrap() * k > F::zero()) {
|
||||
return Err(ChiSquaredError::DoFTooSmall);
|
||||
}
|
||||
DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap())
|
||||
};
|
||||
Ok(ChiSquared { repr })
|
||||
}
|
||||
}
|
||||
impl<F> Distribution<F> for ChiSquared<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
match self.repr {
|
||||
DoFExactlyOne => {
|
||||
// k == 1 => N(0,1)^2
|
||||
let norm: F = rng.sample(StandardNormal);
|
||||
norm * norm
|
||||
}
|
||||
DoFAnythingElse(ref g) => g.sample(rng),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`.
|
||||
///
|
||||
/// This distribution is equivalent to the ratio of two normalised
|
||||
/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) /
|
||||
/// (χ²(n)/n)`.
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the F-distribution with various values of `m` and `n`.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{FisherF, Distribution};
|
||||
///
|
||||
/// let f = FisherF::new(2.0, 32.0).unwrap();
|
||||
/// let v = f.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from an F(2, 32) distribution", v)
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct FisherF<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
numer: ChiSquared<F>,
|
||||
denom: ChiSquared<F>,
|
||||
// denom_dof / numer_dof so that this can just be a straight
|
||||
// multiplication, rather than a division.
|
||||
dof_ratio: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `FisherF::new`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub enum FisherFError {
|
||||
/// `m <= 0` or `nan`.
|
||||
MTooSmall,
|
||||
/// `n <= 0` or `nan`.
|
||||
NTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for FisherFError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
FisherFError::MTooSmall => "m is not positive in Fisher F distribution",
|
||||
FisherFError::NTooSmall => "n is not positive in Fisher F distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for FisherFError {}
|
||||
|
||||
impl<F> FisherF<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Create a new `FisherF` distribution, with the given parameter.
|
||||
pub fn new(m: F, n: F) -> Result<FisherF<F>, FisherFError> {
|
||||
let zero = F::zero();
|
||||
if !(m > zero) {
|
||||
return Err(FisherFError::MTooSmall);
|
||||
}
|
||||
if !(n > zero) {
|
||||
return Err(FisherFError::NTooSmall);
|
||||
}
|
||||
|
||||
Ok(FisherF {
|
||||
numer: ChiSquared::new(m).unwrap(),
|
||||
denom: ChiSquared::new(n).unwrap(),
|
||||
dof_ratio: n / m,
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<F> Distribution<F> for FisherF<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio
|
||||
}
|
||||
}
|
||||
|
||||
/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`.
|
||||
///
|
||||
/// The t-distribution is a continuous probability distribution
|
||||
/// parameterized by degrees of freedom `ν` (`nu`), which
|
||||
/// arises when estimating the mean of a normally-distributed
|
||||
/// population in situations where the sample size is small and
|
||||
/// the population's standard deviation is unknown.
|
||||
/// It is widely used in hypothesis testing.
|
||||
///
|
||||
/// For `ν = 1`, this is equivalent to the standard
|
||||
/// [`Cauchy`](crate::Cauchy) distribution,
|
||||
/// and as `ν` diverges to infinity, `t(ν)` converges to
|
||||
/// [`StandardNormal`](crate::StandardNormal).
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the t-distribution with various degrees of freedom.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{StudentT, Distribution};
|
||||
///
|
||||
/// let t = StudentT::new(11.0).unwrap();
|
||||
/// let v = t.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a t(11) distribution", v)
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct StudentT<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
chi: ChiSquared<F>,
|
||||
dof: F,
|
||||
}
|
||||
|
||||
impl<F> StudentT<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Create a new Student t-distribution with `ν` (nu)
|
||||
/// degrees of freedom.
|
||||
pub fn new(nu: F) -> Result<StudentT<F>, ChiSquaredError> {
|
||||
Ok(StudentT {
|
||||
chi: ChiSquared::new(nu)?,
|
||||
dof: nu,
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<F> Distribution<F> for StudentT<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
let norm: F = rng.sample(StandardNormal);
|
||||
norm * (self.dof / self.chi.sample(rng)).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
/// The algorithm used for sampling the Beta distribution.
|
||||
///
|
||||
/// Reference:
|
||||
///
|
||||
/// R. C. H. Cheng (1978).
|
||||
/// Generating beta variates with nonintegral shape parameters.
|
||||
/// Communications of the ACM 21, 317-322.
|
||||
/// https://doi.org/10.1145/359460.359482
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
enum BetaAlgorithm<N> {
|
||||
BB(BB<N>),
|
||||
BC(BC<N>),
|
||||
}
|
||||
|
||||
/// Algorithm BB for `min(alpha, beta) > 1`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
struct BB<N> {
|
||||
alpha: N,
|
||||
beta: N,
|
||||
gamma: N,
|
||||
}
|
||||
|
||||
/// Algorithm BC for `min(alpha, beta) <= 1`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
struct BC<N> {
|
||||
alpha: N,
|
||||
beta: N,
|
||||
kappa1: N,
|
||||
kappa2: N,
|
||||
}
|
||||
|
||||
/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`.
|
||||
///
|
||||
/// The Beta distribution is a continuous probability distribution
|
||||
/// defined on the interval `[0, 1]`. It is the conjugate prior for the
|
||||
/// parameter `p` of the [`Binomial`][crate::Binomial] distribution.
|
||||
///
|
||||
/// It has two shape parameters `α` (alpha) and `β` (beta) which control
|
||||
/// the shape of the distribution. Both `a` and `β` must be greater than zero.
|
||||
/// The distribution is symmetric when `α = β`.
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the Beta distribution with various combinations
|
||||
/// of `α` and `β`.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{Distribution, Beta};
|
||||
///
|
||||
/// let beta = Beta::new(2.0, 5.0).unwrap();
|
||||
/// let v = beta.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a Beta(2, 5) distribution", v);
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct Beta<F>
|
||||
where
|
||||
F: Float,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
a: F,
|
||||
b: F,
|
||||
switched_params: bool,
|
||||
algorithm: BetaAlgorithm<F>,
|
||||
}
|
||||
|
||||
/// Error type returned from `Beta::new`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub enum BetaError {
|
||||
/// `alpha <= 0` or `nan`.
|
||||
AlphaTooSmall,
|
||||
/// `beta <= 0` or `nan`.
|
||||
BetaTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for BetaError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
BetaError::AlphaTooSmall => "alpha is not positive in beta distribution",
|
||||
BetaError::BetaTooSmall => "beta is not positive in beta distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for BetaError {}
|
||||
|
||||
impl<F> Beta<F>
|
||||
where
|
||||
F: Float,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Construct an object representing the `Beta(alpha, beta)`
|
||||
/// distribution.
|
||||
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, BetaError> {
|
||||
if !(alpha > F::zero()) {
|
||||
return Err(BetaError::AlphaTooSmall);
|
||||
}
|
||||
if !(beta > F::zero()) {
|
||||
return Err(BetaError::BetaTooSmall);
|
||||
}
|
||||
// From now on, we use the notation from the reference,
|
||||
// i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
|
||||
let (a0, b0) = (alpha, beta);
|
||||
let (a, b, switched_params) = if a0 < b0 {
|
||||
(a0, b0, false)
|
||||
} else {
|
||||
(b0, a0, true)
|
||||
};
|
||||
if a > F::one() {
|
||||
// Algorithm BB
|
||||
let alpha = a + b;
|
||||
|
||||
let two = F::from(2.).unwrap();
|
||||
let beta_numer = alpha - two;
|
||||
let beta_denom = two * a * b - alpha;
|
||||
let beta = (beta_numer / beta_denom).sqrt();
|
||||
|
||||
let gamma = a + F::one() / beta;
|
||||
|
||||
Ok(Beta {
|
||||
a,
|
||||
b,
|
||||
switched_params,
|
||||
algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
|
||||
})
|
||||
} else {
|
||||
// Algorithm BC
|
||||
//
|
||||
// Here `a` is the maximum instead of the minimum.
|
||||
let (a, b, switched_params) = (b, a, !switched_params);
|
||||
let alpha = a + b;
|
||||
let beta = F::one() / b;
|
||||
let delta = F::one() + a - b;
|
||||
let kappa1 = delta
|
||||
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
|
||||
/ (a * beta - F::from(14. / 18.).unwrap());
|
||||
let kappa2 = F::from(0.25).unwrap()
|
||||
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
|
||||
|
||||
Ok(Beta {
|
||||
a,
|
||||
b,
|
||||
switched_params,
|
||||
algorithm: BetaAlgorithm::BC(BC {
|
||||
alpha,
|
||||
beta,
|
||||
kappa1,
|
||||
kappa2,
|
||||
}),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Distribution<F> for Beta<F>
|
||||
where
|
||||
F: Float,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
let mut w;
|
||||
match self.algorithm {
|
||||
BetaAlgorithm::BB(algo) => {
|
||||
loop {
|
||||
// 1.
|
||||
let u1 = rng.sample(Open01);
|
||||
let u2 = rng.sample(Open01);
|
||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||
w = self.a * v.exp();
|
||||
let z = u1 * u1 * u2;
|
||||
let r = algo.gamma * v - F::from(4.).unwrap().ln();
|
||||
let s = self.a + r - w;
|
||||
// 2.
|
||||
if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
|
||||
break;
|
||||
}
|
||||
// 3.
|
||||
let t = z.ln();
|
||||
if s >= t {
|
||||
break;
|
||||
}
|
||||
// 4.
|
||||
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
BetaAlgorithm::BC(algo) => {
|
||||
loop {
|
||||
let z;
|
||||
// 1.
|
||||
let u1 = rng.sample(Open01);
|
||||
let u2 = rng.sample(Open01);
|
||||
if u1 < F::from(0.5).unwrap() {
|
||||
// 2.
|
||||
let y = u1 * u2;
|
||||
z = u1 * y;
|
||||
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// 3.
|
||||
z = u1 * u1 * u2;
|
||||
if z <= F::from(0.25).unwrap() {
|
||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||
w = self.a * v.exp();
|
||||
break;
|
||||
}
|
||||
// 4.
|
||||
if z >= algo.kappa2 {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// 5.
|
||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||
w = self.a * v.exp();
|
||||
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
|
||||
- F::from(4.).unwrap().ln()
|
||||
< z.ln())
|
||||
{
|
||||
break;
|
||||
};
|
||||
}
|
||||
}
|
||||
};
|
||||
// 5. for BB, 6. for BC
|
||||
if !self.switched_params {
|
||||
if w == F::infinity() {
|
||||
// Assuming `b` is finite, for large `w`:
|
||||
return F::one();
|
||||
}
|
||||
w / (self.b + w)
|
||||
} else {
|
||||
self.b / (self.b + w)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_chi_squared_one() {
|
||||
let chi = ChiSquared::new(1.0).unwrap();
|
||||
let mut rng = crate::test::rng(201);
|
||||
for _ in 0..1000 {
|
||||
chi.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_chi_squared_small() {
|
||||
let chi = ChiSquared::new(0.5).unwrap();
|
||||
let mut rng = crate::test::rng(202);
|
||||
for _ in 0..1000 {
|
||||
chi.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
fn test_chi_squared_large() {
|
||||
let chi = ChiSquared::new(30.0).unwrap();
|
||||
let mut rng = crate::test::rng(203);
|
||||
for _ in 0..1000 {
|
||||
chi.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_chi_squared_invalid_dof() {
|
||||
ChiSquared::new(-1.0).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_f() {
|
||||
let f = FisherF::new(2.0, 32.0).unwrap();
|
||||
let mut rng = crate::test::rng(204);
|
||||
for _ in 0..1000 {
|
||||
f.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_t() {
|
||||
let t = StudentT::new(11.0).unwrap();
|
||||
let mut rng = crate::test::rng(205);
|
||||
for _ in 0..1000 {
|
||||
t.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta() {
|
||||
let beta = Beta::new(1.0, 2.0).unwrap();
|
||||
let mut rng = crate::test::rng(201);
|
||||
for _ in 0..1000 {
|
||||
beta.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn test_beta_invalid_dof() {
|
||||
Beta::new(0., 0.).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_beta_small_param() {
|
||||
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
|
||||
let mut rng = crate::test::rng(206);
|
||||
for i in 0..1000 {
|
||||
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gamma_distributions_can_be_compared() {
|
||||
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn beta_distributions_can_be_compared() {
|
||||
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chi_squared_distributions_can_be_compared() {
|
||||
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fisher_f_distributions_can_be_compared() {
|
||||
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn student_t_distributions_can_be_compared() {
|
||||
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
|
||||
}
|
||||
}
|
||||
|
@ -46,7 +46,7 @@ pub struct Geometric {
|
||||
k: u64,
|
||||
}
|
||||
|
||||
/// Error type returned from `Geometric::new`.
|
||||
/// Error type returned from [`Geometric::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `p < 0 || p > 1` or `nan`
|
||||
|
@ -52,7 +52,7 @@ where
|
||||
scale: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Gumbel::new`.
|
||||
/// Error type returned from [`Gumbel::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// location is infinite or NaN
|
||||
|
@ -68,7 +68,7 @@ pub struct Hypergeometric {
|
||||
sampling_method: SamplingMethod,
|
||||
}
|
||||
|
||||
/// Error type returned from `Hypergeometric::new`.
|
||||
/// Error type returned from [`Hypergeometric::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `total_population_size` is too large, causing floating point underflow.
|
||||
|
@ -5,7 +5,7 @@ use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
|
||||
/// Error type returned from `InverseGaussian::new`
|
||||
/// Error type returned from [`InverseGaussian::new`]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `mean <= 0` or `nan`.
|
||||
|
@ -98,16 +98,16 @@ pub use rand::distributions::{
|
||||
Standard, Uniform,
|
||||
};
|
||||
|
||||
pub use self::beta::{Beta, Error as BetaError};
|
||||
pub use self::binomial::{Binomial, Error as BinomialError};
|
||||
pub use self::cauchy::{Cauchy, Error as CauchyError};
|
||||
pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError};
|
||||
#[cfg(feature = "alloc")]
|
||||
pub use self::dirichlet::{Dirichlet, Error as DirichletError};
|
||||
pub use self::exponential::{Error as ExpError, Exp, Exp1};
|
||||
pub use self::fisher_f::{Error as FisherFError, FisherF};
|
||||
pub use self::frechet::{Error as FrechetError, Frechet};
|
||||
pub use self::gamma::{
|
||||
Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError,
|
||||
Gamma, StudentT,
|
||||
};
|
||||
pub use self::gamma::{Error as GammaError, Gamma};
|
||||
pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric};
|
||||
pub use self::gumbel::{Error as GumbelError, Gumbel};
|
||||
pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric};
|
||||
@ -126,9 +126,11 @@ pub use self::unit_circle::UnitCircle;
|
||||
pub use self::unit_disc::UnitDisc;
|
||||
pub use self::unit_sphere::UnitSphere;
|
||||
pub use self::weibull::{Error as WeibullError, Weibull};
|
||||
pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError};
|
||||
pub use self::zeta::{Error as ZetaError, Zeta};
|
||||
pub use self::zipf::{Error as ZipfError, Zipf};
|
||||
#[cfg(feature = "alloc")]
|
||||
pub use rand::distributions::{WeightError, WeightedIndex};
|
||||
pub use student_t::StudentT;
|
||||
#[cfg(feature = "alloc")]
|
||||
pub use weighted_alias::WeightedAliasIndex;
|
||||
#[cfg(feature = "alloc")]
|
||||
@ -192,10 +194,13 @@ pub mod weighted_alias;
|
||||
#[cfg(feature = "alloc")]
|
||||
pub mod weighted_tree;
|
||||
|
||||
mod beta;
|
||||
mod binomial;
|
||||
mod cauchy;
|
||||
mod chi_squared;
|
||||
mod dirichlet;
|
||||
mod exponential;
|
||||
mod fisher_f;
|
||||
mod frechet;
|
||||
mod gamma;
|
||||
mod geometric;
|
||||
@ -208,6 +213,7 @@ mod pareto;
|
||||
mod pert;
|
||||
mod poisson;
|
||||
mod skew_normal;
|
||||
mod student_t;
|
||||
mod triangular;
|
||||
mod unit_ball;
|
||||
mod unit_circle;
|
||||
@ -215,5 +221,6 @@ mod unit_disc;
|
||||
mod unit_sphere;
|
||||
mod utils;
|
||||
mod weibull;
|
||||
mod zeta;
|
||||
mod ziggurat_tables;
|
||||
mod zipf;
|
||||
|
@ -153,7 +153,7 @@ where
|
||||
std_dev: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Normal::new` and `LogNormal::new`.
|
||||
/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new).
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// The mean value is too small (log-normal samples must be positive)
|
||||
|
@ -3,7 +3,7 @@ use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
|
||||
/// Error type returned from `NormalInverseGaussian::new`
|
||||
/// Error type returned from [`NormalInverseGaussian::new`]
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `alpha <= 0` or `nan`.
|
||||
|
@ -46,7 +46,7 @@ where
|
||||
inv_neg_shape: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Pareto::new`.
|
||||
/// Error type returned from [`Pareto::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `scale <= 0` or `nan`.
|
||||
|
@ -54,7 +54,7 @@ where
|
||||
magic_val: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Poisson::new`.
|
||||
/// Error type returned from [`Poisson::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `lambda <= 0`
|
||||
|
@ -68,7 +68,7 @@ where
|
||||
shape: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `SkewNormal::new`.
|
||||
/// Error type returned from [`SkewNormal::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// The scale parameter is not finite or it is less or equal to zero.
|
||||
|
107
rand_distr/src/student_t.rs
Normal file
107
rand_distr/src/student_t.rs
Normal file
@ -0,0 +1,107 @@
|
||||
// Copyright 2018 Developers of the Rand project.
|
||||
// Copyright 2013 The Rust Project Developers.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Student's t-distribution.
|
||||
|
||||
use crate::{ChiSquared, ChiSquaredError};
|
||||
use crate::{Distribution, Exp1, Open01, StandardNormal};
|
||||
use num_traits::Float;
|
||||
use rand::Rng;
|
||||
#[cfg(feature = "serde1")]
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`.
|
||||
///
|
||||
/// The t-distribution is a continuous probability distribution
|
||||
/// parameterized by degrees of freedom `ν` (`nu`), which
|
||||
/// arises when estimating the mean of a normally-distributed
|
||||
/// population in situations where the sample size is small and
|
||||
/// the population's standard deviation is unknown.
|
||||
/// It is widely used in hypothesis testing.
|
||||
///
|
||||
/// For `ν = 1`, this is equivalent to the standard
|
||||
/// [`Cauchy`](crate::Cauchy) distribution,
|
||||
/// and as `ν` diverges to infinity, `t(ν)` converges to
|
||||
/// [`StandardNormal`](crate::StandardNormal).
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The plot shows the t-distribution with various degrees of freedom.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::{StudentT, Distribution};
|
||||
///
|
||||
/// let t = StudentT::new(11.0).unwrap();
|
||||
/// let v = t.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a t(11) distribution", v)
|
||||
/// ```
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
pub struct StudentT<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
chi: ChiSquared<F>,
|
||||
dof: F,
|
||||
}
|
||||
|
||||
impl<F> StudentT<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Create a new Student t-distribution with `ν` (nu)
|
||||
/// degrees of freedom.
|
||||
pub fn new(nu: F) -> Result<StudentT<F>, ChiSquaredError> {
|
||||
Ok(StudentT {
|
||||
chi: ChiSquared::new(nu)?,
|
||||
dof: nu,
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<F> Distribution<F> for StudentT<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
let norm: F = rng.sample(StandardNormal);
|
||||
norm * (self.dof / self.chi.sample(rng)).sqrt()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_t() {
|
||||
let t = StudentT::new(11.0).unwrap();
|
||||
let mut rng = crate::test::rng(205);
|
||||
for _ in 0..1000 {
|
||||
t.sample(&mut rng);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn student_t_distributions_can_be_compared() {
|
||||
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
|
||||
}
|
||||
}
|
@ -44,7 +44,7 @@ where
|
||||
scale: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Weibull::new`.
|
||||
/// Error type returned from [`Weibull::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `scale <= 0` or `nan`.
|
||||
|
192
rand_distr/src/zeta.rs
Normal file
192
rand_distr/src/zeta.rs
Normal file
@ -0,0 +1,192 @@
|
||||
// Copyright 2021 Developers of the Rand project.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
|
||||
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
|
||||
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Zeta distribution.
|
||||
|
||||
use crate::{Distribution, Standard};
|
||||
use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::{distributions::OpenClosed01, Rng};
|
||||
|
||||
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`.
|
||||
///
|
||||
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution)
|
||||
/// is a discrete probability distribution with parameter `s`.
|
||||
/// It is a special case of the [`Zipf`](crate::Zipf) distribution with `n = ∞`.
|
||||
/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution.
|
||||
///
|
||||
/// # Density function
|
||||
///
|
||||
/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the
|
||||
/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function).
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The following plot illustrates the zeta distribution for various values of `s`.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rand::prelude::*;
|
||||
/// use rand_distr::Zeta;
|
||||
///
|
||||
/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap());
|
||||
/// println!("{}", val);
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The zeta distribution has no upper limit. Sampled values may be infinite.
|
||||
/// In particular, a value of infinity might be returned for the following
|
||||
/// reasons:
|
||||
/// 1. it is the best representation in the type `F` of the actual sample.
|
||||
/// 2. to prevent infinite loops for very small `s`.
|
||||
///
|
||||
/// # Implementation details
|
||||
///
|
||||
/// We are using the algorithm from
|
||||
/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8),
|
||||
/// Section 6.1, page 551.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub struct Zeta<F>
|
||||
where
|
||||
F: Float,
|
||||
Standard: Distribution<F>,
|
||||
OpenClosed01: Distribution<F>,
|
||||
{
|
||||
s_minus_1: F,
|
||||
b: F,
|
||||
}
|
||||
|
||||
/// Error type returned from [`Zeta::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum Error {
|
||||
/// `s <= 1` or `nan`.
|
||||
STooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
Error::STooSmall => "s <= 1 or is NaN in Zeta distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl<F> Zeta<F>
|
||||
where
|
||||
F: Float,
|
||||
Standard: Distribution<F>,
|
||||
OpenClosed01: Distribution<F>,
|
||||
{
|
||||
/// Construct a new `Zeta` distribution with given `s` parameter.
|
||||
#[inline]
|
||||
pub fn new(s: F) -> Result<Zeta<F>, Error> {
|
||||
if !(s > F::one()) {
|
||||
return Err(Error::STooSmall);
|
||||
}
|
||||
let s_minus_1 = s - F::one();
|
||||
let two = F::one() + F::one();
|
||||
Ok(Zeta {
|
||||
s_minus_1,
|
||||
b: two.powf(s_minus_1),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Distribution<F> for Zeta<F>
|
||||
where
|
||||
F: Float,
|
||||
Standard: Distribution<F>,
|
||||
OpenClosed01: Distribution<F>,
|
||||
{
|
||||
#[inline]
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
loop {
|
||||
let u = rng.sample(OpenClosed01);
|
||||
let x = u.powf(-F::one() / self.s_minus_1).floor();
|
||||
debug_assert!(x >= F::one());
|
||||
if x.is_infinite() {
|
||||
// For sufficiently small `s`, `x` will always be infinite,
|
||||
// which is rejected, resulting in an infinite loop. We avoid
|
||||
// this by always returning infinity instead.
|
||||
return x;
|
||||
}
|
||||
|
||||
let t = (F::one() + F::one() / x).powf(self.s_minus_1);
|
||||
|
||||
let v = rng.sample(Standard);
|
||||
if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
|
||||
let mut rng = crate::test::rng(213);
|
||||
let mut buf = [zero; 4];
|
||||
for x in &mut buf {
|
||||
*x = rng.sample(&distr);
|
||||
}
|
||||
assert_eq!(buf, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zeta_invalid() {
|
||||
Zeta::new(1.).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zeta_nan() {
|
||||
Zeta::new(f64::NAN).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_sample() {
|
||||
let a = 2.0;
|
||||
let d = Zeta::new(a).unwrap();
|
||||
let mut rng = crate::test::rng(1);
|
||||
for _ in 0..1000 {
|
||||
let r = d.sample(&mut rng);
|
||||
assert!(r >= 1.);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_small_a() {
|
||||
let a = 1. + 1e-15;
|
||||
let d = Zeta::new(a).unwrap();
|
||||
let mut rng = crate::test::rng(2);
|
||||
for _ in 0..1000 {
|
||||
let r = d.sample(&mut rng);
|
||||
assert!(r >= 1.);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_value_stability() {
|
||||
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
|
||||
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_distributions_can_be_compared() {
|
||||
assert_eq!(Zeta::new(1.0), Zeta::new(1.0));
|
||||
}
|
||||
}
|
@ -6,131 +6,12 @@
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
//! The Zeta and related distributions.
|
||||
//! The Zipf distribution.
|
||||
|
||||
use crate::{Distribution, Standard};
|
||||
use core::fmt;
|
||||
use num_traits::Float;
|
||||
use rand::{distributions::OpenClosed01, Rng};
|
||||
|
||||
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`.
|
||||
///
|
||||
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution)
|
||||
/// is a discrete probability distribution with parameter `s`.
|
||||
/// It is a special case of the [`Zipf`] distribution with `n = ∞`.
|
||||
/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution.
|
||||
///
|
||||
/// # Density function
|
||||
///
|
||||
/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the
|
||||
/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function).
|
||||
///
|
||||
/// # Plot
|
||||
///
|
||||
/// The following plot illustrates the zeta distribution for various values of `s`.
|
||||
///
|
||||
/// 
|
||||
///
|
||||
/// # Example
|
||||
/// ```
|
||||
/// use rand::prelude::*;
|
||||
/// use rand_distr::Zeta;
|
||||
///
|
||||
/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap());
|
||||
/// println!("{}", val);
|
||||
/// ```
|
||||
///
|
||||
/// # Notes
|
||||
///
|
||||
/// The zeta distribution has no upper limit. Sampled values may be infinite.
|
||||
/// In particular, a value of infinity might be returned for the following
|
||||
/// reasons:
|
||||
/// 1. it is the best representation in the type `F` of the actual sample.
|
||||
/// 2. to prevent infinite loops for very small `s`.
|
||||
///
|
||||
/// # Implementation details
|
||||
///
|
||||
/// We are using the algorithm from
|
||||
/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8),
|
||||
/// Section 6.1, page 551.
|
||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||
pub struct Zeta<F>
|
||||
where
|
||||
F: Float,
|
||||
Standard: Distribution<F>,
|
||||
OpenClosed01: Distribution<F>,
|
||||
{
|
||||
s_minus_1: F,
|
||||
b: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Zeta::new`.
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum ZetaError {
|
||||
/// `s <= 1` or `nan`.
|
||||
STooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for ZetaError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
ZetaError::STooSmall => "s <= 1 or is NaN in Zeta distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for ZetaError {}
|
||||
|
||||
impl<F> Zeta<F>
|
||||
where
|
||||
F: Float,
|
||||
Standard: Distribution<F>,
|
||||
OpenClosed01: Distribution<F>,
|
||||
{
|
||||
/// Construct a new `Zeta` distribution with given `s` parameter.
|
||||
#[inline]
|
||||
pub fn new(s: F) -> Result<Zeta<F>, ZetaError> {
|
||||
if !(s > F::one()) {
|
||||
return Err(ZetaError::STooSmall);
|
||||
}
|
||||
let s_minus_1 = s - F::one();
|
||||
let two = F::one() + F::one();
|
||||
Ok(Zeta {
|
||||
s_minus_1,
|
||||
b: two.powf(s_minus_1),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<F> Distribution<F> for Zeta<F>
|
||||
where
|
||||
F: Float,
|
||||
Standard: Distribution<F>,
|
||||
OpenClosed01: Distribution<F>,
|
||||
{
|
||||
#[inline]
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||
loop {
|
||||
let u = rng.sample(OpenClosed01);
|
||||
let x = u.powf(-F::one() / self.s_minus_1).floor();
|
||||
debug_assert!(x >= F::one());
|
||||
if x.is_infinite() {
|
||||
// For sufficiently small `s`, `x` will always be infinite,
|
||||
// which is rejected, resulting in an infinite loop. We avoid
|
||||
// this by always returning infinity instead.
|
||||
return x;
|
||||
}
|
||||
|
||||
let t = (F::one() + F::one() / x).powf(self.s_minus_1);
|
||||
|
||||
let v = rng.sample(Standard);
|
||||
if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
|
||||
return x;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
use rand::Rng;
|
||||
|
||||
/// The Zipf (Zipfian) distribution `Zipf(n, s)`.
|
||||
///
|
||||
@ -175,26 +56,26 @@ where
|
||||
q: F,
|
||||
}
|
||||
|
||||
/// Error type returned from `Zipf::new`.
|
||||
/// Error type returned from [`Zipf::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum ZipfError {
|
||||
pub enum Error {
|
||||
/// `s < 0` or `nan`.
|
||||
STooSmall,
|
||||
/// `n < 1`.
|
||||
NTooSmall,
|
||||
}
|
||||
|
||||
impl fmt::Display for ZipfError {
|
||||
impl fmt::Display for Error {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(match self {
|
||||
ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution",
|
||||
ZipfError::NTooSmall => "n < 1 in Zipf distribution",
|
||||
Error::STooSmall => "s < 0 or is NaN in Zipf distribution",
|
||||
Error::NTooSmall => "n < 1 in Zipf distribution",
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "std")]
|
||||
impl std::error::Error for ZipfError {}
|
||||
impl std::error::Error for Error {}
|
||||
|
||||
impl<F> Zipf<F>
|
||||
where
|
||||
@ -206,12 +87,12 @@ where
|
||||
///
|
||||
/// For large `n`, rounding may occur to fit the number into the float type.
|
||||
#[inline]
|
||||
pub fn new(n: u64, s: F) -> Result<Zipf<F>, ZipfError> {
|
||||
pub fn new(n: u64, s: F) -> Result<Zipf<F>, Error> {
|
||||
if !(s >= F::zero()) {
|
||||
return Err(ZipfError::STooSmall);
|
||||
return Err(Error::STooSmall);
|
||||
}
|
||||
if n < 1 {
|
||||
return Err(ZipfError::NTooSmall);
|
||||
return Err(Error::NTooSmall);
|
||||
}
|
||||
let n = F::from(n).unwrap(); // This does not fail.
|
||||
let q = if s != F::one() {
|
||||
@ -282,46 +163,6 @@ mod tests {
|
||||
assert_eq!(buf, expected);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zeta_invalid() {
|
||||
Zeta::new(1.).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zeta_nan() {
|
||||
Zeta::new(f64::NAN).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_sample() {
|
||||
let a = 2.0;
|
||||
let d = Zeta::new(a).unwrap();
|
||||
let mut rng = crate::test::rng(1);
|
||||
for _ in 0..1000 {
|
||||
let r = d.sample(&mut rng);
|
||||
assert!(r >= 1.);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_small_a() {
|
||||
let a = 1. + 1e-15;
|
||||
let d = Zeta::new(a).unwrap();
|
||||
let mut rng = crate::test::rng(2);
|
||||
for _ in 0..1000 {
|
||||
let r = d.sample(&mut rng);
|
||||
assert!(r >= 1.);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_value_stability() {
|
||||
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
|
||||
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[should_panic]
|
||||
fn zipf_s_too_small() {
|
||||
@ -392,9 +233,4 @@ mod tests {
|
||||
fn zipf_distributions_can_be_compared() {
|
||||
assert_eq!(Zipf::new(1, 2.0), Zipf::new(1, 2.0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn zeta_distributions_can_be_compared() {
|
||||
assert_eq!(Zeta::new(1.0), Zeta::new(1.0));
|
||||
}
|
||||
}
|
||||
|
@ -75,7 +75,7 @@ const ALWAYS_TRUE: u64 = u64::MAX;
|
||||
// in `no_std` mode.
|
||||
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
|
||||
|
||||
/// Error type returned from `Bernoulli::new`.
|
||||
/// Error type returned from [`Bernoulli::new`].
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||
pub enum BernoulliError {
|
||||
/// `p < 0` or `p > 1`.
|
||||
|
@ -702,7 +702,7 @@ mod test {
|
||||
}
|
||||
}
|
||||
|
||||
/// Errors returned by weighted distributions
|
||||
/// Errors returned by [`WeightedIndex::new`], [`WeightedIndex::update_weights`] and other weighted distributions
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum WeightError {
|
||||
/// The input weight sequence is empty, too long, or wrongly ordered
|
||||
|
Loading…
x
Reference in New Issue
Block a user