rand_distr: split gamma module (#1464)

Move Beta, Student's t, Fisher-F, Chi-squared and Zeta distributions
to their own modules.
This commit is contained in:
Diggory Hardy 2024-07-11 14:50:22 +01:00 committed by GitHub
parent 763dbc5bbb
commit d17ce4e0a1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
25 changed files with 952 additions and 828 deletions

298
rand_distr/src/beta.rs Normal file
View File

@ -0,0 +1,298 @@
// Copyright 2018 Developers of the Rand project.
// Copyright 2013 The Rust Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Beta distribution.
use crate::{Distribution, Open01};
use core::fmt;
use num_traits::Float;
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// The algorithm used for sampling the Beta distribution.
///
/// Reference:
///
/// R. C. H. Cheng (1978).
/// Generating beta variates with nonintegral shape parameters.
/// Communications of the ACM 21, 317-322.
/// https://doi.org/10.1145/359460.359482
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum BetaAlgorithm<N> {
BB(BB<N>),
BC(BC<N>),
}
/// Algorithm BB for `min(alpha, beta) > 1`.
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BB<N> {
alpha: N,
beta: N,
gamma: N,
}
/// Algorithm BC for `min(alpha, beta) <= 1`.
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BC<N> {
alpha: N,
beta: N,
kappa1: N,
kappa2: N,
}
/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`.
///
/// The Beta distribution is a continuous probability distribution
/// defined on the interval `[0, 1]`. It is the conjugate prior for the
/// parameter `p` of the [`Binomial`][crate::Binomial] distribution.
///
/// It has two shape parameters `α` (alpha) and `β` (beta) which control
/// the shape of the distribution. Both `a` and `β` must be greater than zero.
/// The distribution is symmetric when `α = β`.
///
/// # Plot
///
/// The plot shows the Beta distribution with various combinations
/// of `α` and `β`.
///
/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{Distribution, Beta};
///
/// let beta = Beta::new(2.0, 5.0).unwrap();
/// let v = beta.sample(&mut rand::thread_rng());
/// println!("{} is from a Beta(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
a: F,
b: F,
switched_params: bool,
algorithm: BetaAlgorithm<F>,
}
/// Error type returned from [`Beta::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum Error {
/// `alpha <= 0` or `nan`.
AlphaTooSmall,
/// `beta <= 0` or `nan`.
BetaTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::AlphaTooSmall => "alpha is not positive in beta distribution",
Error::BetaTooSmall => "beta is not positive in beta distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
/// Construct an object representing the `Beta(alpha, beta)`
/// distribution.
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, Error> {
if !(alpha > F::zero()) {
return Err(Error::AlphaTooSmall);
}
if !(beta > F::zero()) {
return Err(Error::BetaTooSmall);
}
// From now on, we use the notation from the reference,
// i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
let (a0, b0) = (alpha, beta);
let (a, b, switched_params) = if a0 < b0 {
(a0, b0, false)
} else {
(b0, a0, true)
};
if a > F::one() {
// Algorithm BB
let alpha = a + b;
let two = F::from(2.).unwrap();
let beta_numer = alpha - two;
let beta_denom = two * a * b - alpha;
let beta = (beta_numer / beta_denom).sqrt();
let gamma = a + F::one() / beta;
Ok(Beta {
a,
b,
switched_params,
algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
})
} else {
// Algorithm BC
//
// Here `a` is the maximum instead of the minimum.
let (a, b, switched_params) = (b, a, !switched_params);
let alpha = a + b;
let beta = F::one() / b;
let delta = F::one() + a - b;
let kappa1 = delta
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
/ (a * beta - F::from(14. / 18.).unwrap());
let kappa2 = F::from(0.25).unwrap()
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
Ok(Beta {
a,
b,
switched_params,
algorithm: BetaAlgorithm::BC(BC {
alpha,
beta,
kappa1,
kappa2,
}),
})
}
}
}
impl<F> Distribution<F> for Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let mut w;
match self.algorithm {
BetaAlgorithm::BB(algo) => {
loop {
// 1.
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
let z = u1 * u1 * u2;
let r = algo.gamma * v - F::from(4.).unwrap().ln();
let s = self.a + r - w;
// 2.
if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
break;
}
// 3.
let t = z.ln();
if s >= t {
break;
}
// 4.
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
break;
}
}
}
BetaAlgorithm::BC(algo) => {
loop {
let z;
// 1.
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
if u1 < F::from(0.5).unwrap() {
// 2.
let y = u1 * u2;
z = u1 * y;
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
continue;
}
} else {
// 3.
z = u1 * u1 * u2;
if z <= F::from(0.25).unwrap() {
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
break;
}
// 4.
if z >= algo.kappa2 {
continue;
}
}
// 5.
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
- F::from(4.).unwrap().ln()
< z.ln())
{
break;
};
}
}
};
// 5. for BB, 6. for BC
if !self.switched_params {
if w == F::infinity() {
// Assuming `b` is finite, for large `w`:
return F::one();
}
w / (self.b + w)
} else {
self.b / (self.b + w)
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_beta() {
let beta = Beta::new(1.0, 2.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
beta.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_beta_invalid_dof() {
Beta::new(0., 0.).unwrap();
}
#[test]
fn test_beta_small_param() {
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
let mut rng = crate::test::rng(206);
for i in 0..1000 {
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
}
}
#[test]
fn beta_distributions_can_be_compared() {
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
}
}

View File

@ -52,7 +52,7 @@ pub struct Binomial {
p: f64,
}
/// Error type returned from `Binomial::new`.
/// Error type returned from [`Binomial::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `p < 0` or `nan`.

View File

@ -64,7 +64,7 @@ where
scale: F,
}
/// Error type returned from `Cauchy::new`.
/// Error type returned from [`Cauchy::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `scale <= 0` or `nan`.

View File

@ -0,0 +1,179 @@
// Copyright 2018 Developers of the Rand project.
// Copyright 2013 The Rust Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Chi-squared distribution.
use self::ChiSquaredRepr::*;
use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal};
use core::fmt;
use num_traits::Float;
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`.
///
/// The chi-squared distribution is a continuous probability
/// distribution with parameter `k > 0` degrees of freedom.
///
/// For `k > 0` integral, this distribution is the sum of the squares
/// of `k` independent standard normal random variables. For other
/// `k`, this uses the equivalent characterisation
/// `χ²(k) = Gamma(k/2, 2)`.
///
/// # Plot
///
/// The plot shows the chi-squared distribution with various degrees
/// of freedom.
///
/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{ChiSquared, Distribution};
///
/// let chi = ChiSquared::new(11.0).unwrap();
/// let v = chi.sample(&mut rand::thread_rng());
/// println!("{} is from a χ²(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ChiSquared<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
repr: ChiSquaredRepr<F>,
}
/// Error type returned from [`ChiSquared::new`] and [`StudentT::new`](crate::StudentT::new).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum Error {
/// `0.5 * k <= 0` or `nan`.
DoFTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::DoFTooSmall => {
"degrees-of-freedom k is not positive in chi-squared distribution"
}
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum ChiSquaredRepr<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
// k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1,
// e.g. when alpha = 1/2 as it would be for this case, so special-
// casing and using the definition of N(0,1)^2 is faster.
DoFExactlyOne,
DoFAnythingElse(Gamma<F>),
}
impl<F> ChiSquared<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Create a new chi-squared distribution with degrees-of-freedom
/// `k`.
pub fn new(k: F) -> Result<ChiSquared<F>, Error> {
let repr = if k == F::one() {
DoFExactlyOne
} else {
if !(F::from(0.5).unwrap() * k > F::zero()) {
return Err(Error::DoFTooSmall);
}
DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap())
};
Ok(ChiSquared { repr })
}
}
impl<F> Distribution<F> for ChiSquared<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
match self.repr {
DoFExactlyOne => {
// k == 1 => N(0,1)^2
let norm: F = rng.sample(StandardNormal);
norm * norm
}
DoFAnythingElse(ref g) => g.sample(rng),
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_chi_squared_one() {
let chi = ChiSquared::new(1.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
fn test_chi_squared_small() {
let chi = ChiSquared::new(0.5).unwrap();
let mut rng = crate::test::rng(202);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
fn test_chi_squared_large() {
let chi = ChiSquared::new(30.0).unwrap();
let mut rng = crate::test::rng(203);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_chi_squared_invalid_dof() {
ChiSquared::new(-1.0).unwrap();
}
#[test]
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}
#[test]
fn chi_squared_distributions_can_be_compared() {
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
}
}

View File

@ -31,7 +31,7 @@ where
samplers: [Gamma<F>; N],
}
/// Error type returned from `DirchletFromGamma::new`.
/// Error type returned from [`DirchletFromGamma::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DirichletFromGammaError {
/// Gamma::new(a, 1) failed.
@ -103,7 +103,7 @@ where
samplers: Box<[Beta<F>]>,
}
/// Error type returned from `DirchletFromBeta::new`.
/// Error type returned from [`DirchletFromBeta::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum DirichletFromBetaError {
/// Beta::new(a, b) failed.
@ -226,7 +226,7 @@ where
repr: DirichletRepr<F, N>,
}
/// Error type returned from `Dirchlet::new`.
/// Error type returned from [`Dirichlet::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `alpha.len() < 2`.

