diff --git a/rand_distr/src/dirichlet.rs b/rand_distr/src/dirichlet.rs index 414d694e..825812d5 100644 --- a/rand_distr/src/dirichlet.rs +++ b/rand_distr/src/dirichlet.rs @@ -17,7 +17,7 @@ use rand::Rng; #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct DirichletFromGamma +struct DirichletFromGamma where F: Float, StandardNormal: Distribution, @@ -39,15 +39,15 @@ where // This function is part of a private implementation detail. // It assumes that the input is correct, so no validation is done. #[inline] - fn new(alpha: &[F]) -> Dirichlet { + fn new(alpha: &[F]) -> DirichletFromGamma { let gamma_dists = alpha .iter() .map(|a| Gamma::new(*a, F::one()).unwrap()) .collect::>>() .into_boxed_slice(); - Dirichlet::FromGamma(DirichletFromGamma { + DirichletFromGamma { samplers: gamma_dists, - }) + } } } @@ -76,7 +76,7 @@ where #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub struct DirichletFromBeta +struct DirichletFromBeta where F: Float, StandardNormal: Distribution, @@ -98,7 +98,7 @@ where // This function is part of a private implementation detail. // It assumes that the input is correct, so no validation is done. #[inline] - fn new(alpha: &[F]) -> Dirichlet { + fn new(alpha: &[F]) -> DirichletFromBeta { // Form the right-to-left cumulative sum of alpha, exluding the // first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then // after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl @@ -120,9 +120,9 @@ where .map(|t| Beta::new(*t.0, *t.1).unwrap()) .collect::>>() .into_boxed_slice(); - Dirichlet::FromBeta(DirichletFromBeta { + DirichletFromBeta { samplers: beta_dists, - }) + } } } @@ -148,6 +148,22 @@ where } } +#[derive(Clone, Debug, PartialEq)] +#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] +enum DirichletRepr +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Dirichlet distribution that generates samples using the gamma distribution. + FromGamma(DirichletFromGamma), + + /// Dirichlet distribution that generates samples using the beta distribution. + FromBeta(DirichletFromBeta), +} + /// The Dirichlet distribution `Dirichlet(alpha)`. /// /// The Dirichlet distribution is a family of continuous multivariate @@ -167,18 +183,14 @@ where #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[derive(Clone, Debug, PartialEq)] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] -pub enum Dirichlet +pub struct Dirichlet where F: Float, StandardNormal: Distribution, Exp1: Distribution, Open01: Distribution, { - /// Dirichlet distribution that generates samples using the gamma distribution. - FromGamma(DirichletFromGamma), - - /// Dirichlet distribution that generates samples using the beta distribution. - FromBeta(DirichletFromBeta), + repr: DirichletRepr, } /// Error type returned from `Dirchlet::new`. @@ -231,9 +243,13 @@ where if alpha.iter().all(|x| *x <= NumCast::from(0.1).unwrap()) { // All the values in alpha are less than 0.1. - Ok(DirichletFromBeta::new(alpha)) + Ok(Dirichlet { + repr: DirichletRepr::FromBeta(DirichletFromBeta::new(alpha)), + }) } else { - Ok(DirichletFromGamma::new(alpha)) + Ok(Dirichlet { + repr: DirichletRepr::FromGamma(DirichletFromGamma::new(alpha)), + }) } } @@ -260,9 +276,9 @@ where Open01: Distribution, { fn sample(&self, rng: &mut R) -> Vec { - match self { - Dirichlet::FromGamma(dirichlet) => dirichlet.sample(rng), - Dirichlet::FromBeta(dirichlet) => dirichlet.sample(rng), + match &self.repr { + DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng), + DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng), } } }