Convert Dirichlet from an enum to a struct containing an enum.

This commit is contained in:
warren 2023-02-28 20:01:42 -05:00
parent 4ecb35eaf0
commit 7513e838b7

View File

@ -17,7 +17,7 @@ use rand::Rng;
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct DirichletFromGamma<F>
struct DirichletFromGamma<F>
where
F: Float,
StandardNormal: Distribution<F>,
@ -39,15 +39,15 @@ where
// This function is part of a private implementation detail.
// It assumes that the input is correct, so no validation is done.
#[inline]
fn new(alpha: &[F]) -> Dirichlet<F> {
fn new(alpha: &[F]) -> DirichletFromGamma<F> {
let gamma_dists = alpha
.iter()
.map(|a| Gamma::new(*a, F::one()).unwrap())
.collect::<Vec<Gamma<F>>>()
.into_boxed_slice();
Dirichlet::FromGamma(DirichletFromGamma {
DirichletFromGamma {
samplers: gamma_dists,
})
}
}
}
@ -76,7 +76,7 @@ where
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct DirichletFromBeta<F>
struct DirichletFromBeta<F>
where
F: Float,
StandardNormal: Distribution<F>,
@ -98,7 +98,7 @@ where
// This function is part of a private implementation detail.
// It assumes that the input is correct, so no validation is done.
#[inline]
fn new(alpha: &[F]) -> Dirichlet<F> {
fn new(alpha: &[F]) -> DirichletFromBeta<F> {
// Form the right-to-left cumulative sum of alpha, exluding the
// first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
// after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
@ -120,9 +120,9 @@ where
.map(|t| Beta::new(*t.0, *t.1).unwrap())
.collect::<Vec<Beta<F>>>()
.into_boxed_slice();
Dirichlet::FromBeta(DirichletFromBeta {
DirichletFromBeta {
samplers: beta_dists,
})
}
}
}
@ -148,6 +148,22 @@ where
}
}
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
enum DirichletRepr<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Dirichlet distribution that generates samples using the gamma distribution.
FromGamma(DirichletFromGamma<F>),
/// Dirichlet distribution that generates samples using the beta distribution.
FromBeta(DirichletFromBeta<F>),
}
/// The Dirichlet distribution `Dirichlet(alpha)`.
///
/// The Dirichlet distribution is a family of continuous multivariate
@ -167,18 +183,14 @@ where
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[derive(Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub enum Dirichlet<F>
pub struct Dirichlet<F>
where
F: Float,
StandardNormal: Distribution<F>,
Exp1: Distribution<F>,
Open01: Distribution<F>,
{
/// Dirichlet distribution that generates samples using the gamma distribution.
FromGamma(DirichletFromGamma<F>),
/// Dirichlet distribution that generates samples using the beta distribution.
FromBeta(DirichletFromBeta<F>),
repr: DirichletRepr<F>,
}
/// Error type returned from `Dirchlet::new`.
@ -231,9 +243,13 @@ where
if alpha.iter().all(|x| *x <= NumCast::from(0.1).unwrap()) {
// All the values in alpha are less than 0.1.
Ok(DirichletFromBeta::new(alpha))
Ok(Dirichlet {
repr: DirichletRepr::FromBeta(DirichletFromBeta::new(alpha)),
})
} else {
Ok(DirichletFromGamma::new(alpha))
Ok(Dirichlet {
repr: DirichletRepr::FromGamma(DirichletFromGamma::new(alpha)),
})
}
}
@ -260,9 +276,9 @@ where
Open01: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
match self {
Dirichlet::FromGamma(dirichlet) => dirichlet.sample(rng),
Dirichlet::FromBeta(dirichlet) => dirichlet.sample(rng),
match &self.repr {
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
}
}
}