Apply rustfmt and fix Clippy warnings (#1448)

This commit is contained in:
Artyom Pavlov 2024-05-09 09:50:08 +03:00 committed by GitHub
parent e93776960e
commit 1b762b2867
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
64 changed files with 1533 additions and 952 deletions

23
.github/workflows/benches.yml vendored Normal file
View File

@ -0,0 +1,23 @@
name: Benches
on:
pull_request:
paths:
- ".github/workflows/benches.yml"
- "benches/**"
jobs:
benches:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@master
with:
toolchain: nightly
components: clippy, rustfmt
- name: Rustfmt
run: cargo fmt --all -- --check
- name: Clippy
run: cargo clippy --all --all-targets -- -D warnings
- name: Build
run: RUSTFLAGS=-Dwarnings cargo build --all-targets

View File

@ -72,8 +72,6 @@ jobs:
if: ${{ matrix.variant == 'minimal_versions' }} if: ${{ matrix.variant == 'minimal_versions' }}
run: | run: |
cargo generate-lockfile -Z minimal-versions cargo generate-lockfile -Z minimal-versions
# Overrides for dependencies with incorrect requirements (may need periodic updating)
cargo update -p regex --precise 1.5.1
- name: Maybe nightly - name: Maybe nightly
if: ${{ matrix.toolchain == 'nightly' }} if: ${{ matrix.toolchain == 'nightly' }}
run: | run: |

33
.github/workflows/workspace.yml vendored Normal file
View File

@ -0,0 +1,33 @@
name: Workspace
on:
pull_request:
paths-ignore:
- README.md
- "benches/**"
push:
branches: master
paths-ignore:
- README.md
- "benches/**"
jobs:
clippy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@master
with:
toolchain: 1.78.0
components: clippy
- run: cargo clippy --all --all-targets -- -D warnings
rustfmt:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: dtolnay/rust-toolchain@master
with:
toolchain: stable
components: rustfmt
- run: cargo fmt --all -- --check

View File

@ -58,12 +58,12 @@ unbiased = []
[workspace] [workspace]
members = [ members = [
"benches",
"rand_core", "rand_core",
"rand_distr", "rand_distr",
"rand_chacha", "rand_chacha",
"rand_pcg", "rand_pcg",
] ]
exclude = ["benches"]
[dependencies] [dependencies]
rand_core = { path = "rand_core", version = "=0.9.0-alpha.1", default-features = false } rand_core = { path = "rand_core", version = "=0.9.0-alpha.1", default-features = false }

View File

@ -50,7 +50,6 @@ gen_bytes!(gen_bytes_chacha8, ChaCha8Rng::from_os_rng());
gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_os_rng()); gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_os_rng());
gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_os_rng()); gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_os_rng());
gen_bytes!(gen_bytes_std, StdRng::from_os_rng()); gen_bytes!(gen_bytes_std, StdRng::from_os_rng());
#[cfg(feature = "small_rng")]
gen_bytes!(gen_bytes_small, SmallRng::from_thread_rng()); gen_bytes!(gen_bytes_small, SmallRng::from_thread_rng());
gen_bytes!(gen_bytes_os, UnwrapErr(OsRng)); gen_bytes!(gen_bytes_os, UnwrapErr(OsRng));
gen_bytes!(gen_bytes_thread, thread_rng()); gen_bytes!(gen_bytes_thread, thread_rng());
@ -81,7 +80,6 @@ gen_uint!(gen_u32_chacha8, u32, ChaCha8Rng::from_os_rng());
gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_os_rng()); gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_os_rng());
gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_os_rng()); gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_os_rng());
gen_uint!(gen_u32_std, u32, StdRng::from_os_rng()); gen_uint!(gen_u32_std, u32, StdRng::from_os_rng());
#[cfg(feature = "small_rng")]
gen_uint!(gen_u32_small, u32, SmallRng::from_thread_rng()); gen_uint!(gen_u32_small, u32, SmallRng::from_thread_rng());
gen_uint!(gen_u32_os, u32, UnwrapErr(OsRng)); gen_uint!(gen_u32_os, u32, UnwrapErr(OsRng));
gen_uint!(gen_u32_thread, u32, thread_rng()); gen_uint!(gen_u32_thread, u32, thread_rng());
@ -95,7 +93,6 @@ gen_uint!(gen_u64_chacha8, u64, ChaCha8Rng::from_os_rng());
gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_os_rng()); gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_os_rng());
gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_os_rng()); gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_os_rng());
gen_uint!(gen_u64_std, u64, StdRng::from_os_rng()); gen_uint!(gen_u64_std, u64, StdRng::from_os_rng());
#[cfg(feature = "small_rng")]
gen_uint!(gen_u64_small, u64, SmallRng::from_thread_rng()); gen_uint!(gen_u64_small, u64, SmallRng::from_thread_rng());
gen_uint!(gen_u64_os, u64, UnwrapErr(OsRng)); gen_uint!(gen_u64_os, u64, UnwrapErr(OsRng));
gen_uint!(gen_u64_thread, u64, thread_rng()); gen_uint!(gen_u64_thread, u64, thread_rng());

View File

@ -8,8 +8,10 @@
//! The ChaCha random number generator. //! The ChaCha random number generator.
#[cfg(not(feature = "std"))] use core; #[cfg(not(feature = "std"))]
#[cfg(feature = "std")] use std as core; use core;
#[cfg(feature = "std")]
use std as core;
use self::core::fmt; use self::core::fmt;
use crate::guts::ChaCha; use crate::guts::ChaCha;
@ -27,7 +29,8 @@ const BLOCK_WORDS: u8 = 16;
#[repr(transparent)] #[repr(transparent)]
pub struct Array64<T>([T; 64]); pub struct Array64<T>([T; 64]);
impl<T> Default for Array64<T> impl<T> Default for Array64<T>
where T: Default where
T: Default,
{ {
#[rustfmt::skip] #[rustfmt::skip]
fn default() -> Self { fn default() -> Self {
@ -54,7 +57,8 @@ impl<T> AsMut<[T]> for Array64<T> {
} }
} }
impl<T> Clone for Array64<T> impl<T> Clone for Array64<T>
where T: Copy + Default where
T: Copy + Default,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
let mut new = Self::default(); let mut new = Self::default();
@ -275,20 +279,25 @@ macro_rules! chacha_impl {
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
impl Serialize for $ChaChaXRng { impl Serialize for $ChaChaXRng {
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error> fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
where S: Serializer { where
S: Serializer,
{
$abst::$ChaChaXRng::from(self).serialize(s) $abst::$ChaChaXRng::from(self).serialize(s)
} }
} }
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
impl<'de> Deserialize<'de> for $ChaChaXRng { impl<'de> Deserialize<'de> for $ChaChaXRng {
fn deserialize<D>(d: D) -> Result<Self, D::Error> fn deserialize<D>(d: D) -> Result<Self, D::Error>
where D: Deserializer<'de> { where
D: Deserializer<'de>,
{
$abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x)) $abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x))
} }
} }
mod $abst { mod $abst {
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
// The abstract state of a ChaCha stream, independent of implementation choices. The // The abstract state of a ChaCha stream, independent of implementation choices. The
// comparison and serialization of this object is considered a semver-covered part of // comparison and serialization of this object is considered a semver-covered part of
@ -353,7 +362,8 @@ chacha_impl!(
mod test { mod test {
use rand_core::{RngCore, SeedableRng}; use rand_core::{RngCore, SeedableRng};
#[cfg(feature = "serde1")] use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng}; #[cfg(feature = "serde1")]
use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng};
type ChaChaRng = super::ChaCha20Rng; type ChaChaRng = super::ChaCha20Rng;

View File

@ -12,7 +12,9 @@
use ppv_lite86::{dispatch, dispatch_light128}; use ppv_lite86::{dispatch, dispatch_light128};
pub use ppv_lite86::Machine; pub use ppv_lite86::Machine;
use ppv_lite86::{vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector}; use ppv_lite86::{
vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector,
};
pub(crate) const BLOCK: usize = 16; pub(crate) const BLOCK: usize = 16;
pub(crate) const BLOCK64: u64 = BLOCK as u64; pub(crate) const BLOCK64: u64 = BLOCK as u64;
@ -140,14 +142,18 @@ fn add_pos<Mach: Machine>(m: Mach, d: Mach::u32x4, i: u64) -> Mach::u32x4 {
#[cfg(target_endian = "little")] #[cfg(target_endian = "little")]
fn d0123<Mach: Machine>(m: Mach, d: vec128_storage) -> Mach::u32x4x4 { fn d0123<Mach: Machine>(m: Mach, d: vec128_storage) -> Mach::u32x4x4 {
let d0: Mach::u64x2 = m.unpack(d); let d0: Mach::u64x2 = m.unpack(d);
let incr = Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]); let incr =
Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]);
m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into()) m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into())
} }
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]
#[inline(always)] #[inline(always)]
fn refill_wide_impl<Mach: Machine>( fn refill_wide_impl<Mach: Machine>(
m: Mach, state: &mut ChaCha, drounds: u32, out: &mut [u32; BUFSZ], m: Mach,
state: &mut ChaCha,
drounds: u32,
out: &mut [u32; BUFSZ],
) { ) {
let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]); let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]);
let b = m.unpack(state.b); let b = m.unpack(state.b);

View File

@ -1,4 +1,5 @@
#[cfg(feature = "alloc")] use alloc::boxed::Box; #[cfg(feature = "alloc")]
use alloc::boxed::Box;
use crate::{CryptoRng, RngCore, TryCryptoRng, TryRngCore}; use crate::{CryptoRng, RngCore, TryCryptoRng, TryRngCore};

View File

@ -56,7 +56,8 @@
use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks}; use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks};
use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore}; use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore};
use core::fmt; use core::fmt;
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// A trait for RNGs which do not generate random numbers individually, but in /// A trait for RNGs which do not generate random numbers individually, but in
/// blocks (typically `[u32; N]`). This technique is commonly used by /// blocks (typically `[u32; N]`). This technique is commonly used by

View File

