diff --git a/rand_distr/CHANGELOG.md b/rand_distr/CHANGELOG.md index cab75915..570474be 100644 --- a/rand_distr/CHANGELOG.md +++ b/rand_distr/CHANGELOG.md @@ -5,9 +5,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/) and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). ## Unreleased - ### Added - Add plots for `rand_distr` distributions to documentation (#1434) +- Add `PertBuilder`, fix case where mode ≅ mean (#1452) ## [0.5.0-alpha.1] - 2024-03-18 - Target `rand` version `0.9.0-alpha.1` diff --git a/rand_distr/src/lib.rs b/rand_distr/src/lib.rs index a3852256..f6f3ad54 100644 --- a/rand_distr/src/lib.rs +++ b/rand_distr/src/lib.rs @@ -117,7 +117,7 @@ pub use self::normal_inverse_gaussian::{ Error as NormalInverseGaussianError, NormalInverseGaussian, }; pub use self::pareto::{Error as ParetoError, Pareto}; -pub use self::pert::{Pert, PertError}; +pub use self::pert::{Pert, PertBuilder, PertError}; pub use self::poisson::{Error as PoissonError, Poisson}; pub use self::skew_normal::{Error as SkewNormalError, SkewNormal}; pub use self::triangular::{Triangular, TriangularError}; diff --git a/rand_distr/src/pert.rs b/rand_distr/src/pert.rs index df5361d7..ae268dad 100644 --- a/rand_distr/src/pert.rs +++ b/rand_distr/src/pert.rs @@ -31,7 +31,7 @@ use rand::Rng; /// ```rust /// use rand_distr::{Pert, Distribution}; /// -/// let d = Pert::new(0., 5., 2.5).unwrap(); +/// let d = Pert::new(0., 5.).with_mode(2.5).unwrap(); /// let v = d.sample(&mut rand::thread_rng()); /// println!("{} is from a PERT distribution", v); /// ``` @@ -82,35 +82,75 @@ where Exp1: Distribution, Open01: Distribution, { - /// Set up the PERT distribution with defined `min`, `max` and `mode`. + /// Construct a PERT distribution with defined `min`, `max` /// - /// This is equivalent to calling `Pert::new_with_shape` with `shape == 4.0`. + /// # Example + /// + /// ``` + /// use rand_distr::Pert; + /// let pert_dist = Pert::new(0.0, 10.0) + /// .with_shape(3.5) + /// .with_mean(3.0) + /// .unwrap(); + /// # let _unused: Pert = pert_dist; + /// ``` + #[allow(clippy::new_ret_no_self)] #[inline] - pub fn new(min: F, max: F, mode: F) -> Result, PertError> { - Pert::new_with_shape(min, max, mode, F::from(4.).unwrap()) + pub fn new(min: F, max: F) -> PertBuilder { + let shape = F::from(4.0).unwrap(); + PertBuilder { min, max, shape } + } +} + +/// Struct used to build a [`Pert`] +#[derive(Debug)] +pub struct PertBuilder { + min: F, + max: F, + shape: F, +} + +impl PertBuilder +where + F: Float, + StandardNormal: Distribution, + Exp1: Distribution, + Open01: Distribution, +{ + /// Set the shape parameter + /// + /// If not specified, this defaults to 4. + #[inline] + pub fn with_shape(mut self, shape: F) -> PertBuilder { + self.shape = shape; + self } - /// Set up the PERT distribution with defined `min`, `max`, `mode` and - /// `shape`. - pub fn new_with_shape(min: F, max: F, mode: F, shape: F) -> Result, PertError> { - if !(max > min) { + /// Specify the mean + #[inline] + pub fn with_mean(self, mean: F) -> Result, PertError> { + let two = F::from(2.0).unwrap(); + let mode = ((self.shape + two) * mean - self.min - self.max) / self.shape; + self.with_mode(mode) + } + + /// Specify the mode + #[inline] + pub fn with_mode(self, mode: F) -> Result, PertError> { + if !(self.max > self.min) { return Err(PertError::RangeTooSmall); } - if !(mode >= min && max >= mode) { + if !(mode >= self.min && self.max >= mode) { return Err(PertError::ModeRange); } - if !(shape >= F::from(0.).unwrap()) { + if !(self.shape >= F::from(0.).unwrap()) { return Err(PertError::ShapeTooSmall); } + let (min, max, shape) = (self.min, self.max, self.shape); let range = max - min; - let mu = (min + max + shape * mode) / (shape + F::from(2.).unwrap()); - let v = if mu == mode { - shape * F::from(0.5).unwrap() + F::from(1.).unwrap() - } else { - (mu - min) * (F::from(2.).unwrap() * mode - min - max) / ((mode - mu) * (max - min)) - }; - let w = v * (max - mu) / (mu - min); + let v = F::from(1.0).unwrap() + shape * (mode - min) / range; + let w = F::from(1.0).unwrap() + shape * (max - mode) / range; let beta = Beta::new(v, w).map_err(|_| PertError::RangeTooSmall)?; Ok(Pert { min, range, beta }) } @@ -136,17 +176,38 @@ mod test { #[test] fn test_pert() { for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] { - let _distr = Pert::new(min, max, mode).unwrap(); + let _distr = Pert::new(min, max).with_mode(mode).unwrap(); // TODO: test correctness } for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] { - assert!(Pert::new(min, max, mode).is_err()); + assert!(Pert::new(min, max).with_mode(mode).is_err()); } } #[test] - fn pert_distributions_can_be_compared() { - assert_eq!(Pert::new(1.0, 3.0, 2.0), Pert::new(1.0, 3.0, 2.0)); + fn distributions_can_be_compared() { + let (min, mode, max, shape) = (1.0, 2.0, 3.0, 4.0); + let p1 = Pert::new(min, max).with_mode(mode).unwrap(); + let mean = (min + shape * mode + max) / (shape + 2.0); + let p2 = Pert::new(min, max).with_mean(mean).unwrap(); + assert_eq!(p1, p2); + } + + #[test] + fn mode_almost_half_range() { + assert!(Pert::new(0.0f32, 0.48258883).with_mode(0.24129441).is_ok()); + } + + #[test] + fn almost_symmetric_about_zero() { + let distr = Pert::new(-10f32, 10f32).with_mode(f32::EPSILON); + assert!(distr.is_ok()); + } + + #[test] + fn almost_symmetric() { + let distr = Pert::new(0f32, 2f32).with_mode(1f32 + f32::EPSILON); + assert!(distr.is_ok()); } } diff --git a/rand_distr/tests/value_stability.rs b/rand_distr/tests/value_stability.rs index 31bfce52..b142741e 100644 --- a/rand_distr/tests/value_stability.rs +++ b/rand_distr/tests/value_stability.rs @@ -250,7 +250,7 @@ fn pert_stability() { // mean = 4, var = 12/7 test_samples( 860, - Pert::new(2., 10., 3.).unwrap(), + Pert::new(2., 10.).with_mode(3.).unwrap(), &[ 4.908681667460367, 4.014196196158352,