Fix pert for mode approx eq mean; use builder pattern (#1452)
- Fix #1311 (mode close to mean) - Use a builder pattern, allowing specification via mode OR mean
This commit is contained in:
parent
d17ce4e0a1
commit
2584f48ace
@ -5,9 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
|
||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||
|
||||
## Unreleased
|
||||
|
||||
### Added
|
||||
- Add plots for `rand_distr` distributions to documentation (#1434)
|
||||
- Add `PertBuilder`, fix case where mode ≅ mean (#1452)
|
||||
|
||||
## [0.5.0-alpha.1] - 2024-03-18
|
||||
- Target `rand` version `0.9.0-alpha.1`
|
||||
|
@ -117,7 +117,7 @@ pub use self::normal_inverse_gaussian::{
|
||||
Error as NormalInverseGaussianError, NormalInverseGaussian,
|
||||
};
|
||||
pub use self::pareto::{Error as ParetoError, Pareto};
|
||||
pub use self::pert::{Pert, PertError};
|
||||
pub use self::pert::{Pert, PertBuilder, PertError};
|
||||
pub use self::poisson::{Error as PoissonError, Poisson};
|
||||
pub use self::skew_normal::{Error as SkewNormalError, SkewNormal};
|
||||
pub use self::triangular::{Triangular, TriangularError};
|
||||
|
@ -31,7 +31,7 @@ use rand::Rng;
|
||||
/// ```rust
|
||||
/// use rand_distr::{Pert, Distribution};
|
||||
///
|
||||
/// let d = Pert::new(0., 5., 2.5).unwrap();
|
||||
/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap();
|
||||
/// let v = d.sample(&mut rand::thread_rng());
|
||||
/// println!("{} is from a PERT distribution", v);
|
||||
/// ```
|
||||
@ -82,35 +82,75 @@ where
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Set up the PERT distribution with defined `min`, `max` and `mode`.
|
||||
/// Construct a PERT distribution with defined `min`, `max`
|
||||
///
|
||||
/// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`.
|
||||
/// # Example
|
||||
///
|
||||
/// ```
|
||||
/// use rand_distr::Pert;
|
||||
/// let pert_dist = Pert::new(0.0, 10.0)
|
||||
/// .with_shape(3.5)
|
||||
/// .with_mean(3.0)
|
||||
/// .unwrap();
|
||||
/// # let _unused: Pert<f64> = pert_dist;
|
||||
/// ```
|
||||
#[allow(clippy::new_ret_no_self)]
|
||||
#[inline]
|
||||
pub fn new(min: F, max: F, mode: F) -> Result<Pert<F>, PertError> {
|
||||
Pert::new_with_shape(min, max, mode, F::from(4.).unwrap())
|
||||
pub fn new(min: F, max: F) -> PertBuilder<F> {
|
||||
let shape = F::from(4.0).unwrap();
|
||||
PertBuilder { min, max, shape }
|
||||
}
|
||||
}
|
||||
|
||||
/// Struct used to build a [`Pert`]
|
||||
#[derive(Debug)]
|
||||
pub struct PertBuilder<F> {
|
||||
min: F,
|
||||
max: F,
|
||||
shape: F,
|
||||
}
|
||||
|
||||
impl<F> PertBuilder<F>
|
||||
where
|
||||
F: Float,
|
||||
StandardNormal: Distribution<F>,
|
||||
Exp1: Distribution<F>,
|
||||
Open01: Distribution<F>,
|
||||
{
|
||||
/// Set the shape parameter
|
||||
///
|
||||
/// If not specified, this defaults to 4.
|
||||
#[inline]
|
||||
pub fn with_shape(mut self, shape: F) -> PertBuilder<F> {
|
||||
self.shape = shape;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set up the PERT distribution with defined `min`, `max`, `mode` and
|
||||
/// `shape`.
|
||||
pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result<Pert<F>, PertError> {
|
||||
if !(max > min) {
|
||||
/// Specify the mean
|
||||
#[inline]
|
||||
pub fn with_mean(self, mean: F) -> Result<Pert<F>, PertError> {
|
||||
let two = F::from(2.0).unwrap();
|
||||
let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape;
|
||||
self.with_mode(mode)
|
||||
}
|
||||
|
||||
/// Specify the mode
|
||||
#[inline]
|
||||
pub fn with_mode(self, mode: F) -> Result<Pert<F>, PertError> {
|
||||
if !(self.max > self.min) {
|
||||
return Err(PertError::RangeTooSmall);
|
||||
}
|
||||
if !(mode >= min && max >= mode) {
|
||||
if !(mode >= self.min && self.max >= mode) {
|
||||
return Err(PertError::ModeRange);
|
||||
}
|
||||
if !(shape >= F::from(0.).unwrap()) {
|
||||
if !(self.shape >= F::from(0.).unwrap()) {
|
||||
return Err(PertError::ShapeTooSmall);
|
||||
}
|
||||
|
||||
let (min, max, shape) = (self.min, self.max, self.shape);
|
||||
let range = max - min;
|
||||
let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap());
|
||||
let v = if mu == mode {
|
||||
shape * F::from(0.5).unwrap() + F::from(1.).unwrap()
|
||||
} else {
|
||||
(mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min))
|
||||
};
|
||||
let w = v * (max - mu) / (mu - min);
|
||||
let v = F::from(1.0).unwrap() + shape * (mode - min) / range;
|
||||
let w = F::from(1.0).unwrap() + shape * (max - mode) / range;
|
||||
let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?;
|
||||
Ok(Pert { min, range, beta })
|
||||
}
|
||||
@ -136,17 +176,38 @@ mod test {
|
||||
#[test]
|
||||
fn test_pert() {
|
||||
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
|
||||
let _distr = Pert::new(min, max, mode).unwrap();
|
||||
let _distr = Pert::new(min, max).with_mode(mode).unwrap();
|
||||
// TODO: test correctness
|
||||
}
|
||||
|
||||
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
|
||||
assert!(Pert::new(min, max, mode).is_err());
|
||||
assert!(Pert::new(min, max).with_mode(mode).is_err());
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pert_distributions_can_be_compared() {
|
||||
assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0));
|
||||
fn distributions_can_be_compared() {
|
||||
let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0);
|
||||
let p1 = Pert::new(min, max).with_mode(mode).unwrap();
|
||||
let mean = (min + shape * mode + max) / (shape + 2.0);
|
||||
let p2 = Pert::new(min, max).with_mean(mean).unwrap();
|
||||
assert_eq!(p1, p2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn mode_almost_half_range() {
|
||||
assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn almost_symmetric_about_zero() {
|
||||
let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON);
|
||||
assert!(distr.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn almost_symmetric() {
|
||||
let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON);
|
||||
assert!(distr.is_ok());
|
||||
}
|
||||
}
|
||||
|
@ -250,7 +250,7 @@ fn pert_stability() {
|
||||
// mean = 4, var = 12/7
|
||||
test_samples(
|
||||
860,
|
||||
Pert::new(2., 10., 3.).unwrap(),
|
||||
Pert::new(2., 10.).with_mode(3.).unwrap(),
|
||||
&[
|
||||
4.908681667460367,
|
||||
4.014196196158352,
|
||||
|
Loading…
x
Reference in New Issue
Block a user