diff --git a/rand_distr/src/beta.rs b/rand_distr/src/beta.rs new file mode 100644 index 00000000..9ef5d002 --- /dev/null +++ b/rand_distr/src/beta.rs @@ -0,0 +1,298 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Beta distribution. + +use crate::{Distribution, Open01}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// The algorithm used for sampling the Beta distribution. +/// +/// Reference: +/// +/// R. C. H. Cheng (1978). +/// Generating beta variates with nonintegral shape parameters. +/// Communications of the ACM 21, 317-322. +/// https://doi.org/10.1145/359460.359482 +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +enum BetaAlgorithm { + BB(BB), + BC(BC), +} + +/// Algorithm BB for `min(alpha, beta) > 1`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +struct BB { + alpha: N, + beta: N, + gamma: N, +} + +/// Algorithm BC for `min(alpha, beta) <= 1`. +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +struct BC { + alpha: N, + beta: N, + kappa1: N, + kappa2: N, +} + +/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`. +/// +/// The Beta distribution is a continuous probability distribution +/// defined on the interval `[0, 1]`. It is the conjugate prior for the +/// parameter `p` of the [`Binomial`][crate::Binomial] distribution. +/// +/// It has two shape parameters `α` (alpha) and `β` (beta) which control +/// the shape of the distribution. Both `a` and `β` must be greater than zero. +/// The distribution is symmetric when `α = β`. +/// +/// # Plot +/// +/// The plot shows the Beta distribution with various combinations +/// of `α` and `β`. +/// +/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{Distribution, Beta}; +/// +/// let beta = Beta::new(2.0, 5.0).unwrap(); +/// let v = beta.sample(&mut rand::thread_rng()); +/// println!("{} is from a Beta(2, 5) distribution", v); +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct Beta +where + F: Float, + Open01: Distribution, +{ + a: F, + b: F, + switched_params: bool, + algorithm: BetaAlgorithm, +} + +/// Error type returned from [`Beta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub enum Error { + /// `alpha <= 0` or `nan`. + AlphaTooSmall, + /// `beta <= 0` or `nan`. + BetaTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::AlphaTooSmall => "alpha is not positive in beta distribution", + Error::BetaTooSmall => "beta is not positive in beta distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Beta +where + F: Float, + Open01: Distribution, +{ + /// Construct an object representing the `Beta(alpha, beta)` + /// distribution. + pub fn new(alpha: F, beta: F) -> Result, Error> { + if !(alpha > F::zero()) { + return Err(Error::AlphaTooSmall); + } + if !(beta > F::zero()) { + return Err(Error::BetaTooSmall); + } + // From now on, we use the notation from the reference, + // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. + let (a0, b0) = (alpha, beta); + let (a, b, switched_params) = if a0 < b0 { + (a0, b0, false) + } else { + (b0, a0, true) + }; + if a > F::one() { + // Algorithm BB + let alpha = a + b; + + let two = F::from(2.).unwrap(); + let beta_numer = alpha - two; + let beta_denom = two * a * b - alpha; + let beta = (beta_numer / beta_denom).sqrt(); + + let gamma = a + F::one() / beta; + + Ok(Beta { + a, + b, + switched_params, + algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }), + }) + } else { + // Algorithm BC + // + // Here `a` is the maximum instead of the minimum. + let (a, b, switched_params) = (b, a, !switched_params); + let alpha = a + b; + let beta = F::one() / b; + let delta = F::one() + a - b; + let kappa1 = delta + * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b) + / (a * beta - F::from(14. / 18.).unwrap()); + let kappa2 = F::from(0.25).unwrap() + + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b; + + Ok(Beta { + a, + b, + switched_params, + algorithm: BetaAlgorithm::BC(BC { + alpha, + beta, + kappa1, + kappa2, + }), + }) + } + } +} + +impl Distribution for Beta +where + F: Float, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let mut w; + match self.algorithm { + BetaAlgorithm::BB(algo) => { + loop { + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + let z = u1 * u1 * u2; + let r = algo.gamma * v - F::from(4.).unwrap().ln(); + let s = self.a + r - w; + // 2. + if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z { + break; + } + // 3. + let t = z.ln(); + if s >= t { + break; + } + // 4. + if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { + break; + } + } + } + BetaAlgorithm::BC(algo) => { + loop { + let z; + // 1. + let u1 = rng.sample(Open01); + let u2 = rng.sample(Open01); + if u1 < F::from(0.5).unwrap() { + // 2. + let y = u1 * u2; + z = u1 * y; + if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { + continue; + } + } else { + // 3. + z = u1 * u1 * u2; + if z <= F::from(0.25).unwrap() { + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + break; + } + // 4. + if z >= algo.kappa2 { + continue; + } + } + // 5. + let v = algo.beta * (u1 / (F::one() - u1)).ln(); + w = self.a * v.exp(); + if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) + - F::from(4.).unwrap().ln() + < z.ln()) + { + break; + }; + } + } + }; + // 5. for BB, 6. for BC + if !self.switched_params { + if w == F::infinity() { + // Assuming `b` is finite, for large `w`: + return F::one(); + } + w / (self.b + w) + } else { + self.b / (self.b + w) + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_beta() { + let beta = Beta::new(1.0, 2.0).unwrap(); + let mut rng = crate::test::rng(201); + for _ in 0..1000 { + beta.sample(&mut rng); + } + } + + #[test] + #[should_panic] + fn test_beta_invalid_dof() { + Beta::new(0., 0.).unwrap(); + } + + #[test] + fn test_beta_small_param() { + let beta = Beta::::new(1e-3, 1e-3).unwrap(); + let mut rng = crate::test::rng(206); + for i in 0..1000 { + assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); + } + } + + #[test] + fn beta_distributions_can_be_compared() { + assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/binomial.rs b/rand_distr/src/binomial.rs index 514dbeca..02f7fc37 100644 --- a/rand_distr/src/binomial.rs +++ b/rand_distr/src/binomial.rs @@ -52,7 +52,7 @@ pub struct Binomial { p: f64, } -/// Error type returned from `Binomial::new`. +/// Error type returned from [`Binomial::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `p < 0` or `nan`. diff --git a/rand_distr/src/cauchy.rs b/rand_distr/src/cauchy.rs index 6d4ff4ec..1c6f9123 100644 --- a/rand_distr/src/cauchy.rs +++ b/rand_distr/src/cauchy.rs @@ -64,7 +64,7 @@ where scale: F, } -/// Error type returned from `Cauchy::new`. +/// Error type returned from [`Cauchy::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. diff --git a/rand_distr/src/chi_squared.rs b/rand_distr/src/chi_squared.rs new file mode 100644 index 00000000..fcdc397b --- /dev/null +++ b/rand_distr/src/chi_squared.rs @@ -0,0 +1,179 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Chi-squared distribution. + +use self::ChiSquaredRepr::*; + +use crate::{Distribution, Exp1, Gamma, Open01, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`. +/// +/// The chi-squared distribution is a continuous probability +/// distribution with parameter `k > 0` degrees of freedom. +/// +/// For `k > 0` integral, this distribution is the sum of the squares +/// of `k` independent standard normal random variables. For other +/// `k`, this uses the equivalent characterisation +/// `χ²(k) = Gamma(k/2, 2)`. +/// +/// # Plot +/// +/// The plot shows the chi-squared distribution with various degrees +/// of freedom. +/// +/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{ChiSquared, Distribution}; +/// +/// let chi = ChiSquared::new(11.0).unwrap(); +/// let v = chi.sample(&mut rand::thread_rng()); +/// println!("{} is from a χ²(11) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + repr: ChiSquaredRepr, +} + +/// Error type returned from [`ChiSquared::new`] and [`StudentT::new`](crate::StudentT::new). +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub enum Error { + /// `0.5 * k <= 0` or `nan`. + DoFTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::DoFTooSmall => { + "degrees-of-freedom k is not positive in chi-squared distribution" + } + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +enum ChiSquaredRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, + // e.g. when alpha = 1/2 as it would be for this case, so special- + // casing and using the definition of N(0,1)^2 is faster. + DoFExactlyOne, + DoFAnythingElse(Gamma), +} + +impl ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new chi-squared distribution with degrees-of-freedom + /// `k`. + pub fn new(k: F) -> Result, Error> { + let repr = if k == F::one() { + DoFExactlyOne + } else { + if !(F::from(0.5).unwrap() * k > F::zero()) { + return Err(Error::DoFTooSmall); + } + DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) + }; + Ok(ChiSquared { repr }) + } +} +impl Distribution for ChiSquared +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + match self.repr { + DoFExactlyOne => { + // k == 1 => N(0,1)^2 + let norm: F = rng.sample(StandardNormal); + norm * norm + } + DoFAnythingElse(ref g) => g.sample(rng), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_chi_squared_one() { + let chi = ChiSquared::new(1.0).unwrap(); + let mut rng = crate::test::rng(201); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + fn test_chi_squared_small() { + let chi = ChiSquared::new(0.5).unwrap(); + let mut rng = crate::test::rng(202); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + fn test_chi_squared_large() { + let chi = ChiSquared::new(30.0).unwrap(); + let mut rng = crate::test::rng(203); + for _ in 0..1000 { + chi.sample(&mut rng); + } + } + #[test] + #[should_panic] + fn test_chi_squared_invalid_dof() { + ChiSquared::new(-1.0).unwrap(); + } + + #[test] + fn gamma_distributions_can_be_compared() { + assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); + } + + #[test] + fn chi_squared_distributions_can_be_compared() { + assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); + } +} diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 96053084..aae1e075 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -31,7 +31,7 @@ where samplers: [Gamma; N], } -/// Error type returned from `DirchletFromGamma::new`. +/// Error type returned from [`DirchletFromGamma::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum DirichletFromGammaError { /// Gamma::new(a, 1) failed. @@ -103,7 +103,7 @@ where samplers: Box<[Beta]>, } -/// Error type returned from `DirchletFromBeta::new`. +/// Error type returned from [`DirchletFromBeta::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum DirichletFromBetaError { /// Beta::new(a, b) failed. @@ -226,7 +226,7 @@ where repr: DirichletRepr, } -/// Error type returned from `Dirchlet::new`. +/// Error type returned from [`Dirichlet::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `alpha.len() < 2`. diff --git a/rand_distr/src/exponential.rs b/rand_distr/src/exponential.rs index 4c919b20..c31fc952 100644 --- a/rand_distr/src/exponential.rs +++ b/rand_distr/src/exponential.rs @@ -130,7 +130,7 @@ where lambda_inverse: F, } -/// Error type returned from `Exp::new`. +/// Error type returned from [`Exp::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `lambda < 0` or `nan`. diff --git a/rand_distr/src/fisher_f.rs b/rand_distr/src/fisher_f.rs new file mode 100644 index 00000000..1a16b6d6 --- /dev/null +++ b/rand_distr/src/fisher_f.rs @@ -0,0 +1,131 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Fisher F-distribution. + +use crate::{ChiSquared, Distribution, Exp1, Open01, StandardNormal}; +use core::fmt; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`. +/// +/// This distribution is equivalent to the ratio of two normalised +/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / +/// (χ²(n)/n)`. +/// +/// # Plot +/// +/// The plot shows the F-distribution with various values of `m` and `n`. +/// +/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{FisherF, Distribution}; +/// +/// let f = FisherF::new(2.0, 32.0).unwrap(); +/// let v = f.sample(&mut rand::thread_rng()); +/// println!("{} is from an F(2, 32) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + numer: ChiSquared, + denom: ChiSquared, + // denom_dof / numer_dof so that this can just be a straight + // multiplication, rather than a division. + dof_ratio: F, +} + +/// Error type returned from [`FisherF::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub enum Error { + /// `m <= 0` or `nan`. + MTooSmall, + /// `n <= 0` or `nan`. + NTooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::MTooSmall => "m is not positive in Fisher F distribution", + Error::NTooSmall => "n is not positive in Fisher F distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new `FisherF` distribution, with the given parameter. + pub fn new(m: F, n: F) -> Result, Error> { + let zero = F::zero(); + if !(m > zero) { + return Err(Error::MTooSmall); + } + if !(n > zero) { + return Err(Error::NTooSmall); + } + + Ok(FisherF { + numer: ChiSquared::new(m).unwrap(), + denom: ChiSquared::new(n).unwrap(), + dof_ratio: n / m, + }) + } +} +impl Distribution for FisherF +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_f() { + let f = FisherF::new(2.0, 32.0).unwrap(); + let mut rng = crate::test::rng(204); + for _ in 0..1000 { + f.sample(&mut rng); + } + } + + #[test] + fn fisher_f_distributions_can_be_compared() { + assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); + } +} diff --git a/rand_distr/src/frechet.rs b/rand_distr/src/frechet.rs index b274946d..831561d6 100644 --- a/rand_distr/src/frechet.rs +++ b/rand_distr/src/frechet.rs @@ -55,7 +55,7 @@ where shape: F, } -/// Error type returned from `Frechet::new`. +/// Error type returned from [`Frechet::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// location is infinite or NaN diff --git a/rand_distr/src/gamma.rs b/rand_distr/src/gamma.rs index 23051e45..4699bbb6 100644 --- a/rand_distr/src/gamma.rs +++ b/rand_distr/src/gamma.rs @@ -7,17 +7,11 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Gamma and derived distributions. +//! The Gamma distribution. -// We use the variable names from the published reference, therefore this -// warning is not helpful. -#![allow(clippy::many_single_char_names)] - -use self::ChiSquaredRepr::*; use self::GammaRepr::*; -use crate::normal::StandardNormal; -use crate::{Distribution, Exp, Exp1, Open01}; +use crate::{Distribution, Exp, Exp1, Open01, StandardNormal}; use core::fmt; use num_traits::Float; use rand::Rng; @@ -80,7 +74,7 @@ where repr: GammaRepr, } -/// Error type returned from `Gamma::new`. +/// Error type returned from [`Gamma::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `shape <= 0` or `nan`. @@ -276,632 +270,12 @@ where } } -/// The [chi-squared distribution](https://en.wikipedia.org/wiki/Chi-squared_distribution) `χ²(k)`. -/// -/// The chi-squared distribution is a continuous probability -/// distribution with parameter `k > 0` degrees of freedom. -/// -/// For `k > 0` integral, this distribution is the sum of the squares -/// of `k` independent standard normal random variables. For other -/// `k`, this uses the equivalent characterisation -/// `χ²(k) = Gamma(k/2, 2)`. -/// -/// # Plot -/// -/// The plot shows the chi-squared distribution with various degrees -/// of freedom. -/// -/// ![Chi-squared distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/chi_squared.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{ChiSquared, Distribution}; -/// -/// let chi = ChiSquared::new(11.0).unwrap(); -/// let v = chi.sample(&mut rand::thread_rng()); -/// println!("{} is from a χ²(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - repr: ChiSquaredRepr, -} - -/// Error type returned from `ChiSquared::new` and `StudentT::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum ChiSquaredError { - /// `0.5 * k <= 0` or `nan`. - DoFTooSmall, -} - -impl fmt::Display for ChiSquaredError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ChiSquaredError::DoFTooSmall => { - "degrees-of-freedom k is not positive in chi-squared distribution" - } - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for ChiSquaredError {} - -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum ChiSquaredRepr -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - // k == 1, Gamma(alpha, ..) is particularly slow for alpha < 1, - // e.g. when alpha = 1/2 as it would be for this case, so special- - // casing and using the definition of N(0,1)^2 is faster. - DoFExactlyOne, - DoFAnythingElse(Gamma), -} - -impl ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new chi-squared distribution with degrees-of-freedom - /// `k`. - pub fn new(k: F) -> Result, ChiSquaredError> { - let repr = if k == F::one() { - DoFExactlyOne - } else { - if !(F::from(0.5).unwrap() * k > F::zero()) { - return Err(ChiSquaredError::DoFTooSmall); - } - DoFAnythingElse(Gamma::new(F::from(0.5).unwrap() * k, F::from(2.0).unwrap()).unwrap()) - }; - Ok(ChiSquared { repr }) - } -} -impl Distribution for ChiSquared -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - match self.repr { - DoFExactlyOne => { - // k == 1 => N(0,1)^2 - let norm: F = rng.sample(StandardNormal); - norm * norm - } - DoFAnythingElse(ref g) => g.sample(rng), - } - } -} - -/// The [Fisher F-distribution](https://en.wikipedia.org/wiki/F-distribution) `F(m, n)`. -/// -/// This distribution is equivalent to the ratio of two normalised -/// chi-squared distributions, that is, `F(m,n) = (χ²(m)/m) / -/// (χ²(n)/n)`. -/// -/// # Plot -/// -/// The plot shows the F-distribution with various values of `m` and `n`. -/// -/// ![F-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/fisher_f.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{FisherF, Distribution}; -/// -/// let f = FisherF::new(2.0, 32.0).unwrap(); -/// let v = f.sample(&mut rand::thread_rng()); -/// println!("{} is from an F(2, 32) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - numer: ChiSquared, - denom: ChiSquared, - // denom_dof / numer_dof so that this can just be a straight - // multiplication, rather than a division. - dof_ratio: F, -} - -/// Error type returned from `FisherF::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum FisherFError { - /// `m <= 0` or `nan`. - MTooSmall, - /// `n <= 0` or `nan`. - NTooSmall, -} - -impl fmt::Display for FisherFError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - FisherFError::MTooSmall => "m is not positive in Fisher F distribution", - FisherFError::NTooSmall => "n is not positive in Fisher F distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for FisherFError {} - -impl FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new `FisherF` distribution, with the given parameter. - pub fn new(m: F, n: F) -> Result, FisherFError> { - let zero = F::zero(); - if !(m > zero) { - return Err(FisherFError::MTooSmall); - } - if !(n > zero) { - return Err(FisherFError::NTooSmall); - } - - Ok(FisherF { - numer: ChiSquared::new(m).unwrap(), - denom: ChiSquared::new(n).unwrap(), - dof_ratio: n / m, - }) - } -} -impl Distribution for FisherF -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - self.numer.sample(rng) / self.denom.sample(rng) * self.dof_ratio - } -} - -/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`. -/// -/// The t-distribution is a continuous probability distribution -/// parameterized by degrees of freedom `ν` (`nu`), which -/// arises when estimating the mean of a normally-distributed -/// population in situations where the sample size is small and -/// the population's standard deviation is unknown. -/// It is widely used in hypothesis testing. -/// -/// For `ν = 1`, this is equivalent to the standard -/// [`Cauchy`](crate::Cauchy) distribution, -/// and as `ν` diverges to infinity, `t(ν)` converges to -/// [`StandardNormal`](crate::StandardNormal). -/// -/// # Plot -/// -/// The plot shows the t-distribution with various degrees of freedom. -/// -/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{StudentT, Distribution}; -/// -/// let t = StudentT::new(11.0).unwrap(); -/// let v = t.sample(&mut rand::thread_rng()); -/// println!("{} is from a t(11) distribution", v) -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - chi: ChiSquared, - dof: F, -} - -impl StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - /// Create a new Student t-distribution with `ν` (nu) - /// degrees of freedom. - pub fn new(nu: F) -> Result, ChiSquaredError> { - Ok(StudentT { - chi: ChiSquared::new(nu)?, - dof: nu, - }) - } -} -impl Distribution for StudentT -where - F: Float, - StandardNormal: Distribution, - Exp1: Distribution, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let norm: F = rng.sample(StandardNormal); - norm * (self.dof / self.chi.sample(rng)).sqrt() - } -} - -/// The algorithm used for sampling the Beta distribution. -/// -/// Reference: -/// -/// R. C. H. Cheng (1978). -/// Generating beta variates with nonintegral shape parameters. -/// Communications of the ACM 21, 317-322. -/// https://doi.org/10.1145/359460.359482 -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -enum BetaAlgorithm { - BB(BB), - BC(BC), -} - -/// Algorithm BB for `min(alpha, beta) > 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct BB { - alpha: N, - beta: N, - gamma: N, -} - -/// Algorithm BC for `min(alpha, beta) <= 1`. -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -struct BC { - alpha: N, - beta: N, - kappa1: N, - kappa2: N, -} - -/// The [Beta distribution](https://en.wikipedia.org/wiki/Beta_distribution) `Beta(α, β)`. -/// -/// The Beta distribution is a continuous probability distribution -/// defined on the interval `[0, 1]`. It is the conjugate prior for the -/// parameter `p` of the [`Binomial`][crate::Binomial] distribution. -/// -/// It has two shape parameters `α` (alpha) and `β` (beta) which control -/// the shape of the distribution. Both `a` and `β` must be greater than zero. -/// The distribution is symmetric when `α = β`. -/// -/// # Plot -/// -/// The plot shows the Beta distribution with various combinations -/// of `α` and `β`. -/// -/// ![Beta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/beta.svg) -/// -/// # Example -/// -/// ``` -/// use rand_distr::{Distribution, Beta}; -/// -/// let beta = Beta::new(2.0, 5.0).unwrap(); -/// let v = beta.sample(&mut rand::thread_rng()); -/// println!("{} is from a Beta(2, 5) distribution", v); -/// ``` -#[derive(Clone, Copy, Debug, PartialEq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub struct Beta -where - F: Float, - Open01: Distribution, -{ - a: F, - b: F, - switched_params: bool, - algorithm: BetaAlgorithm, -} - -/// Error type returned from `Beta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] -pub enum BetaError { - /// `alpha <= 0` or `nan`. - AlphaTooSmall, - /// `beta <= 0` or `nan`. - BetaTooSmall, -} - -impl fmt::Display for BetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - BetaError::AlphaTooSmall => "alpha is not positive in beta distribution", - BetaError::BetaTooSmall => "beta is not positive in beta distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for BetaError {} - -impl Beta -where - F: Float, - Open01: Distribution, -{ - /// Construct an object representing the `Beta(alpha, beta)` - /// distribution. - pub fn new(alpha: F, beta: F) -> Result, BetaError> { - if !(alpha > F::zero()) { - return Err(BetaError::AlphaTooSmall); - } - if !(beta > F::zero()) { - return Err(BetaError::BetaTooSmall); - } - // From now on, we use the notation from the reference, - // i.e. `alpha` and `beta` are renamed to `a0` and `b0`. - let (a0, b0) = (alpha, beta); - let (a, b, switched_params) = if a0 < b0 { - (a0, b0, false) - } else { - (b0, a0, true) - }; - if a > F::one() { - // Algorithm BB - let alpha = a + b; - - let two = F::from(2.).unwrap(); - let beta_numer = alpha - two; - let beta_denom = two * a * b - alpha; - let beta = (beta_numer / beta_denom).sqrt(); - - let gamma = a + F::one() / beta; - - Ok(Beta { - a, - b, - switched_params, - algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }), - }) - } else { - // Algorithm BC - // - // Here `a` is the maximum instead of the minimum. - let (a, b, switched_params) = (b, a, !switched_params); - let alpha = a + b; - let beta = F::one() / b; - let delta = F::one() + a - b; - let kappa1 = delta - * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b) - / (a * beta - F::from(14. / 18.).unwrap()); - let kappa2 = F::from(0.25).unwrap() - + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b; - - Ok(Beta { - a, - b, - switched_params, - algorithm: BetaAlgorithm::BC(BC { - alpha, - beta, - kappa1, - kappa2, - }), - }) - } - } -} - -impl Distribution for Beta -where - F: Float, - Open01: Distribution, -{ - fn sample(&self, rng: &mut R) -> F { - let mut w; - match self.algorithm { - BetaAlgorithm::BB(algo) => { - loop { - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - let z = u1 * u1 * u2; - let r = algo.gamma * v - F::from(4.).unwrap().ln(); - let s = self.a + r - w; - // 2. - if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z { - break; - } - // 3. - let t = z.ln(); - if s >= t { - break; - } - // 4. - if !(r + algo.alpha * (algo.alpha / (self.b + w)).ln() < t) { - break; - } - } - } - BetaAlgorithm::BC(algo) => { - loop { - let z; - // 1. - let u1 = rng.sample(Open01); - let u2 = rng.sample(Open01); - if u1 < F::from(0.5).unwrap() { - // 2. - let y = u1 * u2; - z = u1 * y; - if F::from(0.25).unwrap() * u2 + z - y >= algo.kappa1 { - continue; - } - } else { - // 3. - z = u1 * u1 * u2; - if z <= F::from(0.25).unwrap() { - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - break; - } - // 4. - if z >= algo.kappa2 { - continue; - } - } - // 5. - let v = algo.beta * (u1 / (F::one() - u1)).ln(); - w = self.a * v.exp(); - if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) - - F::from(4.).unwrap().ln() - < z.ln()) - { - break; - }; - } - } - }; - // 5. for BB, 6. for BC - if !self.switched_params { - if w == F::infinity() { - // Assuming `b` is finite, for large `w`: - return F::one(); - } - w / (self.b + w) - } else { - self.b / (self.b + w) - } - } -} - #[cfg(test)] mod test { use super::*; - #[test] - fn test_chi_squared_one() { - let chi = ChiSquared::new(1.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_small() { - let chi = ChiSquared::new(0.5).unwrap(); - let mut rng = crate::test::rng(202); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - fn test_chi_squared_large() { - let chi = ChiSquared::new(30.0).unwrap(); - let mut rng = crate::test::rng(203); - for _ in 0..1000 { - chi.sample(&mut rng); - } - } - #[test] - #[should_panic] - fn test_chi_squared_invalid_dof() { - ChiSquared::new(-1.0).unwrap(); - } - - #[test] - fn test_f() { - let f = FisherF::new(2.0, 32.0).unwrap(); - let mut rng = crate::test::rng(204); - for _ in 0..1000 { - f.sample(&mut rng); - } - } - - #[test] - fn test_t() { - let t = StudentT::new(11.0).unwrap(); - let mut rng = crate::test::rng(205); - for _ in 0..1000 { - t.sample(&mut rng); - } - } - - #[test] - fn test_beta() { - let beta = Beta::new(1.0, 2.0).unwrap(); - let mut rng = crate::test::rng(201); - for _ in 0..1000 { - beta.sample(&mut rng); - } - } - - #[test] - #[should_panic] - fn test_beta_invalid_dof() { - Beta::new(0., 0.).unwrap(); - } - - #[test] - fn test_beta_small_param() { - let beta = Beta::::new(1e-3, 1e-3).unwrap(); - let mut rng = crate::test::rng(206); - for i in 0..1000 { - assert!(!beta.sample(&mut rng).is_nan(), "failed at i={}", i); - } - } - #[test] fn gamma_distributions_can_be_compared() { assert_eq!(Gamma::new(1.0, 2.0), Gamma::new(1.0, 2.0)); } - - #[test] - fn beta_distributions_can_be_compared() { - assert_eq!(Beta::new(1.0, 2.0), Beta::new(1.0, 2.0)); - } - - #[test] - fn chi_squared_distributions_can_be_compared() { - assert_eq!(ChiSquared::new(1.0), ChiSquared::new(1.0)); - } - - #[test] - fn fisher_f_distributions_can_be_compared() { - assert_eq!(FisherF::new(1.0, 2.0), FisherF::new(1.0, 2.0)); - } - - #[test] - fn student_t_distributions_can_be_compared() { - assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); - } } diff --git a/rand_distr/src/geometric.rs b/rand_distr/src/geometric.rs index e54496d8..9beb9382 100644 --- a/rand_distr/src/geometric.rs +++ b/rand_distr/src/geometric.rs @@ -46,7 +46,7 @@ pub struct Geometric { k: u64, } -/// Error type returned from `Geometric::new`. +/// Error type returned from [`Geometric::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `p < 0 || p > 1` or `nan` diff --git a/rand_distr/src/gumbel.rs b/rand_distr/src/gumbel.rs index fd9324ac..6a7f1ae7 100644 --- a/rand_distr/src/gumbel.rs +++ b/rand_distr/src/gumbel.rs @@ -52,7 +52,7 @@ where scale: F, } -/// Error type returned from `Gumbel::new`. +/// Error type returned from [`Gumbel::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// location is infinite or NaN diff --git a/rand_distr/src/hypergeometric.rs b/rand_distr/src/hypergeometric.rs index c15b143b..4e4f4306 100644 --- a/rand_distr/src/hypergeometric.rs +++ b/rand_distr/src/hypergeometric.rs @@ -68,7 +68,7 @@ pub struct Hypergeometric { sampling_method: SamplingMethod, } -/// Error type returned from `Hypergeometric::new`. +/// Error type returned from [`Hypergeometric::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `total_population_size` is too large, causing floating point underflow. diff --git a/rand_distr/src/inverse_gaussian.rs b/rand_distr/src/inverse_gaussian.rs index 1039f604..4a5aad79 100644 --- a/rand_distr/src/inverse_gaussian.rs +++ b/rand_distr/src/inverse_gaussian.rs @@ -5,7 +5,7 @@ use core::fmt; use num_traits::Float; use rand::Rng; -/// Error type returned from `InverseGaussian::new` +/// Error type returned from [`InverseGaussian::new`] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Error { /// `mean <= 0` or `nan`. diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index 2394f549..a3852256 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -98,16 +98,16 @@ pub use rand::distributions::{ Standard, Uniform, }; +pub use self::beta::{Beta, Error as BetaError}; pub use self::binomial::{Binomial, Error as BinomialError}; pub use self::cauchy::{Cauchy, Error as CauchyError}; +pub use self::chi_squared::{ChiSquared, Error as ChiSquaredError}; #[cfg(feature = "alloc")] pub use self::dirichlet::{Dirichlet, Error as DirichletError}; pub use self::exponential::{Error as ExpError, Exp, Exp1}; +pub use self::fisher_f::{Error as FisherFError, FisherF}; pub use self::frechet::{Error as FrechetError, Frechet}; -pub use self::gamma::{ - Beta, BetaError, ChiSquared, ChiSquaredError, Error as GammaError, FisherF, FisherFError, - Gamma, StudentT, -}; +pub use self::gamma::{Error as GammaError, Gamma}; pub use self::geometric::{Error as GeoError, Geometric, StandardGeometric}; pub use self::gumbel::{Error as GumbelError, Gumbel}; pub use self::hypergeometric::{Error as HyperGeoError, Hypergeometric}; @@ -126,9 +126,11 @@ pub use self::unit_circle::UnitCircle; pub use self::unit_disc::UnitDisc; pub use self::unit_sphere::UnitSphere; pub use self::weibull::{Error as WeibullError, Weibull}; -pub use self::zipf::{Zeta, ZetaError, Zipf, ZipfError}; +pub use self::zeta::{Error as ZetaError, Zeta}; +pub use self::zipf::{Error as ZipfError, Zipf}; #[cfg(feature = "alloc")] pub use rand::distributions::{WeightError, WeightedIndex}; +pub use student_t::StudentT; #[cfg(feature = "alloc")] pub use weighted_alias::WeightedAliasIndex; #[cfg(feature = "alloc")] @@ -192,10 +194,13 @@ pub mod weighted_alias; #[cfg(feature = "alloc")] pub mod weighted_tree; +mod beta; mod binomial; mod cauchy; +mod chi_squared; mod dirichlet; mod exponential; +mod fisher_f; mod frechet; mod gamma; mod geometric; @@ -208,6 +213,7 @@ mod pareto; mod pert; mod poisson; mod skew_normal; +mod student_t; mod triangular; mod unit_ball; mod unit_circle; @@ -215,5 +221,6 @@ mod unit_disc; mod unit_sphere; mod utils; mod weibull; +mod zeta; mod ziggurat_tables; mod zipf; diff --git a/rand_distr/src/normal.rs b/rand_distr/src/normal.rs index 1b698ec4..f6e6adc4 100644 --- a/rand_distr/src/normal.rs +++ b/rand_distr/src/normal.rs @@ -153,7 +153,7 @@ where std_dev: F, } -/// Error type returned from `Normal::new` and `LogNormal::new`. +/// Error type returned from [`Normal::new`] and [`LogNormal::new`](crate::LogNormal::new). #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// The mean value is too small (log-normal samples must be positive) diff --git a/rand_distr/src/normal_inverse_gaussian.rs b/rand_distr/src/normal_inverse_gaussian.rs index f8f62170..2c7fe71a 100644 --- a/rand_distr/src/normal_inverse_gaussian.rs +++ b/rand_distr/src/normal_inverse_gaussian.rs @@ -3,7 +3,7 @@ use core::fmt; use num_traits::Float; use rand::Rng; -/// Error type returned from `NormalInverseGaussian::new` +/// Error type returned from [`NormalInverseGaussian::new`] #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum Error { /// `alpha <= 0` or `nan`. diff --git a/rand_distr/src/pareto.rs b/rand_distr/src/pareto.rs index ba0465f7..f8b86c70 100644 --- a/rand_distr/src/pareto.rs +++ b/rand_distr/src/pareto.rs @@ -46,7 +46,7 @@ where inv_neg_shape: F, } -/// Error type returned from `Pareto::new`. +/// Error type returned from [`Pareto::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. diff --git a/rand_distr/src/poisson.rs b/rand_distr/src/poisson.rs index c84d4dce..78675ad9 100644 --- a/rand_distr/src/poisson.rs +++ b/rand_distr/src/poisson.rs @@ -54,7 +54,7 @@ where magic_val: F, } -/// Error type returned from `Poisson::new`. +/// Error type returned from [`Poisson::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `lambda <= 0` diff --git a/rand_distr/src/skew_normal.rs b/rand_distr/src/skew_normal.rs index 6ef521be..8ef88428 100644 --- a/rand_distr/src/skew_normal.rs +++ b/rand_distr/src/skew_normal.rs @@ -68,7 +68,7 @@ where shape: F, } -/// Error type returned from `SkewNormal::new`. +/// Error type returned from [`SkewNormal::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// The scale parameter is not finite or it is less or equal to zero. diff --git a/rand_distr/src/student_t.rs b/rand_distr/src/student_t.rs new file mode 100644 index 00000000..86a5fb5b --- /dev/null +++ b/rand_distr/src/student_t.rs @@ -0,0 +1,107 @@ +// Copyright 2018 Developers of the Rand project. +// Copyright 2013 The Rust Project Developers. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Student's t-distribution. + +use crate::{ChiSquared, ChiSquaredError}; +use crate::{Distribution, Exp1, Open01, StandardNormal}; +use num_traits::Float; +use rand::Rng; +#[cfg(feature = "serde1")] +use serde::{Deserialize, Serialize}; + +/// The [Student t-distribution](https://en.wikipedia.org/wiki/Student%27s_t-distribution) `t(ν)`. +/// +/// The t-distribution is a continuous probability distribution +/// parameterized by degrees of freedom `ν` (`nu`), which +/// arises when estimating the mean of a normally-distributed +/// population in situations where the sample size is small and +/// the population's standard deviation is unknown. +/// It is widely used in hypothesis testing. +/// +/// For `ν = 1`, this is equivalent to the standard +/// [`Cauchy`](crate::Cauchy) distribution, +/// and as `ν` diverges to infinity, `t(ν)` converges to +/// [`StandardNormal`](crate::StandardNormal). +/// +/// # Plot +/// +/// The plot shows the t-distribution with various degrees of freedom. +/// +/// ![T-distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/student_t.svg) +/// +/// # Example +/// +/// ``` +/// use rand_distr::{StudentT, Distribution}; +/// +/// let t = StudentT::new(11.0).unwrap(); +/// let v = t.sample(&mut rand::thread_rng()); +/// println!("{} is from a t(11) distribution", v) +/// ``` +#[derive(Clone, Copy, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +pub struct StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + chi: ChiSquared, + dof: F, +} + +impl StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Create a new Student t-distribution with `ν` (nu) + /// degrees of freedom. + pub fn new(nu: F) -> Result, ChiSquaredError> { + Ok(StudentT { + chi: ChiSquared::new(nu)?, + dof: nu, + }) + } +} +impl Distribution for StudentT +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + fn sample(&self, rng: &mut R) -> F { + let norm: F = rng.sample(StandardNormal); + norm * (self.dof / self.chi.sample(rng)).sqrt() + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_t() { + let t = StudentT::new(11.0).unwrap(); + let mut rng = crate::test::rng(205); + for _ in 0..1000 { + t.sample(&mut rng); + } + } + + #[test] + fn student_t_distributions_can_be_compared() { + assert_eq!(StudentT::new(1.0), StudentT::new(1.0)); + } +} diff --git a/rand_distr/src/weibull.rs b/rand_distr/src/weibull.rs index e6f80736..145a4df3 100644 --- a/rand_distr/src/weibull.rs +++ b/rand_distr/src/weibull.rs @@ -44,7 +44,7 @@ where scale: F, } -/// Error type returned from `Weibull::new`. +/// Error type returned from [`Weibull::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum Error { /// `scale <= 0` or `nan`. diff --git a/rand_distr/src/zeta.rs b/rand_distr/src/zeta.rs new file mode 100644 index 00000000..da146883 --- /dev/null +++ b/rand_distr/src/zeta.rs @@ -0,0 +1,192 @@ +// Copyright 2021 Developers of the Rand project. +// +// Licensed under the Apache License, Version 2.0 or the MIT license +// , at your +// option. This file may not be copied, modified, or distributed +// except according to those terms. + +//! The Zeta distribution. + +use crate::{Distribution, Standard}; +use core::fmt; +use num_traits::Float; +use rand::{distributions::OpenClosed01, Rng}; + +/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`. +/// +/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) +/// is a discrete probability distribution with parameter `s`. +/// It is a special case of the [`Zipf`](crate::Zipf) distribution with `n = ∞`. +/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution. +/// +/// # Density function +/// +/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the +/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function). +/// +/// # Plot +/// +/// The following plot illustrates the zeta distribution for various values of `s`. +/// +/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg) +/// +/// # Example +/// ``` +/// use rand::prelude::*; +/// use rand_distr::Zeta; +/// +/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap()); +/// println!("{}", val); +/// ``` +/// +/// # Notes +/// +/// The zeta distribution has no upper limit. Sampled values may be infinite. +/// In particular, a value of infinity might be returned for the following +/// reasons: +/// 1. it is the best representation in the type `F` of the actual sample. +/// 2. to prevent infinite loops for very small `s`. +/// +/// # Implementation details +/// +/// We are using the algorithm from +/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8), +/// Section 6.1, page 551. +#[derive(Clone, Copy, Debug, PartialEq)] +pub struct Zeta +where + F: Float, + Standard: Distribution, + OpenClosed01: Distribution, +{ + s_minus_1: F, + b: F, +} + +/// Error type returned from [`Zeta::new`]. +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +pub enum Error { + /// `s <= 1` or `nan`. + STooSmall, +} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(match self { + Error::STooSmall => "s <= 1 or is NaN in Zeta distribution", + }) + } +} + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl Zeta +where + F: Float, + Standard: Distribution, + OpenClosed01: Distribution, +{ + /// Construct a new `Zeta` distribution with given `s` parameter. + #[inline] + pub fn new(s: F) -> Result, Error> { + if !(s > F::one()) { + return Err(Error::STooSmall); + } + let s_minus_1 = s - F::one(); + let two = F::one() + F::one(); + Ok(Zeta { + s_minus_1, + b: two.powf(s_minus_1), + }) + } +} + +impl Distribution for Zeta +where + F: Float, + Standard: Distribution, + OpenClosed01: Distribution, +{ + #[inline] + fn sample(&self, rng: &mut R) -> F { + loop { + let u = rng.sample(OpenClosed01); + let x = u.powf(-F::one() / self.s_minus_1).floor(); + debug_assert!(x >= F::one()); + if x.is_infinite() { + // For sufficiently small `s`, `x` will always be infinite, + // which is rejected, resulting in an infinite loop. We avoid + // this by always returning infinity instead. + return x; + } + + let t = (F::one() + F::one() / x).powf(self.s_minus_1); + + let v = rng.sample(Standard); + if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { + return x; + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn test_samples>(distr: D, zero: F, expected: &[F]) { + let mut rng = crate::test::rng(213); + let mut buf = [zero; 4]; + for x in &mut buf { + *x = rng.sample(&distr); + } + assert_eq!(buf, expected); + } + + #[test] + #[should_panic] + fn zeta_invalid() { + Zeta::new(1.).unwrap(); + } + + #[test] + #[should_panic] + fn zeta_nan() { + Zeta::new(f64::NAN).unwrap(); + } + + #[test] + fn zeta_sample() { + let a = 2.0; + let d = Zeta::new(a).unwrap(); + let mut rng = crate::test::rng(1); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zeta_small_a() { + let a = 1. + 1e-15; + let d = Zeta::new(a).unwrap(); + let mut rng = crate::test::rng(2); + for _ in 0..1000 { + let r = d.sample(&mut rng); + assert!(r >= 1.); + } + } + + #[test] + fn zeta_value_stability() { + test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]); + test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]); + } + + #[test] + fn zeta_distributions_can_be_compared() { + assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); + } +} diff --git a/rand_distr/src/zipf.rs b/rand_distr/src/zipf.rs index c8e3fef9..70bb891a 100644 --- a/rand_distr/src/zipf.rs +++ b/rand_distr/src/zipf.rs @@ -6,131 +6,12 @@ // option. This file may not be copied, modified, or distributed // except according to those terms. -//! The Zeta and related distributions. +//! The Zipf distribution. use crate::{Distribution, Standard}; use core::fmt; use num_traits::Float; -use rand::{distributions::OpenClosed01, Rng}; - -/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) `Zeta(s)`. -/// -/// The [Zeta distribution](https://en.wikipedia.org/wiki/Zeta_distribution) -/// is a discrete probability distribution with parameter `s`. -/// It is a special case of the [`Zipf`] distribution with `n = ∞`. -/// It is also known as the discrete Pareto, Riemann-Zeta, Zipf, or Zipf–Estoup distribution. -/// -/// # Density function -/// -/// `f(k) = k^(-s) / ζ(s)` for `k >= 1`, where `ζ` is the -/// [Riemann zeta function](https://en.wikipedia.org/wiki/Riemann_zeta_function). -/// -/// # Plot -/// -/// The following plot illustrates the zeta distribution for various values of `s`. -/// -/// ![Zeta distribution](https://raw.githubusercontent.com/rust-random/charts/main/charts/zeta.svg) -/// -/// # Example -/// ``` -/// use rand::prelude::*; -/// use rand_distr::Zeta; -/// -/// let val: f64 = thread_rng().sample(Zeta::new(1.5).unwrap()); -/// println!("{}", val); -/// ``` -/// -/// # Notes -/// -/// The zeta distribution has no upper limit. Sampled values may be infinite. -/// In particular, a value of infinity might be returned for the following -/// reasons: -/// 1. it is the best representation in the type `F` of the actual sample. -/// 2. to prevent infinite loops for very small `s`. -/// -/// # Implementation details -/// -/// We are using the algorithm from -/// [Non-Uniform Random Variate Generation](https://doi.org/10.1007/978-1-4613-8643-8), -/// Section 6.1, page 551. -#[derive(Clone, Copy, Debug, PartialEq)] -pub struct Zeta -where - F: Float, - Standard: Distribution, - OpenClosed01: Distribution, -{ - s_minus_1: F, - b: F, -} - -/// Error type returned from `Zeta::new`. -#[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ZetaError { - /// `s <= 1` or `nan`. - STooSmall, -} - -impl fmt::Display for ZetaError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(match self { - ZetaError::STooSmall => "s <= 1 or is NaN in Zeta distribution", - }) - } -} - -#[cfg(feature = "std")] -impl std::error::Error for ZetaError {} - -impl Zeta -where - F: Float, - Standard: Distribution, - OpenClosed01: Distribution, -{ - /// Construct a new `Zeta` distribution with given `s` parameter. - #[inline] - pub fn new(s: F) -> Result, ZetaError> { - if !(s > F::one()) { - return Err(ZetaError::STooSmall); - } - let s_minus_1 = s - F::one(); - let two = F::one() + F::one(); - Ok(Zeta { - s_minus_1, - b: two.powf(s_minus_1), - }) - } -} - -impl Distribution for Zeta -where - F: Float, - Standard: Distribution, - OpenClosed01: Distribution, -{ - #[inline] - fn sample(&self, rng: &mut R) -> F { - loop { - let u = rng.sample(OpenClosed01); - let x = u.powf(-F::one() / self.s_minus_1).floor(); - debug_assert!(x >= F::one()); - if x.is_infinite() { - // For sufficiently small `s`, `x` will always be infinite, - // which is rejected, resulting in an infinite loop. We avoid - // this by always returning infinity instead. - return x; - } - - let t = (F::one() + F::one() / x).powf(self.s_minus_1); - - let v = rng.sample(Standard); - if v * x * (t - F::one()) * self.b <= t * (self.b - F::one()) { - return x; - } - } - } -} +use rand::Rng; /// The Zipf (Zipfian) distribution `Zipf(n, s)`. /// @@ -175,26 +56,26 @@ where q: F, } -/// Error type returned from `Zipf::new`. +/// Error type returned from [`Zipf::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] -pub enum ZipfError { +pub enum Error { /// `s < 0` or `nan`. STooSmall, /// `n < 1`. NTooSmall, } -impl fmt::Display for ZipfError { +impl fmt::Display for Error { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.write_str(match self { - ZipfError::STooSmall => "s < 0 or is NaN in Zipf distribution", - ZipfError::NTooSmall => "n < 1 in Zipf distribution", + Error::STooSmall => "s < 0 or is NaN in Zipf distribution", + Error::NTooSmall => "n < 1 in Zipf distribution", }) } } #[cfg(feature = "std")] -impl std::error::Error for ZipfError {} +impl std::error::Error for Error {} impl Zipf where @@ -206,12 +87,12 @@ where /// /// For large `n`, rounding may occur to fit the number into the float type. #[inline] - pub fn new(n: u64, s: F) -> Result, ZipfError> { + pub fn new(n: u64, s: F) -> Result, Error> { if !(s >= F::zero()) { - return Err(ZipfError::STooSmall); + return Err(Error::STooSmall); } if n < 1 { - return Err(ZipfError::NTooSmall); + return Err(Error::NTooSmall); } let n = F::from(n).unwrap(); // This does not fail. let q = if s != F::one() { @@ -282,46 +163,6 @@ mod tests { assert_eq!(buf, expected); } - #[test] - #[should_panic] - fn zeta_invalid() { - Zeta::new(1.).unwrap(); - } - - #[test] - #[should_panic] - fn zeta_nan() { - Zeta::new(f64::NAN).unwrap(); - } - - #[test] - fn zeta_sample() { - let a = 2.0; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(1); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_small_a() { - let a = 1. + 1e-15; - let d = Zeta::new(a).unwrap(); - let mut rng = crate::test::rng(2); - for _ in 0..1000 { - let r = d.sample(&mut rng); - assert!(r >= 1.); - } - } - - #[test] - fn zeta_value_stability() { - test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]); - test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]); - } - #[test] #[should_panic] fn zipf_s_too_small() { @@ -392,9 +233,4 @@ mod tests { fn zipf_distributions_can_be_compared() { assert_eq!(Zipf::new(1, 2.0), Zipf::new(1, 2.0)); } - - #[test] - fn zeta_distributions_can_be_compared() { - assert_eq!(Zeta::new(1.0), Zeta::new(1.0)); - } } diff --git a/src/distributions/bernoulli.rs b/src/distributions/bernoulli.rs index 80453496..e49b415f 100644 --- a/src/distributions/bernoulli.rs +++ b/src/distributions/bernoulli.rs @@ -75,7 +75,7 @@ const ALWAYS_TRUE: u64 = u64::MAX; // in `no_std` mode. const SCALE: f64 = 2.0 * (1u64 << 63) as f64; -/// Error type returned from `Bernoulli::new`. +/// Error type returned from [`Bernoulli::new`]. #[derive(Clone, Copy, Debug, PartialEq, Eq)] pub enum BernoulliError { /// `p < 0` or `p > 1`. diff --git a/src/distributions/weighted_index.rs b/src/distributions/weighted_index.rs index 8a887ce3..88fad5a8 100644 --- a/src/distributions/weighted_index.rs +++ b/src/distributions/weighted_index.rs @@ -702,7 +702,7 @@ mod test { } } -/// Errors returned by weighted distributions +/// Errors returned by [`WeightedIndex::new`], [`WeightedIndex::update_weights`] and other weighted distributions #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum WeightError { /// The input weight sequence is empty, too long, or wrongly ordered