@ -199,7 +199,7 @@ macro_rules! impl_try_rng_from_rng_core {
macro_rules! impl_try_crypto_rng_from_crypto_rng { macro_rules! impl_try_crypto_rng_from_crypto_rng {
($t:ty) => { ($t:ty) => {
$crate::impl_try_rng_from_rng_core!($t); $crate::impl_try_rng_from_rng_core!($t);
impl $crate::TryCryptoRng for $t {} impl $crate::TryCryptoRng for $t {}
/// Check at compile time that `$t` implements `CryptoRng` /// Check at compile time that `$t` implements `CryptoRng`

View File

@ -32,11 +32,14 @@
#![deny(missing_docs)] #![deny(missing_docs)]
#![deny(missing_debug_implementations)] #![deny(missing_debug_implementations)]
#![doc(test(attr(allow(unused_variables), deny(warnings))))] #![doc(test(attr(allow(unused_variables), deny(warnings))))]
#![allow(unexpected_cfgs)]
#![cfg_attr(doc_cfg, feature(doc_cfg))] #![cfg_attr(doc_cfg, feature(doc_cfg))]
#![no_std] #![no_std]
#[cfg(feature = "alloc")] extern crate alloc; #[cfg(feature = "alloc")]
#[cfg(feature = "std")] extern crate std; extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
use core::fmt; use core::fmt;
@ -44,11 +47,13 @@ mod blanket_impls;
pub mod block; pub mod block;
pub mod impls; pub mod impls;
pub mod le; pub mod le;
#[cfg(feature = "getrandom")] mod os; #[cfg(feature = "getrandom")]
mod os;
#[cfg(feature = "getrandom")] pub use getrandom;
#[cfg(feature = "getrandom")] pub use os::OsRng;
#[cfg(feature = "getrandom")]
pub use getrandom;
#[cfg(feature = "getrandom")]
pub use os::OsRng;
/// The core of a random number generator. /// The core of a random number generator.
/// ///
@ -213,14 +218,18 @@ pub trait TryRngCore {
/// Wrap RNG with the [`UnwrapErr`] wrapper. /// Wrap RNG with the [`UnwrapErr`] wrapper.
fn unwrap_err(self) -> UnwrapErr<Self> fn unwrap_err(self) -> UnwrapErr<Self>
where Self: Sized { where
Self: Sized,
{
UnwrapErr(self) UnwrapErr(self)
} }
/// Convert an [`RngCore`] to a [`RngReadAdapter`]. /// Convert an [`RngCore`] to a [`RngReadAdapter`].
#[cfg(feature = "std")] #[cfg(feature = "std")]
fn read_adapter(&mut self) -> RngReadAdapter<'_, Self> fn read_adapter(&mut self) -> RngReadAdapter<'_, Self>
where Self: Sized { where
Self: Sized,
{
RngReadAdapter { inner: self } RngReadAdapter { inner: self }
} }
} }

View File

@ -10,11 +10,11 @@
//! The binomial distribution. //! The binomial distribution.
use crate::{Distribution, Uniform}; use crate::{Distribution, Uniform};
use rand::Rng;
use core::fmt;
use core::cmp::Ordering; use core::cmp::Ordering;
use core::fmt;
#[allow(unused_imports)] #[allow(unused_imports)]
use num_traits::Float; use num_traits::Float;
use rand::Rng;
/// The binomial distribution `Binomial(n, p)`. /// The binomial distribution `Binomial(n, p)`.
/// ///
@ -110,21 +110,21 @@ impl Distribution<u64> for Binomial {
// Threshold for preferring the BINV algorithm. The paper suggests 10, // Threshold for preferring the BINV algorithm. The paper suggests 10,
// Ranlib uses 30, and GSL uses 14. // Ranlib uses 30, and GSL uses 14.
const BINV_THRESHOLD: f64 = 10.; const BINV_THRESHOLD: f64 = 10.;
// Same value as in GSL. // Same value as in GSL.
// It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again. // It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
// It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant. // It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
// When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away. // When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
const BINV_MAX_X : u64 = 110; const BINV_MAX_X: u64 = 110;
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) { if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) {
// Use the BINV algorithm. // Use the BINV algorithm.
let s = p / q; let s = p / q;
let a = ((self.n + 1) as f64) * s; let a = ((self.n + 1) as f64) * s;
result = 'outer: loop { result = 'outer: loop {
let mut r = q.powi(self.n as i32); let mut r = q.powi(self.n as i32);
let mut u: f64 = rng.gen(); let mut u: f64 = rng.random();
let mut x = 0; let mut x = 0;
while u > r { while u > r {
@ -136,7 +136,6 @@ impl Distribution<u64> for Binomial {
r *= a / (x as f64) - s; r *= a / (x as f64) - s;
} }
break x; break x;
} }
} else { } else {
// Use the BTPE algorithm. // Use the BTPE algorithm.
@ -238,7 +237,7 @@ impl Distribution<u64> for Binomial {
break; break;
} }
} }
}, }
Ordering::Greater => { Ordering::Greater => {
let mut i = y; let mut i = y;
loop { loop {
@ -248,8 +247,8 @@ impl Distribution<u64> for Binomial {
break; break;
} }
} }
}, }
Ordering::Equal => {}, Ordering::Equal => {}
} }
if v > f { if v > f {
continue; continue;
@ -366,7 +365,7 @@ mod test {
fn binomial_distributions_can_be_compared() { fn binomial_distributions_can_be_compared() {
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0)); assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
} }
#[test] #[test]
fn binomial_avoid_infinite_loop() { fn binomial_avoid_infinite_loop() {
let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap(); let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();

View File

@ -9,10 +9,10 @@
//! The Cauchy distribution. //! The Cauchy distribution.
use num_traits::{Float, FloatConst};
use crate::{Distribution, Standard}; use crate::{Distribution, Standard};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::{Float, FloatConst};
use rand::Rng;
/// The Cauchy distribution `Cauchy(median, scale)`. /// The Cauchy distribution `Cauchy(median, scale)`.
/// ///
@ -34,7 +34,9 @@ use core::fmt;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Cauchy<F> pub struct Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F> where
F: Float + FloatConst,
Standard: Distribution<F>,
{ {
median: F, median: F,
scale: F, scale: F,
@ -60,7 +62,9 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl<F> Cauchy<F> impl<F> Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F> where
F: Float + FloatConst,
Standard: Distribution<F>,
{ {
/// Construct a new `Cauchy` with the given shape parameters /// Construct a new `Cauchy` with the given shape parameters
/// `median` the peak location and `scale` the scale factor. /// `median` the peak location and `scale` the scale factor.
@ -73,7 +77,9 @@ where F: Float + FloatConst, Standard: Distribution<F>
} }
impl<F> Distribution<F> for Cauchy<F> impl<F> Distribution<F> for Cauchy<F>
where F: Float + FloatConst, Standard: Distribution<F> where
F: Float + FloatConst,
Standard: Distribution<F>,
{ {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
// sample from [0, 1) // sample from [0, 1)
@ -138,7 +144,9 @@ mod test {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn gen_samples<F: Float + FloatConst + fmt::Debug>(m: F, s: F, buf: &mut [F]) fn gen_samples<F: Float + FloatConst + fmt::Debug>(m: F, s: F, buf: &mut [F])
where Standard: Distribution<F> { where
Standard: Distribution<F>,
{
let distr = Cauchy::new(m, s).unwrap(); let distr = Cauchy::new(m, s).unwrap();
let mut rng = crate::test::rng(353); let mut rng = crate::test::rng(353);
for x in buf { for x in buf {
@ -148,12 +156,15 @@ mod test {
let mut buf = [0.0; 4]; let mut buf = [0.0; 4];
gen_samples(100f64, 10.0, &mut buf); gen_samples(100f64, 10.0, &mut buf);
assert_eq!(&buf, &[ assert_eq!(
77.93369152808678, &buf,
90.1606912098641, &[
125.31516221323625, 77.93369152808678,
86.10217834773925 90.1606912098641,
]); 125.31516221323625,
86.10217834773925
]
);
// Unfortunately this test is not fully portable due to reliance on the // Unfortunately this test is not fully portable due to reliance on the
// system's implementation of tanf (see doc on Cauchy struct). // system's implementation of tanf (see doc on Cauchy struct).

View File

@ -13,7 +13,8 @@ use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
use core::fmt; use core::fmt;
use num_traits::{Float, NumCast}; use num_traits::{Float, NumCast};
use rand::Rng; use rand::Rng;
#[cfg(feature = "serde_with")] use serde_with::serde_as; #[cfg(feature = "serde_with")]
use serde_with::serde_as;
use alloc::{boxed::Box, vec, vec::Vec}; use alloc::{boxed::Box, vec, vec::Vec};

View File

@ -10,10 +10,10 @@
//! The exponential distribution. //! The exponential distribution.
use crate::utils::ziggurat; use crate::utils::ziggurat;
use num_traits::Float;
use crate::{ziggurat_tables, Distribution}; use crate::{ziggurat_tables, Distribution};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
/// Samples floating-point numbers according to the exponential distribution, /// Samples floating-point numbers according to the exponential distribution,
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or /// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
@ -61,7 +61,7 @@ impl Distribution<f64> for Exp1 {
} }
#[inline] #[inline]
fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 { fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 {
ziggurat_tables::ZIG_EXP_R - rng.gen::<f64>().ln() ziggurat_tables::ZIG_EXP_R - rng.random::<f64>().ln()
} }
ziggurat( ziggurat(
@ -94,7 +94,9 @@ impl Distribution<f64> for Exp1 {
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Exp<F> pub struct Exp<F>
where F: Float, Exp1: Distribution<F> where
F: Float,
Exp1: Distribution<F>,
{ {
/// `lambda` stored as `1/lambda`, since this is what we scale by. /// `lambda` stored as `1/lambda`, since this is what we scale by.
lambda_inverse: F, lambda_inverse: F,
@ -120,16 +122,18 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl<F: Float> Exp<F> impl<F: Float> Exp<F>
where F: Float, Exp1: Distribution<F> where
F: Float,
Exp1: Distribution<F>,
{ {
/// Construct a new `Exp` with the given shape parameter /// Construct a new `Exp` with the given shape parameter
/// `lambda`. /// `lambda`.
/// ///
/// # Remarks /// # Remarks
/// ///
/// For custom types `N` implementing the [`Float`] trait, /// For custom types `N` implementing the [`Float`] trait,
/// the case `lambda = 0` is handled as follows: each sample corresponds /// the case `lambda = 0` is handled as follows: each sample corresponds
/// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types /// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types
/// yield infinity, since `1 / 0 = infinity`. /// yield infinity, since `1 / 0 = infinity`.
#[inline] #[inline]
pub fn new(lambda: F) -> Result<Exp<F>, Error> { pub fn new(lambda: F) -> Result<Exp<F>, Error> {
@ -143,7 +147,9 @@ where F: Float, Exp1: Distribution<F>
} }
impl<F> Distribution<F> for Exp<F> impl<F> Distribution<F> for Exp<F>
where F: Float, Exp1: Distribution<F> where
F: Float,
Exp1: Distribution<F>,
{ {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
rng.sample(Exp1) * self.lambda_inverse rng.sample(Exp1) * self.lambda_inverse

View File

@ -17,12 +17,12 @@ use self::ChiSquaredRepr::*;
use self::GammaRepr::*; use self::GammaRepr::*;
use crate::normal::StandardNormal; use crate::normal::StandardNormal;
use num_traits::Float;
use crate::{Distribution, Exp, Exp1, Open01}; use crate::{Distribution, Exp, Exp1, Open01};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// The Gamma distribution `Gamma(shape, scale)` distribution. /// The Gamma distribution `Gamma(shape, scale)` distribution.
/// ///
@ -566,7 +566,9 @@ where
F: Float, F: Float,
Open01: Distribution<F>, Open01: Distribution<F>,
{ {
a: F, b: F, switched_params: bool, a: F,
b: F,
switched_params: bool,
algorithm: BetaAlgorithm<F>, algorithm: BetaAlgorithm<F>,
} }
@ -618,15 +620,19 @@ where
if a > F::one() { if a > F::one() {
// Algorithm BB // Algorithm BB
let alpha = a + b; let alpha = a + b;
let beta = ((alpha - F::from(2.).unwrap())
/ (F::from(2.).unwrap()*a*b - alpha)).sqrt(); let two = F::from(2.).unwrap();
let beta_numer = alpha - two;
let beta_denom = two * a * b - alpha;
let beta = (beta_numer / beta_denom).sqrt();
let gamma = a + F::one() / beta; let gamma = a + F::one() / beta;
Ok(Beta { Ok(Beta {
a, b, switched_params, a,
algorithm: BetaAlgorithm::BB(BB { b,
alpha, beta, gamma, switched_params,
}) algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
}) })
} else { } else {
// Algorithm BC // Algorithm BC
@ -637,16 +643,21 @@ where
let beta = F::one() / b; let beta = F::one() / b;
let delta = F::one() + a - b; let delta = F::one() + a - b;
let kappa1 = delta let kappa1 = delta
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b) * (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
/ (a*beta - F::from(14. / 18.).unwrap()); / (a * beta - F::from(14. / 18.).unwrap());
let kappa2 = F::from(0.25).unwrap() let kappa2 = F::from(0.25).unwrap()
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b; + (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
Ok(Beta { Ok(Beta {
a, b, switched_params, a,
b,
switched_params,
algorithm: BetaAlgorithm::BC(BC { algorithm: BetaAlgorithm::BC(BC {
alpha, beta, kappa1, kappa2, alpha,
}) beta,
kappa1,
kappa2,
}),
}) })
} }
} }
@ -667,12 +678,11 @@ where
let u2 = rng.sample(Open01); let u2 = rng.sample(Open01);
let v = algo.beta * (u1 / (F::one() - u1)).ln(); let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp(); w = self.a * v.exp();
let z = u1*u1 * u2; let z = u1 * u1 * u2;
let r = algo.gamma * v - F::from(4.).unwrap().ln(); let r = algo.gamma * v - F::from(4.).unwrap().ln();
let s = self.a + r - w; let s = self.a + r - w;
// 2. // 2.
if s + F::one() + F::from(5.).unwrap().ln() if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
>= F::from(5.).unwrap() * z {
break; break;
} }
// 3. // 3.
@ -685,7 +695,7 @@ where
break; break;
} }
} }
}, }
BetaAlgorithm::BC(algo) => { BetaAlgorithm::BC(algo) => {
loop { loop {
let z; let z;
@ -716,11 +726,13 @@ where
let v = algo.beta * (u1 / (F::one() - u1)).ln(); let v = algo.beta * (u1 / (F::one() - u1)).ln();
w = self.a * v.exp(); w = self.a * v.exp();
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v) if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
- F::from(4.).unwrap().ln() < z.ln()) { - F::from(4.).unwrap().ln()
< z.ln())
{
break; break;
}; };
} }
}, }
}; };
// 5. for BB, 6. for BC // 5. for BB, 6. for BC
if !self.switched_params { if !self.switched_params {

View File

@ -1,20 +1,20 @@
//! The geometric distribution. //! The geometric distribution.
use crate::Distribution; use crate::Distribution;
use rand::Rng;
use core::fmt; use core::fmt;
#[allow(unused_imports)] #[allow(unused_imports)]
use num_traits::Float; use num_traits::Float;
use rand::Rng;
/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`. /// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`.
/// ///
/// This is the probability distribution of the number of failures before the /// This is the probability distribution of the number of failures before the
/// first success in a series of Bernoulli trials. It has the density function /// first success in a series of Bernoulli trials. It has the density function
/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success /// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success
/// on each trial. /// on each trial.
/// ///
/// This is the discrete analogue of the [exponential distribution](crate::Exp). /// This is the discrete analogue of the [exponential distribution](crate::Exp).
/// ///
/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised /// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised
/// implementation for `p = 0.5`. /// implementation for `p = 0.5`.
/// ///
@ -29,11 +29,10 @@ use num_traits::Float;
/// ``` /// ```
#[derive(Copy, Clone, Debug, PartialEq)] #[derive(Copy, Clone, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Geometric pub struct Geometric {
{
p: f64, p: f64,
pi: f64, pi: f64,
k: u64 k: u64,
} }
/// Error type returned from `Geometric::new`. /// Error type returned from `Geometric::new`.
@ -46,7 +45,9 @@ pub enum Error {
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self { f.write_str(match self {
Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution", Error::InvalidProbability => {
"p is NaN or outside the interval [0, 1] in geometric distribution"
}
}) })
} }
} }
@ -80,21 +81,24 @@ impl Geometric {
} }
} }
impl Distribution<u64> for Geometric impl Distribution<u64> for Geometric {
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
if self.p >= 2.0 / 3.0 { if self.p >= 2.0 / 3.0 {
// use the trivial algorithm: // use the trivial algorithm:
let mut failures = 0; let mut failures = 0;
loop { loop {
let u = rng.gen::<f64>(); let u = rng.random::<f64>();
if u <= self.p { break; } if u <= self.p {
break;
}
failures += 1; failures += 1;
} }
return failures; return failures;
} }
if self.p == 0.0 { return u64::MAX; } if self.p == 0.0 {
return u64::MAX;
}
let Geometric { p, pi, k } = *self; let Geometric { p, pi, k } = *self;
@ -108,7 +112,7 @@ impl Distribution<u64> for Geometric
// Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k: // Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k:
let d = { let d = {
let mut failures = 0; let mut failures = 0;
while rng.gen::<f64>() < pi { while rng.random::<f64>() < pi {
failures += 1; failures += 1;
} }
failures failures
@ -116,18 +120,18 @@ impl Distribution<u64> for Geometric
// Use rejection sampling for the remainder M from Geo(p) % 2^k: // Use rejection sampling for the remainder M from Geo(p) % 2^k:
// choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M // choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M
// NOTE: The paper suggests using bitwise sampling here, which is // NOTE: The paper suggests using bitwise sampling here, which is
// currently unsupported, but should improve performance by requiring // currently unsupported, but should improve performance by requiring
// fewer iterations on average. ~ October 28, 2020 // fewer iterations on average. ~ October 28, 2020
let m = loop { let m = loop {
let m = rng.gen::<u64>() & ((1 << k) - 1); let m = rng.random::<u64>() & ((1 << k) - 1);
let p_reject = if m <= i32::MAX as u64 { let p_reject = if m <= i32::MAX as u64 {
(1.0 - p).powi(m as i32) (1.0 - p).powi(m as i32)
} else { } else {
(1.0 - p).powf(m as f64) (1.0 - p).powf(m as f64)
}; };
let u = rng.gen::<f64>(); let u = rng.random::<f64>();
if u < p_reject { if u < p_reject {
break m; break m;
} }
@ -140,17 +144,17 @@ impl Distribution<u64> for Geometric
/// Samples integers according to the geometric distribution with success /// Samples integers according to the geometric distribution with success
/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`, /// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`,
/// but faster. /// but faster.
/// ///
/// See [`Geometric`](crate::Geometric) for the general geometric distribution. /// See [`Geometric`](crate::Geometric) for the general geometric distribution.
/// ///
/// Implemented via iterated /// Implemented via iterated
/// [`Rng::gen::<u64>().leading_zeros()`](Rng::gen::<u64>().leading_zeros()). /// [`Rng::gen::<u64>().leading_zeros()`](Rng::gen::<u64>().leading_zeros()).
/// ///
/// # Example /// # Example
/// ``` /// ```
/// use rand::prelude::*; /// use rand::prelude::*;
/// use rand_distr::StandardGeometric; /// use rand_distr::StandardGeometric;
/// ///
/// let v = StandardGeometric.sample(&mut thread_rng()); /// let v = StandardGeometric.sample(&mut thread_rng());
/// println!("{} is from a Geometric(0.5) distribution", v); /// println!("{} is from a Geometric(0.5) distribution", v);
/// ``` /// ```
@ -162,9 +166,11 @@ impl Distribution<u64> for StandardGeometric {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
let mut result = 0; let mut result = 0;
loop { loop {
let x = rng.gen::<u64>().leading_zeros() as u64; let x = rng.random::<u64>().leading_zeros() as u64;
result += x; result += x;
if x < 64 { break; } if x < 64 {
break;
}
} }
result result
} }

View File

@ -1,17 +1,20 @@
//! The hypergeometric distribution. //! The hypergeometric distribution.
use crate::Distribution; use crate::Distribution;
use rand::Rng;
use rand::distributions::uniform::Uniform;
use core::fmt; use core::fmt;
#[allow(unused_imports)] #[allow(unused_imports)]
use num_traits::Float; use num_traits::Float;
use rand::distributions::uniform::Uniform;
use rand::Rng;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
enum SamplingMethod { enum SamplingMethod {
InverseTransform{ initial_p: f64, initial_x: i64 }, InverseTransform {
RejectionAcceptance{ initial_p: f64,
initial_x: i64,
},
RejectionAcceptance {
m: f64, m: f64,
a: f64, a: f64,
lambda_l: f64, lambda_l: f64,
@ -20,24 +23,24 @@ enum SamplingMethod {
x_r: f64, x_r: f64,
p1: f64, p1: f64,
p2: f64, p2: f64,
p3: f64 p3: f64,
}, },
} }
/// The hypergeometric distribution `Hypergeometric(N, K, n)`. /// The hypergeometric distribution `Hypergeometric(N, K, n)`.
/// ///
/// This is the distribution of successes in samples of size `n` drawn without /// This is the distribution of successes in samples of size `n` drawn without
/// replacement from a population of size `N` containing `K` success states. /// replacement from a population of size `N` containing `K` success states.
/// It has the density function: /// It has the density function:
/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`, /// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`,
/// where `binomial(a, b) = a! / (b! * (a - b)!)`. /// where `binomial(a, b) = a! / (b! * (a - b)!)`.
/// ///
/// The [binomial distribution](crate::Binomial) is the analogous distribution /// The [binomial distribution](crate::Binomial) is the analogous distribution
/// for sampling with replacement. It is a good approximation when the population /// for sampling with replacement. It is a good approximation when the population
/// size is much larger than the sample size. /// size is much larger than the sample size.
/// ///
/// # Example /// # Example
/// ///
/// ``` /// ```
/// use rand_distr::{Distribution, Hypergeometric}; /// use rand_distr::{Distribution, Hypergeometric};
/// ///
@ -70,9 +73,15 @@ pub enum Error {
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self { f.write_str(match self {
Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution", Error::PopulationTooLarge => {
Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution", "total_population_size is too large causing underflow in geometric distribution"
Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution", }
Error::ProbabilityTooLarge => {
"population_with_feature > total_population_size in geometric distribution"
}
Error::SampleSizeTooLarge => {
"sample_size > total_population_size in geometric distribution"
}
}) })
} }
} }
@ -97,20 +106,20 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64,
if i <= min_top { if i <= min_top {
result *= i as f64; result *= i as f64;
} }
if i <= min_bottom { if i <= min_bottom {
result /= i as f64; result /= i as f64;
} }
if i <= max_top { if i <= max_top {
result *= i as f64; result *= i as f64;
} }
if i <= max_bottom { if i <= max_bottom {
result /= i as f64; result /= i as f64;
} }
} }
result result
} }
@ -126,7 +135,11 @@ impl Hypergeometric {
/// `K = population_with_feature`, /// `K = population_with_feature`,
/// `n = sample_size`. /// `n = sample_size`.
#[allow(clippy::many_single_char_names)] // Same names as in the reference. #[allow(clippy::many_single_char_names)] // Same names as in the reference.
pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result<Self, Error> { pub fn new(
total_population_size: u64,
population_with_feature: u64,
sample_size: u64,
) -> Result<Self, Error> {
if population_with_feature > total_population_size { if population_with_feature > total_population_size {
return Err(Error::ProbabilityTooLarge); return Err(Error::ProbabilityTooLarge);
} }
@ -151,7 +164,7 @@ impl Hypergeometric {
}; };
// when sampling more than half the total population, take the smaller // when sampling more than half the total population, take the smaller
// group as sampled instead (we can then return n1-x instead). // group as sampled instead (we can then return n1-x instead).
// //
// Note: the boundary condition given in the paper is `sample_size < n / 2`; // Note: the boundary condition given in the paper is `sample_size < n / 2`;
// we're deviating here, because when n is even, it doesn't matter whether // we're deviating here, because when n is even, it doesn't matter whether
// we switch here or not, but when n is odd `n/2 < n - n/2`, so switching // we switch here or not, but when n is odd `n/2 < n - n/2`, so switching
@ -167,7 +180,7 @@ impl Hypergeometric {
// Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`, // Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`,
// where `M` is the mode of the distribution. // where `M` is the mode of the distribution.
// Use algorithm HIN for the remaining parameter space. // Use algorithm HIN for the remaining parameter space.
// //
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer // Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer
// generation of hypergeometric random variates. // generation of hypergeometric random variates.
// J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145 // J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145
@ -176,21 +189,30 @@ impl Hypergeometric {
let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor(); let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD { let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
let (initial_p, initial_x) = if k < n2 { let (initial_p, initial_x) = if k < n2 {
(fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0) (
fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)),
0,
)
} else { } else {
(fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64) (
fraction_of_products_of_factorials((n1, k), (n, k - n2)),
(k - n2) as i64,
)
}; };
if initial_p <= 0.0 || !initial_p.is_finite() { if initial_p <= 0.0 || !initial_p.is_finite() {
return Err(Error::PopulationTooLarge); return Err(Error::PopulationTooLarge);
} }
SamplingMethod::InverseTransform { initial_p, initial_x } SamplingMethod::InverseTransform {
initial_p,
initial_x,
}
} else { } else {
let a = ln_of_factorial(m) + let a = ln_of_factorial(m)
ln_of_factorial(n1 as f64 - m) + + ln_of_factorial(n1 as f64 - m)
ln_of_factorial(k as f64 - m) + + ln_of_factorial(k as f64 - m)
ln_of_factorial((n2 - k) as f64 + m); + ln_of_factorial((n2 - k) as f64 + m);
let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64; let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
let denominator = (n - 1) as f64 * n as f64 * n as f64; let denominator = (n - 1) as f64 * n as f64 * n as f64;
@ -199,17 +221,19 @@ impl Hypergeometric {
let x_l = m - d + 0.5; let x_l = m - d + 0.5;
let x_r = m + d + 0.5; let x_r = m + d + 0.5;
let k_l = f64::exp(a - let k_l = f64::exp(
ln_of_factorial(x_l) - a - ln_of_factorial(x_l)
ln_of_factorial(n1 as f64 - x_l) - - ln_of_factorial(n1 as f64 - x_l)
ln_of_factorial(k as f64 - x_l) - - ln_of_factorial(k as f64 - x_l)
ln_of_factorial((n2 - k) as f64 + x_l)); - ln_of_factorial((n2 - k) as f64 + x_l),
let k_r = f64::exp(a - );
ln_of_factorial(x_r - 1.0) - let k_r = f64::exp(
ln_of_factorial(n1 as f64 - x_r + 1.0) - a - ln_of_factorial(x_r - 1.0)
ln_of_factorial(k as f64 - x_r + 1.0) - - ln_of_factorial(n1 as f64 - x_r + 1.0)
ln_of_factorial((n2 - k) as f64 + x_r - 1.0)); - ln_of_factorial(k as f64 - x_r + 1.0)
- ln_of_factorial((n2 - k) as f64 + x_r - 1.0),
);
let numerator = x_l * ((n2 - k) as f64 + x_l); let numerator = x_l * ((n2 - k) as f64 + x_l);
let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0); let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
let lambda_l = -((numerator / denominator).ln()); let lambda_l = -((numerator / denominator).ln());
@ -225,11 +249,26 @@ impl Hypergeometric {
let p3 = p2 + k_r / lambda_r; let p3 = p2 + k_r / lambda_r;
SamplingMethod::RejectionAcceptance { SamplingMethod::RejectionAcceptance {
m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 m,
a,
lambda_l,
lambda_r,
x_l,
x_r,
p1,
p2,
p3,
} }
}; };
Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method }) Ok(Hypergeometric {
n1,
n2,
k,
offset_x,
sign_x,
sampling_method,
})
} }
} }
@ -238,25 +277,47 @@ impl Distribution<u64> for Hypergeometric {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
use SamplingMethod::*; use SamplingMethod::*;
let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self; let Hypergeometric {
n1,
n2,
k,
sign_x,
offset_x,
sampling_method,
} = *self;
let x = match sampling_method { let x = match sampling_method {
InverseTransform { initial_p: mut p, initial_x: mut x } => { InverseTransform {
let mut u = rng.gen::<f64>(); initial_p: mut p,
while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense initial_x: mut x,
} => {
let mut u = rng.random::<f64>();
// the paper erroneously uses `until n < p`, which doesn't make any sense
while u > p && x < k as i64 {
u -= p; u -= p;
p *= ((n1 as i64 - x) * (k as i64 - x)) as f64; p *= ((n1 as i64 - x) * (k as i64 - x)) as f64;
p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64; p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64;
x += 1; x += 1;
} }
x x
}, }
RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => { RejectionAcceptance {
m,
a,
lambda_l,
lambda_r,
x_l,
x_r,
p1,
p2,
p3,
} => {
let distr_region_select = Uniform::new(0.0, p3).unwrap(); let distr_region_select = Uniform::new(0.0, p3).unwrap();
loop { loop {
let (y, v) = loop { let (y, v) = loop {
let u = distr_region_select.sample(rng); let u = distr_region_select.sample(rng);
let v = rng.gen::<f64>(); // for the accept/reject decision let v = rng.random::<f64>(); // for the accept/reject decision
if u <= p1 { if u <= p1 {
// Region 1, central bell // Region 1, central bell
let y = (x_l + u).floor(); let y = (x_l + u).floor();
@ -277,7 +338,7 @@ impl Distribution<u64> for Hypergeometric {
} }
} }
}; };
// Step 4: Acceptance/Rejection Comparison // Step 4: Acceptance/Rejection Comparison
if m < 100.0 || y <= 50.0 { if m < 100.0 || y <= 50.0 {
// Step 4.1: evaluate f(y) via recursive relationship // Step 4.1: evaluate f(y) via recursive relationship
@ -293,8 +354,10 @@ impl Distribution<u64> for Hypergeometric {
f /= (n1 - i) as f64 * (k - i) as f64; f /= (n1 - i) as f64 * (k - i) as f64;
} }
} }
if v <= f { break y as i64; } if v <= f {
break y as i64;
}
} else { } else {
// Step 4.2: Squeezing // Step 4.2: Squeezing
let y1 = y + 1.0; let y1 = y + 1.0;
@ -307,24 +370,24 @@ impl Distribution<u64> for Hypergeometric {
let t = ym / yk; let t = ym / yk;
let e = -ym / nk; let e = -ym / nk;
let g = yn * yk / (y1 * nk) - 1.0; let g = yn * yk / (y1 * nk) - 1.0;
let dg = if g < 0.0 { let dg = if g < 0.0 { 1.0 + g } else { 1.0 };
1.0 + g
} else {
1.0
};
let gu = g * (1.0 + g * (-0.5 + g / 3.0)); let gu = g * (1.0 + g * (-0.5 + g / 3.0));
let gl = gu - g.powi(4) / (4.0 * dg); let gl = gu - g.powi(4) / (4.0 * dg);
let xm = m + 0.5; let xm = m + 0.5;
let xn = n1 as f64 - m + 0.5; let xn = n1 as f64 - m + 0.5;
let xk = k as f64 - m + 0.5; let xk = k as f64 - m + 0.5;
let nm = n2 as f64 - k as f64 + xm; let nm = n2 as f64 - k as f64 + xm;
let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) + let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0))
xn * s * (1.0 + s * (-0.5 + s / 3.0)) + + xn * s * (1.0 + s * (-0.5 + s / 3.0))
xk * t * (1.0 + t * (-0.5 + t / 3.0)) + + xk * t * (1.0 + t * (-0.5 + t / 3.0))
nm * e * (1.0 + e * (-0.5 + e / 3.0)) + + nm * e * (1.0 + e * (-0.5 + e / 3.0))
y * gu - m * gl + 0.0034; + y * gu
- m * gl
+ 0.0034;
let av = v.ln(); let av = v.ln();
if av > ub { continue; } if av > ub {
continue;
}
let dr = if r < 0.0 { let dr = if r < 0.0 {
xm * r.powi(4) / (1.0 + r) xm * r.powi(4) / (1.0 + r)
} else { } else {
@ -345,17 +408,17 @@ impl Distribution<u64> for Hypergeometric {
} else { } else {
nm * e.powi(4) nm * e.powi(4)
}; };
if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 { if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 {
break y as i64; break y as i64;
} }
// Step 4.3: Final Acceptance/Rejection Test // Step 4.3: Final Acceptance/Rejection Test
let av_critical = a - let av_critical = a
ln_of_factorial(y) - - ln_of_factorial(y)
ln_of_factorial(n1 as f64 - y) - - ln_of_factorial(n1 as f64 - y)
ln_of_factorial(k as f64 - y) - - ln_of_factorial(k as f64 - y)
ln_of_factorial((n2 - k) as f64 + y); - ln_of_factorial((n2 - k) as f64 + y);
if v.ln() <= av_critical { if v.ln() <= av_critical {
break y as i64; break y as i64;
} }
@ -380,8 +443,7 @@ mod test {
assert!(Hypergeometric::new(100, 10, 5).is_ok()); assert!(Hypergeometric::new(100, 10, 5).is_ok());
} }
fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R) fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R) {
{
let distr = Hypergeometric::new(n, k, s).unwrap(); let distr = Hypergeometric::new(n, k, s).unwrap();
let expected_mean = s as f64 * k as f64 / n as f64; let expected_mean = s as f64 * k as f64 / n as f64;

View File

@ -1,7 +1,7 @@
use crate::{Distribution, Standard, StandardNormal}; use crate::{Distribution, Standard, StandardNormal};
use core::fmt;
use num_traits::Float; use num_traits::Float;
use rand::Rng; use rand::Rng;
use core::fmt;
/// Error type returned from `InverseGaussian::new` /// Error type returned from `InverseGaussian::new`
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -68,7 +68,9 @@ where
{ {
#[allow(clippy::many_single_char_names)] #[allow(clippy::many_single_char_names)]
fn sample<R>(&self, rng: &mut R) -> F fn sample<R>(&self, rng: &mut R) -> F
where R: Rng + ?Sized { where
R: Rng + ?Sized,
{
let mu = self.mean; let mu = self.mean;
let l = self.shape; let l = self.shape;
@ -79,7 +81,7 @@ where
let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt()); let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt());
let u: F = rng.gen(); let u: F = rng.random();
if u <= mu / (mu + x) { if u <= mu / (mu + x) {
return x; return x;
@ -112,6 +114,9 @@ mod tests {
#[test] #[test]
fn inverse_gaussian_distributions_can_be_compared() { fn inverse_gaussian_distributions_can_be_compared() {
assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0)); assert_eq!(
InverseGaussian::new(1.0, 2.0),
InverseGaussian::new(1.0, 2.0)
);
} }
} }

View File

@ -21,6 +21,7 @@
)] )]
#![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose #![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose
#![no_std] #![no_std]
#![allow(unexpected_cfgs)]
#![cfg_attr(doc_cfg, feature(doc_cfg))] #![cfg_attr(doc_cfg, feature(doc_cfg))]
//! Generating random samples from probability distributions. //! Generating random samples from probability distributions.
@ -178,10 +179,14 @@ mod test {
macro_rules! assert_almost_eq { macro_rules! assert_almost_eq {
($a:expr, $b:expr, $prec:expr) => { ($a:expr, $b:expr, $prec:expr) => {
let diff = ($a - $b).abs(); let diff = ($a - $b).abs();
assert!(diff <= $prec, assert!(
diff <= $prec,
"assertion failed: `abs(left - right) = {:.1e} < {:e}`, \ "assertion failed: `abs(left - right) = {:.1e} < {:e}`, \
(left: `{}`, right: `{}`)", (left: `{}`, right: `{}`)",
diff, $prec, $a, $b diff,
$prec,
$a,
$b
); );
}; };
} }

View File

@ -10,10 +10,10 @@
//! The normal and derived distributions. //! The normal and derived distributions.
use crate::utils::ziggurat; use crate::utils::ziggurat;
use num_traits::Float;
use crate::{ziggurat_tables, Distribution, Open01}; use crate::{ziggurat_tables, Distribution, Open01};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
/// Samples floating-point numbers according to the normal distribution /// Samples floating-point numbers according to the normal distribution
/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to /// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
@ -115,7 +115,9 @@ impl Distribution<f64> for StandardNormal {
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Normal<F> pub struct Normal<F>
where F: Float, StandardNormal: Distribution<F> where
F: Float,
StandardNormal: Distribution<F>,
{ {
mean: F, mean: F,
std_dev: F, std_dev: F,
@ -144,7 +146,9 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl<F> Normal<F> impl<F> Normal<F>
where F: Float, StandardNormal: Distribution<F> where
F: Float,
StandardNormal: Distribution<F>,
{ {
/// Construct, from mean and standard deviation /// Construct, from mean and standard deviation
/// ///
@ -204,14 +208,15 @@ where F: Float, StandardNormal: Distribution<F>
} }
impl<F> Distribution<F> for Normal<F> impl<F> Distribution<F> for Normal<F>
where F: Float, StandardNormal: Distribution<F> where
F: Float,
StandardNormal: Distribution<F>,
{ {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
self.from_zscore(rng.sample(StandardNormal)) self.from_zscore(rng.sample(StandardNormal))
} }
} }
/// The log-normal distribution `ln N(mean, std_dev**2)`. /// The log-normal distribution `ln N(mean, std_dev**2)`.
/// ///
/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)` /// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)`
@ -230,13 +235,17 @@ where F: Float, StandardNormal: Distribution<F>
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct LogNormal<F> pub struct LogNormal<F>
where F: Float, StandardNormal: Distribution<F> where
F: Float,
StandardNormal: Distribution<F>,
{ {
norm: Normal<F>, norm: Normal<F>,
} }
impl<F> LogNormal<F> impl<F> LogNormal<F>
where F: Float, StandardNormal: Distribution<F> where
F: Float,
StandardNormal: Distribution<F>,
{ {
/// Construct, from (log-space) mean and standard deviation /// Construct, from (log-space) mean and standard deviation
/// ///
@ -307,7 +316,9 @@ where F: Float, StandardNormal: Distribution<F>
} }
impl<F> Distribution<F> for LogNormal<F> impl<F> Distribution<F> for LogNormal<F>
where F: Float, StandardNormal: Distribution<F> where
F: Float,
StandardNormal: Distribution<F>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
@ -348,7 +359,10 @@ mod tests {
#[test] #[test]
fn test_log_normal_cv() { fn test_log_normal_cv() {
let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap(); let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (f64::NEG_INFINITY, 0.0)); assert_eq!(
(lnorm.norm.mean, lnorm.norm.std_dev),
(f64::NEG_INFINITY, 0.0)
);
let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap(); let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0)); assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));

