Fix pert for mode approx eq mean; use builder pattern (#1452)
- Fix #1311 (mode close to mean) - Use a builder pattern, allowing specification via mode OR mean
This commit is contained in:
parent
d17ce4e0a1
commit
2584f48ace
@ -5,9 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
|||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
## Unreleased
|
## Unreleased
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
- Add plots for `rand_distr` distributions to documentation (#1434)
|
- Add plots for `rand_distr` distributions to documentation (#1434)
|
||||||
|
- Add `PertBuilder`, fix case where mode ≅ mean (#1452)
|
||||||
|
|
||||||
## [0.5.0-alpha.1] - 2024-03-18
|
## [0.5.0-alpha.1] - 2024-03-18
|
||||||
- Target `rand` version `0.9.0-alpha.1`
|
- Target `rand` version `0.9.0-alpha.1`
|
||||||
|
@ -117,7 +117,7 @@ pub use self::normal_inverse_gaussian::{
|
|||||||
Error as NormalInverseGaussianError, NormalInverseGaussian,
|
Error as NormalInverseGaussianError, NormalInverseGaussian,
|
||||||
};
|
};
|
||||||
pub use self::pareto::{Error as ParetoError, Pareto};
|
pub use self::pareto::{Error as ParetoError, Pareto};
|
||||||
pub use self::pert::{Pert, PertError};
|
pub use self::pert::{Pert, PertBuilder, PertError};
|
||||||
pub use self::poisson::{Error as PoissonError, Poisson};
|
pub use self::poisson::{Error as PoissonError, Poisson};
|
||||||
pub use self::skew_normal::{Error as SkewNormalError, SkewNormal};
|
pub use self::skew_normal::{Error as SkewNormalError, SkewNormal};
|
||||||
pub use self::triangular::{Triangular, TriangularError};
|
pub use self::triangular::{Triangular, TriangularError};
|
||||||
|
@ -31,7 +31,7 @@ use rand::Rng;
|
|||||||
/// ```rust
|
/// ```rust
|
||||||
/// use rand_distr::{Pert, Distribution};
|
/// use rand_distr::{Pert, Distribution};
|
||||||
///
|
///
|
||||||
/// let d = Pert::new(0., 5., 2.5).unwrap();
|
/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap();
|
||||||
/// let v = d.sample(&mut rand::thread_rng());
|
/// let v = d.sample(&mut rand::thread_rng());
|
||||||
/// println!("{} is from a PERT distribution", v);
|
/// println!("{} is from a PERT distribution", v);
|
||||||
/// ```
|
/// ```
|
||||||
@ -82,35 +82,75 @@ where
|
|||||||
Exp1: Distribution<F>,
|
Exp1: Distribution<F>,
|
||||||
Open01: Distribution<F>,
|
Open01: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Set up the PERT distribution with defined `min`, `max` and `mode`.
|
/// Construct a PERT distribution with defined `min`, `max`
|
||||||
///
|
///
|
||||||
/// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
|
/// # Example
|
||||||
|
///
|
||||||
|
/// ```
|
||||||
|
/// use rand_distr::Pert;
|
||||||
|
/// let pert_dist = Pert::new(0.0, 10.0)
|
||||||
|
/// .with_shape(3.5)
|
||||||
|
/// .with_mean(3.0)
|
||||||
|
/// .unwrap();
|
||||||
|
/// # let _unused: Pert<f64> = pert_dist;
|
||||||
|
/// ```
|
||||||
|
#[allow(clippy::new_ret_no_self)]
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
|
pub fn new(min: F, max: F) -> PertBuilder<F> {
|
||||||
Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
|
let shape = F::from(4.0).unwrap();
|
||||||
|
PertBuilder { min, max, shape }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Struct used to build a [`Pert`]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct PertBuilder<F> {
|
||||||
|
min: F,
|
||||||
|
max: F,
|
||||||
|
shape: F,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<F> PertBuilder<F>
|
||||||
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
|
Exp1: Distribution<F>,
|
||||||
|
Open01: Distribution<F>,
|
||||||
|
{
|
||||||
|
/// Set the shape parameter
|
||||||
|
///
|
||||||
|
/// If not specified, this defaults to 4.
|
||||||
|
#[inline]
|
||||||
|
pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
|
||||||
|
self.shape = shape;
|
||||||
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set up the PERT distribution with defined `min`, `max`, `mode` and
|
/// Specify the mean
|
||||||
/// `shape`.
|
#[inline]
|
||||||
pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
|
pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
|
||||||
if !(max > min) {
|
let two = F::from(2.0).unwrap();
|
||||||
|
let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
|
||||||
|
self.with_mode(mode)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Specify the mode
|
||||||
|
#[inline]
|
||||||
|
pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
|
||||||
|
if !(self.max > self.min) {
|
||||||
return Err(PertError::RangeTooSmall);
|
return Err(PertError::RangeTooSmall);
|
||||||
}
|
}
|
||||||
if !(mode >= min && max >= mode) {
|
if !(mode >= self.min && self.max >= mode) {
|
||||||
return Err(PertError::ModeRange);
|
return Err(PertError::ModeRange);
|
||||||
}
|
}
|
||||||
if !(shape >= F::from(0.).unwrap()) {
|
if !(self.shape >= F::from(0.).unwrap()) {
|
||||||
return Err(PertError::ShapeTooSmall);
|
return Err(PertError::ShapeTooSmall);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let (min, max, shape) = (self.min, self.max, self.shape);
|
||||||
let range = max - min;
|
let range = max - min;
|
||||||
let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap());
|
let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
|
||||||
let v = if mu == mode {
|
let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
|
||||||
shape * F::from(0.5).unwrap() + F::from(1.).unwrap()
|
|
||||||
} else {
|
|
||||||
(mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min))
|
|
||||||
};
|
|
||||||
let w = v * (max - mu) / (mu - min);
|
|
||||||
let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
|
let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
|
||||||
Ok(Pert { min, range, beta })
|
Ok(Pert { min, range, beta })
|
||||||
}
|
}
|
||||||
@ -136,17 +176,38 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_pert() {
|
fn test_pert() {
|
||||||
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
|
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
|
||||||
let _distr = Pert::new(min, max, mode).unwrap();
|
let _distr = Pert::new(min, max).with_mode(mode).unwrap();
|
||||||
// TODO: test correctness
|
// TODO: test correctness
|
||||||
}
|
}
|
||||||
|
|
||||||
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
|
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
|
||||||
assert!(Pert::new(min, max, mode).is_err());
|
assert!(Pert::new(min, max).with_mode(mode).is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pert_distributions_can_be_compared() {
|
fn distributions_can_be_compared() {
|
||||||
assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0));
|
let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
|
||||||
|
let p1 = Pert::new(min, max).with_mode(mode).unwrap();
|
||||||
|
let mean = (min + shape * mode + max) / (shape + 2.0);
|
||||||
|
let p2 = Pert::new(min, max).with_mean(mean).unwrap();
|
||||||
|
assert_eq!(p1, p2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn mode_almost_half_range() {
|
||||||
|
assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn almost_symmetric_about_zero() {
|
||||||
|
let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
|
||||||
|
assert!(distr.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn almost_symmetric() {
|
||||||
|
let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
|
||||||
|
assert!(distr.is_ok());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -250,7 +250,7 @@ fn pert_stability() {
|
|||||||
// mean = 4, var = 12/7
|
// mean = 4, var = 12/7
|
||||||
test_samples(
|
test_samples(
|
||||||
860,
|
860,
|
||||||
Pert::new(2., 10., 3.).unwrap(),
|
Pert::new(2., 10.).with_mode(3.).unwrap(),
|
||||||
&[
|
&[
|
||||||
4.908681667460367,
|
4.908681667460367,
|
||||||
4.014196196158352,
|
4.014196196158352,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user