Convert Dirichlet
from an enum
to a struct
containing an enum
.
This commit is contained in:
parent
4ecb35eaf0
commit
7513e838b7
@ -17,7 +17,7 @@ use rand::Rng;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct DirichletFromGamma<F>
|
||||
struct DirichletFromGamma<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
@ -39,15 +39,15 @@ where
|
||||
// This function is part of a private implementation detail.
|
||||
// It assumes that the input is correct, so no validation is done.
|
||||
#[inline]
|
||||
fn new(alpha: &[F]) -> Dirichlet<F> {
|
||||
fn new(alpha: &[F]) -> DirichletFromGamma<F> {
|
||||
let gamma_dists = alpha
|
||||
.iter()
|
||||
.map(|a| Gamma::new(*a, F::one()).unwrap())
|
||||
.collect::<Vec<Gamma<F>>>()
|
||||
.into_boxed_slice();
|
||||
Dirichlet::FromGamma(DirichletFromGamma {
|
||||
DirichletFromGamma {
|
||||
samplers: gamma_dists,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -76,7 +76,7 @@ where
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub struct DirichletFromBeta<F>
|
||||
struct DirichletFromBeta<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
@ -98,7 +98,7 @@ where
|
||||
// This function is part of a private implementation detail.
|
||||
// It assumes that the input is correct, so no validation is done.
|
||||
#[inline]
|
||||
fn new(alpha: &[F]) -> Dirichlet<F> {
|
||||
fn new(alpha: &[F]) -> DirichletFromBeta<F> {
|
||||
// Form the right-to-left cumulative sum of alpha, exluding the
|
||||
// first element of alpha. E.g. if alpha = [a0, a1, a2, a3], then
|
||||
// after the call to `alpha_sum_rl.reverse()` below, alpha_sum_rl
|
||||
@ -120,9 +120,9 @@ where
|
||||
.map(|t| Beta::new(*t.0, *t.1).unwrap())
|
||||
.collect::<Vec<Beta<F>>>()
|
||||
.into_boxed_slice();
|
||||
Dirichlet::FromBeta(DirichletFromBeta {
|
||||
DirichletFromBeta {
|
||||
samplers: beta_dists,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -148,6 +148,22 @@ where
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||
enum DirichletRepr<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Dirichlet distribution that generates samples using the gamma distribution.
|
||||
FromGamma(DirichletFromGamma<F>),
|
||||
|
||||
/// Dirichlet distribution that generates samples using the beta distribution.
|
||||
FromBeta(DirichletFromBeta<F>),
|
||||
}
|
||||
|
||||
/// The Dirichlet distribution `Dirichlet(alpha)`.
|
||||
///
|
||||
/// The Dirichlet distribution is a family of continuous multivariate
|
||||
@ -167,18 +183,14 @@ where
|
||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
||||
#[derive(Clone, Debug, PartialEq)]
|
||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||
pub enum Dirichlet<F>
|
||||
pub struct Dirichlet<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Dirichlet distribution that generates samples using the gamma distribution.
|
||||
FromGamma(DirichletFromGamma<F>),
|
||||
|
||||
/// Dirichlet distribution that generates samples using the beta distribution.
|
||||
FromBeta(DirichletFromBeta<F>),
|
||||
repr: DirichletRepr<F>,
|
||||
}
|
||||
|
||||
/// Error type returned from `Dirchlet::new`.
|
||||
@ -231,9 +243,13 @@ where
|
||||
|
||||
if alpha.iter().all(|x| *x <= NumCast::from(0.1).unwrap()) {
|
||||
// All the values in alpha are less than 0.1.
|
||||
Ok(DirichletFromBeta::new(alpha))
|
||||
Ok(Dirichlet {
|
||||
repr: DirichletRepr::FromBeta(DirichletFromBeta::new(alpha)),
|
||||
})
|
||||
} else {
|
||||
Ok(DirichletFromGamma::new(alpha))
|
||||
Ok(Dirichlet {
|
||||
repr: DirichletRepr::FromGamma(DirichletFromGamma::new(alpha)),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -260,9 +276,9 @@ where
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Vec<F> {
|
||||
match self {
|
||||
Dirichlet::FromGamma(dirichlet) => dirichlet.sample(rng),
|
||||
Dirichlet::FromBeta(dirichlet) => dirichlet.sample(rng),
|
||||
match &self.repr {
|
||||
DirichletRepr::FromGamma(dirichlet) => dirichlet.sample(rng),
|
||||
DirichletRepr::FromBeta(dirichlet) => dirichlet.sample(rng),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user