View File

@ -1,7 +1,7 @@
use crate::{Distribution, InverseGaussian, Standard, StandardNormal}; use crate::{Distribution, InverseGaussian, Standard, StandardNormal};
use core::fmt;
use num_traits::Float; use num_traits::Float;
use rand::Rng; use rand::Rng;
use core::fmt;
/// Error type returned from `NormalInverseGaussian::new` /// Error type returned from `NormalInverseGaussian::new`
#[derive(Debug, Clone, Copy, PartialEq, Eq)] #[derive(Debug, Clone, Copy, PartialEq, Eq)]
@ -15,8 +15,12 @@ pub enum Error {
impl fmt::Display for Error { impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self { f.write_str(match self {
Error::AlphaNegativeOrNull => "alpha <= 0 or is NaN in normal inverse Gaussian distribution", Error::AlphaNegativeOrNull => {
Error::AbsoluteBetaNotLessThanAlpha => "|beta| >= alpha or is NaN in normal inverse Gaussian distribution", "alpha <= 0 or is NaN in normal inverse Gaussian distribution"
}
Error::AbsoluteBetaNotLessThanAlpha => {
"|beta| >= alpha or is NaN in normal inverse Gaussian distribution"
}
}) })
} }
} }
@ -75,7 +79,9 @@ where
Standard: Distribution<F>, Standard: Distribution<F>,
{ {
fn sample<R>(&self, rng: &mut R) -> F fn sample<R>(&self, rng: &mut R) -> F
where R: Rng + ?Sized { where
R: Rng + ?Sized,
{
let inv_gauss = rng.sample(self.inverse_gaussian); let inv_gauss = rng.sample(self.inverse_gaussian);
self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal) self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal)
@ -105,6 +111,9 @@ mod tests {
#[test] #[test]
fn normal_inverse_gaussian_distributions_can_be_compared() { fn normal_inverse_gaussian_distributions_can_be_compared() {
assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0)); assert_eq!(
NormalInverseGaussian::new(1.0, 2.0),
NormalInverseGaussian::new(1.0, 2.0)
);
} }
} }

View File

@ -8,10 +8,10 @@
//! The Pareto distribution. //! The Pareto distribution.
use num_traits::Float;
use crate::{Distribution, OpenClosed01}; use crate::{Distribution, OpenClosed01};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
/// Samples floating-point numbers according to the Pareto distribution /// Samples floating-point numbers according to the Pareto distribution
/// ///
@ -26,7 +26,9 @@ use core::fmt;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Pareto<F> pub struct Pareto<F>
where F: Float, OpenClosed01: Distribution<F> where
F: Float,
OpenClosed01: Distribution<F>,
{ {
scale: F, scale: F,
inv_neg_shape: F, inv_neg_shape: F,
@ -55,7 +57,9 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl<F> Pareto<F> impl<F> Pareto<F>
where F: Float, OpenClosed01: Distribution<F> where
F: Float,
OpenClosed01: Distribution<F>,
{ {
/// Construct a new Pareto distribution with given `scale` and `shape`. /// Construct a new Pareto distribution with given `scale` and `shape`.
/// ///
@ -78,7 +82,9 @@ where F: Float, OpenClosed01: Distribution<F>
} }
impl<F> Distribution<F> for Pareto<F> impl<F> Distribution<F> for Pareto<F>
where F: Float, OpenClosed01: Distribution<F> where
F: Float,
OpenClosed01: Distribution<F>,
{ {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let u: F = OpenClosed01.sample(rng); let u: F = OpenClosed01.sample(rng);
@ -112,7 +118,9 @@ mod tests {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<F: Float + Debug + Display + LowerExp, D: Distribution<F>>( fn test_samples<F: Float + Debug + Display + LowerExp, D: Distribution<F>>(
distr: D, thresh: F, expected: &[F], distr: D,
thresh: F,
expected: &[F],
) { ) {
let mut rng = crate::test::rng(213); let mut rng = crate::test::rng(213);
for v in expected { for v in expected {
@ -121,15 +129,21 @@ mod tests {
} }
} }
test_samples(Pareto::new(1f32, 1.0).unwrap(), 1e-6, &[ test_samples(
1.0423688, 2.1235929, 4.132709, 1.4679428, Pareto::new(1f32, 1.0).unwrap(),
]); 1e-6,
test_samples(Pareto::new(2.0, 0.5).unwrap(), 1e-14, &[ &[1.0423688, 2.1235929, 4.132709, 1.4679428],
9.019295276219136, );
4.3097126018270595, test_samples(
6.837815045397157, Pareto::new(2.0, 0.5).unwrap(),
105.8826669383772, 1e-14,
]); &[
9.019295276219136,
4.3097126018270595,
6.837815045397157,
105.8826669383772,
],
);
} }
#[test] #[test]

View File

@ -7,10 +7,10 @@
// except according to those terms. // except according to those terms.
//! The PERT distribution. //! The PERT distribution.
use num_traits::Float;
use crate::{Beta, Distribution, Exp1, Open01, StandardNormal}; use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
/// The PERT distribution. /// The PERT distribution.
/// ///
@ -129,20 +129,12 @@ mod test {
#[test] #[test]
fn test_pert() { fn test_pert() {
for &(min, max, mode) in &[ for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
(-1., 1., 0.),
(1., 2., 1.),
(5., 25., 25.),
] {
let _distr = Pert::new(min, max, mode).unwrap(); let _distr = Pert::new(min, max, mode).unwrap();
// TODO: test correctness // TODO: test correctness
} }
for &(min, max, mode) in &[ for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
(-1., 1., 2.),
(-1., 1., -2.),
(2., 1., 1.),
] {
assert!(Pert::new(min, max, mode).is_err()); assert!(Pert::new(min, max, mode).is_err());
} }
} }

View File

@ -9,10 +9,10 @@
//! The Poisson distribution. //! The Poisson distribution.
use num_traits::{Float, FloatConst};
use crate::{Cauchy, Distribution, Standard}; use crate::{Cauchy, Distribution, Standard};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::{Float, FloatConst};
use rand::Rng;
/// The Poisson distribution `Poisson(lambda)`. /// The Poisson distribution `Poisson(lambda)`.
/// ///
@ -31,7 +31,9 @@ use core::fmt;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Poisson<F> pub struct Poisson<F>
where F: Float + FloatConst, Standard: Distribution<F> where
F: Float + FloatConst,
Standard: Distribution<F>,
{ {
lambda: F, lambda: F,
// precalculated values // precalculated values
@ -64,7 +66,9 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl<F> Poisson<F> impl<F> Poisson<F>
where F: Float + FloatConst, Standard: Distribution<F> where
F: Float + FloatConst,
Standard: Distribution<F>,
{ {
/// Construct a new `Poisson` with the given shape parameter /// Construct a new `Poisson` with the given shape parameter
/// `lambda`. /// `lambda`.
@ -87,7 +91,9 @@ where F: Float + FloatConst, Standard: Distribution<F>
} }
impl<F> Distribution<F> for Poisson<F> impl<F> Distribution<F> for Poisson<F>
where F: Float + FloatConst, Standard: Distribution<F> where
F: Float + FloatConst,
Standard: Distribution<F>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
@ -96,9 +102,9 @@ where F: Float + FloatConst, Standard: Distribution<F>
// for low expected values use the Knuth method // for low expected values use the Knuth method
if self.lambda < F::from(12.0).unwrap() { if self.lambda < F::from(12.0).unwrap() {
let mut result = F::one(); let mut result = F::one();
let mut p = rng.gen::<F>(); let mut p = rng.random::<F>();
while p > self.exp_lambda { while p > self.exp_lambda {
p = p*rng.gen::<F>(); p = p * rng.random::<F>();
result = result + F::one(); result = result + F::one();
} }
result - F::one() result - F::one()
@ -139,7 +145,7 @@ where F: Float + FloatConst, Standard: Distribution<F>
.exp(); .exp();
// check with uniform random value - if below the threshold, we are within the target distribution // check with uniform random value - if below the threshold, we are within the target distribution
if rng.gen::<F>() <= check { if rng.random::<F>() <= check {
break; break;
} }
} }
@ -153,7 +159,8 @@ mod test {
use super::*; use super::*;
fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F) fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F)
where Standard: Distribution<F> where
Standard: Distribution<F>,
{ {
let poisson = Poisson::new(lambda).unwrap(); let poisson = Poisson::new(lambda).unwrap();
let mut rng = crate::test::rng(123); let mut rng = crate::test::rng(123);
@ -173,7 +180,7 @@ mod test {
test_poisson_avg_gen::<f32>(10.0, 0.1); test_poisson_avg_gen::<f32>(10.0, 0.1);
test_poisson_avg_gen::<f32>(15.0, 0.1); test_poisson_avg_gen::<f32>(15.0, 0.1);
//Small lambda will use Knuth's method with exp_lambda == 1.0 // Small lambda will use Knuth's method with exp_lambda == 1.0
test_poisson_avg_gen::<f32>(0.00000000000000005, 0.1); test_poisson_avg_gen::<f32>(0.00000000000000005, 0.1);
test_poisson_avg_gen::<f64>(0.00000000000000005, 0.1); test_poisson_avg_gen::<f64>(0.00000000000000005, 0.1);
} }

View File

@ -150,9 +150,7 @@ where
mod tests { mod tests {
use super::*; use super::*;
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>( fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
distr: D, zero: F, expected: &[F],
) {
let mut rng = crate::test::rng(213); let mut rng = crate::test::rng(213);
let mut buf = [zero; 4]; let mut buf = [zero; 4];
for x in &mut buf { for x in &mut buf {
@ -222,12 +220,7 @@ mod tests {
test_samples( test_samples(
SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(), SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(),
0f64, 0f64,
&[ &[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY],
f64::INFINITY,
f64::INFINITY,
f64::INFINITY,
f64::INFINITY,
],
); );
test_samples( test_samples(
SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(), SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
@ -256,6 +249,9 @@ mod tests {
#[test] #[test]
fn skew_normal_distributions_can_be_compared() { fn skew_normal_distributions_can_be_compared() {
assert_eq!(SkewNormal::new(1.0, 2.0, 3.0), SkewNormal::new(1.0, 2.0, 3.0)); assert_eq!(
SkewNormal::new(1.0, 2.0, 3.0),
SkewNormal::new(1.0, 2.0, 3.0)
);
} }
} }

View File

@ -7,10 +7,10 @@
// except according to those terms. // except according to those terms.
//! The triangular distribution. //! The triangular distribution.
use num_traits::Float;
use crate::{Distribution, Standard}; use crate::{Distribution, Standard};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
/// The triangular distribution. /// The triangular distribution.
/// ///
@ -34,7 +34,9 @@ use core::fmt;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Triangular<F> pub struct Triangular<F>
where F: Float, Standard: Distribution<F> where
F: Float,
Standard: Distribution<F>,
{ {
min: F, min: F,
max: F, max: F,
@ -66,7 +68,9 @@ impl fmt::Display for TriangularError {
impl std::error::Error for TriangularError {} impl std::error::Error for TriangularError {}
impl<F> Triangular<F> impl<F> Triangular<F>
where F: Float, Standard: Distribution<F> where
F: Float,
Standard: Distribution<F>,
{ {
/// Set up the Triangular distribution with defined `min`, `max` and `mode`. /// Set up the Triangular distribution with defined `min`, `max` and `mode`.
#[inline] #[inline]
@ -82,7 +86,9 @@ where F: Float, Standard: Distribution<F>
} }
impl<F> Distribution<F> for Triangular<F> impl<F> Distribution<F> for Triangular<F>
where F: Float, Standard: Distribution<F> where
F: Float,
Standard: Distribution<F>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
@ -106,7 +112,7 @@ mod test {
#[test] #[test]
fn test_triangular() { fn test_triangular() {
let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0); let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0);
assert_eq!(half_rng.gen::<f64>(), 0.5); assert_eq!(half_rng.random::<f64>(), 0.5);
for &(min, max, mode, median) in &[ for &(min, max, mode, median) in &[
(-1., 1., 0., 0.), (-1., 1., 0., 0.),
(1., 2., 1., 2. - 0.5f64.sqrt()), (1., 2., 1., 2. - 0.5f64.sqrt()),
@ -122,17 +128,16 @@ mod test {
assert_eq!(distr.sample(&mut half_rng), median); assert_eq!(distr.sample(&mut half_rng), median);
} }
for &(min, max, mode) in &[ for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
(-1., 1., 2.),
(-1., 1., -2.),
(2., 1., 1.),
] {
assert!(Triangular::new(min, max, mode).is_err()); assert!(Triangular::new(min, max, mode).is_err());
} }
} }
#[test] #[test]
fn triangular_distributions_can_be_compared() { fn triangular_distributions_can_be_compared() {
assert_eq!(Triangular::new(1.0, 3.0, 2.0), Triangular::new(1.0, 3.0, 2.0)); assert_eq!(
Triangular::new(1.0, 3.0, 2.0),
Triangular::new(1.0, 3.0, 2.0)
);
} }
} }

View File

@ -6,8 +6,8 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use num_traits::Float;
use crate::{uniform::SampleUniform, Distribution, Uniform}; use crate::{uniform::SampleUniform, Distribution, Uniform};
use num_traits::Float;
use rand::Rng; use rand::Rng;
/// Samples uniformly from the unit ball (surface and interior) in three /// Samples uniformly from the unit ball (surface and interior) in three

View File

@ -6,8 +6,8 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use num_traits::Float;
use crate::{uniform::SampleUniform, Distribution, Uniform}; use crate::{uniform::SampleUniform, Distribution, Uniform};
use num_traits::Float;
use rand::Rng; use rand::Rng;
/// Samples uniformly from the edge of the unit circle in two dimensions. /// Samples uniformly from the edge of the unit circle in two dimensions.

View File

@ -6,8 +6,8 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use num_traits::Float;
use crate::{uniform::SampleUniform, Distribution, Uniform}; use crate::{uniform::SampleUniform, Distribution, Uniform};
use num_traits::Float;
use rand::Rng; use rand::Rng;
/// Samples uniformly from the unit disc in two dimensions. /// Samples uniformly from the unit disc in two dimensions.

View File

@ -6,8 +6,8 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use num_traits::Float;
use crate::{uniform::SampleUniform, Distribution, Uniform}; use crate::{uniform::SampleUniform, Distribution, Uniform};
use num_traits::Float;
use rand::Rng; use rand::Rng;
/// Samples uniformly from the surface of the unit sphere in three dimensions. /// Samples uniformly from the surface of the unit sphere in three dimensions.
@ -42,7 +42,11 @@ impl<F: Float + SampleUniform> Distribution<[F; 3]> for UnitSphere {
continue; continue;
} }
let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt(); let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt();
return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum]; return [
x1 * factor,
x2 * factor,
F::from(1.).unwrap() - F::from(2.).unwrap() * sum,
];
} }
} }
} }

View File

@ -9,9 +9,9 @@
//! Math helper functions //! Math helper functions
use crate::ziggurat_tables; use crate::ziggurat_tables;
use num_traits::Float;
use rand::distributions::hidden_export::IntoFloat; use rand::distributions::hidden_export::IntoFloat;
use rand::Rng; use rand::Rng;
use num_traits::Float;
/// Calculates ln(gamma(x)) (natural logarithm of the gamma /// Calculates ln(gamma(x)) (natural logarithm of the gamma
/// function) using the Lanczos approximation. /// function) using the Lanczos approximation.
@ -77,7 +77,7 @@ pub(crate) fn ziggurat<R: Rng + ?Sized, P, Z>(
x_tab: ziggurat_tables::ZigTable, x_tab: ziggurat_tables::ZigTable,
f_tab: ziggurat_tables::ZigTable, f_tab: ziggurat_tables::ZigTable,
mut pdf: P, mut pdf: P,
mut zero_case: Z mut zero_case: Z,
) -> f64 ) -> f64
where where
P: FnMut(f64) -> f64, P: FnMut(f64) -> f64,
@ -114,7 +114,7 @@ where
return zero_case(rng, u); return zero_case(rng, u);
} }
// algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1 // algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) { if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.random::<f64>() < pdf(x) {
return x; return x;
} }
} }

View File