View File

@ -130,7 +130,7 @@ where
lambda_inverse: F,
}
/// Error type returned from `Exp::new`.
/// Error type returned from [`Exp::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `lambda < 0` or `nan`.

131
rand_distr/src/fisher_f.rs Normal file
View File

@ -0,0 +1,131 @@
// Copyright 2018 Developers of the Rand project.
// Copyright 2013 The Rust Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Fisher F-distribution.
use crate::{ChiSquared, Distribution, Exp1, Open01, StandardNormal};
use core::fmt;
use num_traits::Float;
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`.
///
/// This distribution is equivalent to the ratio of two normalised
/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) /
/// (χ²(n)/n)`.
///
/// # Plot
///
/// The plot shows the F-distribution with various values of `m` and `n`.
///
/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{FisherF, Distribution};
///
/// let f = FisherF::new(2.0, 32.0).unwrap();
/// let v = f.sample(&mut rand::thread_rng());
/// println!("{} is from an F(2, 32) distribution", v)
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct FisherF<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
numer: ChiSquared<F>,
denom: ChiSquared<F>,
// denom_dof / numer_dof so that this can just be a straight
// multiplication, rather than a division.
dof_ratio: F,
}
/// Error type returned from [`FisherF::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum Error {
/// `m <= 0` or `nan`.
MTooSmall,
/// `n <= 0` or `nan`.
NTooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::MTooSmall => "m is not positive in Fisher F distribution",
Error::NTooSmall => "n is not positive in Fisher F distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> FisherF<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Create a new `FisherF` distribution, with the given parameter.
pub fn new(m: F, n: F) -> Result<FisherF<F>, Error> {
let zero = F::zero();
if !(m > zero) {
return Err(Error::MTooSmall);
}
if !(n > zero) {
return Err(Error::NTooSmall);
}
Ok(FisherF {
numer: ChiSquared::new(m).unwrap(),
denom: ChiSquared::new(n).unwrap(),
dof_ratio: n / m,
})
}
}
impl<F> Distribution<F> for FisherF<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_f() {
let f = FisherF::new(2.0, 32.0).unwrap();
let mut rng = crate::test::rng(204);
for _ in 0..1000 {
f.sample(&mut rng);
}
}
#[test]
fn fisher_f_distributions_can_be_compared() {
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
}
}

