diff --git a/CHANGELOG.md b/CHANGELOG.md index 4945f5e8..e8baa39e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update. ### Additions - Use const-generics to support arrays of all sizes (#1104) - Implement `Clone` and `Copy` for `Alphanumeric` (#1126) +- Add `Distribution::map` to derive a distribution using a closure (#1129) ### Other - Reorder asserts in `Uniform` float distributions for easier debugging of non-finite arguments diff --git a/src/distributions/mod.rs b/src/distributions/mod.rs index 8171e30e..d1ae30c9 100644 --- a/src/distributions/mod.rs +++ b/src/distributions/mod.rs @@ -199,6 +199,35 @@ pub trait Distribution { phantom: ::core::marker::PhantomData, } } + + /// Create a distribution of values of 'S' by mapping the output of `Self` + /// through the closure `F` + /// + /// # Example + /// + /// ``` + /// use rand::thread_rng; + /// use rand::distributions::{Distribution, Uniform}; + /// + /// let mut rng = thread_rng(); + /// + /// let die = Uniform::new_inclusive(1, 6); + /// let even_number = die.map(|num| num % 2 == 0); + /// while !even_number.sample(&mut rng) { + /// println!("Still odd; rolling again!"); + /// } + /// ``` + fn map(self, func: F) -> DistMap + where + F: Fn(T) -> S, + Self: Sized, + { + DistMap { + distr: self, + func, + phantom: ::core::marker::PhantomData, + } + } } impl<'a, T, D: Distribution> Distribution for &'a D { @@ -256,6 +285,28 @@ where { } +/// A distribution of values of type `S` derived from the distribution `D` +/// by mapping its output of type `T` through the closure `F`. +/// +/// This `struct` is created by the [`Distribution::map`] method. +/// See its documentation for more. +#[derive(Debug)] +pub struct DistMap { + distr: D, + func: F, + phantom: ::core::marker::PhantomData S>, +} + +impl Distribution for DistMap +where + D: Distribution, + F: Fn(T) -> S, +{ + fn sample(&self, rng: &mut R) -> S { + (self.func)(self.distr.sample(rng)) + } +} + /// A generic random value distribution, implemented for many primitive types. /// Usually generates values with a numerically uniform distribution, and with a /// range appropriate to the type. @@ -360,6 +411,15 @@ mod tests { assert!(0. < sum && sum < 100.); } + #[test] + fn test_distributions_map() { + let dist = Uniform::new_inclusive(0, 5).map(|val| val + 15); + + let mut rng = crate::test::rng(212); + let val = dist.sample(&mut rng); + assert!(val >= 15 && val <= 20); + } + #[test] fn test_make_an_iter() { fn ten_dice_rolls_other_than_five(