@ -8,10 +8,10 @@
//! The Weibull distribution. //! The Weibull distribution.
use num_traits::Float;
use crate::{Distribution, OpenClosed01}; use crate::{Distribution, OpenClosed01};
use rand::Rng;
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::Rng;
/// Samples floating-point numbers according to the Weibull distribution /// Samples floating-point numbers according to the Weibull distribution
/// ///
@ -26,7 +26,9 @@ use core::fmt;
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))] #[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
pub struct Weibull<F> pub struct Weibull<F>
where F: Float, OpenClosed01: Distribution<F> where
F: Float,
OpenClosed01: Distribution<F>,
{ {
inv_shape: F, inv_shape: F,
scale: F, scale: F,
@ -55,7 +57,9 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
impl<F> Weibull<F> impl<F> Weibull<F>
where F: Float, OpenClosed01: Distribution<F> where
F: Float,
OpenClosed01: Distribution<F>,
{ {
/// Construct a new `Weibull` distribution with given `scale` and `shape`. /// Construct a new `Weibull` distribution with given `scale` and `shape`.
pub fn new(scale: F, shape: F) -> Result<Weibull<F>, Error> { pub fn new(scale: F, shape: F) -> Result<Weibull<F>, Error> {
@ -73,7 +77,9 @@ where F: Float, OpenClosed01: Distribution<F>
} }
impl<F> Distribution<F> for Weibull<F> impl<F> Distribution<F> for Weibull<F>
where F: Float, OpenClosed01: Distribution<F> where
F: Float,
OpenClosed01: Distribution<F>,
{ {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let x: F = rng.sample(OpenClosed01); let x: F = rng.sample(OpenClosed01);
@ -106,7 +112,9 @@ mod tests {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>( fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(
distr: D, zero: F, expected: &[F], distr: D,
zero: F,
expected: &[F],
) { ) {
let mut rng = crate::test::rng(213); let mut rng = crate::test::rng(213);
let mut buf = [zero; 4]; let mut buf = [zero; 4];
@ -116,18 +124,21 @@ mod tests {
assert_eq!(buf, expected); assert_eq!(buf, expected);
} }
test_samples(Weibull::new(1.0, 1.0).unwrap(), 0f32, &[ test_samples(
0.041495778, Weibull::new(1.0, 1.0).unwrap(),
0.7531094, 0f32,
1.4189332, &[0.041495778, 0.7531094, 1.4189332, 0.38386202],
0.38386202, );
]); test_samples(
test_samples(Weibull::new(2.0, 0.5).unwrap(), 0f64, &[ Weibull::new(2.0, 0.5).unwrap(),
1.1343478702739669, 0f64,
0.29470010050655226, &[
0.7556151370284702, 1.1343478702739669,
7.877212340241561, 0.29470010050655226,
]); 0.7556151370284702,
7.877212340241561,
],
);
} }
#[test] #[test]

View File

@ -11,13 +11,13 @@
use super::WeightError; use super::WeightError;
use crate::{uniform::SampleUniform, Distribution, Uniform}; use crate::{uniform::SampleUniform, Distribution, Uniform};
use alloc::{boxed::Box, vec, vec::Vec};
use core::fmt; use core::fmt;
use core::iter::Sum; use core::iter::Sum;
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
use rand::Rng; use rand::Rng;
use alloc::{boxed::Box, vec, vec::Vec};
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// A distribution using weighted sampling to pick a discretely selected item. /// A distribution using weighted sampling to pick a discretely selected item.
/// ///
@ -67,8 +67,14 @@ use serde::{Serialize, Deserialize};
/// [`Uniform<W>::sample`]: Distribution::sample /// [`Uniform<W>::sample`]: Distribution::sample
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))] #[cfg_attr(
#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))] feature = "serde1",
serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
)]
#[cfg_attr(
feature = "serde1",
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
)]
pub struct WeightedAliasIndex<W: AliasableWeight> { pub struct WeightedAliasIndex<W: AliasableWeight> {
aliases: Box<[u32]>, aliases: Box<[u32]>,
no_alias_odds: Box<[W]>, no_alias_odds: Box<[W]>,
@ -257,7 +263,8 @@ where
} }
impl<W: AliasableWeight> Clone for WeightedAliasIndex<W> impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
where Uniform<W>: Clone where
Uniform<W>: Clone,
{ {
fn clone(&self) -> Self { fn clone(&self) -> Self {
Self { Self {
@ -308,7 +315,7 @@ pub trait AliasableWeight:
macro_rules! impl_weight_for_float { macro_rules! impl_weight_for_float {
($T: ident) => { ($T: ident) => {
impl AliasableWeight for $T { impl AliasableWeight for $T {
const MAX: Self = ::core::$T::MAX; const MAX: Self = $T::MAX;
const ZERO: Self = 0.0; const ZERO: Self = 0.0;
fn try_from_u32_lossy(n: u32) -> Option<Self> { fn try_from_u32_lossy(n: u32) -> Option<Self> {
@ -337,7 +344,7 @@ fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
macro_rules! impl_weight_for_int { macro_rules! impl_weight_for_int {
($T: ident) => { ($T: ident) => {
impl AliasableWeight for $T { impl AliasableWeight for $T {
const MAX: Self = ::core::$T::MAX; const MAX: Self = $T::MAX;
const ZERO: Self = 0; const ZERO: Self = 0;
fn try_from_u32_lossy(n: u32) -> Option<Self> { fn try_from_u32_lossy(n: u32) -> Option<Self> {
@ -444,7 +451,9 @@ mod test {
} }
fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F) fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
where WeightedAliasIndex<W>: fmt::Debug { where
WeightedAliasIndex<W>: fmt::Debug,
{
const NUM_WEIGHTS: u32 = 10; const NUM_WEIGHTS: u32 = 10;
const ZERO_WEIGHT_INDEX: u32 = 3; const ZERO_WEIGHT_INDEX: u32 = 3;
const NUM_SAMPLES: u32 = 15000; const NUM_SAMPLES: u32 = 15000;
@ -455,7 +464,8 @@ mod test {
let random_weight_distribution = Uniform::new_inclusive( let random_weight_distribution = Uniform::new_inclusive(
W::ZERO, W::ZERO,
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(), W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
).unwrap(); )
.unwrap();
for _ in 0..NUM_WEIGHTS { for _ in 0..NUM_WEIGHTS {
weights.push(rng.sample(&random_weight_distribution)); weights.push(rng.sample(&random_weight_distribution));
} }
@ -497,7 +507,11 @@ mod test {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<W: AliasableWeight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) { fn test_samples<W: AliasableWeight>(
weights: Vec<W>,
buf: &mut [usize],
expected: &[usize],
) {
assert_eq!(buf.len(), expected.len()); assert_eq!(buf.len(), expected.len());
let distr = WeightedAliasIndex::new(weights).unwrap(); let distr = WeightedAliasIndex::new(weights).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
@ -508,14 +522,20 @@ mod test {
} }
let mut buf = [0; 10]; let mut buf = [0; 10];
test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[ test_samples(
6, 5, 7, 5, 8, 7, 6, 2, 3, 7, vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
]); &mut buf,
test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[ &[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
2, 0, 0, 0, 0, 0, 0, 0, 1, 3, );
]); test_samples(
test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[ vec![0.7f32, 0.1, 0.1, 0.1],
2, 1, 2, 3, 2, 1, 3, 2, 1, 1, &mut buf,
]); &[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
);
test_samples(
vec![1.0f64, 0.999, 0.998, 0.997],
&mut buf,
&[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
);
} }
} }

View File

@ -303,6 +303,7 @@ mod test {
#[test] #[test]
fn test_no_item_error() { fn test_no_item_error() {
let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
#[allow(clippy::needless_borrows_for_generic_args)]
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap(); let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
assert_eq!( assert_eq!(
tree.try_sample(&mut rng).unwrap_err(), tree.try_sample(&mut rng).unwrap_err(),
@ -313,10 +314,10 @@ mod test {
#[test] #[test]
fn test_overflow_error() { fn test_overflow_error() {
assert_eq!( assert_eq!(
WeightedTreeIndex::new(&[i32::MAX, 2]), WeightedTreeIndex::new([i32::MAX, 2]),
Err(WeightError::Overflow) Err(WeightError::Overflow)
); );
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap(); let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
assert_eq!(tree.push(3), Err(WeightError::Overflow)); assert_eq!(tree.push(3), Err(WeightError::Overflow));
assert_eq!(tree.update(1, 4), Err(WeightError::Overflow)); assert_eq!(tree.update(1, 4), Err(WeightError::Overflow));
tree.update(1, 2).unwrap(); tree.update(1, 2).unwrap();
@ -324,7 +325,7 @@ mod test {
#[test] #[test]
fn test_all_weights_zero_error() { fn test_all_weights_zero_error() {
let tree = WeightedTreeIndex::<f64>::new(&[0.0, 0.0]).unwrap(); let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
assert_eq!( assert_eq!(
tree.try_sample(&mut rng).unwrap_err(), tree.try_sample(&mut rng).unwrap_err(),
@ -335,37 +336,36 @@ mod test {
#[test] #[test]
fn test_invalid_weight_error() { fn test_invalid_weight_error() {
assert_eq!( assert_eq!(
WeightedTreeIndex::<i32>::new(&[1, -1]).unwrap_err(), WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
WeightError::InvalidWeight WeightError::InvalidWeight
); );
#[allow(clippy::needless_borrows_for_generic_args)]
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap(); let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight); assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight);
tree.push(1).unwrap(); tree.push(1).unwrap();
assert_eq!( assert_eq!(tree.update(0, -1).unwrap_err(), WeightError::InvalidWeight);
tree.update(0, -1).unwrap_err(),
WeightError::InvalidWeight
);
} }
#[test] #[test]
fn test_tree_modifications() { fn test_tree_modifications() {
let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap(); let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
tree.push(3).unwrap(); tree.push(3).unwrap();
tree.push(5).unwrap(); tree.push(5).unwrap();
tree.update(0, 0).unwrap(); tree.update(0, 0).unwrap();
assert_eq!(tree.pop(), Some(5)); assert_eq!(tree.pop(), Some(5));
let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap(); let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
assert_eq!(tree, expected); assert_eq!(tree, expected);
} }
#[test] #[test]
#[allow(clippy::needless_range_loop)]
fn test_sample_counts_match_probabilities() { fn test_sample_counts_match_probabilities() {
let start = 1; let start = 1;
let end = 3; let end = 3;
let samples = 20; let samples = 20;
let mut rng = crate::test::rng(0x9c9fa0b0580a7031); let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect(); let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
let mut tree = WeightedTreeIndex::new(&weights).unwrap(); let mut tree = WeightedTreeIndex::new(weights).unwrap();
let mut total_weight = 0.0; let mut total_weight = 0.0;
let mut weights = alloc::vec![0.0; end]; let mut weights = alloc::vec![0.0; end];
for i in 0..end { for i in 0..end {

View File

@ -8,10 +8,10 @@
//! The Zeta and related distributions. //! The Zeta and related distributions.
use num_traits::Float;
use crate::{Distribution, Standard}; use crate::{Distribution, Standard};
use rand::{Rng, distributions::OpenClosed01};
use core::fmt; use core::fmt;
use num_traits::Float;
use rand::{distributions::OpenClosed01, Rng};
/// Samples integers according to the [zeta distribution]. /// Samples integers according to the [zeta distribution].
/// ///
@ -48,7 +48,10 @@ use core::fmt;
/// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8 /// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
pub struct Zeta<F> pub struct Zeta<F>
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F> where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{ {
a_minus_1: F, a_minus_1: F,
b: F, b: F,
@ -74,7 +77,10 @@ impl fmt::Display for ZetaError {
impl std::error::Error for ZetaError {} impl std::error::Error for ZetaError {}
impl<F> Zeta<F> impl<F> Zeta<F>
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F> where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{ {
/// Construct a new `Zeta` distribution with given `a` parameter. /// Construct a new `Zeta` distribution with given `a` parameter.
#[inline] #[inline]
@ -92,7 +98,10 @@ where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
} }
impl<F> Distribution<F> for Zeta<F> impl<F> Distribution<F> for Zeta<F>
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F> where
F: Float,
Standard: Distribution<F>,
OpenClosed01: Distribution<F>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
@ -144,7 +153,10 @@ where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa /// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa
#[derive(Clone, Copy, Debug, PartialEq)] #[derive(Clone, Copy, Debug, PartialEq)]
pub struct Zipf<F> pub struct Zipf<F>
where F: Float, Standard: Distribution<F> { where
F: Float,
Standard: Distribution<F>,
{
s: F, s: F,
t: F, t: F,
q: F, q: F,
@ -173,7 +185,10 @@ impl fmt::Display for ZipfError {
impl std::error::Error for ZipfError {} impl std::error::Error for ZipfError {}
impl<F> Zipf<F> impl<F> Zipf<F>
where F: Float, Standard: Distribution<F> { where
F: Float,
Standard: Distribution<F>,
{
/// Construct a new `Zipf` distribution for a set with `n` elements and a /// Construct a new `Zipf` distribution for a set with `n` elements and a
/// frequency rank exponent `s`. /// frequency rank exponent `s`.
/// ///
@ -186,7 +201,7 @@ where F: Float, Standard: Distribution<F> {
if n < 1 { if n < 1 {
return Err(ZipfError::NTooSmall); return Err(ZipfError::NTooSmall);
} }
let n = F::from(n).unwrap(); // This does not fail. let n = F::from(n).unwrap(); // This does not fail.
let q = if s != F::one() { let q = if s != F::one() {
// Make sure to calculate the division only once. // Make sure to calculate the division only once.
F::one() / (F::one() - s) F::one() / (F::one() - s)
@ -200,9 +215,7 @@ where F: Float, Standard: Distribution<F> {
F::one() + n.ln() F::one() + n.ln()
}; };
debug_assert!(t > F::zero()); debug_assert!(t > F::zero());
Ok(Zipf { Ok(Zipf { s, t, q })
s, t, q
})
} }
/// Inverse cumulative density function /// Inverse cumulative density function
@ -221,7 +234,9 @@ where F: Float, Standard: Distribution<F> {
} }
impl<F> Distribution<F> for Zipf<F> impl<F> Distribution<F> for Zipf<F>
where F: Float, Standard: Distribution<F> where
F: Float,
Standard: Distribution<F>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
@ -246,9 +261,7 @@ where F: Float, Standard: Distribution<F>
mod tests { mod tests {
use super::*; use super::*;
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>( fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
distr: D, zero: F, expected: &[F],
) {
let mut rng = crate::test::rng(213); let mut rng = crate::test::rng(213);
let mut buf = [zero; 4]; let mut buf = [zero; 4];
for x in &mut buf { for x in &mut buf {
@ -293,12 +306,8 @@ mod tests {
#[test] #[test]
fn zeta_value_stability() { fn zeta_value_stability() {
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[ test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
1.0, 2.0, 1.0, 1.0, test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
]);
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[
2.0, 1.0, 1.0, 1.0,
]);
} }
#[test] #[test]
@ -363,12 +372,8 @@ mod tests {
#[test] #[test]
fn zipf_value_stability() { fn zipf_value_stability() {
test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[ test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]);
10.0, 2.0, 6.0, 7.0 test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]);
]);
test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[
1.0, 2.0, 3.0, 2.0
]);
} }
#[test] #[test]

View File

@ -57,7 +57,7 @@ fn normal() {
let mut diff = [0.; HIST_LEN]; let mut diff = [0.; HIST_LEN];
for (i, n) in hist.normalized_bins().enumerate() { for (i, n) in hist.normalized_bins().enumerate() {
let bin = (n as f64) / (N_SAMPLES as f64); let bin = n / (N_SAMPLES as f64);
diff[i] = (bin - expected[i]).abs(); diff[i] = (bin - expected[i]).abs();
} }
@ -140,7 +140,7 @@ fn skew_normal() {
let mut diff = [0.; HIST_LEN]; let mut diff = [0.; HIST_LEN];
for (i, n) in hist.normalized_bins().enumerate() { for (i, n) in hist.normalized_bins().enumerate() {
let bin = (n as f64) / (N_SAMPLES as f64); let bin = n / (N_SAMPLES as f64);
diff[i] = (bin - expected[i]).abs(); diff[i] = (bin - expected[i]).abs();
} }

View File

@ -16,7 +16,7 @@ pub fn render_u64(data: &[u64], buffer: &mut String) {
match data.len() { match data.len() {
0 => { 0 => {
return; return;
}, }
1 => { 1 => {
if data[0] == 0 { if data[0] == 0 {
buffer.push(TICKS[0]); buffer.push(TICKS[0]);
@ -24,8 +24,8 @@ pub fn render_u64(data: &[u64], buffer: &mut String) {
buffer.push(TICKS[N - 1]); buffer.push(TICKS[N - 1]);
} }
return; return;
}, }
_ => {}, _ => {}
} }
let max = data.iter().max().unwrap(); let max = data.iter().max().unwrap();
let min = data.iter().min().unwrap(); let min = data.iter().min().unwrap();
@ -56,7 +56,7 @@ pub fn render_f64(data: &[f64], buffer: &mut String) {
match data.len() { match data.len() {
0 => { 0 => {
return; return;
}, }
1 => { 1 => {
if data[0] == 0. { if data[0] == 0. {
buffer.push(TICKS[0]); buffer.push(TICKS[0]);
@ -64,16 +64,14 @@ pub fn render_f64(data: &[f64], buffer: &mut String) {
buffer.push(TICKS[N - 1]); buffer.push(TICKS[N - 1]);
} }
return; return;
}, }
_ => {}, _ => {}
} }
for x in data { for x in data {
assert!(x.is_finite(), "can only render finite values"); assert!(x.is_finite(), "can only render finite values");
} }
let max = data.iter().fold( let max = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
f64::NEG_INFINITY, |a, &b| a.max(b)); let min = data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
let min = data.iter().fold(
f64::INFINITY, |a, &b| a.min(b));
let scale = ((N - 1) as f64) / (max - min); let scale = ((N - 1) as f64) / (max - min);
for x in data { for x in data {
let tick = ((x - min) * scale) as usize; let tick = ((x - min) * scale) as usize;

View File

@ -53,9 +53,7 @@ impl<T: ApproxEq> ApproxEq for [T; 3] {
} }
} }
fn test_samples<F: Debug + ApproxEq, D: Distribution<F>>( fn test_samples<F: Debug + ApproxEq, D: Distribution<F>>(seed: u64, distr: D, expected: &[F]) {
seed: u64, distr: D, expected: &[F],
) {
let mut rng = get_rng(seed); let mut rng = get_rng(seed);
for val in expected { for val in expected {
let x = rng.sample(&distr); let x = rng.sample(&distr);
@ -68,16 +66,28 @@ fn binomial_stability() {
// We have multiple code paths: np < 10, p > 0.5 // We have multiple code paths: np < 10, p > 0.5
test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]); test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]);
test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]); test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]);
test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]); test_samples(
353,
Binomial::new(2000, 0.6).unwrap(),
&[1194, 1208, 1192, 1210],
);
} }
#[test] #[test]
fn geometric_stability() { fn geometric_stability() {
test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]); test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]);
test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]); test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]);
test_samples(464, Geometric::new(0.05).unwrap(), &[24, 51, 81, 67, 27, 11, 7, 6]); test_samples(
test_samples(464, Geometric::new(0.95).unwrap(), &[0, 0, 0, 0, 1, 0, 0, 0]); 464,
Geometric::new(0.05).unwrap(),
&[24, 51, 81, 67, 27, 11, 7, 6],
);
test_samples(
464,
Geometric::new(0.95).unwrap(),
&[0, 0, 0, 0, 1, 0, 0, 0],
);
// expect non-random behaviour for series of pre-determined trials // expect non-random behaviour for series of pre-determined trials
test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]); test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]);
@ -87,260 +97,404 @@ fn geometric_stability() {
#[test] #[test]
fn hypergeometric_stability() { fn hypergeometric_stability() {
// We have multiple code paths based on the distribution's mode and sample_size // We have multiple code paths based on the distribution's mode and sample_size
test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[4, 3, 2, 2, 3, 2, 3, 1]); // Algorithm HIN test_samples(
test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[23, 27, 26, 27, 22, 24, 31, 22]); // Algorithm H2PE 7221,
Hypergeometric::new(99, 33, 8).unwrap(),
&[4, 3, 2, 2, 3, 2, 3, 1],
); // Algorithm HIN
test_samples(
7221,
Hypergeometric::new(100, 50, 50).unwrap(),
&[23, 27, 26, 27, 22, 24, 31, 22],
); // Algorithm H2PE
} }
#[test] #[test]
fn unit_ball_stability() { fn unit_ball_stability() {
test_samples(2, UnitBall, &[ test_samples(
[0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706], 2,
[0.10588569388223945, -0.4734350111375454, -0.7392104908825501], UnitBall,
[0.11060237642041049, -0.16065642822852677, -0.8444043930440075] &[
]); [
0.018035709265959987f64,
-0.4348771383120438,
-0.07982762085055706,
],
[
0.10588569388223945,
-0.4734350111375454,
-0.7392104908825501,
],
[
0.11060237642041049,
-0.16065642822852677,
-0.8444043930440075,
],
],
);
} }
#[test] #[test]
fn unit_circle_stability() { fn unit_circle_stability() {
test_samples(2, UnitCircle, &[ test_samples(
[-0.9965658683520504f64, -0.08280380447614634], 2,
[-0.9790853270389644, -0.20345004884984505], UnitCircle,
[-0.8449189758898707, 0.5348943112253227], &[
]); [-0.9965658683520504f64, -0.08280380447614634],
[-0.9790853270389644, -0.20345004884984505],
[-0.8449189758898707, 0.5348943112253227],
],
);
} }
#[test] #[test]
fn unit_sphere_stability() { fn unit_sphere_stability() {
test_samples(2, UnitSphere, &[ test_samples(
[0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027], 2,
[-0.09978440840914075, 0.9706650829833128, -0.21875184231323952], UnitSphere,
[0.2735582468624679, 0.9435374242279655, -0.1868234852870203], &[
]); [
0.03247542860231647f64,
-0.7830477442152738,
0.6211131755296027,
],
[
-0.09978440840914075,
0.9706650829833128,
-0.21875184231323952,
],
[0.2735582468624679, 0.9435374242279655, -0.1868234852870203],
],
);
} }
#[test] #[test]
fn unit_disc_stability() { fn unit_disc_stability() {
test_samples(2, UnitDisc, &[ test_samples(
[0.018035709265959987f64, -0.4348771383120438], 2,
[-0.07982762085055706, 0.7765329819820659], UnitDisc,
[0.21450745997299503, 0.7398636984333291], &[
]); [0.018035709265959987f64, -0.4348771383120438],
[-0.07982762085055706, 0.7765329819820659],
[0.21450745997299503, 0.7398636984333291],
],
);
} }
#[test] #[test]
fn pareto_stability() { fn pareto_stability() {
test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[ test_samples(
1.0423688f32, 2.1235929, 4.132709, 1.4679428, 213,
]); Pareto::new(1.0, 1.0).unwrap(),
test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[ &[1.0423688f32, 2.1235929, 4.132709, 1.4679428],
9.019295276219136f64, );
4.3097126018270595, test_samples(
6.837815045397157, 213,
105.8826669383772, Pareto::new(2.0, 0.5).unwrap(),
]); &[
9.019295276219136f64,
4.3097126018270595,
6.837815045397157,
105.8826669383772,
],
);
} }
#[test] #[test]
fn poisson_stability() { fn poisson_stability() {
test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]); test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]);
test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]); test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]);
test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]); test_samples(
223,
Poisson::new(27.0).unwrap(),
&[28.0f32, 32.0, 36.0, 36.0],
);
} }
#[test] #[test]
fn triangular_stability() { fn triangular_stability() {
test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[ test_samples(
5.74373257511361f64, 860,
7.890059162791258f64, Triangular::new(2., 10., 3.).unwrap(),
4.7256280652553455f64, &[
2.9474808121184077f64, 5.74373257511361f64,
3.058301946314053f64, 7.890059162791258f64,
]); 4.7256280652553455f64,
2.9474808121184077f64,
3.058301946314053f64,
],
);
} }
#[test] #[test]
fn normal_inverse_gaussian_stability() { fn normal_inverse_gaussian_stability() {
test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ test_samples(
0.6568966f32, 1.3744819, 2.216063, 0.11488572, 213,
]); NormalInverseGaussian::new(2.0, 1.0).unwrap(),
test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[ &[0.6568966f32, 1.3744819, 2.216063, 0.11488572],
0.6838707059642927f64, );
2.4447306460569784, test_samples(
0.2361045023235968, 213,
1.7774534624785319, NormalInverseGaussian::new(2.0, 1.0).unwrap(),
]); &[
0.6838707059642927f64,
2.4447306460569784,
0.2361045023235968,
1.7774534624785319,
],
);
} }
#[test] #[test]
fn pert_stability() { fn pert_stability() {
// mean = 4, var = 12/7 // mean = 4, var = 12/7
test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[ test_samples(
4.908681667460367, 860,
4.014196196158352, Pert::new(2., 10., 3.).unwrap(),
2.6489397149197234, &[
3.4569780580044727, 4.908681667460367,
4.242864311947118, 4.014196196158352,
]); 2.6489397149197234,
3.4569780580044727,
4.242864311947118,
],
);
} }
#[test] #[test]
fn inverse_gaussian_stability() { fn inverse_gaussian_stability() {
test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[ test_samples(
0.9339157f32, 1.108113, 0.50864697, 0.39849377, 213,
]); InverseGaussian::new(1.0, 3.0).unwrap(),
test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[ &[0.9339157f32, 1.108113, 0.50864697, 0.39849377],
1.0707604954722476f64, );
0.9628140605340697, test_samples(
0.4069687656468226, 213,
0.660283852985818, InverseGaussian::new(1.0, 3.0).unwrap(),
]); &[
1.0707604954722476f64,
0.9628140605340697,
0.4069687656468226,
0.660283852985818,
],
);
} }
#[test] #[test]
fn gamma_stability() { fn gamma_stability() {
// Gamma has 3 cases: shape == 1, shape < 1, shape > 1 // Gamma has 3 cases: shape == 1, shape < 1, shape > 1
test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[ test_samples(
5.398085f32, 9.162783, 0.2300583, 1.7235851, 223,
]); Gamma::new(1.0, 5.0).unwrap(),
test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[ &[5.398085f32, 9.162783, 0.2300583, 1.7235851],
0.5051203f32, 0.9048302, 3.095812, 1.8566116, );
]); test_samples(
test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[ 223,
7.783878094584059f64, Gamma::new(0.8, 5.0).unwrap(),
1.4939528171618057, &[0.5051203f32, 0.9048302, 3.095812, 1.8566116],
8.638017638857592, );
3.0949337228829004, test_samples(
]); 223,
Gamma::new(1.1, 5.0).unwrap(),
&[
7.783878094584059f64,
1.4939528171618057,
8.638017638857592,
3.0949337228829004,
],
);
// ChiSquared has 2 cases: k == 1, k != 1 // ChiSquared has 2 cases: k == 1, k != 1
test_samples(223, ChiSquared::new(1.0).unwrap(), &[ test_samples(
0.4893526200348249f64, 223,
1.635249736808788, ChiSquared::new(1.0).unwrap(),
0.5013580219361969, &[
0.1457735613733489, 0.4893526200348249f64,
]); 1.635249736808788,
test_samples(223, ChiSquared::new(0.1).unwrap(), &[ 0.5013580219361969,
0.014824404726978617f64, 0.1457735613733489,
0.021602123937134326, ],
0.0000003431429746851693, );
0.00000002291755769542258, test_samples(
]); 223,
test_samples(223, ChiSquared::new(10.0).unwrap(), &[ ChiSquared::new(0.1).unwrap(),
12.693656f32, 6.812016, 11.082001, 12.436167, &[
]); 0.014824404726978617f64,
0.021602123937134326,
0.0000003431429746851693,
0.00000002291755769542258,
],
);
test_samples(
223,
ChiSquared::new(10.0).unwrap(),
&[12.693656f32, 6.812016, 11.082001, 12.436167],
);
// FisherF has same special cases as ChiSquared on each param // FisherF has same special cases as ChiSquared on each param
test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[ test_samples(
0.32283646f32, 0.048049655, 0.0788893, 1.817178, 223,
]); FisherF::new(1.0, 13.5).unwrap(),
test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[ &[0.32283646f32, 0.048049655, 0.0788893, 1.817178],
0.29925257f32, 3.4392934, 9.567652, 0.020074, );
]); test_samples(
test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[ 223,
3.3196593155045124f64, FisherF::new(1.0, 1.0).unwrap(),
0.3409169916262829, &[0.29925257f32, 3.4392934, 9.567652, 0.020074],
0.03377989856426519, );
0.00004041672861036937, test_samples(
]); 223,
FisherF::new(0.7, 13.5).unwrap(),
&[
3.3196593155045124f64,
0.3409169916262829,
0.03377989856426519,
0.00004041672861036937,
],
);
// StudentT has same special cases as ChiSquared // StudentT has same special cases as ChiSquared
test_samples(223, StudentT::new(1.0).unwrap(), &[ test_samples(
0.54703987f32, -1.8545331, 3.093162, -0.14168274, 223,
]); StudentT::new(1.0).unwrap(),
test_samples(223, StudentT::new(1.1).unwrap(), &[ &[0.54703987f32, -1.8545331, 3.093162, -0.14168274],
0.7729195887949754f64, );
1.2606210611616204, test_samples(
-1.7553606501113175, 223,
-2.377641221169782, StudentT::new(1.1).unwrap(),
]); &[
0.7729195887949754f64,
1.2606210611616204,
-1.7553606501113175,
-2.377641221169782,
],
);
// Beta has two special cases: // Beta has two special cases:
// //
// 1. min(alpha, beta) <= 1 // 1. min(alpha, beta) <= 1
// 2. min(alpha, beta) > 1 // 2. min(alpha, beta) > 1
test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[ test_samples(
0.8300703726659456, 223,
0.8134131062097899, Beta::new(1.0, 0.8).unwrap(),
0.47912589330631555, &[
0.25323238071138526, 0.8300703726659456,
]); 0.8134131062097899,
test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[ 0.47912589330631555,
0.49563509121756827, 0.25323238071138526,
0.9551305482256759, ],
0.5151181353461637, );
0.7551732971235077, test_samples(
]); 223,
Beta::new(3.0, 1.2).unwrap(),
&[
0.49563509121756827,
0.9551305482256759,
0.5151181353461637,
0.7551732971235077,
],
);
} }
#[test] #[test]
fn exponential_stability() { fn exponential_stability() {
test_samples(223, Exp1, &[ test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]);
1.079617f32, 1.8325565, 0.04601166, 0.34471703, test_samples(
]); 223,
test_samples(223, Exp1, &[ Exp1,
1.0796170642388276f64, &[
1.8325565304274, 1.0796170642388276f64,
0.04601166186842716, 1.8325565304274,
0.3447170217100157, 0.04601166186842716,
]); 0.3447170217100157,
],
);
test_samples(223, Exp::new(2.0).unwrap(), &[ test_samples(
0.5398085f32, 0.91627824, 0.02300583, 0.17235851, 223,
]); Exp::new(2.0).unwrap(),
test_samples(223, Exp::new(1.0).unwrap(), &[ &[0.5398085f32, 0.91627824, 0.02300583, 0.17235851],
1.0796170642388276f64, );
1.8325565304274, test_samples(
0.04601166186842716, 223,
0.3447170217100157, Exp::new(1.0).unwrap(),
]); &[
1.0796170642388276f64,
1.8325565304274,
0.04601166186842716,
0.3447170217100157,
],
);
} }
#[test] #[test]
fn normal_stability() { fn normal_stability() {
test_samples(213, StandardNormal, &[ test_samples(
-0.11844189f32, 0.781378, 0.06563994, -1.1932899, 213,
]); StandardNormal,
test_samples(213, StandardNormal, &[ &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899],
-0.11844188827977231f64, );
0.7813779637772346, test_samples(
0.06563993969580051, 213,
-1.1932899004186373, StandardNormal,
]); &[
-0.11844188827977231f64,
0.7813779637772346,
0.06563993969580051,
-1.1932899004186373,
],
);
test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[ test_samples(
-0.11844189f32, 0.781378, 0.06563994, -1.1932899, 213,
]); Normal::new(0.0, 1.0).unwrap(),
test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[ &[-0.11844189f32, 0.781378, 0.06563994, -1.1932899],
1.940779055860114f64, );
2.3906889818886174, test_samples(
2.0328199698479, 213,
1.4033550497906813, Normal::new(2.0, 0.5).unwrap(),
]); &[
1.940779055860114f64,
2.3906889818886174,
2.0328199698479,
1.4033550497906813,
],
);
test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[ test_samples(
0.88830346f32, 2.1844804, 1.0678421, 0.30322206, 213,
]); LogNormal::new(0.0, 1.0).unwrap(),
test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[ &[0.88830346f32, 2.1844804, 1.0678421, 0.30322206],
6.964174338639032f64, );
10.921015733601452, test_samples(
7.6355881556915906, 213,
4.068828213584092, LogNormal::new(2.0, 0.5).unwrap(),
]); &[
6.964174338639032f64,
10.921015733601452,
7.6355881556915906,
4.068828213584092,
],
);
} }
#[test] #[test]
fn weibull_stability() { fn weibull_stability() {
test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[ test_samples(
0.041495778f32, 0.7531094, 1.4189332, 0.38386202, 213,
]); Weibull::new(1.0, 1.0).unwrap(),
test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[ &[0.041495778f32, 0.7531094, 1.4189332, 0.38386202],
1.1343478702739669f64, );
0.29470010050655226, test_samples(
0.7556151370284702, 213,
7.877212340241561, Weibull::new(2.0, 0.5).unwrap(),
]); &[
1.1343478702739669f64,
0.29470010050655226,
0.7556151370284702,
7.877212340241561,
],
);
} }
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
@ -351,13 +505,16 @@ fn dirichlet_stability() {
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()), rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146] [0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
); );
assert_eq!(rng.sample(Dirichlet::new([8.0; 5]).unwrap()), [ assert_eq!(
0.17684200044809556, rng.sample(Dirichlet::new([8.0; 5]).unwrap()),
0.29915953935953055, [
0.1832858056608014, 0.17684200044809556,
0.1425623503573967, 0.29915953935953055,
0.19815030417417595 0.1832858056608014,
]); 0.1425623503573967,
0.19815030417417595
]
);
// Test stability for the case where all alphas are less than 0.1. // Test stability for the case where all alphas are less than 0.1.
assert_eq!( assert_eq!(
rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()), rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()),
@ -372,12 +529,16 @@ fn dirichlet_stability() {
#[test] #[test]
fn cauchy_stability() { fn cauchy_stability() {
test_samples(353, Cauchy::new(100f64, 10.0).unwrap(), &[ test_samples(
77.93369152808678f64, 353,
90.1606912098641, Cauchy::new(100f64, 10.0).unwrap(),
125.31516221323625, &[
86.10217834773925, 77.93369152808678f64,
]); 90.1606912098641,
125.31516221323625,
86.10217834773925,
],
);
// Unfortunately this test is not fully portable due to reliance on the // Unfortunately this test is not fully portable due to reliance on the
// system's implementation of tanf (see doc on Cauchy struct). // system's implementation of tanf (see doc on Cauchy struct).
@ -386,7 +547,7 @@ fn cauchy_stability() {
let mut rng = get_rng(353); let mut rng = get_rng(353);
let expected = [15.023088, -5.446413, 3.7092876, 3.112482]; let expected = [15.023088, -5.446413, 3.7092876, 3.112482];
for &a in expected.iter() { for &a in expected.iter() {
let b = rng.sample(&distr); let b = rng.sample(distr);
assert_almost_eq!(a, b, 1e-5); assert_almost_eq!(a, b, 1e-5);
} }
} }

View File

@ -15,7 +15,8 @@ const MULTIPLIER: u128 = 0x2360_ED05_1FC6_5DA4_4385_DF64_9FCC_F645;
use core::fmt; use core::fmt;
use rand_core::{impls, le, RngCore, SeedableRng}; use rand_core::{impls, le, RngCore, SeedableRng};
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// A PCG random number generator (XSL RR 128/64 (LCG) variant). /// A PCG random number generator (XSL RR 128/64 (LCG) variant).
/// ///
@ -153,7 +154,6 @@ impl RngCore for Lcg128Xsl64 {
} }
} }
/// A PCG random number generator (XSL 128/64 (MCG) variant). /// A PCG random number generator (XSL 128/64 (MCG) variant).
/// ///
/// Permuted Congruential Generator with 128-bit state, internal Multiplicative /// Permuted Congruential Generator with 128-bit state, internal Multiplicative