View File

@ -55,7 +55,7 @@ where
shape: F,
}
/// Error type returned from `Frechet::new`.
/// Error type returned from [`Frechet::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// location is infinite or NaN

View File

@ -7,17 +7,11 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Gamma and derived distributions.
//! The Gamma distribution.
// We use the variable names from the published reference, therefore this
// warning is not helpful.
#![allow(clippy::many_single_char_names)]
use self::ChiSquaredRepr::*;
use self::GammaRepr::*;
use crate::normal::StandardNormal;
use crate::{Distribution, Exp, Exp1, Open01};
use crate::{Distribution, Exp, Exp1, Open01, StandardNormal};
use core::fmt;
use num_traits::Float;
use rand::Rng;
@ -80,7 +74,7 @@ where
repr: GammaRepr<F>,
}
/// Error type returned from `Gamma::new`.
/// Error type returned from [`Gamma::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `shape <= 0` or `nan`.
@ -276,632 +270,12 @@ where
}
}
/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`.
///
/// The chi-squared distribution is a continuous probability
/// distribution with parameter `k > 0` degrees of freedom.
///
/// For `k > 0` integral, this distribution is the sum of the squares
/// of `k` independent standard normal random variables. For other
/// `k`, this uses the equivalent characterisation
/// `χ²(k) = Gamma(k/2, 2)`.
///
/// # Plot
///
/// The plot shows the chi-squared distribution with various degrees
/// of freedom.
///
/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{ChiSquared, Distribution};
///
/// let chi = ChiSquared::new(11.0).unwrap();
/// let v = chi.sample(&mut rand::thread_rng());
/// println!("{} is from a χ²(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct ChiSquared<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
repr: ChiSquaredRepr<F>,
}
/// Error type returned from `ChiSquared::new` and `StudentT::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum ChiSquaredError {
/// `0.5 * k <= 0` or `nan`.
DoFTooSmall,
}
impl fmt::Display for ChiSquaredError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
ChiSquaredError::DoFTooSmall => {
"degrees-of-freedom k is not positive in chi-squared distribution"
}
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for ChiSquaredError {}
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum ChiSquaredRepr<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
// k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1,
// e.g. when alpha = 1/2 as it would be for this case, so special-
// casing and using the definition of N(0,1)^2 is faster.
DoFExactlyOne,
DoFAnythingElse(Gamma<F>),
}
impl<F> ChiSquared<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Create a new chi-squared distribution with degrees-of-freedom
/// `k`.
pub fn new(k: F) -> Result<ChiSquared<F>, ChiSquaredError> {
let repr = if k == F::one() {
DoFExactlyOne
} else {
if !(F::from(0.5).unwrap() * k > F::zero()) {
return Err(ChiSquaredError::DoFTooSmall);
}
DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap())
};
Ok(ChiSquared { repr })
}
}
impl<F> Distribution<F> for ChiSquared<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
match self.repr {
DoFExactlyOne => {
// k == 1 => N(0,1)^2
let norm: F = rng.sample(StandardNormal);
norm * norm
}
DoFAnythingElse(ref g) => g.sample(rng),
}
}
}
/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`.
///
/// This distribution is equivalent to the ratio of two normalised
/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) /
/// (χ²(n)/n)`.
///
/// # Plot
///
/// The plot shows the F-distribution with various values of `m` and `n`.
///
/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{FisherF, Distribution};
///
/// let f = FisherF::new(2.0, 32.0).unwrap();
/// let v = f.sample(&mut rand::thread_rng());
/// println!("{} is from an F(2, 32) distribution", v)
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct FisherF<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
numer: ChiSquared<F>,
denom: ChiSquared<F>,
// denom_dof / numer_dof so that this can just be a straight
// multiplication, rather than a division.
dof_ratio: F,
}
/// Error type returned from `FisherF::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum FisherFError {
/// `m <= 0` or `nan`.
MTooSmall,
/// `n <= 0` or `nan`.
NTooSmall,
}
impl fmt::Display for FisherFError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
FisherFError::MTooSmall => "m is not positive in Fisher F distribution",
FisherFError::NTooSmall => "n is not positive in Fisher F distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for FisherFError {}
impl<F> FisherF<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Create a new `FisherF` distribution, with the given parameter.
pub fn new(m: F, n: F) -> Result<FisherF<F>, FisherFError> {
let zero = F::zero();
if !(m > zero) {
return Err(FisherFError::MTooSmall);
}
if !(n > zero) {
return Err(FisherFError::NTooSmall);
}
Ok(FisherF {
numer: ChiSquared::new(m).unwrap(),
denom: ChiSquared::new(n).unwrap(),
dof_ratio: n / m,
})
}
}
impl<F> Distribution<F> for FisherF<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio
}
}
/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`.
///
/// The t-distribution is a continuous probability distribution
/// parameterized by degrees of freedom `ν` (`nu`), which
/// arises when estimating the mean of a normally-distributed
/// population in situations where the sample size is small and
/// the population's standard deviation is unknown.
/// It is widely used in hypothesis testing.
///
/// For `ν = 1`, this is equivalent to the standard
/// [`Cauchy`](crate::Cauchy) distribution,
/// and as `ν` diverges to infinity, `t(ν)` converges to
/// [`StandardNormal`](crate::StandardNormal).
///
/// # Plot
///
/// The plot shows the t-distribution with various degrees of freedom.
///
/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{StudentT, Distribution};
///
/// let t = StudentT::new(11.0).unwrap();
/// let v = t.sample(&mut rand::thread_rng());
/// println!("{} is from a t(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StudentT<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
chi: ChiSquared<F>,
dof: F,
}
impl<F> StudentT<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Create a new Student t-distribution with `ν` (nu)
/// degrees of freedom.
pub fn new(nu: F) -> Result<StudentT<F>, ChiSquaredError> {
Ok(StudentT {
chi: ChiSquared::new(nu)?,
dof: nu,
})
}
}
impl<F> Distribution<F> for StudentT<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let norm: F = rng.sample(StandardNormal);
norm * (self.dof / self.chi.sample(rng)).sqrt()
}
}
/// The algorithm used for sampling the Beta distribution.
///
/// Reference:
///
/// R. C. H. Cheng (1978).
/// Generating beta variates with nonintegral shape parameters.
/// Communications of the ACM 21, 317-322.
/// https://doi.org/10.1145/359460.359482
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum BetaAlgorithm<N> {
BB(BB<N>),
BC(BC<N>),
}
/// Algorithm BB for `min(alpha, beta) > 1`.
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BB<N> {
alpha: N,
beta: N,
gamma: N,
}
/// Algorithm BC for `min(alpha, beta) <= 1`.
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
struct BC<N> {
alpha: N,
beta: N,
kappa1: N,
kappa2: N,
}
/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`.
///
/// The Beta distribution is a continuous probability distribution
/// defined on the interval `[0, 1]`. It is the conjugate prior for the
/// parameter `p` of the [`Binomial`][crate::Binomial] distribution.
///
/// It has two shape parameters `α` (alpha) and `β` (beta) which control
/// the shape of the distribution. Both `a` and `β` must be greater than zero.
/// The distribution is symmetric when `α = β`.
///
/// # Plot
///
/// The plot shows the Beta distribution with various combinations
/// of `α` and `β`.
///
/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{Distribution, Beta};
///
/// let beta = Beta::new(2.0, 5.0).unwrap();
/// let v = beta.sample(&mut rand::thread_rng());
/// println!("{} is from a Beta(2, 5) distribution", v);
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
a: F,
b: F,
switched_params: bool,
algorithm: BetaAlgorithm<F>,
}
/// Error type returned from `Beta::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub enum BetaError {
/// `alpha <= 0` or `nan`.
AlphaTooSmall,
/// `beta <= 0` or `nan`.
BetaTooSmall,
}
impl fmt::Display for BetaError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
BetaError::AlphaTooSmall => "alpha is not positive in beta distribution",
BetaError::BetaTooSmall => "beta is not positive in beta distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for BetaError {}
impl<F> Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
/// Construct an object representing the `Beta(alpha, beta)`
/// distribution.
pub fn new(alpha: F, beta: F) -> Result<Beta<F>, BetaError> {
if !(alpha > F::zero()) {
return Err(BetaError::AlphaTooSmall);
}
if !(beta > F::zero()) {
return Err(BetaError::BetaTooSmall);
}
// From now on, we use the notation from the reference,
// i.e. `alpha` and `beta` are renamed to `a0` and `b0`.
let (a0, b0) = (alpha, beta);
let (a, b, switched_params) = if a0 < b0 {
(a0, b0, false)
} else {
(b0, a0, true)
};
if a > F::one() {
// Algorithm BB
let alpha = a + b;
let two = F::from(2.).unwrap();
let beta_numer = alpha - two;
let beta_denom = two * a * b - alpha;
let beta = (beta_numer / beta_denom).sqrt();
let gamma = a + F::one() / beta;
Ok(Beta {
a,
b,
switched_params,
algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
})
} else {
// Algorithm BC
//
// Here `a` is the maximum instead of the minimum.
let (a, b, switched_params) = (b, a, !switched_params);
let alpha = a + b;
let beta = F::one() / b;
let delta = F::one() + a - b;
let kappa1 = delta
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
/ (a * beta - F::from(14. / 18.).unwrap());
let kappa2 = F::from(0.25).unwrap()
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
Ok(Beta {
a,
b,
switched_params,
algorithm: BetaAlgorithm::BC(BC {
alpha,
beta,
kappa1,
kappa2,
}),
})
}
}
}
impl<F> Distribution<F> for Beta<F>
where
F: Float,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let mut w;
match self.algorithm {
BetaAlgorithm::BB(algo) => {
loop {
// 1.
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
let z = u1 * u1 * u2;
let r = algo.gamma * v - F::from(4.).unwrap().ln();
let s = self.a + r - w;
// 2.
if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
break;
}
// 3.
let t = z.ln();
if s >= t {
break;
}
// 4.
if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) {
break;
}
}
}
BetaAlgorithm::BC(algo) => {
loop {
let z;
// 1.
let u1 = rng.sample(Open01);
let u2 = rng.sample(Open01);
if u1 < F::from(0.5).unwrap() {
// 2.
let y = u1 * u2;
z = u1 * y;
if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 {
continue;
}
} else {
// 3.
z = u1 * u1 * u2;
if z <= F::from(0.25).unwrap() {
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
break;
}
// 4.
if z >= algo.kappa2 {
continue;
}
}
// 5.
let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp();
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
- F::from(4.).unwrap().ln()
< z.ln())
{
break;
};
}
}
};
// 5. for BB, 6. for BC
if !self.switched_params {
if w == F::infinity() {
// Assuming `b` is finite, for large `w`:
return F::one();
}
w / (self.b + w)
} else {
self.b / (self.b + w)
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_chi_squared_one() {
let chi = ChiSquared::new(1.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
fn test_chi_squared_small() {
let chi = ChiSquared::new(0.5).unwrap();
let mut rng = crate::test::rng(202);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
fn test_chi_squared_large() {
let chi = ChiSquared::new(30.0).unwrap();
let mut rng = crate::test::rng(203);
for _ in 0..1000 {
chi.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_chi_squared_invalid_dof() {
ChiSquared::new(-1.0).unwrap();
}
#[test]
fn test_f() {
let f = FisherF::new(2.0, 32.0).unwrap();
let mut rng = crate::test::rng(204);
for _ in 0..1000 {
f.sample(&mut rng);
}
}
#[test]
fn test_t() {
let t = StudentT::new(11.0).unwrap();
let mut rng = crate::test::rng(205);
for _ in 0..1000 {
t.sample(&mut rng);
}
}
#[test]
fn test_beta() {
let beta = Beta::new(1.0, 2.0).unwrap();
let mut rng = crate::test::rng(201);
for _ in 0..1000 {
beta.sample(&mut rng);
}
}
#[test]
#[should_panic]
fn test_beta_invalid_dof() {
Beta::new(0., 0.).unwrap();
}
#[test]
fn test_beta_small_param() {
let beta = Beta::<f64>::new(1e-3, 1e-3).unwrap();
let mut rng = crate::test::rng(206);
for i in 0..1000 {
assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i);
}
}
#[test]
fn gamma_distributions_can_be_compared() {
assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0));
}
#[test]
fn beta_distributions_can_be_compared() {
assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0));
}
#[test]
fn chi_squared_distributions_can_be_compared() {
assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0));
}
#[test]
fn fisher_f_distributions_can_be_compared() {
assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0));
}
#[test]
fn student_t_distributions_can_be_compared() {
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
}
}

View File

@ -46,7 +46,7 @@ pub struct Geometric {
k: u64,
}
/// Error type returned from `Geometric::new`.
/// Error type returned from [`Geometric::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `p < 0 || p > 1` or `nan`

View File

@ -52,7 +52,7 @@ where
scale: F,
}
/// Error type returned from `Gumbel::new`.
/// Error type returned from [`Gumbel::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// location is infinite or NaN

View File

@ -68,7 +68,7 @@ pub struct Hypergeometric {
sampling_method: SamplingMethod,
}
/// Error type returned from `Hypergeometric::new`.
/// Error type returned from [`Hypergeometric::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `total_population_size` is too large, causing floating point underflow.

View File

@ -5,7 +5,7 @@ use core::fmt;
use num_traits::Float;
use rand::Rng;
/// Error type returned from `InverseGaussian::new`
/// Error type returned from [`InverseGaussian::new`]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
/// `mean <= 0` or `nan`.

View File

@ -98,16 +98,16 @@ pub use rand::distributions::{
Standard, Uniform,
};
pub use self::beta::{Beta, Error as BetaError};
pub use self::binomial::{Binomial, Error as BinomialError};
pub use self::cauchy::{Cauchy, Error as CauchyError};
pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError};
#[cfg(feature = "alloc")]
pub use self::dirichlet::{Dirichlet, Error as DirichletError};
pub use self::exponential::{Error as ExpError, Exp, Exp1};
pub use self::fisher_f::{Error as FisherFError, FisherF};
pub use self::frechet::{Error as FrechetError, Frechet};
pub use self::gamma::{
Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError,
Gamma, StudentT,
};
pub use self::gamma::{Error as GammaError, Gamma};
pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric};
pub use self::gumbel::{Error as GumbelError, Gumbel};
pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric};
@ -126,9 +126,11 @@ pub use self::unit_circle::UnitCircle;
pub use self::unit_disc::UnitDisc;
pub use self::unit_sphere::UnitSphere;
pub use self::weibull::{Error as WeibullError, Weibull};
pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError};
pub use self::zeta::{Error as ZetaError, Zeta};
pub use self::zipf::{Error as ZipfError, Zipf};
#[cfg(feature = "alloc")]
pub use rand::distributions::{WeightError, WeightedIndex};
pub use student_t::StudentT;
#[cfg(feature = "alloc")]
pub use weighted_alias::WeightedAliasIndex;
#[cfg(feature = "alloc")]
@ -192,10 +194,13 @@ pub mod weighted_alias;
#[cfg(feature = "alloc")]
pub mod weighted_tree;
mod beta;
mod binomial;
mod cauchy;
mod chi_squared;
mod dirichlet;
mod exponential;
mod fisher_f;
mod frechet;
mod gamma;
mod geometric;
@ -208,6 +213,7 @@ mod pareto;
mod pert;
mod poisson;
mod skew_normal;
mod student_t;
mod triangular;
mod unit_ball;
mod unit_circle;
@ -215,5 +221,6 @@ mod unit_disc;
mod unit_sphere;
mod utils;
mod weibull;
mod zeta;
mod ziggurat_tables;
mod zipf;