View File

@ -15,7 +15,8 @@ const MULTIPLIER: u64 = 15750249268501108917;
use core::fmt; use core::fmt;
use rand_core::{impls, le, RngCore, SeedableRng}; use rand_core::{impls, le, RngCore, SeedableRng};
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// A PCG random number generator (CM DXSM 128/64 (LCG) variant). /// A PCG random number generator (CM DXSM 128/64 (LCG) variant).
/// ///

View File

@ -12,7 +12,8 @@
use core::fmt; use core::fmt;
use rand_core::{impls, le, RngCore, SeedableRng}; use rand_core::{impls, le, RngCore, SeedableRng};
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
// This is the default multiplier used by PCG for 64-bit state. // This is the default multiplier used by PCG for 64-bit state.
const MULTIPLIER: u64 = 6364136223846793005; const MULTIPLIER: u64 = 6364136223846793005;

View File

@ -1,32 +0,0 @@
# This rustfmt file is added for configuration, but in practice much of our
# code is hand-formatted, frequently with more readable results.
# Comments:
normalize_comments = true
wrap_comments = false
comment_width = 90 # small excess is okay but prefer 80
# Arguments:
use_small_heuristics = "Default"
# TODO: single line functions only where short, please?
# https://github.com/rust-lang/rustfmt/issues/3358
fn_single_line = false
fn_args_layout = "Compressed"
overflow_delimited_expr = true
where_single_line = true
# enum_discrim_align_threshold = 20
# struct_field_align_threshold = 20
# Compatibility:
edition = "2021"
# Misc:
inline_attribute_width = 80
blank_lines_upper_bound = 2
reorder_impl_items = true
# report_todo = "Unnumbered"
# report_fixme = "Unnumbered"
# Ignored files:
ignore = []

View File

@ -13,7 +13,7 @@ use crate::Rng;
use core::fmt; use core::fmt;
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// The Bernoulli distribution. /// The Bernoulli distribution.
/// ///
@ -151,7 +151,8 @@ mod test {
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
fn test_serializing_deserializing_bernoulli() { fn test_serializing_deserializing_bernoulli() {
let coin_flip = Bernoulli::new(0.5).unwrap(); let coin_flip = Bernoulli::new(0.5).unwrap();
let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap(); let de_coin_flip: Bernoulli =
bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
assert_eq!(coin_flip.p_int, de_coin_flip.p_int); assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
} }
@ -208,9 +209,10 @@ mod test {
for x in &mut buf { for x in &mut buf {
*x = rng.sample(distr); *x = rng.sample(distr);
} }
assert_eq!(buf, [ assert_eq!(
true, false, false, true, false, false, true, true, true, true buf,
]); [true, false, false, true, false, false, true, true, true, true]
);
} }
#[test] #[test]

View File

@ -10,7 +10,8 @@
//! Distribution trait and associates //! Distribution trait and associates
use crate::Rng; use crate::Rng;
#[cfg(feature = "alloc")] use alloc::string::String; #[cfg(feature = "alloc")]
use alloc::string::String;
use core::iter; use core::iter;
/// Types (distributions) that can be used to create a random instance of `T`. /// Types (distributions) that can be used to create a random instance of `T`.

View File

@ -8,14 +8,15 @@
//! Basic floating-point number distributions //! Basic floating-point number distributions
use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils}; use crate::distributions::utils::{FloatAsSIMD, FloatSIMDUtils, IntAsSIMD};
use crate::distributions::{Distribution, Standard}; use crate::distributions::{Distribution, Standard};
use crate::Rng; use crate::Rng;
use core::mem; use core::mem;
#[cfg(feature = "simd_support")] use core::simd::prelude::*; #[cfg(feature = "simd_support")]
use core::simd::prelude::*;
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// A distribution to sample floating point numbers uniformly in the half-open /// A distribution to sample floating point numbers uniformly in the half-open
/// interval `(0, 1]`, i.e. including 1 but not 0. /// interval `(0, 1]`, i.e. including 1 but not 0.
@ -72,7 +73,6 @@ pub struct OpenClosed01;
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Open01; pub struct Open01;
// This trait is needed by both this lib and rand_distr hence is a hidden export // This trait is needed by both this lib and rand_distr hence is a hidden export
#[doc(hidden)] #[doc(hidden)]
pub trait IntoFloat { pub trait IntoFloat {
@ -146,12 +146,11 @@ macro_rules! float_impls {
// Transmute-based method; 23/52 random bits; (0, 1) interval. // Transmute-based method; 23/52 random bits; (0, 1) interval.
// We use the most significant bits because for simple RNGs // We use the most significant bits because for simple RNGs
// those are usually more random. // those are usually more random.
use core::$f_scalar::EPSILON;
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8; let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
let value: $uty = rng.random(); let value: $uty = rng.random();
let fraction = value >> $uty::splat(float_size - $fraction_bits); let fraction = value >> $uty::splat(float_size - $fraction_bits);
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0) fraction.into_float_with_exponent(0) - $ty::splat(1.0 - $f_scalar::EPSILON / 2.0)
} }
} }
} }
@ -210,9 +209,15 @@ mod tests {
let mut zeros = StepRng::new(0, 0); let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0); let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0)); assert_eq!(
one.sample::<$ty, _>(Open01),
$EPSILON / two * $ty::splat(3.0)
);
let mut max = StepRng::new(!0, 0); let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two); assert_eq!(
max.sample::<$ty, _>(Open01),
$ty::splat(1.0) - $EPSILON / two
);
} }
}; };
} }
@ -252,9 +257,15 @@ mod tests {
let mut zeros = StepRng::new(0, 0); let mut zeros = StepRng::new(0, 0);
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two); assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
let mut one = StepRng::new(1 << 12, 0); let mut one = StepRng::new(1 << 12, 0);
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0)); assert_eq!(
one.sample::<$ty, _>(Open01),
$EPSILON / two * $ty::splat(3.0)
);
let mut max = StepRng::new(!0, 0); let mut max = StepRng::new(!0, 0);
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two); assert_eq!(
max.sample::<$ty, _>(Open01),
$ty::splat(1.0) - $EPSILON / two
);
} }
}; };
} }
@ -269,7 +280,9 @@ mod tests {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>( fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>(
distr: &D, zero: T, expected: &[T], distr: &D,
zero: T,
expected: &[T],
) { ) {
let mut rng = crate::test::rng(0x6f44f5646c2a7334); let mut rng = crate::test::rng(0x6f44f5646c2a7334);
let mut buf = [zero; 3]; let mut buf = [zero; 3];
@ -280,25 +293,25 @@ mod tests {
} }
test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]); test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]);
test_samples(&Standard, 0f64, &[ test_samples(
0.7346051961657583, &Standard,
0.20298547462974248, 0f64,
0.8166436635290655, &[0.7346051961657583, 0.20298547462974248, 0.8166436635290655],
]); );
test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]); test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]);
test_samples(&OpenClosed01, 0f64, &[ test_samples(
0.7346051961657584, &OpenClosed01,
0.2029854746297426, 0f64,
0.8166436635290656, &[0.7346051961657584, 0.2029854746297426, 0.8166436635290656],
]); );
test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]); test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]);
test_samples(&Open01, 0f64, &[ test_samples(
0.7346051961657584, &Open01,
0.20298547462974248, 0f64,
0.8166436635290656, &[0.7346051961657584, 0.20298547462974248, 0.8166436635290656],
]); );
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
{ {
@ -306,17 +319,25 @@ mod tests {
// non-SIMD types; we assume this pattern continues across all // non-SIMD types; we assume this pattern continues across all
// SIMD types. // SIMD types.
test_samples(&Standard, f32x2::from([0.0, 0.0]), &[ test_samples(
f32x2::from([0.0035963655, 0.7346052]), &Standard,
f32x2::from([0.09778172, 0.20298547]), f32x2::from([0.0, 0.0]),
f32x2::from([0.34296435, 0.81664366]), &[
]); f32x2::from([0.0035963655, 0.7346052]),
f32x2::from([0.09778172, 0.20298547]),
f32x2::from([0.34296435, 0.81664366]),
],
);
test_samples(&Standard, f64x2::from([0.0, 0.0]), &[ test_samples(
f64x2::from([0.7346051961657583, 0.20298547462974248]), &Standard,
f64x2::from([0.8166436635290655, 0.7423708925400552]), f64x2::from([0.0, 0.0]),
f64x2::from([0.16387782224016323, 0.9087068770169618]), &[
]); f64x2::from([0.7346051961657583, 0.20298547462974248]),
f64x2::from([0.8166436635290655, 0.7423708925400552]),
f64x2::from([0.16387782224016323, 0.9087068770169618]),
],
);
} }
} }
} }

View File

@ -19,10 +19,11 @@ use core::arch::x86_64::__m512i;
#[cfg(target_arch = "x86_64")] #[cfg(target_arch = "x86_64")]
use core::arch::x86_64::{__m128i, __m256i}; use core::arch::x86_64::{__m128i, __m256i};
use core::num::{ use core::num::{
NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize,NonZeroU128, NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize, NonZeroU128,
NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize,NonZeroI128 NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize,
}; };
#[cfg(feature = "simd_support")] use core::simd::*; #[cfg(feature = "simd_support")]
use core::simd::*;
impl Distribution<u8> for Standard { impl Distribution<u8> for Standard {
#[inline] #[inline]
@ -211,7 +212,9 @@ mod tests {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<T: Copy + core::fmt::Debug + PartialEq>(zero: T, expected: &[T]) fn test_samples<T: Copy + core::fmt::Debug + PartialEq>(zero: T, expected: &[T])
where Standard: Distribution<T> { where
Standard: Distribution<T>,
{
let mut rng = crate::test::rng(807); let mut rng = crate::test::rng(807);
let mut buf = [zero; 3]; let mut buf = [zero; 3];
for x in &mut buf { for x in &mut buf {
@ -223,24 +226,33 @@ mod tests {
test_samples(0u8, &[9, 247, 111]); test_samples(0u8, &[9, 247, 111]);
test_samples(0u16, &[32265, 42999, 38255]); test_samples(0u16, &[32265, 42999, 38255]);
test_samples(0u32, &[2220326409, 2575017975, 2018088303]); test_samples(0u32, &[2220326409, 2575017975, 2018088303]);
test_samples(0u64, &[ test_samples(
11059617991457472009, 0u64,
16096616328739788143, &[
1487364411147516184, 11059617991457472009,
]); 16096616328739788143,
test_samples(0u128, &[ 1487364411147516184,
296930161868957086625409848350820761097, ],
145644820879247630242265036535529306392, );
111087889832015897993126088499035356354, test_samples(
]); 0u128,
&[
296930161868957086625409848350820761097,
145644820879247630242265036535529306392,
111087889832015897993126088499035356354,
],
);
#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))] #[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))]
test_samples(0usize, &[2220326409, 2575017975, 2018088303]); test_samples(0usize, &[2220326409, 2575017975, 2018088303]);
#[cfg(target_pointer_width = "64")] #[cfg(target_pointer_width = "64")]
test_samples(0usize, &[ test_samples(
11059617991457472009, 0usize,
16096616328739788143, &[
1487364411147516184, 11059617991457472009,
]); 16096616328739788143,
1487364411147516184,
],
);
test_samples(0i8, &[9, -9, 111]); test_samples(0i8, &[9, -9, 111]);
// Skip further i* types: they are simple reinterpretation of u* samples // Skip further i* types: they are simple reinterpretation of u* samples
@ -249,49 +261,58 @@ mod tests {
{ {
// We only test a sub-set of types here and make assumptions about the rest. // We only test a sub-set of types here and make assumptions about the rest.
test_samples(u8x4::default(), &[ test_samples(
u8x4::from([9, 126, 87, 132]), u8x4::default(),
u8x4::from([247, 167, 123, 153]), &[
u8x4::from([111, 149, 73, 120]), u8x4::from([9, 126, 87, 132]),
]); u8x4::from([247, 167, 123, 153]),
test_samples(u8x8::default(), &[ u8x4::from([111, 149, 73, 120]),
u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]), ],
u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]), );
u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]), test_samples(
]); u8x8::default(),
&[
u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]),
u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]),
u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]),
],
);
test_samples(i64x8::default(), &[ test_samples(
i64x8::from([ i64x8::default(),
-7387126082252079607, &[
-2350127744969763473, i64x8::from([
1487364411147516184, -7387126082252079607,
7895421560427121838, -2350127744969763473,
602190064936008898, 1487364411147516184,
6022086574635100741, 7895421560427121838,
-5080089175222015595, 602190064936008898,
-4066367846667249123, 6022086574635100741,
]), -5080089175222015595,
i64x8::from([ -4066367846667249123,
9180885022207963908, ]),
3095981199532211089, i64x8::from([
6586075293021332726, 9180885022207963908,
419343203796414657, 3095981199532211089,
3186951873057035255, 6586075293021332726,
5287129228749947252, 419343203796414657,
444726432079249540, 3186951873057035255,
-1587028029513790706, 5287129228749947252,
]), 444726432079249540,
i64x8::from([ -1587028029513790706,
6075236523189346388, ]),
1351763722368165432, i64x8::from([
-6192309979959753740, 6075236523189346388,
-7697775502176768592, 1351763722368165432,
-4482022114172078123, -6192309979959753740,
7522501477800909500, -7697775502176768592,
-1837258847956201231, -4482022114172078123,
-586926753024886735, 7522501477800909500,
]), -1837258847956201231,
]); -586926753024886735,
]),
],
);
} }
} }
} }

View File

@ -110,10 +110,10 @@ pub mod hidden_export {
pub mod uniform; pub mod uniform;
pub use self::bernoulli::{Bernoulli, BernoulliError}; pub use self::bernoulli::{Bernoulli, BernoulliError};
pub use self::distribution::{Distribution, DistIter, DistMap};
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
pub use self::distribution::DistString; pub use self::distribution::DistString;
pub use self::distribution::{DistIter, DistMap, Distribution};
pub use self::float::{Open01, OpenClosed01}; pub use self::float::{Open01, OpenClosed01};
pub use self::other::Alphanumeric; pub use self::other::Alphanumeric;
pub use self::slice::Slice; pub use self::slice::Slice;

View File

@ -8,24 +8,23 @@
//! The implementations of the `Standard` distribution for other built-in types. //! The implementations of the `Standard` distribution for other built-in types.
use core::char;
use core::num::Wrapping;
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
use alloc::string::String; use alloc::string::String;
use core::char;
use core::num::Wrapping;
use crate::distributions::{Distribution, Standard, Uniform};
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
use crate::distributions::DistString; use crate::distributions::DistString;
use crate::distributions::{Distribution, Standard, Uniform};
use crate::Rng; use crate::Rng;
#[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize};
use core::mem::{self, MaybeUninit}; use core::mem::{self, MaybeUninit};
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
use core::simd::prelude::*; use core::simd::prelude::*;
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
use core::simd::{LaneCount, MaskElement, SupportedLaneCount}; use core::simd::{LaneCount, MaskElement, SupportedLaneCount};
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
// ----- Sampling distributions ----- // ----- Sampling distributions -----
@ -71,7 +70,6 @@ use core::simd::{LaneCount, MaskElement, SupportedLaneCount};
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Alphanumeric; pub struct Alphanumeric;
// ----- Implementations of distributions ----- // ----- Implementations of distributions -----
impl Distribution<char> for Standard { impl Distribution<char> for Standard {
@ -240,7 +238,8 @@ macro_rules! tuple_impls {
tuple_impls! {A B C D E F G H I J K L} tuple_impls! {A B C D E F G H I J K L}
impl<T, const N: usize> Distribution<[T; N]> for Standard impl<T, const N: usize> Distribution<[T; N]> for Standard
where Standard: Distribution<T> where
Standard: Distribution<T>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; N] { fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; N] {
@ -255,7 +254,8 @@ where Standard: Distribution<T>
} }
impl<T> Distribution<Option<T>> for Standard impl<T> Distribution<Option<T>> for Standard
where Standard: Distribution<T> where
Standard: Distribution<T>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<T> { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<T> {
@ -269,7 +269,8 @@ where Standard: Distribution<T>
} }
impl<T> Distribution<Wrapping<T>> for Standard impl<T> Distribution<Wrapping<T>> for Standard
where Standard: Distribution<T> where
Standard: Distribution<T>,
{ {
#[inline] #[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Wrapping<T> { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Wrapping<T> {
@ -277,7 +278,6 @@ where Standard: Distribution<T>
} }
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -315,9 +315,7 @@ mod tests {
let mut incorrect = false; let mut incorrect = false;
for _ in 0..100 { for _ in 0..100 {
let c: char = rng.sample(Alphanumeric).into(); let c: char = rng.sample(Alphanumeric).into();
incorrect |= !(('0'..='9').contains(&c) || incorrect |= !c.is_ascii_alphanumeric();
('A'..='Z').contains(&c) ||
('a'..='z').contains(&c) );
} }
assert!(!incorrect); assert!(!incorrect);
} }
@ -325,7 +323,9 @@ mod tests {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>( fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>(
distr: &D, zero: T, expected: &[T], distr: &D,
zero: T,
expected: &[T],
) { ) {
let mut rng = crate::test::rng(807); let mut rng = crate::test::rng(807);
let mut buf = [zero; 5]; let mut buf = [zero; 5];
@ -335,54 +335,66 @@ mod tests {
assert_eq!(&buf, expected); assert_eq!(&buf, expected);
} }
test_samples(&Standard, 'a', &[ test_samples(
'\u{8cdac}', &Standard,
'\u{a346a}', 'a',
'\u{80120}', &[
'\u{ed692}', '\u{8cdac}',
'\u{35888}', '\u{a346a}',
]); '\u{80120}',
'\u{ed692}',
'\u{35888}',
],
);
test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]); test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]);
test_samples(&Standard, false, &[true, true, false, true, false]); test_samples(&Standard, false, &[true, true, false, true, false]);
test_samples(&Standard, None as Option<bool>, &[ test_samples(
Some(true), &Standard,
None, None as Option<bool>,
Some(false), &[Some(true), None, Some(false), None, Some(false)],
None, );
Some(false), test_samples(
]); &Standard,
test_samples(&Standard, Wrapping(0i32), &[ Wrapping(0i32),
Wrapping(-2074640887), &[
Wrapping(-1719949321), Wrapping(-2074640887),
Wrapping(2018088303), Wrapping(-1719949321),
Wrapping(-547181756), Wrapping(2018088303),
Wrapping(838957336), Wrapping(-547181756),
]); Wrapping(838957336),
],
);
// We test only sub-sets of tuple and array impls // We test only sub-sets of tuple and array impls
test_samples(&Standard, (), &[(), (), (), (), ()]); test_samples(&Standard, (), &[(), (), (), (), ()]);
test_samples(&Standard, (false,), &[ test_samples(
(true,), &Standard,
(true,),
(false,), (false,),
(true,), &[(true,), (true,), (false,), (true,), (false,)],
(false,), );
]); test_samples(
test_samples(&Standard, (false, false), &[ &Standard,
(true, true),
(false, true),
(false, false), (false, false),
(true, false), &[
(false, false), (true, true),
]); (false, true),
(false, false),
(true, false),
(false, false),
],
);
test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]); test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]);
test_samples(&Standard, [0u8; 3], &[ test_samples(
[9, 247, 111], &Standard,
[68, 24, 13], [0u8; 3],
[174, 19, 194], &[
[172, 69, 213], [9, 247, 111],
[149, 207, 29], [68, 24, 13],
]); [174, 19, 194],
[172, 69, 213],
[149, 207, 29],
],
);
} }
} }

View File