View File

@ -153,7 +153,7 @@ where
std_dev: F,
}
/// Error type returned from `Normal::new` and `LogNormal::new`.
/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new).
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// The mean value is too small (log-normal samples must be positive)

View File

@ -3,7 +3,7 @@ use core::fmt;
use num_traits::Float;
use rand::Rng;
/// Error type returned from `NormalInverseGaussian::new`
/// Error type returned from [`NormalInverseGaussian::new`]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Error {
/// `alpha <= 0` or `nan`.

View File

@ -46,7 +46,7 @@ where
inv_neg_shape: F,
}
/// Error type returned from `Pareto::new`.
/// Error type returned from [`Pareto::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `scale <= 0` or `nan`.

View File

@ -54,7 +54,7 @@ where
magic_val: F,
}
/// Error type returned from `Poisson::new`.
/// Error type returned from [`Poisson::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `lambda <= 0`

View File

@ -68,7 +68,7 @@ where
shape: F,
}
/// Error type returned from `SkewNormal::new`.
/// Error type returned from [`SkewNormal::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// The scale parameter is not finite or it is less or equal to zero.

107
rand_distr/src/student_t.rs Normal file
View File

@ -0,0 +1,107 @@
// Copyright 2018 Developers of the Rand project.
// Copyright 2013 The Rust Project Developers.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Student's t-distribution.
use crate::{ChiSquared, ChiSquaredError};
use crate::{Distribution, Exp1, Open01, StandardNormal};
use num_traits::Float;
use rand::Rng;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`.
///
/// The t-distribution is a continuous probability distribution
/// parameterized by degrees of freedom `ν` (`nu`), which
/// arises when estimating the mean of a normally-distributed
/// population in situations where the sample size is small and
/// the population's standard deviation is unknown.
/// It is widely used in hypothesis testing.
///
/// For `ν = 1`, this is equivalent to the standard
/// [`Cauchy`](crate::Cauchy) distribution,
/// and as `ν` diverges to infinity, `t(ν)` converges to
/// [`StandardNormal`](crate::StandardNormal).
///
/// # Plot
///
/// The plot shows the t-distribution with various degrees of freedom.
///
/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg)
///
/// # Example
///
/// ```
/// use rand_distr::{StudentT, Distribution};
///
/// let t = StudentT::new(11.0).unwrap();
/// let v = t.sample(&mut rand::thread_rng());
/// println!("{} is from a t(11) distribution", v)
/// ```
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct StudentT<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
chi: ChiSquared<F>,
dof: F,
}
impl<F> StudentT<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Create a new Student t-distribution with `ν` (nu)
/// degrees of freedom.
pub fn new(nu: F) -> Result<StudentT<F>, ChiSquaredError> {
Ok(StudentT {
chi: ChiSquared::new(nu)?,
dof: nu,
})
}
}
impl<F> Distribution<F> for StudentT<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let norm: F = rng.sample(StandardNormal);
norm * (self.dof / self.chi.sample(rng)).sqrt()
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_t() {
let t = StudentT::new(11.0).unwrap();
let mut rng = crate::test::rng(205);
for _ in 0..1000 {
t.sample(&mut rng);
}
}
#[test]
fn student_t_distributions_can_be_compared() {
assert_eq!(StudentT::new(1.0), StudentT::new(1.0));
}
}