@ -148,7 +148,11 @@ impl<'a> super::DistString for Slice<'a, char> {
// Split the extension of string to reuse the unused capacities. // Split the extension of string to reuse the unused capacities.
// Skip the split for small length or only ascii slice. // Skip the split for small length or only ascii slice.
let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 }; let mut extend_len = if max_char_len == 1 || len < 100 {
len
} else {
len / 4
};
let mut remain_len = len; let mut remain_len = len;
while extend_len > 0 { while extend_len > 0 {
string.reserve(max_char_len * extend_len); string.reserve(max_char_len * extend_len);

View File

@ -104,18 +104,22 @@
//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow //! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow
use core::fmt; use core::fmt;
use core::time::Duration;
use core::ops::{Range, RangeInclusive}; use core::ops::{Range, RangeInclusive};
use core::time::Duration;
use crate::distributions::float::IntoFloat; use crate::distributions::float::IntoFloat;
use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply}; use crate::distributions::utils::{
BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply,
};
use crate::distributions::Distribution; use crate::distributions::Distribution;
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
use crate::distributions::Standard; use crate::distributions::Standard;
use crate::{Rng, RngCore}; use crate::{Rng, RngCore};
#[cfg(feature = "simd_support")] use core::simd::prelude::*; #[cfg(feature = "simd_support")]
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount}; use core::simd::prelude::*;
#[cfg(feature = "simd_support")]
use core::simd::{LaneCount, SupportedLaneCount};
/// Error type returned from [`Uniform::new`] and `new_inclusive`. /// Error type returned from [`Uniform::new`] and `new_inclusive`.
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
@ -140,7 +144,7 @@ impl fmt::Display for Error {
impl std::error::Error for Error {} impl std::error::Error for Error {}
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// Sample values uniformly between two bounds. /// Sample values uniformly between two bounds.
/// ///
@ -194,7 +198,10 @@ use serde::{Serialize, Deserialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq)] #[derive(Clone, Copy, Debug, PartialEq, Eq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))] #[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))] #[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))]
#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))] #[cfg_attr(
feature = "serde1",
serde(bound(deserialize = "X::Sampler: Deserialize<'de>"))
)]
pub struct Uniform<X: SampleUniform>(X::Sampler); pub struct Uniform<X: SampleUniform>(X::Sampler);
impl<X: SampleUniform> Uniform<X> { impl<X: SampleUniform> Uniform<X> {
@ -297,7 +304,11 @@ pub trait UniformSampler: Sized {
/// <T as SampleUniform>::Sampler::sample_single(lb, ub, &mut rng).unwrap() /// <T as SampleUniform>::Sampler::sample_single(lb, ub, &mut rng).unwrap()
/// } /// }
/// ``` /// ```
fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R) -> Result<Self::X, Error> fn sample_single<R: Rng + ?Sized, B1, B2>(
low: B1,
high: B2,
rng: &mut R,
) -> Result<Self::X, Error>
where where
B1: SampleBorrow<Self::X> + Sized, B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized, B2: SampleBorrow<Self::X> + Sized,
@ -314,10 +325,14 @@ pub trait UniformSampler: Sized {
/// some types more optimal implementations for single usage may be provided /// some types more optimal implementations for single usage may be provided
/// via this method. /// via this method.
/// Results may not be identical. /// Results may not be identical.
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R) fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
-> Result<Self::X, Error> low: B1,
where B1: SampleBorrow<Self::X> + Sized, high: B2,
B2: SampleBorrow<Self::X> + Sized rng: &mut R,
) -> Result<Self::X, Error>
where
B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized,
{ {
let uniform: Self = UniformSampler::new_inclusive(low, high)?; let uniform: Self = UniformSampler::new_inclusive(low, high)?;
Ok(uniform.sample(rng)) Ok(uniform.sample(rng))
@ -340,7 +355,6 @@ impl<X: SampleUniform> TryFrom<RangeInclusive<X>> for Uniform<X> {
} }
} }
/// Helper trait similar to [`Borrow`] but implemented /// Helper trait similar to [`Borrow`] but implemented
/// only for SampleUniform and references to SampleUniform in /// only for SampleUniform and references to SampleUniform in
/// order to resolve ambiguity issues. /// order to resolve ambiguity issues.
@ -353,7 +367,8 @@ pub trait SampleBorrow<Borrowed> {
fn borrow(&self) -> &Borrowed; fn borrow(&self) -> &Borrowed;
} }
impl<Borrowed> SampleBorrow<Borrowed> for Borrowed impl<Borrowed> SampleBorrow<Borrowed> for Borrowed
where Borrowed: SampleUniform where
Borrowed: SampleUniform,
{ {
#[inline(always)] #[inline(always)]
fn borrow(&self) -> &Borrowed { fn borrow(&self) -> &Borrowed {
@ -361,7 +376,8 @@ where Borrowed: SampleUniform
} }
} }
impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed
where Borrowed: SampleUniform where
Borrowed: SampleUniform,
{ {
#[inline(always)] #[inline(always)]
fn borrow(&self) -> &Borrowed { fn borrow(&self) -> &Borrowed {
@ -405,12 +421,10 @@ impl<T: SampleUniform + PartialOrd> SampleRange<T> for RangeInclusive<T> {
} }
} }
//////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////
// What follows are all back-ends. // What follows are all back-ends.
/// The back-end implementing [`UniformSampler`] for integer types. /// The back-end implementing [`UniformSampler`] for integer types.
/// ///
/// Unless you are implementing [`UniformSampler`] for your own type, this type /// Unless you are implementing [`UniformSampler`] for your own type, this type
@ -505,7 +519,7 @@ macro_rules! uniform_int_impl {
Ok(UniformInt { Ok(UniformInt {
low, low,
range: range as $ty, // type: $uty range: range as $ty, // type: $uty
thresh: thresh as $uty as $ty, // type: $sample_ty thresh: thresh as $uty as $ty, // type: $sample_ty
}) })
} }
@ -529,7 +543,11 @@ macro_rules! uniform_int_impl {
} }
#[inline] #[inline]
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error> fn sample_single<R: Rng + ?Sized, B1, B2>(
low_b: B1,
high_b: B2,
rng: &mut R,
) -> Result<Self::X, Error>
where where
B1: SampleBorrow<Self::X> + Sized, B1: SampleBorrow<Self::X> + Sized,
B2: SampleBorrow<Self::X> + Sized, B2: SampleBorrow<Self::X> + Sized,
@ -549,7 +567,9 @@ macro_rules! uniform_int_impl {
#[cfg(not(feature = "unbiased"))] #[cfg(not(feature = "unbiased"))]
#[inline] #[inline]
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>( fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R, low_b: B1,
high_b: B2,
rng: &mut R,
) -> Result<Self::X, Error> ) -> Result<Self::X, Error>
where where
B1: SampleBorrow<Self::X> + Sized, B1: SampleBorrow<Self::X> + Sized,
@ -585,7 +605,9 @@ macro_rules! uniform_int_impl {
#[cfg(feature = "unbiased")] #[cfg(feature = "unbiased")]
#[inline] #[inline]
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>( fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
low_b: B1, high_b: B2, rng: &mut R, low_b: B1,
high_b: B2,
rng: &mut R,
) -> Result<Self::X, Error> ) -> Result<Self::X, Error>
where where
B1: SampleBorrow<$ty> + Sized, B1: SampleBorrow<$ty> + Sized,
@ -599,7 +621,7 @@ macro_rules! uniform_int_impl {
let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty; let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty;
if range == 0 { if range == 0 {
// Range is MAX+1 (unrepresentable), so we need a special case // Range is MAX+1 (unrepresentable), so we need a special case
return Ok(rng.gen()); return Ok(rng.random());
} }
let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range); let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range);
@ -844,7 +866,12 @@ impl UniformSampler for UniformChar {
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
impl super::DistString for Uniform<char> { impl super::DistString for Uniform<char> {
fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) { fn append_string<R: Rng + ?Sized>(
&self,
rng: &mut R,
string: &mut alloc::string::String,
len: usize,
) {
// Getting the hi value to assume the required length to reserve in string. // Getting the hi value to assume the required length to reserve in string.
let mut hi = self.0.sampler.low + self.0.sampler.range - 1; let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
if hi >= CHAR_SURROGATE_START { if hi >= CHAR_SURROGATE_START {
@ -911,7 +938,7 @@ macro_rules! uniform_float_impl {
return Err(Error::EmptyRange); return Err(Error::EmptyRange);
} }
let max_rand = <$ty>::splat( let max_rand = <$ty>::splat(
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, ($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
); );
let mut scale = high - low; let mut scale = high - low;
@ -947,7 +974,7 @@ macro_rules! uniform_float_impl {
return Err(Error::EmptyRange); return Err(Error::EmptyRange);
} }
let max_rand = <$ty>::splat( let max_rand = <$ty>::splat(
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0, ($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
); );
let mut scale = (high - low) / max_rand; let mut scale = (high - low) / max_rand;
@ -1111,7 +1138,6 @@ uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 } uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 }
/// The back-end implementing [`UniformSampler`] for `Duration`. /// The back-end implementing [`UniformSampler`] for `Duration`.
/// ///
/// Unless you are implementing [`UniformSampler`] for your own types, this type /// Unless you are implementing [`UniformSampler`] for your own types, this type
@ -1248,26 +1274,29 @@ impl UniformSampler for UniformDuration {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::rngs::mock::StepRng;
use crate::distributions::utils::FloatSIMDScalarUtils; use crate::distributions::utils::FloatSIMDScalarUtils;
use crate::rngs::mock::StepRng;
#[test] #[test]
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
fn test_serialization_uniform_duration() { fn test_serialization_uniform_duration() {
let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap(); let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap();
let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap(); let de_distr: UniformDuration =
bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
assert_eq!(distr, de_distr); assert_eq!(distr, de_distr);
} }
#[test] #[test]
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
fn test_uniform_serialization() { fn test_uniform_serialization() {
let unit_box: Uniform<i32> = Uniform::new(-1, 1).unwrap(); let unit_box: Uniform<i32> = Uniform::new(-1, 1).unwrap();
let de_unit_box: Uniform<i32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); let de_unit_box: Uniform<i32> =
bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();
assert_eq!(unit_box.0, de_unit_box.0); assert_eq!(unit_box.0, de_unit_box.0);
let unit_box: Uniform<f32> = Uniform::new(-1., 1.).unwrap(); let unit_box: Uniform<f32> = Uniform::new(-1., 1.).unwrap();
let de_unit_box: Uniform<f32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap(); let de_unit_box: Uniform<f32> =
bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();
assert_eq!(unit_box.0, de_unit_box.0); assert_eq!(unit_box.0, de_unit_box.0);
} }
@ -1293,10 +1322,6 @@ mod tests {
#[test] #[test]
#[cfg_attr(miri, ignore)] // Miri is too slow #[cfg_attr(miri, ignore)] // Miri is too slow
fn test_integers() { fn test_integers() {
use core::{i128, u128};
use core::{i16, i32, i64, i8, isize};
use core::{u16, u32, u64, u8, usize};
let mut rng = crate::test::rng(251); let mut rng = crate::test::rng(251);
macro_rules! t { macro_rules! t {
($ty:ident, $v:expr, $le:expr, $lt:expr) => {{ ($ty:ident, $v:expr, $le:expr, $lt:expr) => {{
@ -1383,14 +1408,15 @@ mod tests {
let mut max = core::char::from_u32(0).unwrap(); let mut max = core::char::from_u32(0).unwrap();
for _ in 0..100 { for _ in 0..100 {
let c = rng.gen_range('A'..='Z'); let c = rng.gen_range('A'..='Z');
assert!(('A'..='Z').contains(&c)); assert!(c.is_ascii_uppercase());
max = max.max(c); max = max.max(c);
} }
assert_eq!(max, 'Z'); assert_eq!(max, 'Z');
let d = Uniform::new( let d = Uniform::new(
core::char::from_u32(0xD7F0).unwrap(), core::char::from_u32(0xD7F0).unwrap(),
core::char::from_u32(0xE010).unwrap(), core::char::from_u32(0xE010).unwrap(),
).unwrap(); )
.unwrap();
for _ in 0..100 { for _ in 0..100 {
let c = d.sample(&mut rng); let c = d.sample(&mut rng);
assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF); assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
@ -1403,12 +1429,16 @@ mod tests {
let string2 = Uniform::new( let string2 = Uniform::new(
core::char::from_u32(0x0000).unwrap(), core::char::from_u32(0x0000).unwrap(),
core::char::from_u32(0x0080).unwrap(), core::char::from_u32(0x0080).unwrap(),
).unwrap().sample_string(&mut rng, 100); )
.unwrap()
.sample_string(&mut rng, 100);
assert_eq!(string2.capacity(), 100); assert_eq!(string2.capacity(), 100);
let string3 = Uniform::new_inclusive( let string3 = Uniform::new_inclusive(
core::char::from_u32(0x0000).unwrap(), core::char::from_u32(0x0000).unwrap(),
core::char::from_u32(0x0080).unwrap(), core::char::from_u32(0x0080).unwrap(),
).unwrap().sample_string(&mut rng, 100); )
.unwrap()
.sample_string(&mut rng, 100);
assert_eq!(string3.capacity(), 200); assert_eq!(string3.capacity(), 200);
} }
} }
@ -1430,8 +1460,8 @@ mod tests {
(-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)), (-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)),
(-<$f_scalar>::from_bits(5), 0.0), (-<$f_scalar>::from_bits(5), 0.0),
(-<$f_scalar>::from_bits(7), -0.0), (-<$f_scalar>::from_bits(7), -0.0),
(0.1 * ::core::$f_scalar::MAX, ::core::$f_scalar::MAX), (0.1 * $f_scalar::MAX, $f_scalar::MAX),
(-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7), (-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7),
]; ];
for &(low_scalar, high_scalar) in v.iter() { for &(low_scalar, high_scalar) in v.iter() {
for lane in 0..<$ty>::LEN { for lane in 0..<$ty>::LEN {
@ -1444,27 +1474,47 @@ mod tests {
assert!(low_scalar <= v && v < high_scalar); assert!(low_scalar <= v && v < high_scalar);
let v = rng.sample(my_incl_uniform).extract(lane); let v = rng.sample(my_incl_uniform).extract(lane);
assert!(low_scalar <= v && v <= high_scalar); assert!(low_scalar <= v && v <= high_scalar);
let v = <$ty as SampleUniform>::Sampler let v =
::sample_single(low, high, &mut rng).unwrap().extract(lane); <$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
.unwrap()
.extract(lane);
assert!(low_scalar <= v && v < high_scalar); assert!(low_scalar <= v && v < high_scalar);
let v = <$ty as SampleUniform>::Sampler let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
::sample_single_inclusive(low, high, &mut rng).unwrap().extract(lane); low, high, &mut rng,
)
.unwrap()
.extract(lane);
assert!(low_scalar <= v && v <= high_scalar); assert!(low_scalar <= v && v <= high_scalar);
} }
assert_eq!( assert_eq!(
rng.sample(Uniform::new_inclusive(low, low).unwrap()).extract(lane), rng.sample(Uniform::new_inclusive(low, low).unwrap())
.extract(lane),
low_scalar low_scalar
); );
assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar); assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar);
assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar); assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar);
assert_eq!(<$ty as SampleUniform>::Sampler assert_eq!(
::sample_single(low, high, &mut zero_rng).unwrap() <$ty as SampleUniform>::Sampler::sample_single(
.extract(lane), low_scalar); low,
assert_eq!(<$ty as SampleUniform>::Sampler high,
::sample_single_inclusive(low, high, &mut zero_rng).unwrap() &mut zero_rng
.extract(lane), low_scalar); )
.unwrap()
.extract(lane),
low_scalar
);
assert_eq!(
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
low,
high,
&mut zero_rng
)
.unwrap()
.extract(lane),
low_scalar
);
assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar); assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar); assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
@ -1472,9 +1522,16 @@ mod tests {
// assert!(<$ty as SampleUniform>::Sampler // assert!(<$ty as SampleUniform>::Sampler
// ::sample_single(low, high, &mut max_rng).unwrap() // ::sample_single(low, high, &mut max_rng).unwrap()
// .extract(lane) < high_scalar); // .extract(lane) < high_scalar);
assert!(<$ty as SampleUniform>::Sampler assert!(
::sample_single_inclusive(low, high, &mut max_rng).unwrap() <$ty as SampleUniform>::Sampler::sample_single_inclusive(
.extract(lane) <= high_scalar); low,
high,
&mut max_rng
)
.unwrap()
.extract(lane)
<= high_scalar
);
// Don't run this test for really tiny differences between high and low // Don't run this test for really tiny differences between high and low
// since for those rounding might result in selecting high for a very // since for those rounding might result in selecting high for a very
@ -1485,27 +1542,26 @@ mod tests {
(-1i64 << $bits_shifted) as u64, (-1i64 << $bits_shifted) as u64,
); );
assert!( assert!(
<$ty as SampleUniform>::Sampler <$ty as SampleUniform>::Sampler::sample_single(
::sample_single(low, high, &mut lowering_max_rng).unwrap() low,
.extract(lane) < high_scalar high,
&mut lowering_max_rng
)
.unwrap()
.extract(lane)
< high_scalar
); );
} }
} }
} }
assert_eq!( assert_eq!(
rng.sample(Uniform::new_inclusive( rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()),
::core::$f_scalar::MAX, $f_scalar::MAX
::core::$f_scalar::MAX
).unwrap()),
::core::$f_scalar::MAX
); );
assert_eq!( assert_eq!(
rng.sample(Uniform::new_inclusive( rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()),
-::core::$f_scalar::MAX, -$f_scalar::MAX
-::core::$f_scalar::MAX
).unwrap()),
-::core::$f_scalar::MAX
); );
}}; }};
} }
@ -1549,21 +1605,18 @@ mod tests {
macro_rules! t { macro_rules! t {
($ty:ident, $f_scalar:ident) => {{ ($ty:ident, $f_scalar:ident) => {{
let v: &[($f_scalar, $f_scalar)] = &[ let v: &[($f_scalar, $f_scalar)] = &[
(::std::$f_scalar::NAN, 0.0), ($f_scalar::NAN, 0.0),
(1.0, ::std::$f_scalar::NAN), (1.0, $f_scalar::NAN),
(::std::$f_scalar::NAN, ::std::$f_scalar::NAN), ($f_scalar::NAN, $f_scalar::NAN),
(1.0, 0.5), (1.0, 0.5),
(::std::$f_scalar::MAX, -::std::$f_scalar::MAX), ($f_scalar::MAX, -$f_scalar::MAX),
(::std::$f_scalar::INFINITY, ::std::$f_scalar::INFINITY), ($f_scalar::INFINITY, $f_scalar::INFINITY),
( ($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY),
::std::$f_scalar::NEG_INFINITY, ($f_scalar::NEG_INFINITY, 5.0),
::std::$f_scalar::NEG_INFINITY, (5.0, $f_scalar::INFINITY),
), ($f_scalar::NAN, $f_scalar::INFINITY),
(::std::$f_scalar::NEG_INFINITY, 5.0), ($f_scalar::NEG_INFINITY, $f_scalar::NAN),
(5.0, ::std::$f_scalar::INFINITY), ($f_scalar::NEG_INFINITY, $f_scalar::INFINITY),
(::std::$f_scalar::NAN, ::std::$f_scalar::INFINITY),
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::NAN),
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY),
]; ];
for &(low_scalar, high_scalar) in v.iter() { for &(low_scalar, high_scalar) in v.iter() {
for lane in 0..<$ty>::LEN { for lane in 0..<$ty>::LEN {
@ -1593,7 +1646,6 @@ mod tests {
} }
} }
#[test] #[test]
#[cfg_attr(miri, ignore)] // Miri is too slow #[cfg_attr(miri, ignore)] // Miri is too slow
fn test_durations() { fn test_durations() {
@ -1602,10 +1654,7 @@ mod tests {
let v = &[ let v = &[
(Duration::new(10, 50000), Duration::new(100, 1234)), (Duration::new(10, 50000), Duration::new(100, 1234)),
(Duration::new(0, 100), Duration::new(1, 50)), (Duration::new(0, 100), Duration::new(1, 50)),
( (Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)),
Duration::new(0, 0),
Duration::new(u64::MAX, 999_999_999),
),
]; ];
for &(low, high) in v.iter() { for &(low, high) in v.iter() {
let my_uniform = Uniform::new(low, high).unwrap(); let my_uniform = Uniform::new(low, high).unwrap();
@ -1707,8 +1756,13 @@ mod tests {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<T: SampleUniform + Copy + fmt::Debug + PartialEq>( fn test_samples<T: SampleUniform + Copy + fmt::Debug + PartialEq>(
lb: T, ub: T, expected_single: &[T], expected_multiple: &[T], lb: T,
) where Uniform<T>: Distribution<T> { ub: T,
expected_single: &[T],
expected_multiple: &[T],
) where
Uniform<T>: Distribution<T>,
{
let mut rng = crate::test::rng(897); let mut rng = crate::test::rng(897);
let mut buf = [lb; 3]; let mut buf = [lb; 3];
@ -1730,11 +1784,12 @@ mod tests {
test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]); test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]);
test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]); test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]);
test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[ test_samples(
0.008194133, 0f32,
0.00398172, 1e-2f32,
0.007428536, &[0.0003070104, 0.0026630748, 0.00979833],
]); &[0.008194133, 0.00398172, 0.007428536],
);
test_samples( test_samples(
-1e10f64, -1e10f64,
1e10f64, 1e10f64,
@ -1760,9 +1815,15 @@ mod tests {
#[test] #[test]
fn uniform_distributions_can_be_compared() { fn uniform_distributions_can_be_compared() {
assert_eq!(Uniform::new(1.0, 2.0).unwrap(), Uniform::new(1.0, 2.0).unwrap()); assert_eq!(
Uniform::new(1.0, 2.0).unwrap(),
Uniform::new(1.0, 2.0).unwrap()
);
// To cover UniformInt // To cover UniformInt
assert_eq!(Uniform::new(1_u32, 2_u32).unwrap(), Uniform::new(1_u32, 2_u32).unwrap()); assert_eq!(
Uniform::new(1_u32, 2_u32).unwrap(),
Uniform::new(1_u32, 2_u32).unwrap()
);
} }
} }

View File

@ -8,9 +8,10 @@
//! Math helper functions //! Math helper functions
#[cfg(feature = "simd_support")] use core::simd::prelude::*; #[cfg(feature = "simd_support")]
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount}; use core::simd::prelude::*;
#[cfg(feature = "simd_support")]
use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
pub(crate) trait WideningMultiply<RHS = Self> { pub(crate) trait WideningMultiply<RHS = Self> {
type Output; type Output;
@ -146,8 +147,10 @@ wmul_impl_usize! { u64 }
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
mod simd_wmul { mod simd_wmul {
use super::*; use super::*;
#[cfg(target_arch = "x86")] use core::arch::x86::*; #[cfg(target_arch = "x86")]
#[cfg(target_arch = "x86_64")] use core::arch::x86_64::*; use core::arch::x86::*;
#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::*;
wmul_impl! { wmul_impl! {
(u8x4, u16x4), (u8x4, u16x4),
@ -340,12 +343,12 @@ macro_rules! scalar_float_impl {
scalar_float_impl!(f32, u32); scalar_float_impl!(f32, u32);
scalar_float_impl!(f64, u64); scalar_float_impl!(f64, u64);
#[cfg(feature = "simd_support")] #[cfg(feature = "simd_support")]
macro_rules! simd_impl { macro_rules! simd_impl {
($fty:ident, $uty:ident) => { ($fty:ident, $uty:ident) => {
impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES> impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES>
where LaneCount<LANES>: SupportedLaneCount where
LaneCount<LANES>: SupportedLaneCount,
{ {
type Mask = Mask<<$fty as SimdElement>::Mask, LANES>; type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
type UInt = Simd<$uty, LANES>; type UInt = Simd<$uty, LANES>;

View File

@ -108,7 +108,11 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
X: Weight, X: Weight,
{ {
let mut iter = weights.into_iter(); let mut iter = weights.into_iter();
let mut total_weight: X = iter.next().ok_or(WeightError::InvalidInput)?.borrow().clone(); let mut total_weight: X = iter
.next()
.ok_or(WeightError::InvalidInput)?
.borrow()
.clone();
let zero = X::ZERO; let zero = X::ZERO;
if !(total_weight >= zero) { if !(total_weight >= zero) {
@ -252,9 +256,9 @@ pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
} }
impl<'a, X> Debug for WeightedIndexIter<'a, X> impl<'a, X> Debug for WeightedIndexIter<'a, X>
where where
X: SampleUniform + PartialOrd + Debug, X: SampleUniform + PartialOrd + Debug,
X::Sampler: Debug, X::Sampler: Debug,
{ {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("WeightedIndexIter") f.debug_struct("WeightedIndexIter")
@ -278,10 +282,7 @@ where
impl<'a, X> Iterator for WeightedIndexIter<'a, X> impl<'a, X> Iterator for WeightedIndexIter<'a, X>
where where
X: for<'b> core::ops::SubAssign<&'b X> X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
+ SampleUniform
+ PartialOrd
+ Clone,
{ {
type Item = X; type Item = X;
@ -315,15 +316,16 @@ impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
/// ``` /// ```
pub fn weight(&self, index: usize) -> Option<X> pub fn weight(&self, index: usize) -> Option<X>
where where
X: for<'a> core::ops::SubAssign<&'a X> X: for<'a> core::ops::SubAssign<&'a X>,
{ {
let mut weight = if index < self.cumulative_weights.len() { use core::cmp::Ordering::*;
self.cumulative_weights[index].clone()
} else if index == self.cumulative_weights.len() { let mut weight = match index.cmp(&self.cumulative_weights.len()) {
self.total_weight.clone() Less => self.cumulative_weights[index].clone(),
} else { Equal => self.total_weight.clone(),
return None; Greater => return None,
}; };
if index > 0 { if index > 0 {
weight -= &self.cumulative_weights[index - 1]; weight -= &self.cumulative_weights[index - 1];
} }
@ -348,7 +350,7 @@ impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
/// ``` /// ```
pub fn weights(&self) -> WeightedIndexIter<'_, X> pub fn weights(&self) -> WeightedIndexIter<'_, X>
where where
X: for<'a> core::ops::SubAssign<&'a X> X: for<'a> core::ops::SubAssign<&'a X>,
{ {
WeightedIndexIter { WeightedIndexIter {
weighted_index: self, weighted_index: self,
@ -387,6 +389,7 @@ pub trait Weight: Clone {
/// - `Result::Err`: Returns an error when `Self` cannot represent the /// - `Result::Err`: Returns an error when `Self` cannot represent the
/// result of `self + v` (i.e. overflow). The value of `self` should be /// result of `self + v` (i.e. overflow). The value of `self` should be
/// discarded. /// discarded.
#[allow(clippy::result_unit_err)]
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>; fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
} }
@ -417,6 +420,7 @@ macro_rules! impl_weight_float {
($t:ty) => { ($t:ty) => {
impl Weight for $t { impl Weight for $t {
const ZERO: Self = 0.0; const ZERO: Self = 0.0;
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> { fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
// Floats have an explicit representation for overflow // Floats have an explicit representation for overflow
*self += *v; *self += *v;
@ -435,7 +439,7 @@ mod test {
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
#[test] #[test]
fn test_weightedindex_serde1() { fn test_weightedindex_serde1() {
let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap(); let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
let ser_weighted_index = bincode::serialize(&weighted_index).unwrap(); let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
let de_weighted_index: WeightedIndex<i32> = let de_weighted_index: WeightedIndex<i32> =
@ -451,20 +455,20 @@ mod test {
#[test] #[test]
fn test_accepting_nan() { fn test_accepting_nan() {
assert_eq!( assert_eq!(
WeightedIndex::new(&[f32::NAN, 0.5]).unwrap_err(), WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
WeightError::InvalidWeight, WeightError::InvalidWeight,
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[f32::NAN]).unwrap_err(), WeightedIndex::new([f32::NAN]).unwrap_err(),
WeightError::InvalidWeight, WeightError::InvalidWeight,
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[0.5, f32::NAN]).unwrap_err(), WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
WeightError::InvalidWeight, WeightError::InvalidWeight,
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[0.5, 7.0]) WeightedIndex::new([0.5, 7.0])
.unwrap() .unwrap()
.update_weights(&[(0, &f32::NAN)]) .update_weights(&[(0, &f32::NAN)])
.unwrap_err(), .unwrap_err(),
@ -516,10 +520,10 @@ mod test {
verify(chosen); verify(chosen);
for _ in 0..5 { for _ in 0..5 {
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
assert_eq!( assert_eq!(
WeightedIndex::new(&[0, 0, 0, 0, 10, 0]) WeightedIndex::new([0, 0, 0, 0, 10, 0])
.unwrap() .unwrap()
.sample(&mut r), .sample(&mut r),
4 4
@ -531,19 +535,19 @@ mod test {
WeightError::InvalidInput WeightError::InvalidInput
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[0]).unwrap_err(), WeightedIndex::new([0]).unwrap_err(),
WeightError::InsufficientNonZero WeightError::InsufficientNonZero
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
WeightError::InvalidWeight WeightError::InvalidWeight
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
WeightError::InvalidWeight WeightError::InvalidWeight
); );
assert_eq!( assert_eq!(
WeightedIndex::new(&[-10]).unwrap_err(), WeightedIndex::new([-10]).unwrap_err(),
WeightError::InvalidWeight WeightError::InvalidWeight
); );
} }
@ -649,7 +653,9 @@ mod test {
#[test] #[test]
fn value_stability() { fn value_stability() {
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>( fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
weights: I, buf: &mut [usize], expected: &[usize], weights: I,
buf: &mut [usize],
expected: &[usize],
) where ) where
I: IntoIterator, I: IntoIterator,
I::Item: SampleBorrow<X>, I::Item: SampleBorrow<X>,
@ -665,17 +671,17 @@ mod test {
let mut buf = [0; 10]; let mut buf = [0; 10];
test_samples( test_samples(
&[1i32, 1, 1, 1, 1, 1, 1, 1, 1], [1i32, 1, 1, 1, 1, 1, 1, 1, 1],
&mut buf, &mut buf,
&[0, 6, 2, 6, 3, 4, 7, 8, 2, 5], &[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
); );
test_samples( test_samples(
&[0.7f32, 0.1, 0.1, 0.1], [0.7f32, 0.1, 0.1, 0.1],
&mut buf, &mut buf,
&[0, 0, 0, 1, 0, 0, 2, 3, 0, 0], &[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
); );
test_samples( test_samples(
&[1.0f64, 0.999, 0.998, 0.997], [1.0f64, 0.999, 0.998, 0.997],
&mut buf, &mut buf,
&[2, 2, 1, 3, 2, 1, 3, 3, 2, 1], &[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
); );
@ -683,7 +689,7 @@ mod test {
#[test] #[test]
fn weighted_index_distributions_can_be_compared() { fn weighted_index_distributions_can_be_compared() {
assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2])); assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
} }
#[test] #[test]

View File

@ -50,6 +50,7 @@
#![doc(test(attr(allow(unused_variables), deny(warnings))))] #![doc(test(attr(allow(unused_variables), deny(warnings))))]
#![no_std] #![no_std]
#![cfg_attr(feature = "simd_support", feature(portable_simd))] #![cfg_attr(feature = "simd_support", feature(portable_simd))]
#![allow(unexpected_cfgs)]
#![cfg_attr(doc_cfg, feature(doc_cfg))] #![cfg_attr(doc_cfg, feature(doc_cfg))]
#![allow( #![allow(
clippy::float_cmp, clippy::float_cmp,
@ -57,8 +58,10 @@
clippy::nonminimal_bool clippy::nonminimal_bool
)] )]
#[cfg(feature = "alloc")] extern crate alloc; #[cfg(feature = "alloc")]
#[cfg(feature = "std")] extern crate std; extern crate alloc;
#[cfg(feature = "std")]
extern crate std;
#[allow(unused)] #[allow(unused)]
macro_rules! trace { ($($x:tt)*) => ( macro_rules! trace { ($($x:tt)*) => (
@ -160,7 +163,9 @@ use crate::distributions::{Distribution, Standard};
)] )]
#[inline] #[inline]
pub fn random<T>() -> T pub fn random<T>() -> T
where Standard: Distribution<T> { where
Standard: Distribution<T>,
{
thread_rng().random() thread_rng().random()
} }

View File

@ -18,7 +18,8 @@
//! # let _: f32 = r.random(); //! # let _: f32 = r.random();
//! ``` //! ```
#[doc(no_inline)] pub use crate::distributions::Distribution; #[doc(no_inline)]
pub use crate::distributions::Distribution;
#[cfg(feature = "small_rng")] #[cfg(feature = "small_rng")]
#[doc(no_inline)] #[doc(no_inline)]
pub use crate::rngs::SmallRng; pub use crate::rngs::SmallRng;
@ -33,4 +34,5 @@ pub use crate::seq::{IndexedMutRandom, IndexedRandom, IteratorRandom, SliceRando
#[doc(no_inline)] #[doc(no_inline)]
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
pub use crate::{random, thread_rng}; pub use crate::{random, thread_rng};
#[doc(no_inline)] pub use crate::{CryptoRng, Rng, RngCore, SeedableRng}; #[doc(no_inline)]
pub use crate::{CryptoRng, Rng, RngCore, SeedableRng};

View File

@ -89,7 +89,9 @@ pub trait Rng: RngCore {
/// [`Standard`]: distributions::Standard /// [`Standard`]: distributions::Standard
#[inline] #[inline]
fn random<T>(&mut self) -> T fn random<T>(&mut self) -> T
where Standard: Distribution<T> { where
Standard: Distribution<T>,
{
Standard.sample(self) Standard.sample(self)
} }
@ -309,7 +311,9 @@ pub trait Rng: RngCore {
note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024." note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024."
)] )]
fn gen<T>(&mut self) -> T fn gen<T>(&mut self) -> T
where Standard: Distribution<T> { where
Standard: Distribution<T>,
{
self.random() self.random()
} }
} }
@ -402,7 +406,8 @@ impl_fill!(u16, u32, u64, usize, u128,);
impl_fill!(i8, i16, i32, i64, isize, i128,); impl_fill!(i8, i16, i32, i64, isize, i128,);
impl<T, const N: usize> Fill for [T; N] impl<T, const N: usize> Fill for [T; N]
where [T]: Fill where
[T]: Fill,
{ {
fn fill<R: Rng + ?Sized>(&mut self, rng: &mut R) { fn fill<R: Rng + ?Sized>(&mut self, rng: &mut R) {
<[T] as Fill>::fill(self, rng) <[T] as Fill>::fill(self, rng)
@ -414,7 +419,8 @@ mod test {
use super::*; use super::*;
use crate::rngs::mock::StepRng; use crate::rngs::mock::StepRng;
use crate::test::rng; use crate::test::rng;
#[cfg(feature = "alloc")] use alloc::boxed::Box; #[cfg(feature = "alloc")]
use alloc::boxed::Box;
#[test] #[test]
fn test_fill_bytes_default() { fn test_fill_bytes_default() {

View File

@ -10,7 +10,8 @@
use rand_core::{impls, RngCore}; use rand_core::{impls, RngCore};
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// A mock generator yielding very predictable output /// A mock generator yielding very predictable output
/// ///
@ -78,7 +79,8 @@ rand_core::impl_try_rng_from_rng_core!(StepRng);
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
#[cfg(any(feature = "alloc", feature = "serde1"))] use super::StepRng; #[cfg(any(feature = "alloc", feature = "serde1"))]
use super::StepRng;
#[test] #[test]
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]

View File

@ -102,18 +102,22 @@ pub use reseeding::ReseedingRng;
pub mod mock; // Public so we don't export `StepRng` directly, making it a bit pub mod mock; // Public so we don't export `StepRng` directly, making it a bit
// more clear it is intended for testing. // more clear it is intended for testing.
#[cfg(feature = "small_rng")] mod small; #[cfg(feature = "small_rng")]
mod small;
#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))] #[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))]
mod xoshiro128plusplus; mod xoshiro128plusplus;
#[cfg(all(feature = "small_rng", target_pointer_width = "64"))] #[cfg(all(feature = "small_rng", target_pointer_width = "64"))]
mod xoshiro256plusplus; mod xoshiro256plusplus;
#[cfg(feature = "std_rng")] mod std; #[cfg(feature = "std_rng")]
mod std;
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
pub(crate) mod thread; pub(crate) mod thread;
#[cfg(feature = "small_rng")] pub use self::small::SmallRng; #[cfg(feature = "small_rng")]
#[cfg(feature = "std_rng")] pub use self::std::StdRng; pub use self::small::SmallRng;
#[cfg(feature = "std_rng")]
pub use self::std::StdRng;
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))] #[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
pub use self::thread::ThreadRng; pub use self::thread::ThreadRng;

View File

@ -33,7 +33,6 @@ use crate::rngs::ReseedingRng;
// `ThreadRng` internally, which is nonsensical anyway. We should also never run // `ThreadRng` internally, which is nonsensical anyway. We should also never run
// `ThreadRng` in destructors of its implementation, which is also nonsensical. // `ThreadRng` in destructors of its implementation, which is also nonsensical.
// Number of generated bytes after which to reseed `ThreadRng`. // Number of generated bytes after which to reseed `ThreadRng`.
// According to benchmarks, reseeding has a noticeable impact with thresholds // According to benchmarks, reseeding has a noticeable impact with thresholds
// of 32 kB and less. We choose 64 kB to avoid significant overhead. // of 32 kB and less. We choose 64 kB to avoid significant overhead.

View File

@ -9,7 +9,8 @@
use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32}; use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32};
use rand_core::le::read_u32_into; use rand_core::le::read_u32_into;
use rand_core::{RngCore, SeedableRng}; use rand_core::{RngCore, SeedableRng};
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// A xoshiro128++ random number generator. /// A xoshiro128++ random number generator.
/// ///

View File

@ -9,7 +9,8 @@
use rand_core::impls::fill_bytes_via_next; use rand_core::impls::fill_bytes_via_next;
use rand_core::le::read_u64_into; use rand_core::le::read_u64_into;
use rand_core::{RngCore, SeedableRng}; use rand_core::{RngCore, SeedableRng};
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize}; #[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
/// A xoshiro256++ random number generator. /// A xoshiro256++ random number generator.
/// ///

View File

@ -10,7 +10,7 @@ use crate::RngCore;
pub(crate) struct CoinFlipper<R: RngCore> { pub(crate) struct CoinFlipper<R: RngCore> {
pub rng: R, pub rng: R,
chunk: u32, //TODO(opt): this should depend on RNG word size chunk: u32, // TODO(opt): this should depend on RNG word size
chunk_remaining: u32, chunk_remaining: u32,
} }
@ -92,7 +92,7 @@ impl<R: RngCore> CoinFlipper<R> {
// If n * 2^c > `usize::MAX` we always return `true` anyway // If n * 2^c > `usize::MAX` we always return `true` anyway
n = n.saturating_mul(2_usize.pow(c)); n = n.saturating_mul(2_usize.pow(c));
} else { } else {
//At least one tail // At least one tail
if c == 1 { if c == 1 {
// Calculate 2n - d. // Calculate 2n - d.
// We need to use wrapping as 2n might be greater than `usize::MAX` // We need to use wrapping as 2n might be greater than `usize::MAX`

View File

@ -7,23 +7,30 @@
// except according to those terms. // except according to those terms.
//! Low-level API for sampling indices //! Low-level API for sampling indices
use core::{cmp::Ordering, hash::Hash, ops::AddAssign};
#[cfg(feature = "alloc")] use core::slice; #[cfg(feature = "alloc")]
use core::slice;
#[cfg(feature = "alloc")] use alloc::vec::{self, Vec}; #[cfg(feature = "alloc")]
use alloc::vec::{self, Vec};
// BTreeMap is not as fast in tests, but better than nothing. // BTreeMap is not as fast in tests, but better than nothing.
#[cfg(all(feature = "alloc", not(feature = "std")))] #[cfg(all(feature = "alloc", not(feature = "std")))]
use alloc::collections::BTreeSet; use alloc::collections::BTreeSet;
#[cfg(feature = "std")] use std::collections::HashSet; #[cfg(feature = "std")]
use std::collections::HashSet;
#[cfg(feature = "std")] #[cfg(feature = "std")]
use super::WeightError; use super::WeightError;
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}}; use crate::{
distributions::{uniform::SampleUniform, Distribution, Uniform},
Rng,
};
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
use serde::{Serialize, Deserialize}; use serde::{Deserialize, Serialize};
/// A vector of indices. /// A vector of indices.
/// ///
@ -88,8 +95,8 @@ impl IndexVec {
} }
impl IntoIterator for IndexVec { impl IntoIterator for IndexVec {
type Item = usize;
type IntoIter = IndexVecIntoIter; type IntoIter = IndexVecIntoIter;
type Item = usize;
/// Convert into an iterator over the indices as a sequence of `usize` values /// Convert into an iterator over the indices as a sequence of `usize` values
#[inline] #[inline]
@ -196,7 +203,6 @@ impl Iterator for IndexVecIntoIter {
impl ExactSizeIterator for IndexVecIntoIter {} impl ExactSizeIterator for IndexVecIntoIter {}
/// Randomly sample exactly `amount` distinct indices from `0..length`, and /// Randomly sample exactly `amount` distinct indices from `0..length`, and
/// return them in random order (fully shuffled). /// return them in random order (fully shuffled).
/// ///
@ -221,7 +227,9 @@ impl ExactSizeIterator for IndexVecIntoIter {}
/// Panics if `amount > length`. /// Panics if `amount > length`.
#[track_caller] #[track_caller]
pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
where R: Rng + ?Sized { where
R: Rng + ?Sized,
{
if amount > length { if amount > length {
panic!("`amount` of samples must be less than or equal to `length`"); panic!("`amount` of samples must be less than or equal to `length`");
} }
@ -276,7 +284,10 @@ where R: Rng + ?Sized {
#[cfg(feature = "std")] #[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
pub fn sample_weighted<R, F, X>( pub fn sample_weighted<R, F, X>(
rng: &mut R, length: usize, weight: F, amount: usize, rng: &mut R,
length: usize,
weight: F,
amount: usize,
) -> Result<IndexVec, WeightError> ) -> Result<IndexVec, WeightError>
where where
R: Rng + ?Sized, R: Rng + ?Sized,
@ -293,7 +304,6 @@ where
} }
} }
/// Randomly sample exactly `amount` distinct indices from `0..length`, and /// Randomly sample exactly `amount` distinct indices from `0..length`, and
/// return them in an arbitrary order (there is no guarantee of shuffling or /// return them in an arbitrary order (there is no guarantee of shuffling or
/// ordering). The weights are to be provided by the input function `weights`, /// ordering). The weights are to be provided by the input function `weights`,
@ -308,7 +318,10 @@ where
/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive. /// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive.
#[cfg(feature = "std")] #[cfg(feature = "std")]
fn sample_efraimidis_spirakis<R, F, X, N>( fn sample_efraimidis_spirakis<R, F, X, N>(
rng: &mut R, length: N, weight: F, amount: N, rng: &mut R,
length: N,
weight: F,
amount: N,
) -> Result<IndexVec, WeightError> ) -> Result<IndexVec, WeightError>
where where
R: Rng + ?Sized, R: Rng + ?Sized,
@ -325,23 +338,27 @@ where
index: N, index: N,
key: f64, key: f64,
} }
impl<N> PartialOrd for Element<N> { impl<N> PartialOrd for Element<N> {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> { fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.key.partial_cmp(&other.key) Some(self.cmp(other))
} }
} }
impl<N> Ord for Element<N> { impl<N> Ord for Element<N> {
fn cmp(&self, other: &Self) -> core::cmp::Ordering { fn cmp(&self, other: &Self) -> Ordering {
// partial_cmp will always produce a value, // partial_cmp will always produce a value,
// because we check that the weights are not nan // because we check that the weights are not nan
self.partial_cmp(other).unwrap() self.key.partial_cmp(&other.key).unwrap()
} }
} }
impl<N> PartialEq for Element<N> { impl<N> PartialEq for Element<N> {
fn eq(&self, other: &Self) -> bool { fn eq(&self, other: &Self) -> bool {
self.key == other.key self.key == other.key
} }
} }
impl<N> Eq for Element<N> {} impl<N> Eq for Element<N> {}
let mut candidates = Vec::with_capacity(length.as_usize()); let mut candidates = Vec::with_capacity(length.as_usize());
@ -367,8 +384,7 @@ where
// keys. Do this by using `select_nth_unstable` to put the elements with // keys. Do this by using `select_nth_unstable` to put the elements with
// the *smallest* keys at the beginning of the list in `O(n)` time, which // the *smallest* keys at the beginning of the list in `O(n)` time, which
// provides equivalent information about the elements with the *greatest* keys. // provides equivalent information about the elements with the *greatest* keys.
let (_, mid, greater) let (_, mid, greater) = candidates.select_nth_unstable(avail - amount.as_usize());
= candidates.select_nth_unstable(avail - amount.as_usize());
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize()); let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
result.push(mid.index); result.push(mid.index);
@ -385,7 +401,9 @@ where
/// ///
/// This implementation uses `O(amount)` memory and `O(amount^2)` time. /// This implementation uses `O(amount)` memory and `O(amount^2)` time.
fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
where R: Rng + ?Sized { where
R: Rng + ?Sized,
{
// Note that the values returned by `rng.gen_range()` can be // Note that the values returned by `rng.gen_range()` can be
// inferred from the returned vector by working backwards from // inferred from the returned vector by working backwards from
// the last entry. This bijection proves the algorithm fair. // the last entry. This bijection proves the algorithm fair.
@ -414,7 +432,9 @@ where R: Rng + ?Sized {
/// ///
/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time. /// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
where R: Rng + ?Sized { where
R: Rng + ?Sized,
{
debug_assert!(amount <= length); debug_assert!(amount <= length);
let mut indices: Vec<u32> = Vec::with_capacity(length as usize); let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
indices.extend(0..length); indices.extend(0..length);
@ -427,12 +447,12 @@ where R: Rng + ?Sized {
IndexVec::from(indices) IndexVec::from(indices)
} }
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + Hash + AddAssign {
+ core::hash::Hash + core::ops::AddAssign {
fn zero() -> Self; fn zero() -> Self;
fn one() -> Self; fn one() -> Self;
fn as_usize(self) -> usize; fn as_usize(self) -> usize;
} }
impl UInt for u32 { impl UInt for u32 {
#[inline] #[inline]
fn zero() -> Self { fn zero() -> Self {
@ -449,6 +469,7 @@ impl UInt for u32 {
self as usize self as usize
} }
} }
impl UInt for usize { impl UInt for usize {
#[inline] #[inline]
fn zero() -> Self { fn zero() -> Self {
@ -507,19 +528,23 @@ mod test {
#[cfg(feature = "serde1")] #[cfg(feature = "serde1")]
fn test_serialization_index_vec() { fn test_serialization_index_vec() {
let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]); let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]);
let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap(); let de_some_index_vec: IndexVec =
bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
match (some_index_vec, de_some_index_vec) { match (some_index_vec, de_some_index_vec) {
(IndexVec::U32(a), IndexVec::U32(b)) => { (IndexVec::U32(a), IndexVec::U32(b)) => {
assert_eq!(a, b); assert_eq!(a, b);
}, }
(IndexVec::USize(a), IndexVec::USize(b)) => { (IndexVec::USize(a), IndexVec::USize(b)) => {
assert_eq!(a, b); assert_eq!(a, b);
}, }
_ => {panic!("failed to seralize/deserialize `IndexVec`")} _ => {
panic!("failed to seralize/deserialize `IndexVec`")
}
} }
} }
#[cfg(feature = "alloc")] use alloc::vec; #[cfg(feature = "alloc")]
use alloc::vec;
#[test] #[test]
fn test_sample_boundaries() { fn test_sample_boundaries() {
@ -593,7 +618,7 @@ mod test {
for &i in &indices { for &i in &indices {
assert!((i as usize) < len); assert!((i as usize) < len);
} }
}, }
IndexVec::USize(_) => panic!("expected `IndexVec::U32`"), IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
} }
} }
@ -628,11 +653,15 @@ mod test {
do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace
do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace
do_test(1_000_000, 8, &[ do_test(
103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573, 1_000_000,
]); // floyd 8,
do_test(1_000_000, 180, &[ &[103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573],
103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573, ); // floyd
]); // rejection do_test(
1_000_000,
180,
&[103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573],
); // rejection
} }
} }

View File

@ -44,7 +44,8 @@ use alloc::vec::Vec;
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
use crate::distributions::uniform::{SampleBorrow, SampleUniform}; use crate::distributions::uniform::{SampleBorrow, SampleUniform};
#[cfg(feature = "alloc")] use crate::distributions::Weight; #[cfg(feature = "alloc")]
use crate::distributions::Weight;
use crate::Rng; use crate::Rng;
use self::coin_flipper::CoinFlipper; use self::coin_flipper::CoinFlipper;
@ -167,7 +168,9 @@ pub trait IndexedRandom: Index<usize> {
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
fn choose_weighted<R, F, B, X>( fn choose_weighted<R, F, B, X>(
&self, rng: &mut R, weight: F, &self,
rng: &mut R,
weight: F,
) -> Result<&Self::Output, WeightError> ) -> Result<&Self::Output, WeightError>
where where
R: Rng + ?Sized, R: Rng + ?Sized,
@ -212,13 +215,15 @@ pub trait IndexedRandom: Index<usize> {
/// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>()); /// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
/// ``` /// ```
/// [`choose_multiple`]: IndexedRandom::choose_multiple /// [`choose_multiple`]: IndexedRandom::choose_multiple
//
// Note: this is feature-gated on std due to usage of f64::powf. // Note: this is feature-gated on std due to usage of f64::powf.
// If necessary, we may use alloc+libm as an alternative (see PR #1089). // If necessary, we may use alloc+libm as an alternative (see PR #1089).
#[cfg(feature = "std")] #[cfg(feature = "std")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
fn choose_multiple_weighted<R, F, X>( fn choose_multiple_weighted<R, F, X>(
&self, rng: &mut R, amount: usize, weight: F, &self,
rng: &mut R,
amount: usize,
weight: F,
) -> Result<SliceChooseIter<Self, Self::Output>, WeightError> ) -> Result<SliceChooseIter<Self, Self::Output>, WeightError>
where where
Self::Output: Sized, Self::Output: Sized,
@ -285,7 +290,9 @@ pub trait IndexedMutRandom: IndexedRandom + IndexMut<usize> {
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))] #[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
fn choose_weighted_mut<R, F, B, X>( fn choose_weighted_mut<R, F, B, X>(
&mut self, rng: &mut R, weight: F, &mut self,
rng: &mut R,
weight: F,
) -> Result<&mut Self::Output, WeightError> ) -> Result<&mut Self::Output, WeightError>
where where
R: Rng + ?Sized, R: Rng + ?Sized,
@ -358,7 +365,9 @@ pub trait SliceRandom: IndexedMutRandom {
/// ///
/// For slices, complexity is `O(m)` where `m = amount`. /// For slices, complexity is `O(m)` where `m = amount`.
fn partial_shuffle<R>( fn partial_shuffle<R>(
&mut self, rng: &mut R, amount: usize, &mut self,
rng: &mut R,
amount: usize,
) -> (&mut [Self::Output], &mut [Self::Output]) ) -> (&mut [Self::Output], &mut [Self::Output])
where where
Self::Output: Sized, Self::Output: Sized,
@ -624,9 +633,7 @@ impl<T> SliceRandom for [T] {
self.partial_shuffle(rng, self.len()); self.partial_shuffle(rng, self.len());
} }
fn partial_shuffle<R>( fn partial_shuffle<R>(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T])
&mut self, rng: &mut R, amount: usize,
) -> (&mut [T], &mut [T])
where where
R: Rng + ?Sized, R: Rng + ?Sized,
{ {
@ -1294,7 +1301,10 @@ mod test {
fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) { fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
let mut rng = crate::test::rng(412); let mut rng = crate::test::rng(412);
let mut buf = [0u32; 8]; let mut buf = [0u32; 8];
assert_eq!(iter.clone().choose_multiple_fill(&mut rng, &mut buf), v.len()); assert_eq!(
iter.clone().choose_multiple_fill(&mut rng, &mut buf),
v.len()
);
assert_eq!(&buf[0..v.len()], v); assert_eq!(&buf[0..v.len()], v);
#[cfg(feature = "alloc")] #[cfg(feature = "alloc")]