View File

@ -44,7 +44,7 @@ where
scale: F,
}
/// Error type returned from `Weibull::new`.
/// Error type returned from [`Weibull::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `scale <= 0` or `nan`.

192
rand_distr/src/zeta.rs Normal file
View File

@ -0,0 +1,192 @@
// Copyright 2021 Developers of the Rand project.
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Zeta distribution.
use crate::{Distribution, Standard};
use core::fmt;
use num_traits::Float;
use rand::{distributions::OpenClosed01, Rng};
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`.
///
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution)
/// is a discrete probability distribution with parameter `s`.
/// It is a special case of the [`Zipf`](crate::Zipf) distribution with `n = ∞`.
/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or ZipfEstoup distribution.
///
/// # Density function
///
/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the
/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function).
///
/// # Plot
///
/// The following plot illustrates the zeta distribution for various values of `s`.
///
/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg)
///
/// # Example
/// ```
/// use rand::prelude::*;
/// use rand_distr::Zeta;
///
/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap());
/// println!("{}", val);
/// ```
///
/// # Notes
///
/// The zeta distribution has no upper limit. Sampled values may be infinite.
/// In particular, a value of infinity might be returned for the following
/// reasons:
/// 1. it is the best representation in the type `F` of the actual sample.
/// 2. to prevent infinite loops for very small `s`.
///
/// # Implementation details
///
/// We are using the algorithm from
/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8),
/// Section 6.1, page 551.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Zeta<F>
where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{
s_minus_1: F,
b: F,
}
/// Error type returned from [`Zeta::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
/// `s <= 1` or `nan`.
STooSmall,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::STooSmall => "s <= 1 or is NaN in Zeta distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> Zeta<F>
where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{
/// Construct a new `Zeta` distribution with given `s` parameter.
#[inline]
pub fn new(s: F) -> Result<Zeta<F>, Error> {
if !(s > F::one()) {
return Err(Error::STooSmall);
}
let s_minus_1 = s - F::one();
let two = F::one() + F::one();
Ok(Zeta {
s_minus_1,
b: two.powf(s_minus_1),
})
}
}
impl<F> Distribution<F> for Zeta<F>
where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
loop {
let u = rng.sample(OpenClosed01);
let x = u.powf(-F::one() / self.s_minus_1).floor();
debug_assert!(x >= F::one());
if x.is_infinite() {
// For sufficiently small `s`, `x` will always be infinite,
// which is rejected, resulting in an infinite loop. We avoid
// this by always returning infinity instead.
return x;
}
let t = (F::one() + F::one() / x).powf(self.s_minus_1);
let v = rng.sample(Standard);
if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
return x;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
let mut rng = crate::test::rng(213);
let mut buf = [zero; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}
#[test]
#[should_panic]
fn zeta_invalid() {
Zeta::new(1.).unwrap();
}
#[test]
#[should_panic]
fn zeta_nan() {
Zeta::new(f64::NAN).unwrap();
}
#[test]
fn zeta_sample() {
let a = 2.0;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(1);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zeta_small_a() {
let a = 1. + 1e-15;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zeta_value_stability() {
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
}
#[test]
fn zeta_distributions_can_be_compared() {
assert_eq!(Zeta::new(1.0), Zeta::new(1.0));
}
}

View File

@ -6,131 +6,12 @@
// option. This file may not be copied, modified, or distributed
// except according to those terms.
//! The Zeta and related distributions.
//! The Zipf distribution.
use crate::{Distribution, Standard};
use core::fmt;
use num_traits::Float;
use rand::{distributions::OpenClosed01, Rng};
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`.
///
/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution)
/// is a discrete probability distribution with parameter `s`.
/// It is a special case of the [`Zipf`] distribution with `n = ∞`.
/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or ZipfEstoup distribution.
///
/// # Density function
///
/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the
/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function).
///
/// # Plot
///
/// The following plot illustrates the zeta distribution for various values of `s`.
///
/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg)
///
/// # Example
/// ```
/// use rand::prelude::*;
/// use rand_distr::Zeta;
///
/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap());
/// println!("{}", val);
/// ```
///
/// # Notes
///
/// The zeta distribution has no upper limit. Sampled values may be infinite.
/// In particular, a value of infinity might be returned for the following
/// reasons:
/// 1. it is the best representation in the type `F` of the actual sample.
/// 2. to prevent infinite loops for very small `s`.
///
/// # Implementation details
///
/// We are using the algorithm from
/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8),
/// Section 6.1, page 551.
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct Zeta<F>
where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{
s_minus_1: F,
b: F,
}
/// Error type returned from `Zeta::new`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ZetaError {
/// `s <= 1` or `nan`.
STooSmall,
}
impl fmt::Display for ZetaError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
ZetaError::STooSmall => "s <= 1 or is NaN in Zeta distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for ZetaError {}
impl<F> Zeta<F>
where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{
/// Construct a new `Zeta` distribution with given `s` parameter.
#[inline]
pub fn new(s: F) -> Result<Zeta<F>, ZetaError> {
if !(s > F::one()) {
return Err(ZetaError::STooSmall);
}
let s_minus_1 = s - F::one();
let two = F::one() + F::one();
Ok(Zeta {
s_minus_1,
b: two.powf(s_minus_1),
})
}
}
impl<F> Distribution<F> for Zeta<F>
where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
loop {
let u = rng.sample(OpenClosed01);
let x = u.powf(-F::one() / self.s_minus_1).floor();
debug_assert!(x >= F::one());
if x.is_infinite() {
// For sufficiently small `s`, `x` will always be infinite,
// which is rejected, resulting in an infinite loop. We avoid
// this by always returning infinity instead.
return x;
}
let t = (F::one() + F::one() / x).powf(self.s_minus_1);
let v = rng.sample(Standard);
if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) {
return x;
}
}
}
}
use rand::Rng;
/// The Zipf (Zipfian) distribution `Zipf(n, s)`.
///
@ -175,26 +56,26 @@ where
q: F,
}
/// Error type returned from `Zipf::new`.
/// Error type returned from [`Zipf::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum ZipfError {
pub enum Error {
/// `s < 0` or `nan`.
STooSmall,
/// `n < 1`.
NTooSmall,
}
impl fmt::Display for ZipfError {
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution",
ZipfError::NTooSmall => "n < 1 in Zipf distribution",
Error::STooSmall => "s < 0 or is NaN in Zipf distribution",
Error::NTooSmall => "n < 1 in Zipf distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for ZipfError {}
impl std::error::Error for Error {}
impl<F> Zipf<F>
where
@ -206,12 +87,12 @@ where
///
/// For large `n`, rounding may occur to fit the number into the float type.
#[inline]
pub fn new(n: u64, s: F) -> Result<Zipf<F>, ZipfError> {
pub fn new(n: u64, s: F) -> Result<Zipf<F>, Error> {
if !(s >= F::zero()) {
return Err(ZipfError::STooSmall);
return Err(Error::STooSmall);
}
if n < 1 {
return Err(ZipfError::NTooSmall);
return Err(Error::NTooSmall);
}
let n = F::from(n).unwrap(); // This does not fail.
let q = if s != F::one() {
@ -282,46 +163,6 @@ mod tests {
assert_eq!(buf, expected);
}
#[test]
#[should_panic]
fn zeta_invalid() {
Zeta::new(1.).unwrap();
}
#[test]
#[should_panic]
fn zeta_nan() {
Zeta::new(f64::NAN).unwrap();
}
#[test]
fn zeta_sample() {
let a = 2.0;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(1);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zeta_small_a() {
let a = 1. + 1e-15;
let d = Zeta::new(a).unwrap();
let mut rng = crate::test::rng(2);
for _ in 0..1000 {
let r = d.sample(&mut rng);
assert!(r >= 1.);
}
}
#[test]
fn zeta_value_stability() {
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
}
#[test]
#[should_panic]
fn zipf_s_too_small() {
@ -392,9 +233,4 @@ mod tests {
fn zipf_distributions_can_be_compared() {
assert_eq!(Zipf::new(1, 2.0), Zipf::new(1, 2.0));
}
#[test]
fn zeta_distributions_can_be_compared() {
assert_eq!(Zeta::new(1.0), Zeta::new(1.0));
}
}

View File

@ -75,7 +75,7 @@ const ALWAYS_TRUE: u64 = u64::MAX;
// in `no_std` mode.
const SCALE: f64 = 2.0 * (1u64 << 63) as f64;
/// Error type returned from `Bernoulli::new`.
/// Error type returned from [`Bernoulli::new`].
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum BernoulliError {
/// `p < 0` or `p > 1`.

View File

@ -702,7 +702,7 @@ mod test {
}
}
/// Errors returned by weighted distributions
/// Errors returned by [`WeightedIndex::new`], [`WeightedIndex::update_weights`] and other weighted distributions
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WeightError {
/// The input weight sequence is empty, too long, or wrongly ordered