Apply rustfmt and fix Clippy warnings (#1448)
This commit is contained in:
parent
e93776960e
commit
1b762b2867
23
.github/workflows/benches.yml
vendored
Normal file
23
.github/workflows/benches.yml
vendored
Normal file
@ -0,0 +1,23 @@
|
|||||||
|
name: Benches
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- ".github/workflows/benches.yml"
|
||||||
|
- "benches/**"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
benches:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@master
|
||||||
|
with:
|
||||||
|
toolchain: nightly
|
||||||
|
components: clippy, rustfmt
|
||||||
|
- name: Rustfmt
|
||||||
|
run: cargo fmt --all -- --check
|
||||||
|
- name: Clippy
|
||||||
|
run: cargo clippy --all --all-targets -- -D warnings
|
||||||
|
- name: Build
|
||||||
|
run: RUSTFLAGS=-Dwarnings cargo build --all-targets
|
2
.github/workflows/test.yml
vendored
2
.github/workflows/test.yml
vendored
@ -72,8 +72,6 @@ jobs:
|
|||||||
if: ${{ matrix.variant == 'minimal_versions' }}
|
if: ${{ matrix.variant == 'minimal_versions' }}
|
||||||
run: |
|
run: |
|
||||||
cargo generate-lockfile -Z minimal-versions
|
cargo generate-lockfile -Z minimal-versions
|
||||||
# Overrides for dependencies with incorrect requirements (may need periodic updating)
|
|
||||||
cargo update -p regex --precise 1.5.1
|
|
||||||
- name: Maybe nightly
|
- name: Maybe nightly
|
||||||
if: ${{ matrix.toolchain == 'nightly' }}
|
if: ${{ matrix.toolchain == 'nightly' }}
|
||||||
run: |
|
run: |
|
||||||
|
33
.github/workflows/workspace.yml
vendored
Normal file
33
.github/workflows/workspace.yml
vendored
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
name: Workspace
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths-ignore:
|
||||||
|
- README.md
|
||||||
|
- "benches/**"
|
||||||
|
push:
|
||||||
|
branches: master
|
||||||
|
paths-ignore:
|
||||||
|
- README.md
|
||||||
|
- "benches/**"
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
clippy:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@master
|
||||||
|
with:
|
||||||
|
toolchain: 1.78.0
|
||||||
|
components: clippy
|
||||||
|
- run: cargo clippy --all --all-targets -- -D warnings
|
||||||
|
|
||||||
|
rustfmt:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: dtolnay/rust-toolchain@master
|
||||||
|
with:
|
||||||
|
toolchain: stable
|
||||||
|
components: rustfmt
|
||||||
|
- run: cargo fmt --all -- --check
|
@ -58,12 +58,12 @@ unbiased = []
|
|||||||
|
|
||||||
[workspace]
|
[workspace]
|
||||||
members = [
|
members = [
|
||||||
"benches",
|
|
||||||
"rand_core",
|
"rand_core",
|
||||||
"rand_distr",
|
"rand_distr",
|
||||||
"rand_chacha",
|
"rand_chacha",
|
||||||
"rand_pcg",
|
"rand_pcg",
|
||||||
]
|
]
|
||||||
|
exclude = ["benches"]
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
rand_core = { path = "rand_core", version = "=0.9.0-alpha.1", default-features = false }
|
rand_core = { path = "rand_core", version = "=0.9.0-alpha.1", default-features = false }
|
||||||
|
@ -50,7 +50,6 @@ gen_bytes!(gen_bytes_chacha8, ChaCha8Rng::from_os_rng());
|
|||||||
gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_os_rng());
|
gen_bytes!(gen_bytes_chacha12, ChaCha12Rng::from_os_rng());
|
||||||
gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_os_rng());
|
gen_bytes!(gen_bytes_chacha20, ChaCha20Rng::from_os_rng());
|
||||||
gen_bytes!(gen_bytes_std, StdRng::from_os_rng());
|
gen_bytes!(gen_bytes_std, StdRng::from_os_rng());
|
||||||
#[cfg(feature = "small_rng")]
|
|
||||||
gen_bytes!(gen_bytes_small, SmallRng::from_thread_rng());
|
gen_bytes!(gen_bytes_small, SmallRng::from_thread_rng());
|
||||||
gen_bytes!(gen_bytes_os, UnwrapErr(OsRng));
|
gen_bytes!(gen_bytes_os, UnwrapErr(OsRng));
|
||||||
gen_bytes!(gen_bytes_thread, thread_rng());
|
gen_bytes!(gen_bytes_thread, thread_rng());
|
||||||
@ -81,7 +80,6 @@ gen_uint!(gen_u32_chacha8, u32, ChaCha8Rng::from_os_rng());
|
|||||||
gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_os_rng());
|
gen_uint!(gen_u32_chacha12, u32, ChaCha12Rng::from_os_rng());
|
||||||
gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_os_rng());
|
gen_uint!(gen_u32_chacha20, u32, ChaCha20Rng::from_os_rng());
|
||||||
gen_uint!(gen_u32_std, u32, StdRng::from_os_rng());
|
gen_uint!(gen_u32_std, u32, StdRng::from_os_rng());
|
||||||
#[cfg(feature = "small_rng")]
|
|
||||||
gen_uint!(gen_u32_small, u32, SmallRng::from_thread_rng());
|
gen_uint!(gen_u32_small, u32, SmallRng::from_thread_rng());
|
||||||
gen_uint!(gen_u32_os, u32, UnwrapErr(OsRng));
|
gen_uint!(gen_u32_os, u32, UnwrapErr(OsRng));
|
||||||
gen_uint!(gen_u32_thread, u32, thread_rng());
|
gen_uint!(gen_u32_thread, u32, thread_rng());
|
||||||
@ -95,7 +93,6 @@ gen_uint!(gen_u64_chacha8, u64, ChaCha8Rng::from_os_rng());
|
|||||||
gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_os_rng());
|
gen_uint!(gen_u64_chacha12, u64, ChaCha12Rng::from_os_rng());
|
||||||
gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_os_rng());
|
gen_uint!(gen_u64_chacha20, u64, ChaCha20Rng::from_os_rng());
|
||||||
gen_uint!(gen_u64_std, u64, StdRng::from_os_rng());
|
gen_uint!(gen_u64_std, u64, StdRng::from_os_rng());
|
||||||
#[cfg(feature = "small_rng")]
|
|
||||||
gen_uint!(gen_u64_small, u64, SmallRng::from_thread_rng());
|
gen_uint!(gen_u64_small, u64, SmallRng::from_thread_rng());
|
||||||
gen_uint!(gen_u64_os, u64, UnwrapErr(OsRng));
|
gen_uint!(gen_u64_os, u64, UnwrapErr(OsRng));
|
||||||
gen_uint!(gen_u64_thread, u64, thread_rng());
|
gen_uint!(gen_u64_thread, u64, thread_rng());
|
||||||
|
@ -8,8 +8,10 @@
|
|||||||
|
|
||||||
//! The ChaCha random number generator.
|
//! The ChaCha random number generator.
|
||||||
|
|
||||||
#[cfg(not(feature = "std"))] use core;
|
#[cfg(not(feature = "std"))]
|
||||||
#[cfg(feature = "std")] use std as core;
|
use core;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
use std as core;
|
||||||
|
|
||||||
use self::core::fmt;
|
use self::core::fmt;
|
||||||
use crate::guts::ChaCha;
|
use crate::guts::ChaCha;
|
||||||
@ -27,7 +29,8 @@ const BLOCK_WORDS: u8 = 16;
|
|||||||
#[repr(transparent)]
|
#[repr(transparent)]
|
||||||
pub struct Array64<T>([T; 64]);
|
pub struct Array64<T>([T; 64]);
|
||||||
impl<T> Default for Array64<T>
|
impl<T> Default for Array64<T>
|
||||||
where T: Default
|
where
|
||||||
|
T: Default,
|
||||||
{
|
{
|
||||||
#[rustfmt::skip]
|
#[rustfmt::skip]
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
@ -54,7 +57,8 @@ impl<T> AsMut<[T]> for Array64<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<T> Clone for Array64<T>
|
impl<T> Clone for Array64<T>
|
||||||
where T: Copy + Default
|
where
|
||||||
|
T: Copy + Default,
|
||||||
{
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
let mut new = Self::default();
|
let mut new = Self::default();
|
||||||
@ -275,20 +279,25 @@ macro_rules! chacha_impl {
|
|||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
impl Serialize for $ChaChaXRng {
|
impl Serialize for $ChaChaXRng {
|
||||||
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
|
fn serialize<S>(&self, s: S) -> Result<S::Ok, S::Error>
|
||||||
where S: Serializer {
|
where
|
||||||
|
S: Serializer,
|
||||||
|
{
|
||||||
$abst::$ChaChaXRng::from(self).serialize(s)
|
$abst::$ChaChaXRng::from(self).serialize(s)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
impl<'de> Deserialize<'de> for $ChaChaXRng {
|
impl<'de> Deserialize<'de> for $ChaChaXRng {
|
||||||
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
fn deserialize<D>(d: D) -> Result<Self, D::Error>
|
||||||
where D: Deserializer<'de> {
|
where
|
||||||
|
D: Deserializer<'de>,
|
||||||
|
{
|
||||||
$abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x))
|
$abst::$ChaChaXRng::deserialize(d).map(|x| Self::from(&x))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
mod $abst {
|
mod $abst {
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// The abstract state of a ChaCha stream, independent of implementation choices. The
|
// The abstract state of a ChaCha stream, independent of implementation choices. The
|
||||||
// comparison and serialization of this object is considered a semver-covered part of
|
// comparison and serialization of this object is considered a semver-covered part of
|
||||||
@ -353,7 +362,8 @@ chacha_impl!(
|
|||||||
mod test {
|
mod test {
|
||||||
use rand_core::{RngCore, SeedableRng};
|
use rand_core::{RngCore, SeedableRng};
|
||||||
|
|
||||||
#[cfg(feature = "serde1")] use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng};
|
#[cfg(feature = "serde1")]
|
||||||
|
use super::{ChaCha12Rng, ChaCha20Rng, ChaCha8Rng};
|
||||||
|
|
||||||
type ChaChaRng = super::ChaCha20Rng;
|
type ChaChaRng = super::ChaCha20Rng;
|
||||||
|
|
||||||
|
@ -12,7 +12,9 @@
|
|||||||
use ppv_lite86::{dispatch, dispatch_light128};
|
use ppv_lite86::{dispatch, dispatch_light128};
|
||||||
|
|
||||||
pub use ppv_lite86::Machine;
|
pub use ppv_lite86::Machine;
|
||||||
use ppv_lite86::{vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector};
|
use ppv_lite86::{
|
||||||
|
vec128_storage, ArithOps, BitOps32, LaneWords4, MultiLane, StoreBytes, Vec4, Vec4Ext, Vector,
|
||||||
|
};
|
||||||
|
|
||||||
pub(crate) const BLOCK: usize = 16;
|
pub(crate) const BLOCK: usize = 16;
|
||||||
pub(crate) const BLOCK64: u64 = BLOCK as u64;
|
pub(crate) const BLOCK64: u64 = BLOCK as u64;
|
||||||
@ -140,14 +142,18 @@ fn add_pos<Mach: Machine>(m: Mach, d: Mach::u32x4, i: u64) -> Mach::u32x4 {
|
|||||||
#[cfg(target_endian = "little")]
|
#[cfg(target_endian = "little")]
|
||||||
fn d0123<Mach: Machine>(m: Mach, d: vec128_storage) -> Mach::u32x4x4 {
|
fn d0123<Mach: Machine>(m: Mach, d: vec128_storage) -> Mach::u32x4x4 {
|
||||||
let d0: Mach::u64x2 = m.unpack(d);
|
let d0: Mach::u64x2 = m.unpack(d);
|
||||||
let incr = Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]);
|
let incr =
|
||||||
|
Mach::u64x2x4::from_lanes([m.vec([0, 0]), m.vec([1, 0]), m.vec([2, 0]), m.vec([3, 0])]);
|
||||||
m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into())
|
m.unpack((Mach::u64x2x4::from_lanes([d0, d0, d0, d0]) + incr).into())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::many_single_char_names)]
|
#[allow(clippy::many_single_char_names)]
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn refill_wide_impl<Mach: Machine>(
|
fn refill_wide_impl<Mach: Machine>(
|
||||||
m: Mach, state: &mut ChaCha, drounds: u32, out: &mut [u32; BUFSZ],
|
m: Mach,
|
||||||
|
state: &mut ChaCha,
|
||||||
|
drounds: u32,
|
||||||
|
out: &mut [u32; BUFSZ],
|
||||||
) {
|
) {
|
||||||
let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]);
|
let k = m.vec([0x6170_7865, 0x3320_646e, 0x7962_2d32, 0x6b20_6574]);
|
||||||
let b = m.unpack(state.b);
|
let b = m.unpack(state.b);
|
||||||
|
@ -1,4 +1,5 @@
|
|||||||
#[cfg(feature = "alloc")] use alloc::boxed::Box;
|
#[cfg(feature = "alloc")]
|
||||||
|
use alloc::boxed::Box;
|
||||||
|
|
||||||
use crate::{CryptoRng, RngCore, TryCryptoRng, TryRngCore};
|
use crate::{CryptoRng, RngCore, TryCryptoRng, TryRngCore};
|
||||||
|
|
||||||
|
@ -56,7 +56,8 @@
|
|||||||
use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks};
|
use crate::impls::{fill_via_u32_chunks, fill_via_u64_chunks};
|
||||||
use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore};
|
use crate::{CryptoRng, RngCore, SeedableRng, TryRngCore};
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A trait for RNGs which do not generate random numbers individually, but in
|
/// A trait for RNGs which do not generate random numbers individually, but in
|
||||||
/// blocks (typically `[u32; N]`). This technique is commonly used by
|
/// blocks (typically `[u32; N]`). This technique is commonly used by
|
||||||
|
@ -199,7 +199,7 @@ macro_rules! impl_try_rng_from_rng_core {
|
|||||||
macro_rules! impl_try_crypto_rng_from_crypto_rng {
|
macro_rules! impl_try_crypto_rng_from_crypto_rng {
|
||||||
($t:ty) => {
|
($t:ty) => {
|
||||||
$crate::impl_try_rng_from_rng_core!($t);
|
$crate::impl_try_rng_from_rng_core!($t);
|
||||||
|
|
||||||
impl $crate::TryCryptoRng for $t {}
|
impl $crate::TryCryptoRng for $t {}
|
||||||
|
|
||||||
/// Check at compile time that `$t` implements `CryptoRng`
|
/// Check at compile time that `$t` implements `CryptoRng`
|
||||||
|
@ -32,11 +32,14 @@
|
|||||||
#![deny(missing_docs)]
|
#![deny(missing_docs)]
|
||||||
#![deny(missing_debug_implementations)]
|
#![deny(missing_debug_implementations)]
|
||||||
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
||||||
|
#![allow(unexpected_cfgs)]
|
||||||
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
||||||
#![no_std]
|
#![no_std]
|
||||||
|
|
||||||
#[cfg(feature = "alloc")] extern crate alloc;
|
#[cfg(feature = "alloc")]
|
||||||
#[cfg(feature = "std")] extern crate std;
|
extern crate alloc;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
extern crate std;
|
||||||
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
|
||||||
@ -44,11 +47,13 @@ mod blanket_impls;
|
|||||||
pub mod block;
|
pub mod block;
|
||||||
pub mod impls;
|
pub mod impls;
|
||||||
pub mod le;
|
pub mod le;
|
||||||
#[cfg(feature = "getrandom")] mod os;
|
#[cfg(feature = "getrandom")]
|
||||||
|
mod os;
|
||||||
#[cfg(feature = "getrandom")] pub use getrandom;
|
|
||||||
#[cfg(feature = "getrandom")] pub use os::OsRng;
|
|
||||||
|
|
||||||
|
#[cfg(feature = "getrandom")]
|
||||||
|
pub use getrandom;
|
||||||
|
#[cfg(feature = "getrandom")]
|
||||||
|
pub use os::OsRng;
|
||||||
|
|
||||||
/// The core of a random number generator.
|
/// The core of a random number generator.
|
||||||
///
|
///
|
||||||
@ -213,14 +218,18 @@ pub trait TryRngCore {
|
|||||||
|
|
||||||
/// Wrap RNG with the [`UnwrapErr`] wrapper.
|
/// Wrap RNG with the [`UnwrapErr`] wrapper.
|
||||||
fn unwrap_err(self) -> UnwrapErr<Self>
|
fn unwrap_err(self) -> UnwrapErr<Self>
|
||||||
where Self: Sized {
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
UnwrapErr(self)
|
UnwrapErr(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Convert an [`RngCore`] to a [`RngReadAdapter`].
|
/// Convert an [`RngCore`] to a [`RngReadAdapter`].
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
fn read_adapter(&mut self) -> RngReadAdapter<'_, Self>
|
fn read_adapter(&mut self) -> RngReadAdapter<'_, Self>
|
||||||
where Self: Sized {
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
RngReadAdapter { inner: self }
|
RngReadAdapter { inner: self }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -10,11 +10,11 @@
|
|||||||
//! The binomial distribution.
|
//! The binomial distribution.
|
||||||
|
|
||||||
use crate::{Distribution, Uniform};
|
use crate::{Distribution, Uniform};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
|
||||||
use core::cmp::Ordering;
|
use core::cmp::Ordering;
|
||||||
|
use core::fmt;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// The binomial distribution `Binomial(n, p)`.
|
/// The binomial distribution `Binomial(n, p)`.
|
||||||
///
|
///
|
||||||
@ -110,21 +110,21 @@ impl Distribution<u64> for Binomial {
|
|||||||
// Threshold for preferring the BINV algorithm. The paper suggests 10,
|
// Threshold for preferring the BINV algorithm. The paper suggests 10,
|
||||||
// Ranlib uses 30, and GSL uses 14.
|
// Ranlib uses 30, and GSL uses 14.
|
||||||
const BINV_THRESHOLD: f64 = 10.;
|
const BINV_THRESHOLD: f64 = 10.;
|
||||||
|
|
||||||
// Same value as in GSL.
|
// Same value as in GSL.
|
||||||
// It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
|
// It is possible for BINV to get stuck, so we break if x > BINV_MAX_X and try again.
|
||||||
// It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
|
// It would be safer to set BINV_MAX_X to self.n, but it is extremely unlikely to be relevant.
|
||||||
// When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
|
// When n*p < 10, so is n*p*q which is the variance, so a result > 110 would be 100 / sqrt(10) = 31 standard deviations away.
|
||||||
const BINV_MAX_X : u64 = 110;
|
const BINV_MAX_X: u64 = 110;
|
||||||
|
|
||||||
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) {
|
if (self.n as f64) * p < BINV_THRESHOLD && self.n <= (i32::MAX as u64) {
|
||||||
// Use the BINV algorithm.
|
// Use the BINV algorithm.
|
||||||
let s = p / q;
|
let s = p / q;
|
||||||
let a = ((self.n + 1) as f64) * s;
|
let a = ((self.n + 1) as f64) * s;
|
||||||
|
|
||||||
result = 'outer: loop {
|
result = 'outer: loop {
|
||||||
let mut r = q.powi(self.n as i32);
|
let mut r = q.powi(self.n as i32);
|
||||||
let mut u: f64 = rng.gen();
|
let mut u: f64 = rng.random();
|
||||||
let mut x = 0;
|
let mut x = 0;
|
||||||
|
|
||||||
while u > r {
|
while u > r {
|
||||||
@ -136,7 +136,6 @@ impl Distribution<u64> for Binomial {
|
|||||||
r *= a / (x as f64) - s;
|
r *= a / (x as f64) - s;
|
||||||
}
|
}
|
||||||
break x;
|
break x;
|
||||||
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
// Use the BTPE algorithm.
|
// Use the BTPE algorithm.
|
||||||
@ -238,7 +237,7 @@ impl Distribution<u64> for Binomial {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Ordering::Greater => {
|
Ordering::Greater => {
|
||||||
let mut i = y;
|
let mut i = y;
|
||||||
loop {
|
loop {
|
||||||
@ -248,8 +247,8 @@ impl Distribution<u64> for Binomial {
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
Ordering::Equal => {},
|
Ordering::Equal => {}
|
||||||
}
|
}
|
||||||
if v > f {
|
if v > f {
|
||||||
continue;
|
continue;
|
||||||
@ -366,7 +365,7 @@ mod test {
|
|||||||
fn binomial_distributions_can_be_compared() {
|
fn binomial_distributions_can_be_compared() {
|
||||||
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
|
assert_eq!(Binomial::new(1, 1.0), Binomial::new(1, 1.0));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn binomial_avoid_infinite_loop() {
|
fn binomial_avoid_infinite_loop() {
|
||||||
let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
|
let dist = Binomial::new(16000000, 3.1444753148558566e-10).unwrap();
|
||||||
|
@ -9,10 +9,10 @@
|
|||||||
|
|
||||||
//! The Cauchy distribution.
|
//! The Cauchy distribution.
|
||||||
|
|
||||||
use num_traits::{Float, FloatConst};
|
|
||||||
use crate::{Distribution, Standard};
|
use crate::{Distribution, Standard};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::{Float, FloatConst};
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// The Cauchy distribution `Cauchy(median, scale)`.
|
/// The Cauchy distribution `Cauchy(median, scale)`.
|
||||||
///
|
///
|
||||||
@ -34,7 +34,9 @@ use core::fmt;
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Cauchy<F>
|
pub struct Cauchy<F>
|
||||||
where F: Float + FloatConst, Standard: Distribution<F>
|
where
|
||||||
|
F: Float + FloatConst,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
median: F,
|
median: F,
|
||||||
scale: F,
|
scale: F,
|
||||||
@ -60,7 +62,9 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
impl<F> Cauchy<F>
|
impl<F> Cauchy<F>
|
||||||
where F: Float + FloatConst, Standard: Distribution<F>
|
where
|
||||||
|
F: Float + FloatConst,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct a new `Cauchy` with the given shape parameters
|
/// Construct a new `Cauchy` with the given shape parameters
|
||||||
/// `median` the peak location and `scale` the scale factor.
|
/// `median` the peak location and `scale` the scale factor.
|
||||||
@ -73,7 +77,9 @@ where F: Float + FloatConst, Standard: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Cauchy<F>
|
impl<F> Distribution<F> for Cauchy<F>
|
||||||
where F: Float + FloatConst, Standard: Distribution<F>
|
where
|
||||||
|
F: Float + FloatConst,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
// sample from [0, 1)
|
// sample from [0, 1)
|
||||||
@ -138,7 +144,9 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn gen_samples<F: Float + FloatConst + fmt::Debug>(m: F, s: F, buf: &mut [F])
|
fn gen_samples<F: Float + FloatConst + fmt::Debug>(m: F, s: F, buf: &mut [F])
|
||||||
where Standard: Distribution<F> {
|
where
|
||||||
|
Standard: Distribution<F>,
|
||||||
|
{
|
||||||
let distr = Cauchy::new(m, s).unwrap();
|
let distr = Cauchy::new(m, s).unwrap();
|
||||||
let mut rng = crate::test::rng(353);
|
let mut rng = crate::test::rng(353);
|
||||||
for x in buf {
|
for x in buf {
|
||||||
@ -148,12 +156,15 @@ mod test {
|
|||||||
|
|
||||||
let mut buf = [0.0; 4];
|
let mut buf = [0.0; 4];
|
||||||
gen_samples(100f64, 10.0, &mut buf);
|
gen_samples(100f64, 10.0, &mut buf);
|
||||||
assert_eq!(&buf, &[
|
assert_eq!(
|
||||||
77.93369152808678,
|
&buf,
|
||||||
90.1606912098641,
|
&[
|
||||||
125.31516221323625,
|
77.93369152808678,
|
||||||
86.10217834773925
|
90.1606912098641,
|
||||||
]);
|
125.31516221323625,
|
||||||
|
86.10217834773925
|
||||||
|
]
|
||||||
|
);
|
||||||
|
|
||||||
// Unfortunately this test is not fully portable due to reliance on the
|
// Unfortunately this test is not fully portable due to reliance on the
|
||||||
// system's implementation of tanf (see doc on Cauchy struct).
|
// system's implementation of tanf (see doc on Cauchy struct).
|
||||||
|
@ -13,7 +13,8 @@ use crate::{Beta, Distribution, Exp1, Gamma, Open01, StandardNormal};
|
|||||||
use core::fmt;
|
use core::fmt;
|
||||||
use num_traits::{Float, NumCast};
|
use num_traits::{Float, NumCast};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
#[cfg(feature = "serde_with")] use serde_with::serde_as;
|
#[cfg(feature = "serde_with")]
|
||||||
|
use serde_with::serde_as;
|
||||||
|
|
||||||
use alloc::{boxed::Box, vec, vec::Vec};
|
use alloc::{boxed::Box, vec, vec::Vec};
|
||||||
|
|
||||||
|
@ -10,10 +10,10 @@
|
|||||||
//! The exponential distribution.
|
//! The exponential distribution.
|
||||||
|
|
||||||
use crate::utils::ziggurat;
|
use crate::utils::ziggurat;
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{ziggurat_tables, Distribution};
|
use crate::{ziggurat_tables, Distribution};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples floating-point numbers according to the exponential distribution,
|
/// Samples floating-point numbers according to the exponential distribution,
|
||||||
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
|
/// with rate parameter `λ = 1`. This is equivalent to `Exp::new(1.0)` or
|
||||||
@ -61,7 +61,7 @@ impl Distribution<f64> for Exp1 {
|
|||||||
}
|
}
|
||||||
#[inline]
|
#[inline]
|
||||||
fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 {
|
fn zero_case<R: Rng + ?Sized>(rng: &mut R, _u: f64) -> f64 {
|
||||||
ziggurat_tables::ZIG_EXP_R - rng.gen::<f64>().ln()
|
ziggurat_tables::ZIG_EXP_R - rng.random::<f64>().ln()
|
||||||
}
|
}
|
||||||
|
|
||||||
ziggurat(
|
ziggurat(
|
||||||
@ -94,7 +94,9 @@ impl Distribution<f64> for Exp1 {
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Exp<F>
|
pub struct Exp<F>
|
||||||
where F: Float, Exp1: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Exp1: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// `lambda` stored as `1/lambda`, since this is what we scale by.
|
/// `lambda` stored as `1/lambda`, since this is what we scale by.
|
||||||
lambda_inverse: F,
|
lambda_inverse: F,
|
||||||
@ -120,16 +122,18 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
impl<F: Float> Exp<F>
|
impl<F: Float> Exp<F>
|
||||||
where F: Float, Exp1: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Exp1: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct a new `Exp` with the given shape parameter
|
/// Construct a new `Exp` with the given shape parameter
|
||||||
/// `lambda`.
|
/// `lambda`.
|
||||||
///
|
///
|
||||||
/// # Remarks
|
/// # Remarks
|
||||||
///
|
///
|
||||||
/// For custom types `N` implementing the [`Float`] trait,
|
/// For custom types `N` implementing the [`Float`] trait,
|
||||||
/// the case `lambda = 0` is handled as follows: each sample corresponds
|
/// the case `lambda = 0` is handled as follows: each sample corresponds
|
||||||
/// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types
|
/// to a sample from an `Exp1` multiplied by `1 / 0`. Primitive types
|
||||||
/// yield infinity, since `1 / 0 = infinity`.
|
/// yield infinity, since `1 / 0 = infinity`.
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn new(lambda: F) -> Result<Exp<F>, Error> {
|
pub fn new(lambda: F) -> Result<Exp<F>, Error> {
|
||||||
@ -143,7 +147,9 @@ where F: Float, Exp1: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Exp<F>
|
impl<F> Distribution<F> for Exp<F>
|
||||||
where F: Float, Exp1: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Exp1: Distribution<F>,
|
||||||
{
|
{
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
rng.sample(Exp1) * self.lambda_inverse
|
rng.sample(Exp1) * self.lambda_inverse
|
||||||
|
@ -17,12 +17,12 @@ use self::ChiSquaredRepr::*;
|
|||||||
use self::GammaRepr::*;
|
use self::GammaRepr::*;
|
||||||
|
|
||||||
use crate::normal::StandardNormal;
|
use crate::normal::StandardNormal;
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{Distribution, Exp, Exp1, Open01};
|
use crate::{Distribution, Exp, Exp1, Open01};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// The Gamma distribution `Gamma(shape, scale)` distribution.
|
/// The Gamma distribution `Gamma(shape, scale)` distribution.
|
||||||
///
|
///
|
||||||
@ -566,7 +566,9 @@ where
|
|||||||
F: Float,
|
F: Float,
|
||||||
Open01: Distribution<F>,
|
Open01: Distribution<F>,
|
||||||
{
|
{
|
||||||
a: F, b: F, switched_params: bool,
|
a: F,
|
||||||
|
b: F,
|
||||||
|
switched_params: bool,
|
||||||
algorithm: BetaAlgorithm<F>,
|
algorithm: BetaAlgorithm<F>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -618,15 +620,19 @@ where
|
|||||||
if a > F::one() {
|
if a > F::one() {
|
||||||
// Algorithm BB
|
// Algorithm BB
|
||||||
let alpha = a + b;
|
let alpha = a + b;
|
||||||
let beta = ((alpha - F::from(2.).unwrap())
|
|
||||||
/ (F::from(2.).unwrap()*a*b - alpha)).sqrt();
|
let two = F::from(2.).unwrap();
|
||||||
|
let beta_numer = alpha - two;
|
||||||
|
let beta_denom = two * a * b - alpha;
|
||||||
|
let beta = (beta_numer / beta_denom).sqrt();
|
||||||
|
|
||||||
let gamma = a + F::one() / beta;
|
let gamma = a + F::one() / beta;
|
||||||
|
|
||||||
Ok(Beta {
|
Ok(Beta {
|
||||||
a, b, switched_params,
|
a,
|
||||||
algorithm: BetaAlgorithm::BB(BB {
|
b,
|
||||||
alpha, beta, gamma,
|
switched_params,
|
||||||
})
|
algorithm: BetaAlgorithm::BB(BB { alpha, beta, gamma }),
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
// Algorithm BC
|
// Algorithm BC
|
||||||
@ -637,16 +643,21 @@ where
|
|||||||
let beta = F::one() / b;
|
let beta = F::one() / b;
|
||||||
let delta = F::one() + a - b;
|
let delta = F::one() + a - b;
|
||||||
let kappa1 = delta
|
let kappa1 = delta
|
||||||
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap()*b)
|
* (F::from(1. / 18. / 4.).unwrap() + F::from(3. / 18. / 4.).unwrap() * b)
|
||||||
/ (a*beta - F::from(14. / 18.).unwrap());
|
/ (a * beta - F::from(14. / 18.).unwrap());
|
||||||
let kappa2 = F::from(0.25).unwrap()
|
let kappa2 = F::from(0.25).unwrap()
|
||||||
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap()/delta)*b;
|
+ (F::from(0.5).unwrap() + F::from(0.25).unwrap() / delta) * b;
|
||||||
|
|
||||||
Ok(Beta {
|
Ok(Beta {
|
||||||
a, b, switched_params,
|
a,
|
||||||
|
b,
|
||||||
|
switched_params,
|
||||||
algorithm: BetaAlgorithm::BC(BC {
|
algorithm: BetaAlgorithm::BC(BC {
|
||||||
alpha, beta, kappa1, kappa2,
|
alpha,
|
||||||
})
|
beta,
|
||||||
|
kappa1,
|
||||||
|
kappa2,
|
||||||
|
}),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -667,12 +678,11 @@ where
|
|||||||
let u2 = rng.sample(Open01);
|
let u2 = rng.sample(Open01);
|
||||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||||
w = self.a * v.exp();
|
w = self.a * v.exp();
|
||||||
let z = u1*u1 * u2;
|
let z = u1 * u1 * u2;
|
||||||
let r = algo.gamma * v - F::from(4.).unwrap().ln();
|
let r = algo.gamma * v - F::from(4.).unwrap().ln();
|
||||||
let s = self.a + r - w;
|
let s = self.a + r - w;
|
||||||
// 2.
|
// 2.
|
||||||
if s + F::one() + F::from(5.).unwrap().ln()
|
if s + F::one() + F::from(5.).unwrap().ln() >= F::from(5.).unwrap() * z {
|
||||||
>= F::from(5.).unwrap() * z {
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
// 3.
|
// 3.
|
||||||
@ -685,7 +695,7 @@ where
|
|||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
BetaAlgorithm::BC(algo) => {
|
BetaAlgorithm::BC(algo) => {
|
||||||
loop {
|
loop {
|
||||||
let z;
|
let z;
|
||||||
@ -716,11 +726,13 @@ where
|
|||||||
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
let v = algo.beta * (u1 / (F::one() - u1)).ln();
|
||||||
w = self.a * v.exp();
|
w = self.a * v.exp();
|
||||||
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
|
if !(algo.alpha * ((algo.alpha / (self.b + w)).ln() + v)
|
||||||
- F::from(4.).unwrap().ln() < z.ln()) {
|
- F::from(4.).unwrap().ln()
|
||||||
|
< z.ln())
|
||||||
|
{
|
||||||
break;
|
break;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
};
|
};
|
||||||
// 5. for BB, 6. for BC
|
// 5. for BB, 6. for BC
|
||||||
if !self.switched_params {
|
if !self.switched_params {
|
||||||
|
@ -1,20 +1,20 @@
|
|||||||
//! The geometric distribution.
|
//! The geometric distribution.
|
||||||
|
|
||||||
use crate::Distribution;
|
use crate::Distribution;
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`.
|
/// The geometric distribution `Geometric(p)` bounded to `[0, u64::MAX]`.
|
||||||
///
|
///
|
||||||
/// This is the probability distribution of the number of failures before the
|
/// This is the probability distribution of the number of failures before the
|
||||||
/// first success in a series of Bernoulli trials. It has the density function
|
/// first success in a series of Bernoulli trials. It has the density function
|
||||||
/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success
|
/// `f(k) = (1 - p)^k p` for `k >= 0`, where `p` is the probability of success
|
||||||
/// on each trial.
|
/// on each trial.
|
||||||
///
|
///
|
||||||
/// This is the discrete analogue of the [exponential distribution](crate::Exp).
|
/// This is the discrete analogue of the [exponential distribution](crate::Exp).
|
||||||
///
|
///
|
||||||
/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised
|
/// Note that [`StandardGeometric`](crate::StandardGeometric) is an optimised
|
||||||
/// implementation for `p = 0.5`.
|
/// implementation for `p = 0.5`.
|
||||||
///
|
///
|
||||||
@ -29,11 +29,10 @@ use num_traits::Float;
|
|||||||
/// ```
|
/// ```
|
||||||
#[derive(Copy, Clone, Debug, PartialEq)]
|
#[derive(Copy, Clone, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Geometric
|
pub struct Geometric {
|
||||||
{
|
|
||||||
p: f64,
|
p: f64,
|
||||||
pi: f64,
|
pi: f64,
|
||||||
k: u64
|
k: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Error type returned from `Geometric::new`.
|
/// Error type returned from `Geometric::new`.
|
||||||
@ -46,7 +45,9 @@ pub enum Error {
|
|||||||
impl fmt::Display for Error {
|
impl fmt::Display for Error {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.write_str(match self {
|
f.write_str(match self {
|
||||||
Error::InvalidProbability => "p is NaN or outside the interval [0, 1] in geometric distribution",
|
Error::InvalidProbability => {
|
||||||
|
"p is NaN or outside the interval [0, 1] in geometric distribution"
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -80,21 +81,24 @@ impl Geometric {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Distribution<u64> for Geometric
|
impl Distribution<u64> for Geometric {
|
||||||
{
|
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
||||||
if self.p >= 2.0 / 3.0 {
|
if self.p >= 2.0 / 3.0 {
|
||||||
// use the trivial algorithm:
|
// use the trivial algorithm:
|
||||||
let mut failures = 0;
|
let mut failures = 0;
|
||||||
loop {
|
loop {
|
||||||
let u = rng.gen::<f64>();
|
let u = rng.random::<f64>();
|
||||||
if u <= self.p { break; }
|
if u <= self.p {
|
||||||
|
break;
|
||||||
|
}
|
||||||
failures += 1;
|
failures += 1;
|
||||||
}
|
}
|
||||||
return failures;
|
return failures;
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.p == 0.0 { return u64::MAX; }
|
if self.p == 0.0 {
|
||||||
|
return u64::MAX;
|
||||||
|
}
|
||||||
|
|
||||||
let Geometric { p, pi, k } = *self;
|
let Geometric { p, pi, k } = *self;
|
||||||
|
|
||||||
@ -108,7 +112,7 @@ impl Distribution<u64> for Geometric
|
|||||||
// Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k:
|
// Use the trivial algorithm to sample D from Geo(pi) = Geo(p) / 2^k:
|
||||||
let d = {
|
let d = {
|
||||||
let mut failures = 0;
|
let mut failures = 0;
|
||||||
while rng.gen::<f64>() < pi {
|
while rng.random::<f64>() < pi {
|
||||||
failures += 1;
|
failures += 1;
|
||||||
}
|
}
|
||||||
failures
|
failures
|
||||||
@ -116,18 +120,18 @@ impl Distribution<u64> for Geometric
|
|||||||
|
|
||||||
// Use rejection sampling for the remainder M from Geo(p) % 2^k:
|
// Use rejection sampling for the remainder M from Geo(p) % 2^k:
|
||||||
// choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M
|
// choose M uniformly from [0, 2^k), but reject with probability (1 - p)^M
|
||||||
// NOTE: The paper suggests using bitwise sampling here, which is
|
// NOTE: The paper suggests using bitwise sampling here, which is
|
||||||
// currently unsupported, but should improve performance by requiring
|
// currently unsupported, but should improve performance by requiring
|
||||||
// fewer iterations on average. ~ October 28, 2020
|
// fewer iterations on average. ~ October 28, 2020
|
||||||
let m = loop {
|
let m = loop {
|
||||||
let m = rng.gen::<u64>() & ((1 << k) - 1);
|
let m = rng.random::<u64>() & ((1 << k) - 1);
|
||||||
let p_reject = if m <= i32::MAX as u64 {
|
let p_reject = if m <= i32::MAX as u64 {
|
||||||
(1.0 - p).powi(m as i32)
|
(1.0 - p).powi(m as i32)
|
||||||
} else {
|
} else {
|
||||||
(1.0 - p).powf(m as f64)
|
(1.0 - p).powf(m as f64)
|
||||||
};
|
};
|
||||||
|
|
||||||
let u = rng.gen::<f64>();
|
let u = rng.random::<f64>();
|
||||||
if u < p_reject {
|
if u < p_reject {
|
||||||
break m;
|
break m;
|
||||||
}
|
}
|
||||||
@ -140,17 +144,17 @@ impl Distribution<u64> for Geometric
|
|||||||
/// Samples integers according to the geometric distribution with success
|
/// Samples integers according to the geometric distribution with success
|
||||||
/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`,
|
/// probability `p = 0.5`. This is equivalent to `Geometeric::new(0.5)`,
|
||||||
/// but faster.
|
/// but faster.
|
||||||
///
|
///
|
||||||
/// See [`Geometric`](crate::Geometric) for the general geometric distribution.
|
/// See [`Geometric`](crate::Geometric) for the general geometric distribution.
|
||||||
///
|
///
|
||||||
/// Implemented via iterated
|
/// Implemented via iterated
|
||||||
/// [`Rng::gen::<u64>().leading_zeros()`](Rng::gen::<u64>().leading_zeros()).
|
/// [`Rng::gen::<u64>().leading_zeros()`](Rng::gen::<u64>().leading_zeros()).
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
/// ```
|
/// ```
|
||||||
/// use rand::prelude::*;
|
/// use rand::prelude::*;
|
||||||
/// use rand_distr::StandardGeometric;
|
/// use rand_distr::StandardGeometric;
|
||||||
///
|
///
|
||||||
/// let v = StandardGeometric.sample(&mut thread_rng());
|
/// let v = StandardGeometric.sample(&mut thread_rng());
|
||||||
/// println!("{} is from a Geometric(0.5) distribution", v);
|
/// println!("{} is from a Geometric(0.5) distribution", v);
|
||||||
/// ```
|
/// ```
|
||||||
@ -162,9 +166,11 @@ impl Distribution<u64> for StandardGeometric {
|
|||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
||||||
let mut result = 0;
|
let mut result = 0;
|
||||||
loop {
|
loop {
|
||||||
let x = rng.gen::<u64>().leading_zeros() as u64;
|
let x = rng.random::<u64>().leading_zeros() as u64;
|
||||||
result += x;
|
result += x;
|
||||||
if x < 64 { break; }
|
if x < 64 {
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
@ -1,17 +1,20 @@
|
|||||||
//! The hypergeometric distribution.
|
//! The hypergeometric distribution.
|
||||||
|
|
||||||
use crate::Distribution;
|
use crate::Distribution;
|
||||||
use rand::Rng;
|
|
||||||
use rand::distributions::uniform::Uniform;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
|
use rand::distributions::uniform::Uniform;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
enum SamplingMethod {
|
enum SamplingMethod {
|
||||||
InverseTransform{ initial_p: f64, initial_x: i64 },
|
InverseTransform {
|
||||||
RejectionAcceptance{
|
initial_p: f64,
|
||||||
|
initial_x: i64,
|
||||||
|
},
|
||||||
|
RejectionAcceptance {
|
||||||
m: f64,
|
m: f64,
|
||||||
a: f64,
|
a: f64,
|
||||||
lambda_l: f64,
|
lambda_l: f64,
|
||||||
@ -20,24 +23,24 @@ enum SamplingMethod {
|
|||||||
x_r: f64,
|
x_r: f64,
|
||||||
p1: f64,
|
p1: f64,
|
||||||
p2: f64,
|
p2: f64,
|
||||||
p3: f64
|
p3: f64,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
/// The hypergeometric distribution `Hypergeometric(N, K, n)`.
|
/// The hypergeometric distribution `Hypergeometric(N, K, n)`.
|
||||||
///
|
///
|
||||||
/// This is the distribution of successes in samples of size `n` drawn without
|
/// This is the distribution of successes in samples of size `n` drawn without
|
||||||
/// replacement from a population of size `N` containing `K` success states.
|
/// replacement from a population of size `N` containing `K` success states.
|
||||||
/// It has the density function:
|
/// It has the density function:
|
||||||
/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`,
|
/// `f(k) = binomial(K, k) * binomial(N-K, n-k) / binomial(N, n)`,
|
||||||
/// where `binomial(a, b) = a! / (b! * (a - b)!)`.
|
/// where `binomial(a, b) = a! / (b! * (a - b)!)`.
|
||||||
///
|
///
|
||||||
/// The [binomial distribution](crate::Binomial) is the analogous distribution
|
/// The [binomial distribution](crate::Binomial) is the analogous distribution
|
||||||
/// for sampling with replacement. It is a good approximation when the population
|
/// for sampling with replacement. It is a good approximation when the population
|
||||||
/// size is much larger than the sample size.
|
/// size is much larger than the sample size.
|
||||||
///
|
///
|
||||||
/// # Example
|
/// # Example
|
||||||
///
|
///
|
||||||
/// ```
|
/// ```
|
||||||
/// use rand_distr::{Distribution, Hypergeometric};
|
/// use rand_distr::{Distribution, Hypergeometric};
|
||||||
///
|
///
|
||||||
@ -70,9 +73,15 @@ pub enum Error {
|
|||||||
impl fmt::Display for Error {
|
impl fmt::Display for Error {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.write_str(match self {
|
f.write_str(match self {
|
||||||
Error::PopulationTooLarge => "total_population_size is too large causing underflow in geometric distribution",
|
Error::PopulationTooLarge => {
|
||||||
Error::ProbabilityTooLarge => "population_with_feature > total_population_size in geometric distribution",
|
"total_population_size is too large causing underflow in geometric distribution"
|
||||||
Error::SampleSizeTooLarge => "sample_size > total_population_size in geometric distribution",
|
}
|
||||||
|
Error::ProbabilityTooLarge => {
|
||||||
|
"population_with_feature > total_population_size in geometric distribution"
|
||||||
|
}
|
||||||
|
Error::SampleSizeTooLarge => {
|
||||||
|
"sample_size > total_population_size in geometric distribution"
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -97,20 +106,20 @@ fn fraction_of_products_of_factorials(numerator: (u64, u64), denominator: (u64,
|
|||||||
if i <= min_top {
|
if i <= min_top {
|
||||||
result *= i as f64;
|
result *= i as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
if i <= min_bottom {
|
if i <= min_bottom {
|
||||||
result /= i as f64;
|
result /= i as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
if i <= max_top {
|
if i <= max_top {
|
||||||
result *= i as f64;
|
result *= i as f64;
|
||||||
}
|
}
|
||||||
|
|
||||||
if i <= max_bottom {
|
if i <= max_bottom {
|
||||||
result /= i as f64;
|
result /= i as f64;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -126,7 +135,11 @@ impl Hypergeometric {
|
|||||||
/// `K = population_with_feature`,
|
/// `K = population_with_feature`,
|
||||||
/// `n = sample_size`.
|
/// `n = sample_size`.
|
||||||
#[allow(clippy::many_single_char_names)] // Same names as in the reference.
|
#[allow(clippy::many_single_char_names)] // Same names as in the reference.
|
||||||
pub fn new(total_population_size: u64, population_with_feature: u64, sample_size: u64) -> Result<Self, Error> {
|
pub fn new(
|
||||||
|
total_population_size: u64,
|
||||||
|
population_with_feature: u64,
|
||||||
|
sample_size: u64,
|
||||||
|
) -> Result<Self, Error> {
|
||||||
if population_with_feature > total_population_size {
|
if population_with_feature > total_population_size {
|
||||||
return Err(Error::ProbabilityTooLarge);
|
return Err(Error::ProbabilityTooLarge);
|
||||||
}
|
}
|
||||||
@ -151,7 +164,7 @@ impl Hypergeometric {
|
|||||||
};
|
};
|
||||||
// when sampling more than half the total population, take the smaller
|
// when sampling more than half the total population, take the smaller
|
||||||
// group as sampled instead (we can then return n1-x instead).
|
// group as sampled instead (we can then return n1-x instead).
|
||||||
//
|
//
|
||||||
// Note: the boundary condition given in the paper is `sample_size < n / 2`;
|
// Note: the boundary condition given in the paper is `sample_size < n / 2`;
|
||||||
// we're deviating here, because when n is even, it doesn't matter whether
|
// we're deviating here, because when n is even, it doesn't matter whether
|
||||||
// we switch here or not, but when n is odd `n/2 < n - n/2`, so switching
|
// we switch here or not, but when n is odd `n/2 < n - n/2`, so switching
|
||||||
@ -167,7 +180,7 @@ impl Hypergeometric {
|
|||||||
// Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`,
|
// Algorithm H2PE has bounded runtime only if `M - max(0, k-n2) >= 10`,
|
||||||
// where `M` is the mode of the distribution.
|
// where `M` is the mode of the distribution.
|
||||||
// Use algorithm HIN for the remaining parameter space.
|
// Use algorithm HIN for the remaining parameter space.
|
||||||
//
|
//
|
||||||
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer
|
// Voratas Kachitvichyanukul and Bruce W. Schmeiser. 1985. Computer
|
||||||
// generation of hypergeometric random variates.
|
// generation of hypergeometric random variates.
|
||||||
// J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145
|
// J. Statist. Comput. Simul. Vol.22 (August 1985), 127-145
|
||||||
@ -176,21 +189,30 @@ impl Hypergeometric {
|
|||||||
let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
|
let m = ((k + 1) as f64 * (n1 + 1) as f64 / (n + 2) as f64).floor();
|
||||||
let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
|
let sampling_method = if m - f64::max(0.0, k as f64 - n2 as f64) < HIN_THRESHOLD {
|
||||||
let (initial_p, initial_x) = if k < n2 {
|
let (initial_p, initial_x) = if k < n2 {
|
||||||
(fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)), 0)
|
(
|
||||||
|
fraction_of_products_of_factorials((n2, n - k), (n, n2 - k)),
|
||||||
|
0,
|
||||||
|
)
|
||||||
} else {
|
} else {
|
||||||
(fraction_of_products_of_factorials((n1, k), (n, k - n2)), (k - n2) as i64)
|
(
|
||||||
|
fraction_of_products_of_factorials((n1, k), (n, k - n2)),
|
||||||
|
(k - n2) as i64,
|
||||||
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
if initial_p <= 0.0 || !initial_p.is_finite() {
|
if initial_p <= 0.0 || !initial_p.is_finite() {
|
||||||
return Err(Error::PopulationTooLarge);
|
return Err(Error::PopulationTooLarge);
|
||||||
}
|
}
|
||||||
|
|
||||||
SamplingMethod::InverseTransform { initial_p, initial_x }
|
SamplingMethod::InverseTransform {
|
||||||
|
initial_p,
|
||||||
|
initial_x,
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
let a = ln_of_factorial(m) +
|
let a = ln_of_factorial(m)
|
||||||
ln_of_factorial(n1 as f64 - m) +
|
+ ln_of_factorial(n1 as f64 - m)
|
||||||
ln_of_factorial(k as f64 - m) +
|
+ ln_of_factorial(k as f64 - m)
|
||||||
ln_of_factorial((n2 - k) as f64 + m);
|
+ ln_of_factorial((n2 - k) as f64 + m);
|
||||||
|
|
||||||
let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
|
let numerator = (n - k) as f64 * k as f64 * n1 as f64 * n2 as f64;
|
||||||
let denominator = (n - 1) as f64 * n as f64 * n as f64;
|
let denominator = (n - 1) as f64 * n as f64 * n as f64;
|
||||||
@ -199,17 +221,19 @@ impl Hypergeometric {
|
|||||||
let x_l = m - d + 0.5;
|
let x_l = m - d + 0.5;
|
||||||
let x_r = m + d + 0.5;
|
let x_r = m + d + 0.5;
|
||||||
|
|
||||||
let k_l = f64::exp(a -
|
let k_l = f64::exp(
|
||||||
ln_of_factorial(x_l) -
|
a - ln_of_factorial(x_l)
|
||||||
ln_of_factorial(n1 as f64 - x_l) -
|
- ln_of_factorial(n1 as f64 - x_l)
|
||||||
ln_of_factorial(k as f64 - x_l) -
|
- ln_of_factorial(k as f64 - x_l)
|
||||||
ln_of_factorial((n2 - k) as f64 + x_l));
|
- ln_of_factorial((n2 - k) as f64 + x_l),
|
||||||
let k_r = f64::exp(a -
|
);
|
||||||
ln_of_factorial(x_r - 1.0) -
|
let k_r = f64::exp(
|
||||||
ln_of_factorial(n1 as f64 - x_r + 1.0) -
|
a - ln_of_factorial(x_r - 1.0)
|
||||||
ln_of_factorial(k as f64 - x_r + 1.0) -
|
- ln_of_factorial(n1 as f64 - x_r + 1.0)
|
||||||
ln_of_factorial((n2 - k) as f64 + x_r - 1.0));
|
- ln_of_factorial(k as f64 - x_r + 1.0)
|
||||||
|
- ln_of_factorial((n2 - k) as f64 + x_r - 1.0),
|
||||||
|
);
|
||||||
|
|
||||||
let numerator = x_l * ((n2 - k) as f64 + x_l);
|
let numerator = x_l * ((n2 - k) as f64 + x_l);
|
||||||
let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
|
let denominator = (n1 as f64 - x_l + 1.0) * (k as f64 - x_l + 1.0);
|
||||||
let lambda_l = -((numerator / denominator).ln());
|
let lambda_l = -((numerator / denominator).ln());
|
||||||
@ -225,11 +249,26 @@ impl Hypergeometric {
|
|||||||
let p3 = p2 + k_r / lambda_r;
|
let p3 = p2 + k_r / lambda_r;
|
||||||
|
|
||||||
SamplingMethod::RejectionAcceptance {
|
SamplingMethod::RejectionAcceptance {
|
||||||
m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3
|
m,
|
||||||
|
a,
|
||||||
|
lambda_l,
|
||||||
|
lambda_r,
|
||||||
|
x_l,
|
||||||
|
x_r,
|
||||||
|
p1,
|
||||||
|
p2,
|
||||||
|
p3,
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(Hypergeometric { n1, n2, k, offset_x, sign_x, sampling_method })
|
Ok(Hypergeometric {
|
||||||
|
n1,
|
||||||
|
n2,
|
||||||
|
k,
|
||||||
|
offset_x,
|
||||||
|
sign_x,
|
||||||
|
sampling_method,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -238,25 +277,47 @@ impl Distribution<u64> for Hypergeometric {
|
|||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> u64 {
|
||||||
use SamplingMethod::*;
|
use SamplingMethod::*;
|
||||||
|
|
||||||
let Hypergeometric { n1, n2, k, sign_x, offset_x, sampling_method } = *self;
|
let Hypergeometric {
|
||||||
|
n1,
|
||||||
|
n2,
|
||||||
|
k,
|
||||||
|
sign_x,
|
||||||
|
offset_x,
|
||||||
|
sampling_method,
|
||||||
|
} = *self;
|
||||||
let x = match sampling_method {
|
let x = match sampling_method {
|
||||||
InverseTransform { initial_p: mut p, initial_x: mut x } => {
|
InverseTransform {
|
||||||
let mut u = rng.gen::<f64>();
|
initial_p: mut p,
|
||||||
while u > p && x < k as i64 { // the paper erroneously uses `until n < p`, which doesn't make any sense
|
initial_x: mut x,
|
||||||
|
} => {
|
||||||
|
let mut u = rng.random::<f64>();
|
||||||
|
|
||||||
|
// the paper erroneously uses `until n < p`, which doesn't make any sense
|
||||||
|
while u > p && x < k as i64 {
|
||||||
u -= p;
|
u -= p;
|
||||||
p *= ((n1 as i64 - x) * (k as i64 - x)) as f64;
|
p *= ((n1 as i64 - x) * (k as i64 - x)) as f64;
|
||||||
p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64;
|
p /= ((x + 1) * (n2 as i64 - k as i64 + 1 + x)) as f64;
|
||||||
x += 1;
|
x += 1;
|
||||||
}
|
}
|
||||||
x
|
x
|
||||||
},
|
}
|
||||||
RejectionAcceptance { m, a, lambda_l, lambda_r, x_l, x_r, p1, p2, p3 } => {
|
RejectionAcceptance {
|
||||||
|
m,
|
||||||
|
a,
|
||||||
|
lambda_l,
|
||||||
|
lambda_r,
|
||||||
|
x_l,
|
||||||
|
x_r,
|
||||||
|
p1,
|
||||||
|
p2,
|
||||||
|
p3,
|
||||||
|
} => {
|
||||||
let distr_region_select = Uniform::new(0.0, p3).unwrap();
|
let distr_region_select = Uniform::new(0.0, p3).unwrap();
|
||||||
loop {
|
loop {
|
||||||
let (y, v) = loop {
|
let (y, v) = loop {
|
||||||
let u = distr_region_select.sample(rng);
|
let u = distr_region_select.sample(rng);
|
||||||
let v = rng.gen::<f64>(); // for the accept/reject decision
|
let v = rng.random::<f64>(); // for the accept/reject decision
|
||||||
|
|
||||||
if u <= p1 {
|
if u <= p1 {
|
||||||
// Region 1, central bell
|
// Region 1, central bell
|
||||||
let y = (x_l + u).floor();
|
let y = (x_l + u).floor();
|
||||||
@ -277,7 +338,7 @@ impl Distribution<u64> for Hypergeometric {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// Step 4: Acceptance/Rejection Comparison
|
// Step 4: Acceptance/Rejection Comparison
|
||||||
if m < 100.0 || y <= 50.0 {
|
if m < 100.0 || y <= 50.0 {
|
||||||
// Step 4.1: evaluate f(y) via recursive relationship
|
// Step 4.1: evaluate f(y) via recursive relationship
|
||||||
@ -293,8 +354,10 @@ impl Distribution<u64> for Hypergeometric {
|
|||||||
f /= (n1 - i) as f64 * (k - i) as f64;
|
f /= (n1 - i) as f64 * (k - i) as f64;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if v <= f { break y as i64; }
|
if v <= f {
|
||||||
|
break y as i64;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
// Step 4.2: Squeezing
|
// Step 4.2: Squeezing
|
||||||
let y1 = y + 1.0;
|
let y1 = y + 1.0;
|
||||||
@ -307,24 +370,24 @@ impl Distribution<u64> for Hypergeometric {
|
|||||||
let t = ym / yk;
|
let t = ym / yk;
|
||||||
let e = -ym / nk;
|
let e = -ym / nk;
|
||||||
let g = yn * yk / (y1 * nk) - 1.0;
|
let g = yn * yk / (y1 * nk) - 1.0;
|
||||||
let dg = if g < 0.0 {
|
let dg = if g < 0.0 { 1.0 + g } else { 1.0 };
|
||||||
1.0 + g
|
|
||||||
} else {
|
|
||||||
1.0
|
|
||||||
};
|
|
||||||
let gu = g * (1.0 + g * (-0.5 + g / 3.0));
|
let gu = g * (1.0 + g * (-0.5 + g / 3.0));
|
||||||
let gl = gu - g.powi(4) / (4.0 * dg);
|
let gl = gu - g.powi(4) / (4.0 * dg);
|
||||||
let xm = m + 0.5;
|
let xm = m + 0.5;
|
||||||
let xn = n1 as f64 - m + 0.5;
|
let xn = n1 as f64 - m + 0.5;
|
||||||
let xk = k as f64 - m + 0.5;
|
let xk = k as f64 - m + 0.5;
|
||||||
let nm = n2 as f64 - k as f64 + xm;
|
let nm = n2 as f64 - k as f64 + xm;
|
||||||
let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0)) +
|
let ub = xm * r * (1.0 + r * (-0.5 + r / 3.0))
|
||||||
xn * s * (1.0 + s * (-0.5 + s / 3.0)) +
|
+ xn * s * (1.0 + s * (-0.5 + s / 3.0))
|
||||||
xk * t * (1.0 + t * (-0.5 + t / 3.0)) +
|
+ xk * t * (1.0 + t * (-0.5 + t / 3.0))
|
||||||
nm * e * (1.0 + e * (-0.5 + e / 3.0)) +
|
+ nm * e * (1.0 + e * (-0.5 + e / 3.0))
|
||||||
y * gu - m * gl + 0.0034;
|
+ y * gu
|
||||||
|
- m * gl
|
||||||
|
+ 0.0034;
|
||||||
let av = v.ln();
|
let av = v.ln();
|
||||||
if av > ub { continue; }
|
if av > ub {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
let dr = if r < 0.0 {
|
let dr = if r < 0.0 {
|
||||||
xm * r.powi(4) / (1.0 + r)
|
xm * r.powi(4) / (1.0 + r)
|
||||||
} else {
|
} else {
|
||||||
@ -345,17 +408,17 @@ impl Distribution<u64> for Hypergeometric {
|
|||||||
} else {
|
} else {
|
||||||
nm * e.powi(4)
|
nm * e.powi(4)
|
||||||
};
|
};
|
||||||
|
|
||||||
if av < ub - 0.25*(dr + ds + dt + de) + (y + m)*(gl - gu) - 0.0078 {
|
if av < ub - 0.25 * (dr + ds + dt + de) + (y + m) * (gl - gu) - 0.0078 {
|
||||||
break y as i64;
|
break y as i64;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 4.3: Final Acceptance/Rejection Test
|
// Step 4.3: Final Acceptance/Rejection Test
|
||||||
let av_critical = a -
|
let av_critical = a
|
||||||
ln_of_factorial(y) -
|
- ln_of_factorial(y)
|
||||||
ln_of_factorial(n1 as f64 - y) -
|
- ln_of_factorial(n1 as f64 - y)
|
||||||
ln_of_factorial(k as f64 - y) -
|
- ln_of_factorial(k as f64 - y)
|
||||||
ln_of_factorial((n2 - k) as f64 + y);
|
- ln_of_factorial((n2 - k) as f64 + y);
|
||||||
if v.ln() <= av_critical {
|
if v.ln() <= av_critical {
|
||||||
break y as i64;
|
break y as i64;
|
||||||
}
|
}
|
||||||
@ -380,8 +443,7 @@ mod test {
|
|||||||
assert!(Hypergeometric::new(100, 10, 5).is_ok());
|
assert!(Hypergeometric::new(100, 10, 5).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R)
|
fn test_hypergeometric_mean_and_variance<R: Rng>(n: u64, k: u64, s: u64, rng: &mut R) {
|
||||||
{
|
|
||||||
let distr = Hypergeometric::new(n, k, s).unwrap();
|
let distr = Hypergeometric::new(n, k, s).unwrap();
|
||||||
|
|
||||||
let expected_mean = s as f64 * k as f64 / n as f64;
|
let expected_mean = s as f64 * k as f64 / n as f64;
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{Distribution, Standard, StandardNormal};
|
use crate::{Distribution, Standard, StandardNormal};
|
||||||
|
use core::fmt;
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use core::fmt;
|
|
||||||
|
|
||||||
/// Error type returned from `InverseGaussian::new`
|
/// Error type returned from `InverseGaussian::new`
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@ -68,7 +68,9 @@ where
|
|||||||
{
|
{
|
||||||
#[allow(clippy::many_single_char_names)]
|
#[allow(clippy::many_single_char_names)]
|
||||||
fn sample<R>(&self, rng: &mut R) -> F
|
fn sample<R>(&self, rng: &mut R) -> F
|
||||||
where R: Rng + ?Sized {
|
where
|
||||||
|
R: Rng + ?Sized,
|
||||||
|
{
|
||||||
let mu = self.mean;
|
let mu = self.mean;
|
||||||
let l = self.shape;
|
let l = self.shape;
|
||||||
|
|
||||||
@ -79,7 +81,7 @@ where
|
|||||||
|
|
||||||
let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt());
|
let x = mu + mu_2l * (y - (F::from(4.).unwrap() * l * y + y * y).sqrt());
|
||||||
|
|
||||||
let u: F = rng.gen();
|
let u: F = rng.random();
|
||||||
|
|
||||||
if u <= mu / (mu + x) {
|
if u <= mu / (mu + x) {
|
||||||
return x;
|
return x;
|
||||||
@ -112,6 +114,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn inverse_gaussian_distributions_can_be_compared() {
|
fn inverse_gaussian_distributions_can_be_compared() {
|
||||||
assert_eq!(InverseGaussian::new(1.0, 2.0), InverseGaussian::new(1.0, 2.0));
|
assert_eq!(
|
||||||
|
InverseGaussian::new(1.0, 2.0),
|
||||||
|
InverseGaussian::new(1.0, 2.0)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -21,6 +21,7 @@
|
|||||||
)]
|
)]
|
||||||
#![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose
|
#![allow(clippy::neg_cmp_op_on_partial_ord)] // suggested fix too verbose
|
||||||
#![no_std]
|
#![no_std]
|
||||||
|
#![allow(unexpected_cfgs)]
|
||||||
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
||||||
|
|
||||||
//! Generating random samples from probability distributions.
|
//! Generating random samples from probability distributions.
|
||||||
@ -178,10 +179,14 @@ mod test {
|
|||||||
macro_rules! assert_almost_eq {
|
macro_rules! assert_almost_eq {
|
||||||
($a:expr, $b:expr, $prec:expr) => {
|
($a:expr, $b:expr, $prec:expr) => {
|
||||||
let diff = ($a - $b).abs();
|
let diff = ($a - $b).abs();
|
||||||
assert!(diff <= $prec,
|
assert!(
|
||||||
|
diff <= $prec,
|
||||||
"assertion failed: `abs(left - right) = {:.1e} < {:e}`, \
|
"assertion failed: `abs(left - right) = {:.1e} < {:e}`, \
|
||||||
(left: `{}`, right: `{}`)",
|
(left: `{}`, right: `{}`)",
|
||||||
diff, $prec, $a, $b
|
diff,
|
||||||
|
$prec,
|
||||||
|
$a,
|
||||||
|
$b
|
||||||
);
|
);
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
@ -10,10 +10,10 @@
|
|||||||
//! The normal and derived distributions.
|
//! The normal and derived distributions.
|
||||||
|
|
||||||
use crate::utils::ziggurat;
|
use crate::utils::ziggurat;
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{ziggurat_tables, Distribution, Open01};
|
use crate::{ziggurat_tables, Distribution, Open01};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples floating-point numbers according to the normal distribution
|
/// Samples floating-point numbers according to the normal distribution
|
||||||
/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
|
/// `N(0, 1)` (a.k.a. a standard normal, or Gaussian). This is equivalent to
|
||||||
@ -115,7 +115,9 @@ impl Distribution<f64> for StandardNormal {
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Normal<F>
|
pub struct Normal<F>
|
||||||
where F: Float, StandardNormal: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
{
|
{
|
||||||
mean: F,
|
mean: F,
|
||||||
std_dev: F,
|
std_dev: F,
|
||||||
@ -144,7 +146,9 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
impl<F> Normal<F>
|
impl<F> Normal<F>
|
||||||
where F: Float, StandardNormal: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct, from mean and standard deviation
|
/// Construct, from mean and standard deviation
|
||||||
///
|
///
|
||||||
@ -204,14 +208,15 @@ where F: Float, StandardNormal: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Normal<F>
|
impl<F> Distribution<F> for Normal<F>
|
||||||
where F: Float, StandardNormal: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
{
|
{
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
self.from_zscore(rng.sample(StandardNormal))
|
self.from_zscore(rng.sample(StandardNormal))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// The log-normal distribution `ln N(mean, std_dev**2)`.
|
/// The log-normal distribution `ln N(mean, std_dev**2)`.
|
||||||
///
|
///
|
||||||
/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)`
|
/// If `X` is log-normal distributed, then `ln(X)` is `N(mean, std_dev**2)`
|
||||||
@ -230,13 +235,17 @@ where F: Float, StandardNormal: Distribution<F>
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct LogNormal<F>
|
pub struct LogNormal<F>
|
||||||
where F: Float, StandardNormal: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
{
|
{
|
||||||
norm: Normal<F>,
|
norm: Normal<F>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<F> LogNormal<F>
|
impl<F> LogNormal<F>
|
||||||
where F: Float, StandardNormal: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct, from (log-space) mean and standard deviation
|
/// Construct, from (log-space) mean and standard deviation
|
||||||
///
|
///
|
||||||
@ -307,7 +316,9 @@ where F: Float, StandardNormal: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for LogNormal<F>
|
impl<F> Distribution<F> for LogNormal<F>
|
||||||
where F: Float, StandardNormal: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
StandardNormal: Distribution<F>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
@ -348,7 +359,10 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_log_normal_cv() {
|
fn test_log_normal_cv() {
|
||||||
let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
|
let lnorm = LogNormal::from_mean_cv(0.0, 0.0).unwrap();
|
||||||
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (f64::NEG_INFINITY, 0.0));
|
assert_eq!(
|
||||||
|
(lnorm.norm.mean, lnorm.norm.std_dev),
|
||||||
|
(f64::NEG_INFINITY, 0.0)
|
||||||
|
);
|
||||||
|
|
||||||
let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
|
let lnorm = LogNormal::from_mean_cv(1.0, 0.0).unwrap();
|
||||||
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
|
assert_eq!((lnorm.norm.mean, lnorm.norm.std_dev), (0.0, 0.0));
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
use crate::{Distribution, InverseGaussian, Standard, StandardNormal};
|
use crate::{Distribution, InverseGaussian, Standard, StandardNormal};
|
||||||
|
use core::fmt;
|
||||||
use num_traits::Float;
|
use num_traits::Float;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use core::fmt;
|
|
||||||
|
|
||||||
/// Error type returned from `NormalInverseGaussian::new`
|
/// Error type returned from `NormalInverseGaussian::new`
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
@ -15,8 +15,12 @@ pub enum Error {
|
|||||||
impl fmt::Display for Error {
|
impl fmt::Display for Error {
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.write_str(match self {
|
f.write_str(match self {
|
||||||
Error::AlphaNegativeOrNull => "alpha <= 0 or is NaN in normal inverse Gaussian distribution",
|
Error::AlphaNegativeOrNull => {
|
||||||
Error::AbsoluteBetaNotLessThanAlpha => "|beta| >= alpha or is NaN in normal inverse Gaussian distribution",
|
"alpha <= 0 or is NaN in normal inverse Gaussian distribution"
|
||||||
|
}
|
||||||
|
Error::AbsoluteBetaNotLessThanAlpha => {
|
||||||
|
"|beta| >= alpha or is NaN in normal inverse Gaussian distribution"
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -75,7 +79,9 @@ where
|
|||||||
Standard: Distribution<F>,
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
fn sample<R>(&self, rng: &mut R) -> F
|
fn sample<R>(&self, rng: &mut R) -> F
|
||||||
where R: Rng + ?Sized {
|
where
|
||||||
|
R: Rng + ?Sized,
|
||||||
|
{
|
||||||
let inv_gauss = rng.sample(self.inverse_gaussian);
|
let inv_gauss = rng.sample(self.inverse_gaussian);
|
||||||
|
|
||||||
self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal)
|
self.beta * inv_gauss + inv_gauss.sqrt() * rng.sample(StandardNormal)
|
||||||
@ -105,6 +111,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn normal_inverse_gaussian_distributions_can_be_compared() {
|
fn normal_inverse_gaussian_distributions_can_be_compared() {
|
||||||
assert_eq!(NormalInverseGaussian::new(1.0, 2.0), NormalInverseGaussian::new(1.0, 2.0));
|
assert_eq!(
|
||||||
|
NormalInverseGaussian::new(1.0, 2.0),
|
||||||
|
NormalInverseGaussian::new(1.0, 2.0)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,10 +8,10 @@
|
|||||||
|
|
||||||
//! The Pareto distribution.
|
//! The Pareto distribution.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{Distribution, OpenClosed01};
|
use crate::{Distribution, OpenClosed01};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples floating-point numbers according to the Pareto distribution
|
/// Samples floating-point numbers according to the Pareto distribution
|
||||||
///
|
///
|
||||||
@ -26,7 +26,9 @@ use core::fmt;
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Pareto<F>
|
pub struct Pareto<F>
|
||||||
where F: Float, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
scale: F,
|
scale: F,
|
||||||
inv_neg_shape: F,
|
inv_neg_shape: F,
|
||||||
@ -55,7 +57,9 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
impl<F> Pareto<F>
|
impl<F> Pareto<F>
|
||||||
where F: Float, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct a new Pareto distribution with given `scale` and `shape`.
|
/// Construct a new Pareto distribution with given `scale` and `shape`.
|
||||||
///
|
///
|
||||||
@ -78,7 +82,9 @@ where F: Float, OpenClosed01: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Pareto<F>
|
impl<F> Distribution<F> for Pareto<F>
|
||||||
where F: Float, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
let u: F = OpenClosed01.sample(rng);
|
let u: F = OpenClosed01.sample(rng);
|
||||||
@ -112,7 +118,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<F: Float + Debug + Display + LowerExp, D: Distribution<F>>(
|
fn test_samples<F: Float + Debug + Display + LowerExp, D: Distribution<F>>(
|
||||||
distr: D, thresh: F, expected: &[F],
|
distr: D,
|
||||||
|
thresh: F,
|
||||||
|
expected: &[F],
|
||||||
) {
|
) {
|
||||||
let mut rng = crate::test::rng(213);
|
let mut rng = crate::test::rng(213);
|
||||||
for v in expected {
|
for v in expected {
|
||||||
@ -121,15 +129,21 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
test_samples(Pareto::new(1f32, 1.0).unwrap(), 1e-6, &[
|
test_samples(
|
||||||
1.0423688, 2.1235929, 4.132709, 1.4679428,
|
Pareto::new(1f32, 1.0).unwrap(),
|
||||||
]);
|
1e-6,
|
||||||
test_samples(Pareto::new(2.0, 0.5).unwrap(), 1e-14, &[
|
&[1.0423688, 2.1235929, 4.132709, 1.4679428],
|
||||||
9.019295276219136,
|
);
|
||||||
4.3097126018270595,
|
test_samples(
|
||||||
6.837815045397157,
|
Pareto::new(2.0, 0.5).unwrap(),
|
||||||
105.8826669383772,
|
1e-14,
|
||||||
]);
|
&[
|
||||||
|
9.019295276219136,
|
||||||
|
4.3097126018270595,
|
||||||
|
6.837815045397157,
|
||||||
|
105.8826669383772,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -7,10 +7,10 @@
|
|||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
//! The PERT distribution.
|
//! The PERT distribution.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
|
use crate::{Beta, Distribution, Exp1, Open01, StandardNormal};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// The PERT distribution.
|
/// The PERT distribution.
|
||||||
///
|
///
|
||||||
@ -129,20 +129,12 @@ mod test {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_pert() {
|
fn test_pert() {
|
||||||
for &(min, max, mode) in &[
|
for &(min, max, mode) in &[(-1., 1., 0.), (1., 2., 1.), (5., 25., 25.)] {
|
||||||
(-1., 1., 0.),
|
|
||||||
(1., 2., 1.),
|
|
||||||
(5., 25., 25.),
|
|
||||||
] {
|
|
||||||
let _distr = Pert::new(min, max, mode).unwrap();
|
let _distr = Pert::new(min, max, mode).unwrap();
|
||||||
// TODO: test correctness
|
// TODO: test correctness
|
||||||
}
|
}
|
||||||
|
|
||||||
for &(min, max, mode) in &[
|
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
|
||||||
(-1., 1., 2.),
|
|
||||||
(-1., 1., -2.),
|
|
||||||
(2., 1., 1.),
|
|
||||||
] {
|
|
||||||
assert!(Pert::new(min, max, mode).is_err());
|
assert!(Pert::new(min, max, mode).is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,10 +9,10 @@
|
|||||||
|
|
||||||
//! The Poisson distribution.
|
//! The Poisson distribution.
|
||||||
|
|
||||||
use num_traits::{Float, FloatConst};
|
|
||||||
use crate::{Cauchy, Distribution, Standard};
|
use crate::{Cauchy, Distribution, Standard};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::{Float, FloatConst};
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// The Poisson distribution `Poisson(lambda)`.
|
/// The Poisson distribution `Poisson(lambda)`.
|
||||||
///
|
///
|
||||||
@ -31,7 +31,9 @@ use core::fmt;
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Poisson<F>
|
pub struct Poisson<F>
|
||||||
where F: Float + FloatConst, Standard: Distribution<F>
|
where
|
||||||
|
F: Float + FloatConst,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
lambda: F,
|
lambda: F,
|
||||||
// precalculated values
|
// precalculated values
|
||||||
@ -64,7 +66,9 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
impl<F> Poisson<F>
|
impl<F> Poisson<F>
|
||||||
where F: Float + FloatConst, Standard: Distribution<F>
|
where
|
||||||
|
F: Float + FloatConst,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct a new `Poisson` with the given shape parameter
|
/// Construct a new `Poisson` with the given shape parameter
|
||||||
/// `lambda`.
|
/// `lambda`.
|
||||||
@ -87,7 +91,9 @@ where F: Float + FloatConst, Standard: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Poisson<F>
|
impl<F> Distribution<F> for Poisson<F>
|
||||||
where F: Float + FloatConst, Standard: Distribution<F>
|
where
|
||||||
|
F: Float + FloatConst,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
@ -96,9 +102,9 @@ where F: Float + FloatConst, Standard: Distribution<F>
|
|||||||
// for low expected values use the Knuth method
|
// for low expected values use the Knuth method
|
||||||
if self.lambda < F::from(12.0).unwrap() {
|
if self.lambda < F::from(12.0).unwrap() {
|
||||||
let mut result = F::one();
|
let mut result = F::one();
|
||||||
let mut p = rng.gen::<F>();
|
let mut p = rng.random::<F>();
|
||||||
while p > self.exp_lambda {
|
while p > self.exp_lambda {
|
||||||
p = p*rng.gen::<F>();
|
p = p * rng.random::<F>();
|
||||||
result = result + F::one();
|
result = result + F::one();
|
||||||
}
|
}
|
||||||
result - F::one()
|
result - F::one()
|
||||||
@ -139,7 +145,7 @@ where F: Float + FloatConst, Standard: Distribution<F>
|
|||||||
.exp();
|
.exp();
|
||||||
|
|
||||||
// check with uniform random value - if below the threshold, we are within the target distribution
|
// check with uniform random value - if below the threshold, we are within the target distribution
|
||||||
if rng.gen::<F>() <= check {
|
if rng.random::<F>() <= check {
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -153,7 +159,8 @@ mod test {
|
|||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F)
|
fn test_poisson_avg_gen<F: Float + FloatConst>(lambda: F, tol: F)
|
||||||
where Standard: Distribution<F>
|
where
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
let poisson = Poisson::new(lambda).unwrap();
|
let poisson = Poisson::new(lambda).unwrap();
|
||||||
let mut rng = crate::test::rng(123);
|
let mut rng = crate::test::rng(123);
|
||||||
@ -173,7 +180,7 @@ mod test {
|
|||||||
test_poisson_avg_gen::<f32>(10.0, 0.1);
|
test_poisson_avg_gen::<f32>(10.0, 0.1);
|
||||||
test_poisson_avg_gen::<f32>(15.0, 0.1);
|
test_poisson_avg_gen::<f32>(15.0, 0.1);
|
||||||
|
|
||||||
//Small lambda will use Knuth's method with exp_lambda == 1.0
|
// Small lambda will use Knuth's method with exp_lambda == 1.0
|
||||||
test_poisson_avg_gen::<f32>(0.00000000000000005, 0.1);
|
test_poisson_avg_gen::<f32>(0.00000000000000005, 0.1);
|
||||||
test_poisson_avg_gen::<f64>(0.00000000000000005, 0.1);
|
test_poisson_avg_gen::<f64>(0.00000000000000005, 0.1);
|
||||||
}
|
}
|
||||||
|
@ -150,9 +150,7 @@ where
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(
|
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
|
||||||
distr: D, zero: F, expected: &[F],
|
|
||||||
) {
|
|
||||||
let mut rng = crate::test::rng(213);
|
let mut rng = crate::test::rng(213);
|
||||||
let mut buf = [zero; 4];
|
let mut buf = [zero; 4];
|
||||||
for x in &mut buf {
|
for x in &mut buf {
|
||||||
@ -222,12 +220,7 @@ mod tests {
|
|||||||
test_samples(
|
test_samples(
|
||||||
SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(),
|
SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(),
|
||||||
0f64,
|
0f64,
|
||||||
&[
|
&[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY],
|
||||||
f64::INFINITY,
|
|
||||||
f64::INFINITY,
|
|
||||||
f64::INFINITY,
|
|
||||||
f64::INFINITY,
|
|
||||||
],
|
|
||||||
);
|
);
|
||||||
test_samples(
|
test_samples(
|
||||||
SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
|
SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
|
||||||
@ -256,6 +249,9 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn skew_normal_distributions_can_be_compared() {
|
fn skew_normal_distributions_can_be_compared() {
|
||||||
assert_eq!(SkewNormal::new(1.0, 2.0, 3.0), SkewNormal::new(1.0, 2.0, 3.0));
|
assert_eq!(
|
||||||
|
SkewNormal::new(1.0, 2.0, 3.0),
|
||||||
|
SkewNormal::new(1.0, 2.0, 3.0)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -7,10 +7,10 @@
|
|||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
//! The triangular distribution.
|
//! The triangular distribution.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{Distribution, Standard};
|
use crate::{Distribution, Standard};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// The triangular distribution.
|
/// The triangular distribution.
|
||||||
///
|
///
|
||||||
@ -34,7 +34,9 @@ use core::fmt;
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Triangular<F>
|
pub struct Triangular<F>
|
||||||
where F: Float, Standard: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
min: F,
|
min: F,
|
||||||
max: F,
|
max: F,
|
||||||
@ -66,7 +68,9 @@ impl fmt::Display for TriangularError {
|
|||||||
impl std::error::Error for TriangularError {}
|
impl std::error::Error for TriangularError {}
|
||||||
|
|
||||||
impl<F> Triangular<F>
|
impl<F> Triangular<F>
|
||||||
where F: Float, Standard: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Set up the Triangular distribution with defined `min`, `max` and `mode`.
|
/// Set up the Triangular distribution with defined `min`, `max` and `mode`.
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -82,7 +86,9 @@ where F: Float, Standard: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Triangular<F>
|
impl<F> Distribution<F> for Triangular<F>
|
||||||
where F: Float, Standard: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
@ -106,7 +112,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_triangular() {
|
fn test_triangular() {
|
||||||
let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0);
|
let mut half_rng = mock::StepRng::new(0x8000_0000_0000_0000, 0);
|
||||||
assert_eq!(half_rng.gen::<f64>(), 0.5);
|
assert_eq!(half_rng.random::<f64>(), 0.5);
|
||||||
for &(min, max, mode, median) in &[
|
for &(min, max, mode, median) in &[
|
||||||
(-1., 1., 0., 0.),
|
(-1., 1., 0., 0.),
|
||||||
(1., 2., 1., 2. - 0.5f64.sqrt()),
|
(1., 2., 1., 2. - 0.5f64.sqrt()),
|
||||||
@ -122,17 +128,16 @@ mod test {
|
|||||||
assert_eq!(distr.sample(&mut half_rng), median);
|
assert_eq!(distr.sample(&mut half_rng), median);
|
||||||
}
|
}
|
||||||
|
|
||||||
for &(min, max, mode) in &[
|
for &(min, max, mode) in &[(-1., 1., 2.), (-1., 1., -2.), (2., 1., 1.)] {
|
||||||
(-1., 1., 2.),
|
|
||||||
(-1., 1., -2.),
|
|
||||||
(2., 1., 1.),
|
|
||||||
] {
|
|
||||||
assert!(Triangular::new(min, max, mode).is_err());
|
assert!(Triangular::new(min, max, mode).is_err());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn triangular_distributions_can_be_compared() {
|
fn triangular_distributions_can_be_compared() {
|
||||||
assert_eq!(Triangular::new(1.0, 3.0, 2.0), Triangular::new(1.0, 3.0, 2.0));
|
assert_eq!(
|
||||||
|
Triangular::new(1.0, 3.0, 2.0),
|
||||||
|
Triangular::new(1.0, 3.0, 2.0)
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
// option. This file may not be copied, modified, or distributed
|
// option. This file may not be copied, modified, or distributed
|
||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
||||||
|
use num_traits::Float;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples uniformly from the unit ball (surface and interior) in three
|
/// Samples uniformly from the unit ball (surface and interior) in three
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
// option. This file may not be copied, modified, or distributed
|
// option. This file may not be copied, modified, or distributed
|
||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
||||||
|
use num_traits::Float;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples uniformly from the edge of the unit circle in two dimensions.
|
/// Samples uniformly from the edge of the unit circle in two dimensions.
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
// option. This file may not be copied, modified, or distributed
|
// option. This file may not be copied, modified, or distributed
|
||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
||||||
|
use num_traits::Float;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples uniformly from the unit disc in two dimensions.
|
/// Samples uniformly from the unit disc in two dimensions.
|
||||||
|
@ -6,8 +6,8 @@
|
|||||||
// option. This file may not be copied, modified, or distributed
|
// option. This file may not be copied, modified, or distributed
|
||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
||||||
|
use num_traits::Float;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples uniformly from the surface of the unit sphere in three dimensions.
|
/// Samples uniformly from the surface of the unit sphere in three dimensions.
|
||||||
@ -42,7 +42,11 @@ impl<F: Float + SampleUniform> Distribution<[F; 3]> for UnitSphere {
|
|||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt();
|
let factor = F::from(2.).unwrap() * (F::one() - sum).sqrt();
|
||||||
return [x1 * factor, x2 * factor, F::from(1.).unwrap() - F::from(2.).unwrap() * sum];
|
return [
|
||||||
|
x1 * factor,
|
||||||
|
x2 * factor,
|
||||||
|
F::from(1.).unwrap() - F::from(2.).unwrap() * sum,
|
||||||
|
];
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -9,9 +9,9 @@
|
|||||||
//! Math helper functions
|
//! Math helper functions
|
||||||
|
|
||||||
use crate::ziggurat_tables;
|
use crate::ziggurat_tables;
|
||||||
|
use num_traits::Float;
|
||||||
use rand::distributions::hidden_export::IntoFloat;
|
use rand::distributions::hidden_export::IntoFloat;
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use num_traits::Float;
|
|
||||||
|
|
||||||
/// Calculates ln(gamma(x)) (natural logarithm of the gamma
|
/// Calculates ln(gamma(x)) (natural logarithm of the gamma
|
||||||
/// function) using the Lanczos approximation.
|
/// function) using the Lanczos approximation.
|
||||||
@ -77,7 +77,7 @@ pub(crate) fn ziggurat<R: Rng + ?Sized, P, Z>(
|
|||||||
x_tab: ziggurat_tables::ZigTable,
|
x_tab: ziggurat_tables::ZigTable,
|
||||||
f_tab: ziggurat_tables::ZigTable,
|
f_tab: ziggurat_tables::ZigTable,
|
||||||
mut pdf: P,
|
mut pdf: P,
|
||||||
mut zero_case: Z
|
mut zero_case: Z,
|
||||||
) -> f64
|
) -> f64
|
||||||
where
|
where
|
||||||
P: FnMut(f64) -> f64,
|
P: FnMut(f64) -> f64,
|
||||||
@ -114,7 +114,7 @@ where
|
|||||||
return zero_case(rng, u);
|
return zero_case(rng, u);
|
||||||
}
|
}
|
||||||
// algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
|
// algebraically equivalent to f1 + DRanU()*(f0 - f1) < 1
|
||||||
if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.gen::<f64>() < pdf(x) {
|
if f_tab[i + 1] + (f_tab[i] - f_tab[i + 1]) * rng.random::<f64>() < pdf(x) {
|
||||||
return x;
|
return x;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,10 +8,10 @@
|
|||||||
|
|
||||||
//! The Weibull distribution.
|
//! The Weibull distribution.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{Distribution, OpenClosed01};
|
use crate::{Distribution, OpenClosed01};
|
||||||
use rand::Rng;
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::Rng;
|
||||||
|
|
||||||
/// Samples floating-point numbers according to the Weibull distribution
|
/// Samples floating-point numbers according to the Weibull distribution
|
||||||
///
|
///
|
||||||
@ -26,7 +26,9 @@ use core::fmt;
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(serde::Serialize, serde::Deserialize))]
|
||||||
pub struct Weibull<F>
|
pub struct Weibull<F>
|
||||||
where F: Float, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
inv_shape: F,
|
inv_shape: F,
|
||||||
scale: F,
|
scale: F,
|
||||||
@ -55,7 +57,9 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
impl<F> Weibull<F>
|
impl<F> Weibull<F>
|
||||||
where F: Float, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct a new `Weibull` distribution with given `scale` and `shape`.
|
/// Construct a new `Weibull` distribution with given `scale` and `shape`.
|
||||||
pub fn new(scale: F, shape: F) -> Result<Weibull<F>, Error> {
|
pub fn new(scale: F, shape: F) -> Result<Weibull<F>, Error> {
|
||||||
@ -73,7 +77,9 @@ where F: Float, OpenClosed01: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Weibull<F>
|
impl<F> Distribution<F> for Weibull<F>
|
||||||
where F: Float, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
let x: F = rng.sample(OpenClosed01);
|
let x: F = rng.sample(OpenClosed01);
|
||||||
@ -106,7 +112,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(
|
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(
|
||||||
distr: D, zero: F, expected: &[F],
|
distr: D,
|
||||||
|
zero: F,
|
||||||
|
expected: &[F],
|
||||||
) {
|
) {
|
||||||
let mut rng = crate::test::rng(213);
|
let mut rng = crate::test::rng(213);
|
||||||
let mut buf = [zero; 4];
|
let mut buf = [zero; 4];
|
||||||
@ -116,18 +124,21 @@ mod tests {
|
|||||||
assert_eq!(buf, expected);
|
assert_eq!(buf, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_samples(Weibull::new(1.0, 1.0).unwrap(), 0f32, &[
|
test_samples(
|
||||||
0.041495778,
|
Weibull::new(1.0, 1.0).unwrap(),
|
||||||
0.7531094,
|
0f32,
|
||||||
1.4189332,
|
&[0.041495778, 0.7531094, 1.4189332, 0.38386202],
|
||||||
0.38386202,
|
);
|
||||||
]);
|
test_samples(
|
||||||
test_samples(Weibull::new(2.0, 0.5).unwrap(), 0f64, &[
|
Weibull::new(2.0, 0.5).unwrap(),
|
||||||
1.1343478702739669,
|
0f64,
|
||||||
0.29470010050655226,
|
&[
|
||||||
0.7556151370284702,
|
1.1343478702739669,
|
||||||
7.877212340241561,
|
0.29470010050655226,
|
||||||
]);
|
0.7556151370284702,
|
||||||
|
7.877212340241561,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -11,13 +11,13 @@
|
|||||||
|
|
||||||
use super::WeightError;
|
use super::WeightError;
|
||||||
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
use crate::{uniform::SampleUniform, Distribution, Uniform};
|
||||||
|
use alloc::{boxed::Box, vec, vec::Vec};
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use core::iter::Sum;
|
use core::iter::Sum;
|
||||||
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
|
use core::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign};
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use alloc::{boxed::Box, vec, vec::Vec};
|
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A distribution using weighted sampling to pick a discretely selected item.
|
/// A distribution using weighted sampling to pick a discretely selected item.
|
||||||
///
|
///
|
||||||
@ -67,8 +67,14 @@ use serde::{Serialize, Deserialize};
|
|||||||
/// [`Uniform<W>::sample`]: Distribution::sample
|
/// [`Uniform<W>::sample`]: Distribution::sample
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
||||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||||
#[cfg_attr(feature = "serde1", serde(bound(serialize = "W: Serialize, W::Sampler: Serialize")))]
|
#[cfg_attr(
|
||||||
#[cfg_attr(feature = "serde1", serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>")))]
|
feature = "serde1",
|
||||||
|
serde(bound(serialize = "W: Serialize, W::Sampler: Serialize"))
|
||||||
|
)]
|
||||||
|
#[cfg_attr(
|
||||||
|
feature = "serde1",
|
||||||
|
serde(bound(deserialize = "W: Deserialize<'de>, W::Sampler: Deserialize<'de>"))
|
||||||
|
)]
|
||||||
pub struct WeightedAliasIndex<W: AliasableWeight> {
|
pub struct WeightedAliasIndex<W: AliasableWeight> {
|
||||||
aliases: Box<[u32]>,
|
aliases: Box<[u32]>,
|
||||||
no_alias_odds: Box<[W]>,
|
no_alias_odds: Box<[W]>,
|
||||||
@ -257,7 +263,8 @@ where
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
|
impl<W: AliasableWeight> Clone for WeightedAliasIndex<W>
|
||||||
where Uniform<W>: Clone
|
where
|
||||||
|
Uniform<W>: Clone,
|
||||||
{
|
{
|
||||||
fn clone(&self) -> Self {
|
fn clone(&self) -> Self {
|
||||||
Self {
|
Self {
|
||||||
@ -308,7 +315,7 @@ pub trait AliasableWeight:
|
|||||||
macro_rules! impl_weight_for_float {
|
macro_rules! impl_weight_for_float {
|
||||||
($T: ident) => {
|
($T: ident) => {
|
||||||
impl AliasableWeight for $T {
|
impl AliasableWeight for $T {
|
||||||
const MAX: Self = ::core::$T::MAX;
|
const MAX: Self = $T::MAX;
|
||||||
const ZERO: Self = 0.0;
|
const ZERO: Self = 0.0;
|
||||||
|
|
||||||
fn try_from_u32_lossy(n: u32) -> Option<Self> {
|
fn try_from_u32_lossy(n: u32) -> Option<Self> {
|
||||||
@ -337,7 +344,7 @@ fn pairwise_sum<T: AliasableWeight>(values: &[T]) -> T {
|
|||||||
macro_rules! impl_weight_for_int {
|
macro_rules! impl_weight_for_int {
|
||||||
($T: ident) => {
|
($T: ident) => {
|
||||||
impl AliasableWeight for $T {
|
impl AliasableWeight for $T {
|
||||||
const MAX: Self = ::core::$T::MAX;
|
const MAX: Self = $T::MAX;
|
||||||
const ZERO: Self = 0;
|
const ZERO: Self = 0;
|
||||||
|
|
||||||
fn try_from_u32_lossy(n: u32) -> Option<Self> {
|
fn try_from_u32_lossy(n: u32) -> Option<Self> {
|
||||||
@ -444,7 +451,9 @@ mod test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
|
fn test_weighted_index<W: AliasableWeight, F: Fn(W) -> f64>(w_to_f64: F)
|
||||||
where WeightedAliasIndex<W>: fmt::Debug {
|
where
|
||||||
|
WeightedAliasIndex<W>: fmt::Debug,
|
||||||
|
{
|
||||||
const NUM_WEIGHTS: u32 = 10;
|
const NUM_WEIGHTS: u32 = 10;
|
||||||
const ZERO_WEIGHT_INDEX: u32 = 3;
|
const ZERO_WEIGHT_INDEX: u32 = 3;
|
||||||
const NUM_SAMPLES: u32 = 15000;
|
const NUM_SAMPLES: u32 = 15000;
|
||||||
@ -455,7 +464,8 @@ mod test {
|
|||||||
let random_weight_distribution = Uniform::new_inclusive(
|
let random_weight_distribution = Uniform::new_inclusive(
|
||||||
W::ZERO,
|
W::ZERO,
|
||||||
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
|
W::MAX / W::try_from_u32_lossy(NUM_WEIGHTS).unwrap(),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
for _ in 0..NUM_WEIGHTS {
|
for _ in 0..NUM_WEIGHTS {
|
||||||
weights.push(rng.sample(&random_weight_distribution));
|
weights.push(rng.sample(&random_weight_distribution));
|
||||||
}
|
}
|
||||||
@ -497,7 +507,11 @@ mod test {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<W: AliasableWeight>(weights: Vec<W>, buf: &mut [usize], expected: &[usize]) {
|
fn test_samples<W: AliasableWeight>(
|
||||||
|
weights: Vec<W>,
|
||||||
|
buf: &mut [usize],
|
||||||
|
expected: &[usize],
|
||||||
|
) {
|
||||||
assert_eq!(buf.len(), expected.len());
|
assert_eq!(buf.len(), expected.len());
|
||||||
let distr = WeightedAliasIndex::new(weights).unwrap();
|
let distr = WeightedAliasIndex::new(weights).unwrap();
|
||||||
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
||||||
@ -508,14 +522,20 @@ mod test {
|
|||||||
}
|
}
|
||||||
|
|
||||||
let mut buf = [0; 10];
|
let mut buf = [0; 10];
|
||||||
test_samples(vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1], &mut buf, &[
|
test_samples(
|
||||||
6, 5, 7, 5, 8, 7, 6, 2, 3, 7,
|
vec![1i32, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
]);
|
&mut buf,
|
||||||
test_samples(vec![0.7f32, 0.1, 0.1, 0.1], &mut buf, &[
|
&[6, 5, 7, 5, 8, 7, 6, 2, 3, 7],
|
||||||
2, 0, 0, 0, 0, 0, 0, 0, 1, 3,
|
);
|
||||||
]);
|
test_samples(
|
||||||
test_samples(vec![1.0f64, 0.999, 0.998, 0.997], &mut buf, &[
|
vec![0.7f32, 0.1, 0.1, 0.1],
|
||||||
2, 1, 2, 3, 2, 1, 3, 2, 1, 1,
|
&mut buf,
|
||||||
]);
|
&[2, 0, 0, 0, 0, 0, 0, 0, 1, 3],
|
||||||
|
);
|
||||||
|
test_samples(
|
||||||
|
vec![1.0f64, 0.999, 0.998, 0.997],
|
||||||
|
&mut buf,
|
||||||
|
&[2, 1, 2, 3, 2, 1, 3, 2, 1, 1],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -303,6 +303,7 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_no_item_error() {
|
fn test_no_item_error() {
|
||||||
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
||||||
|
#[allow(clippy::needless_borrows_for_generic_args)]
|
||||||
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
|
let tree = WeightedTreeIndex::<f64>::new(&[]).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tree.try_sample(&mut rng).unwrap_err(),
|
tree.try_sample(&mut rng).unwrap_err(),
|
||||||
@ -313,10 +314,10 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_overflow_error() {
|
fn test_overflow_error() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedTreeIndex::new(&[i32::MAX, 2]),
|
WeightedTreeIndex::new([i32::MAX, 2]),
|
||||||
Err(WeightError::Overflow)
|
Err(WeightError::Overflow)
|
||||||
);
|
);
|
||||||
let mut tree = WeightedTreeIndex::new(&[i32::MAX - 2, 1]).unwrap();
|
let mut tree = WeightedTreeIndex::new([i32::MAX - 2, 1]).unwrap();
|
||||||
assert_eq!(tree.push(3), Err(WeightError::Overflow));
|
assert_eq!(tree.push(3), Err(WeightError::Overflow));
|
||||||
assert_eq!(tree.update(1, 4), Err(WeightError::Overflow));
|
assert_eq!(tree.update(1, 4), Err(WeightError::Overflow));
|
||||||
tree.update(1, 2).unwrap();
|
tree.update(1, 2).unwrap();
|
||||||
@ -324,7 +325,7 @@ mod test {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_all_weights_zero_error() {
|
fn test_all_weights_zero_error() {
|
||||||
let tree = WeightedTreeIndex::<f64>::new(&[0.0, 0.0]).unwrap();
|
let tree = WeightedTreeIndex::<f64>::new([0.0, 0.0]).unwrap();
|
||||||
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
tree.try_sample(&mut rng).unwrap_err(),
|
tree.try_sample(&mut rng).unwrap_err(),
|
||||||
@ -335,37 +336,36 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_invalid_weight_error() {
|
fn test_invalid_weight_error() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedTreeIndex::<i32>::new(&[1, -1]).unwrap_err(),
|
WeightedTreeIndex::<i32>::new([1, -1]).unwrap_err(),
|
||||||
WeightError::InvalidWeight
|
WeightError::InvalidWeight
|
||||||
);
|
);
|
||||||
|
#[allow(clippy::needless_borrows_for_generic_args)]
|
||||||
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
|
let mut tree = WeightedTreeIndex::<i32>::new(&[]).unwrap();
|
||||||
assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight);
|
assert_eq!(tree.push(-1).unwrap_err(), WeightError::InvalidWeight);
|
||||||
tree.push(1).unwrap();
|
tree.push(1).unwrap();
|
||||||
assert_eq!(
|
assert_eq!(tree.update(0, -1).unwrap_err(), WeightError::InvalidWeight);
|
||||||
tree.update(0, -1).unwrap_err(),
|
|
||||||
WeightError::InvalidWeight
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_tree_modifications() {
|
fn test_tree_modifications() {
|
||||||
let mut tree = WeightedTreeIndex::new(&[9, 1, 2]).unwrap();
|
let mut tree = WeightedTreeIndex::new([9, 1, 2]).unwrap();
|
||||||
tree.push(3).unwrap();
|
tree.push(3).unwrap();
|
||||||
tree.push(5).unwrap();
|
tree.push(5).unwrap();
|
||||||
tree.update(0, 0).unwrap();
|
tree.update(0, 0).unwrap();
|
||||||
assert_eq!(tree.pop(), Some(5));
|
assert_eq!(tree.pop(), Some(5));
|
||||||
let expected = WeightedTreeIndex::new(&[0, 1, 2, 3]).unwrap();
|
let expected = WeightedTreeIndex::new([0, 1, 2, 3]).unwrap();
|
||||||
assert_eq!(tree, expected);
|
assert_eq!(tree, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
#[allow(clippy::needless_range_loop)]
|
||||||
fn test_sample_counts_match_probabilities() {
|
fn test_sample_counts_match_probabilities() {
|
||||||
let start = 1;
|
let start = 1;
|
||||||
let end = 3;
|
let end = 3;
|
||||||
let samples = 20;
|
let samples = 20;
|
||||||
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
let mut rng = crate::test::rng(0x9c9fa0b0580a7031);
|
||||||
let weights: Vec<_> = (0..end).map(|_| rng.gen()).collect();
|
let weights: Vec<f64> = (0..end).map(|_| rng.random()).collect();
|
||||||
let mut tree = WeightedTreeIndex::new(&weights).unwrap();
|
let mut tree = WeightedTreeIndex::new(weights).unwrap();
|
||||||
let mut total_weight = 0.0;
|
let mut total_weight = 0.0;
|
||||||
let mut weights = alloc::vec![0.0; end];
|
let mut weights = alloc::vec![0.0; end];
|
||||||
for i in 0..end {
|
for i in 0..end {
|
||||||
|
@ -8,10 +8,10 @@
|
|||||||
|
|
||||||
//! The Zeta and related distributions.
|
//! The Zeta and related distributions.
|
||||||
|
|
||||||
use num_traits::Float;
|
|
||||||
use crate::{Distribution, Standard};
|
use crate::{Distribution, Standard};
|
||||||
use rand::{Rng, distributions::OpenClosed01};
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use num_traits::Float;
|
||||||
|
use rand::{distributions::OpenClosed01, Rng};
|
||||||
|
|
||||||
/// Samples integers according to the [zeta distribution].
|
/// Samples integers according to the [zeta distribution].
|
||||||
///
|
///
|
||||||
@ -48,7 +48,10 @@ use core::fmt;
|
|||||||
/// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8
|
/// [Non-Uniform Random Variate Generation]: https://doi.org/10.1007/978-1-4613-8643-8
|
||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
pub struct Zeta<F>
|
pub struct Zeta<F>
|
||||||
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
a_minus_1: F,
|
a_minus_1: F,
|
||||||
b: F,
|
b: F,
|
||||||
@ -74,7 +77,10 @@ impl fmt::Display for ZetaError {
|
|||||||
impl std::error::Error for ZetaError {}
|
impl std::error::Error for ZetaError {}
|
||||||
|
|
||||||
impl<F> Zeta<F>
|
impl<F> Zeta<F>
|
||||||
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
/// Construct a new `Zeta` distribution with given `a` parameter.
|
/// Construct a new `Zeta` distribution with given `a` parameter.
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -92,7 +98,10 @@ where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Zeta<F>
|
impl<F> Distribution<F> for Zeta<F>
|
||||||
where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
|
OpenClosed01: Distribution<F>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
@ -144,7 +153,10 @@ where F: Float, Standard: Distribution<F>, OpenClosed01: Distribution<F>
|
|||||||
/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa
|
/// [1]: https://jasoncrease.medium.com/rejection-sampling-the-zipf-distribution-6b359792cffa
|
||||||
#[derive(Clone, Copy, Debug, PartialEq)]
|
#[derive(Clone, Copy, Debug, PartialEq)]
|
||||||
pub struct Zipf<F>
|
pub struct Zipf<F>
|
||||||
where F: Float, Standard: Distribution<F> {
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
|
{
|
||||||
s: F,
|
s: F,
|
||||||
t: F,
|
t: F,
|
||||||
q: F,
|
q: F,
|
||||||
@ -173,7 +185,10 @@ impl fmt::Display for ZipfError {
|
|||||||
impl std::error::Error for ZipfError {}
|
impl std::error::Error for ZipfError {}
|
||||||
|
|
||||||
impl<F> Zipf<F>
|
impl<F> Zipf<F>
|
||||||
where F: Float, Standard: Distribution<F> {
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
|
{
|
||||||
/// Construct a new `Zipf` distribution for a set with `n` elements and a
|
/// Construct a new `Zipf` distribution for a set with `n` elements and a
|
||||||
/// frequency rank exponent `s`.
|
/// frequency rank exponent `s`.
|
||||||
///
|
///
|
||||||
@ -186,7 +201,7 @@ where F: Float, Standard: Distribution<F> {
|
|||||||
if n < 1 {
|
if n < 1 {
|
||||||
return Err(ZipfError::NTooSmall);
|
return Err(ZipfError::NTooSmall);
|
||||||
}
|
}
|
||||||
let n = F::from(n).unwrap(); // This does not fail.
|
let n = F::from(n).unwrap(); // This does not fail.
|
||||||
let q = if s != F::one() {
|
let q = if s != F::one() {
|
||||||
// Make sure to calculate the division only once.
|
// Make sure to calculate the division only once.
|
||||||
F::one() / (F::one() - s)
|
F::one() / (F::one() - s)
|
||||||
@ -200,9 +215,7 @@ where F: Float, Standard: Distribution<F> {
|
|||||||
F::one() + n.ln()
|
F::one() + n.ln()
|
||||||
};
|
};
|
||||||
debug_assert!(t > F::zero());
|
debug_assert!(t > F::zero());
|
||||||
Ok(Zipf {
|
Ok(Zipf { s, t, q })
|
||||||
s, t, q
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Inverse cumulative density function
|
/// Inverse cumulative density function
|
||||||
@ -221,7 +234,9 @@ where F: Float, Standard: Distribution<F> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<F> Distribution<F> for Zipf<F>
|
impl<F> Distribution<F> for Zipf<F>
|
||||||
where F: Float, Standard: Distribution<F>
|
where
|
||||||
|
F: Float,
|
||||||
|
Standard: Distribution<F>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
|
||||||
@ -246,9 +261,7 @@ where F: Float, Standard: Distribution<F>
|
|||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(
|
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
|
||||||
distr: D, zero: F, expected: &[F],
|
|
||||||
) {
|
|
||||||
let mut rng = crate::test::rng(213);
|
let mut rng = crate::test::rng(213);
|
||||||
let mut buf = [zero; 4];
|
let mut buf = [zero; 4];
|
||||||
for x in &mut buf {
|
for x in &mut buf {
|
||||||
@ -293,12 +306,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn zeta_value_stability() {
|
fn zeta_value_stability() {
|
||||||
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[
|
test_samples(Zeta::new(1.5).unwrap(), 0f32, &[1.0, 2.0, 1.0, 1.0]);
|
||||||
1.0, 2.0, 1.0, 1.0,
|
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[2.0, 1.0, 1.0, 1.0]);
|
||||||
]);
|
|
||||||
test_samples(Zeta::new(2.0).unwrap(), 0f64, &[
|
|
||||||
2.0, 1.0, 1.0, 1.0,
|
|
||||||
]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
@ -363,12 +372,8 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn zipf_value_stability() {
|
fn zipf_value_stability() {
|
||||||
test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[
|
test_samples(Zipf::new(10, 0.5).unwrap(), 0f32, &[10.0, 2.0, 6.0, 7.0]);
|
||||||
10.0, 2.0, 6.0, 7.0
|
test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[1.0, 2.0, 3.0, 2.0]);
|
||||||
]);
|
|
||||||
test_samples(Zipf::new(10, 2.0).unwrap(), 0f64, &[
|
|
||||||
1.0, 2.0, 3.0, 2.0
|
|
||||||
]);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -57,7 +57,7 @@ fn normal() {
|
|||||||
|
|
||||||
let mut diff = [0.; HIST_LEN];
|
let mut diff = [0.; HIST_LEN];
|
||||||
for (i, n) in hist.normalized_bins().enumerate() {
|
for (i, n) in hist.normalized_bins().enumerate() {
|
||||||
let bin = (n as f64) / (N_SAMPLES as f64);
|
let bin = n / (N_SAMPLES as f64);
|
||||||
diff[i] = (bin - expected[i]).abs();
|
diff[i] = (bin - expected[i]).abs();
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ fn skew_normal() {
|
|||||||
|
|
||||||
let mut diff = [0.; HIST_LEN];
|
let mut diff = [0.; HIST_LEN];
|
||||||
for (i, n) in hist.normalized_bins().enumerate() {
|
for (i, n) in hist.normalized_bins().enumerate() {
|
||||||
let bin = (n as f64) / (N_SAMPLES as f64);
|
let bin = n / (N_SAMPLES as f64);
|
||||||
diff[i] = (bin - expected[i]).abs();
|
diff[i] = (bin - expected[i]).abs();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -16,7 +16,7 @@ pub fn render_u64(data: &[u64], buffer: &mut String) {
|
|||||||
match data.len() {
|
match data.len() {
|
||||||
0 => {
|
0 => {
|
||||||
return;
|
return;
|
||||||
},
|
}
|
||||||
1 => {
|
1 => {
|
||||||
if data[0] == 0 {
|
if data[0] == 0 {
|
||||||
buffer.push(TICKS[0]);
|
buffer.push(TICKS[0]);
|
||||||
@ -24,8 +24,8 @@ pub fn render_u64(data: &[u64], buffer: &mut String) {
|
|||||||
buffer.push(TICKS[N - 1]);
|
buffer.push(TICKS[N - 1]);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
},
|
}
|
||||||
_ => {},
|
_ => {}
|
||||||
}
|
}
|
||||||
let max = data.iter().max().unwrap();
|
let max = data.iter().max().unwrap();
|
||||||
let min = data.iter().min().unwrap();
|
let min = data.iter().min().unwrap();
|
||||||
@ -56,7 +56,7 @@ pub fn render_f64(data: &[f64], buffer: &mut String) {
|
|||||||
match data.len() {
|
match data.len() {
|
||||||
0 => {
|
0 => {
|
||||||
return;
|
return;
|
||||||
},
|
}
|
||||||
1 => {
|
1 => {
|
||||||
if data[0] == 0. {
|
if data[0] == 0. {
|
||||||
buffer.push(TICKS[0]);
|
buffer.push(TICKS[0]);
|
||||||
@ -64,16 +64,14 @@ pub fn render_f64(data: &[f64], buffer: &mut String) {
|
|||||||
buffer.push(TICKS[N - 1]);
|
buffer.push(TICKS[N - 1]);
|
||||||
}
|
}
|
||||||
return;
|
return;
|
||||||
},
|
}
|
||||||
_ => {},
|
_ => {}
|
||||||
}
|
}
|
||||||
for x in data {
|
for x in data {
|
||||||
assert!(x.is_finite(), "can only render finite values");
|
assert!(x.is_finite(), "can only render finite values");
|
||||||
}
|
}
|
||||||
let max = data.iter().fold(
|
let max = data.iter().fold(f64::NEG_INFINITY, |a, &b| a.max(b));
|
||||||
f64::NEG_INFINITY, |a, &b| a.max(b));
|
let min = data.iter().fold(f64::INFINITY, |a, &b| a.min(b));
|
||||||
let min = data.iter().fold(
|
|
||||||
f64::INFINITY, |a, &b| a.min(b));
|
|
||||||
let scale = ((N - 1) as f64) / (max - min);
|
let scale = ((N - 1) as f64) / (max - min);
|
||||||
for x in data {
|
for x in data {
|
||||||
let tick = ((x - min) * scale) as usize;
|
let tick = ((x - min) * scale) as usize;
|
||||||
|
@ -53,9 +53,7 @@ impl<T: ApproxEq> ApproxEq for [T; 3] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_samples<F: Debug + ApproxEq, D: Distribution<F>>(
|
fn test_samples<F: Debug + ApproxEq, D: Distribution<F>>(seed: u64, distr: D, expected: &[F]) {
|
||||||
seed: u64, distr: D, expected: &[F],
|
|
||||||
) {
|
|
||||||
let mut rng = get_rng(seed);
|
let mut rng = get_rng(seed);
|
||||||
for val in expected {
|
for val in expected {
|
||||||
let x = rng.sample(&distr);
|
let x = rng.sample(&distr);
|
||||||
@ -68,16 +66,28 @@ fn binomial_stability() {
|
|||||||
// We have multiple code paths: np < 10, p > 0.5
|
// We have multiple code paths: np < 10, p > 0.5
|
||||||
test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]);
|
test_samples(353, Binomial::new(2, 0.7).unwrap(), &[1, 1, 2, 1]);
|
||||||
test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]);
|
test_samples(353, Binomial::new(20, 0.3).unwrap(), &[7, 7, 5, 7]);
|
||||||
test_samples(353, Binomial::new(2000, 0.6).unwrap(), &[1194, 1208, 1192, 1210]);
|
test_samples(
|
||||||
|
353,
|
||||||
|
Binomial::new(2000, 0.6).unwrap(),
|
||||||
|
&[1194, 1208, 1192, 1210],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn geometric_stability() {
|
fn geometric_stability() {
|
||||||
test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]);
|
test_samples(464, StandardGeometric, &[3, 0, 1, 0, 0, 3, 2, 1, 2, 0]);
|
||||||
|
|
||||||
test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]);
|
test_samples(464, Geometric::new(0.5).unwrap(), &[2, 1, 1, 0, 0, 1, 0, 1]);
|
||||||
test_samples(464, Geometric::new(0.05).unwrap(), &[24, 51, 81, 67, 27, 11, 7, 6]);
|
test_samples(
|
||||||
test_samples(464, Geometric::new(0.95).unwrap(), &[0, 0, 0, 0, 1, 0, 0, 0]);
|
464,
|
||||||
|
Geometric::new(0.05).unwrap(),
|
||||||
|
&[24, 51, 81, 67, 27, 11, 7, 6],
|
||||||
|
);
|
||||||
|
test_samples(
|
||||||
|
464,
|
||||||
|
Geometric::new(0.95).unwrap(),
|
||||||
|
&[0, 0, 0, 0, 1, 0, 0, 0],
|
||||||
|
);
|
||||||
|
|
||||||
// expect non-random behaviour for series of pre-determined trials
|
// expect non-random behaviour for series of pre-determined trials
|
||||||
test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]);
|
test_samples(464, Geometric::new(0.0).unwrap(), &[u64::MAX; 100][..]);
|
||||||
@ -87,260 +97,404 @@ fn geometric_stability() {
|
|||||||
#[test]
|
#[test]
|
||||||
fn hypergeometric_stability() {
|
fn hypergeometric_stability() {
|
||||||
// We have multiple code paths based on the distribution's mode and sample_size
|
// We have multiple code paths based on the distribution's mode and sample_size
|
||||||
test_samples(7221, Hypergeometric::new(99, 33, 8).unwrap(), &[4, 3, 2, 2, 3, 2, 3, 1]); // Algorithm HIN
|
test_samples(
|
||||||
test_samples(7221, Hypergeometric::new(100, 50, 50).unwrap(), &[23, 27, 26, 27, 22, 24, 31, 22]); // Algorithm H2PE
|
7221,
|
||||||
|
Hypergeometric::new(99, 33, 8).unwrap(),
|
||||||
|
&[4, 3, 2, 2, 3, 2, 3, 1],
|
||||||
|
); // Algorithm HIN
|
||||||
|
test_samples(
|
||||||
|
7221,
|
||||||
|
Hypergeometric::new(100, 50, 50).unwrap(),
|
||||||
|
&[23, 27, 26, 27, 22, 24, 31, 22],
|
||||||
|
); // Algorithm H2PE
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn unit_ball_stability() {
|
fn unit_ball_stability() {
|
||||||
test_samples(2, UnitBall, &[
|
test_samples(
|
||||||
[0.018035709265959987f64, -0.4348771383120438, -0.07982762085055706],
|
2,
|
||||||
[0.10588569388223945, -0.4734350111375454, -0.7392104908825501],
|
UnitBall,
|
||||||
[0.11060237642041049, -0.16065642822852677, -0.8444043930440075]
|
&[
|
||||||
]);
|
[
|
||||||
|
0.018035709265959987f64,
|
||||||
|
-0.4348771383120438,
|
||||||
|
-0.07982762085055706,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.10588569388223945,
|
||||||
|
-0.4734350111375454,
|
||||||
|
-0.7392104908825501,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
0.11060237642041049,
|
||||||
|
-0.16065642822852677,
|
||||||
|
-0.8444043930440075,
|
||||||
|
],
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn unit_circle_stability() {
|
fn unit_circle_stability() {
|
||||||
test_samples(2, UnitCircle, &[
|
test_samples(
|
||||||
[-0.9965658683520504f64, -0.08280380447614634],
|
2,
|
||||||
[-0.9790853270389644, -0.20345004884984505],
|
UnitCircle,
|
||||||
[-0.8449189758898707, 0.5348943112253227],
|
&[
|
||||||
]);
|
[-0.9965658683520504f64, -0.08280380447614634],
|
||||||
|
[-0.9790853270389644, -0.20345004884984505],
|
||||||
|
[-0.8449189758898707, 0.5348943112253227],
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn unit_sphere_stability() {
|
fn unit_sphere_stability() {
|
||||||
test_samples(2, UnitSphere, &[
|
test_samples(
|
||||||
[0.03247542860231647f64, -0.7830477442152738, 0.6211131755296027],
|
2,
|
||||||
[-0.09978440840914075, 0.9706650829833128, -0.21875184231323952],
|
UnitSphere,
|
||||||
[0.2735582468624679, 0.9435374242279655, -0.1868234852870203],
|
&[
|
||||||
]);
|
[
|
||||||
|
0.03247542860231647f64,
|
||||||
|
-0.7830477442152738,
|
||||||
|
0.6211131755296027,
|
||||||
|
],
|
||||||
|
[
|
||||||
|
-0.09978440840914075,
|
||||||
|
0.9706650829833128,
|
||||||
|
-0.21875184231323952,
|
||||||
|
],
|
||||||
|
[0.2735582468624679, 0.9435374242279655, -0.1868234852870203],
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn unit_disc_stability() {
|
fn unit_disc_stability() {
|
||||||
test_samples(2, UnitDisc, &[
|
test_samples(
|
||||||
[0.018035709265959987f64, -0.4348771383120438],
|
2,
|
||||||
[-0.07982762085055706, 0.7765329819820659],
|
UnitDisc,
|
||||||
[0.21450745997299503, 0.7398636984333291],
|
&[
|
||||||
]);
|
[0.018035709265959987f64, -0.4348771383120438],
|
||||||
|
[-0.07982762085055706, 0.7765329819820659],
|
||||||
|
[0.21450745997299503, 0.7398636984333291],
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pareto_stability() {
|
fn pareto_stability() {
|
||||||
test_samples(213, Pareto::new(1.0, 1.0).unwrap(), &[
|
test_samples(
|
||||||
1.0423688f32, 2.1235929, 4.132709, 1.4679428,
|
213,
|
||||||
]);
|
Pareto::new(1.0, 1.0).unwrap(),
|
||||||
test_samples(213, Pareto::new(2.0, 0.5).unwrap(), &[
|
&[1.0423688f32, 2.1235929, 4.132709, 1.4679428],
|
||||||
9.019295276219136f64,
|
);
|
||||||
4.3097126018270595,
|
test_samples(
|
||||||
6.837815045397157,
|
213,
|
||||||
105.8826669383772,
|
Pareto::new(2.0, 0.5).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
9.019295276219136f64,
|
||||||
|
4.3097126018270595,
|
||||||
|
6.837815045397157,
|
||||||
|
105.8826669383772,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn poisson_stability() {
|
fn poisson_stability() {
|
||||||
test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]);
|
test_samples(223, Poisson::new(7.0).unwrap(), &[5.0f32, 11.0, 6.0, 5.0]);
|
||||||
test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]);
|
test_samples(223, Poisson::new(7.0).unwrap(), &[9.0f64, 5.0, 7.0, 6.0]);
|
||||||
test_samples(223, Poisson::new(27.0).unwrap(), &[28.0f32, 32.0, 36.0, 36.0]);
|
test_samples(
|
||||||
|
223,
|
||||||
|
Poisson::new(27.0).unwrap(),
|
||||||
|
&[28.0f32, 32.0, 36.0, 36.0],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn triangular_stability() {
|
fn triangular_stability() {
|
||||||
test_samples(860, Triangular::new(2., 10., 3.).unwrap(), &[
|
test_samples(
|
||||||
5.74373257511361f64,
|
860,
|
||||||
7.890059162791258f64,
|
Triangular::new(2., 10., 3.).unwrap(),
|
||||||
4.7256280652553455f64,
|
&[
|
||||||
2.9474808121184077f64,
|
5.74373257511361f64,
|
||||||
3.058301946314053f64,
|
7.890059162791258f64,
|
||||||
]);
|
4.7256280652553455f64,
|
||||||
|
2.9474808121184077f64,
|
||||||
|
3.058301946314053f64,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn normal_inverse_gaussian_stability() {
|
fn normal_inverse_gaussian_stability() {
|
||||||
test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[
|
test_samples(
|
||||||
0.6568966f32, 1.3744819, 2.216063, 0.11488572,
|
213,
|
||||||
]);
|
NormalInverseGaussian::new(2.0, 1.0).unwrap(),
|
||||||
test_samples(213, NormalInverseGaussian::new(2.0, 1.0).unwrap(), &[
|
&[0.6568966f32, 1.3744819, 2.216063, 0.11488572],
|
||||||
0.6838707059642927f64,
|
);
|
||||||
2.4447306460569784,
|
test_samples(
|
||||||
0.2361045023235968,
|
213,
|
||||||
1.7774534624785319,
|
NormalInverseGaussian::new(2.0, 1.0).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
0.6838707059642927f64,
|
||||||
|
2.4447306460569784,
|
||||||
|
0.2361045023235968,
|
||||||
|
1.7774534624785319,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pert_stability() {
|
fn pert_stability() {
|
||||||
// mean = 4, var = 12/7
|
// mean = 4, var = 12/7
|
||||||
test_samples(860, Pert::new(2., 10., 3.).unwrap(), &[
|
test_samples(
|
||||||
4.908681667460367,
|
860,
|
||||||
4.014196196158352,
|
Pert::new(2., 10., 3.).unwrap(),
|
||||||
2.6489397149197234,
|
&[
|
||||||
3.4569780580044727,
|
4.908681667460367,
|
||||||
4.242864311947118,
|
4.014196196158352,
|
||||||
]);
|
2.6489397149197234,
|
||||||
|
3.4569780580044727,
|
||||||
|
4.242864311947118,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn inverse_gaussian_stability() {
|
fn inverse_gaussian_stability() {
|
||||||
test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(),&[
|
test_samples(
|
||||||
0.9339157f32, 1.108113, 0.50864697, 0.39849377,
|
213,
|
||||||
]);
|
InverseGaussian::new(1.0, 3.0).unwrap(),
|
||||||
test_samples(213, InverseGaussian::new(1.0, 3.0).unwrap(), &[
|
&[0.9339157f32, 1.108113, 0.50864697, 0.39849377],
|
||||||
1.0707604954722476f64,
|
);
|
||||||
0.9628140605340697,
|
test_samples(
|
||||||
0.4069687656468226,
|
213,
|
||||||
0.660283852985818,
|
InverseGaussian::new(1.0, 3.0).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
1.0707604954722476f64,
|
||||||
|
0.9628140605340697,
|
||||||
|
0.4069687656468226,
|
||||||
|
0.660283852985818,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn gamma_stability() {
|
fn gamma_stability() {
|
||||||
// Gamma has 3 cases: shape == 1, shape < 1, shape > 1
|
// Gamma has 3 cases: shape == 1, shape < 1, shape > 1
|
||||||
test_samples(223, Gamma::new(1.0, 5.0).unwrap(), &[
|
test_samples(
|
||||||
5.398085f32, 9.162783, 0.2300583, 1.7235851,
|
223,
|
||||||
]);
|
Gamma::new(1.0, 5.0).unwrap(),
|
||||||
test_samples(223, Gamma::new(0.8, 5.0).unwrap(), &[
|
&[5.398085f32, 9.162783, 0.2300583, 1.7235851],
|
||||||
0.5051203f32, 0.9048302, 3.095812, 1.8566116,
|
);
|
||||||
]);
|
test_samples(
|
||||||
test_samples(223, Gamma::new(1.1, 5.0).unwrap(), &[
|
223,
|
||||||
7.783878094584059f64,
|
Gamma::new(0.8, 5.0).unwrap(),
|
||||||
1.4939528171618057,
|
&[0.5051203f32, 0.9048302, 3.095812, 1.8566116],
|
||||||
8.638017638857592,
|
);
|
||||||
3.0949337228829004,
|
test_samples(
|
||||||
]);
|
223,
|
||||||
|
Gamma::new(1.1, 5.0).unwrap(),
|
||||||
|
&[
|
||||||
|
7.783878094584059f64,
|
||||||
|
1.4939528171618057,
|
||||||
|
8.638017638857592,
|
||||||
|
3.0949337228829004,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
// ChiSquared has 2 cases: k == 1, k != 1
|
// ChiSquared has 2 cases: k == 1, k != 1
|
||||||
test_samples(223, ChiSquared::new(1.0).unwrap(), &[
|
test_samples(
|
||||||
0.4893526200348249f64,
|
223,
|
||||||
1.635249736808788,
|
ChiSquared::new(1.0).unwrap(),
|
||||||
0.5013580219361969,
|
&[
|
||||||
0.1457735613733489,
|
0.4893526200348249f64,
|
||||||
]);
|
1.635249736808788,
|
||||||
test_samples(223, ChiSquared::new(0.1).unwrap(), &[
|
0.5013580219361969,
|
||||||
0.014824404726978617f64,
|
0.1457735613733489,
|
||||||
0.021602123937134326,
|
],
|
||||||
0.0000003431429746851693,
|
);
|
||||||
0.00000002291755769542258,
|
test_samples(
|
||||||
]);
|
223,
|
||||||
test_samples(223, ChiSquared::new(10.0).unwrap(), &[
|
ChiSquared::new(0.1).unwrap(),
|
||||||
12.693656f32, 6.812016, 11.082001, 12.436167,
|
&[
|
||||||
]);
|
0.014824404726978617f64,
|
||||||
|
0.021602123937134326,
|
||||||
|
0.0000003431429746851693,
|
||||||
|
0.00000002291755769542258,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
test_samples(
|
||||||
|
223,
|
||||||
|
ChiSquared::new(10.0).unwrap(),
|
||||||
|
&[12.693656f32, 6.812016, 11.082001, 12.436167],
|
||||||
|
);
|
||||||
|
|
||||||
// FisherF has same special cases as ChiSquared on each param
|
// FisherF has same special cases as ChiSquared on each param
|
||||||
test_samples(223, FisherF::new(1.0, 13.5).unwrap(), &[
|
test_samples(
|
||||||
0.32283646f32, 0.048049655, 0.0788893, 1.817178,
|
223,
|
||||||
]);
|
FisherF::new(1.0, 13.5).unwrap(),
|
||||||
test_samples(223, FisherF::new(1.0, 1.0).unwrap(), &[
|
&[0.32283646f32, 0.048049655, 0.0788893, 1.817178],
|
||||||
0.29925257f32, 3.4392934, 9.567652, 0.020074,
|
);
|
||||||
]);
|
test_samples(
|
||||||
test_samples(223, FisherF::new(0.7, 13.5).unwrap(), &[
|
223,
|
||||||
3.3196593155045124f64,
|
FisherF::new(1.0, 1.0).unwrap(),
|
||||||
0.3409169916262829,
|
&[0.29925257f32, 3.4392934, 9.567652, 0.020074],
|
||||||
0.03377989856426519,
|
);
|
||||||
0.00004041672861036937,
|
test_samples(
|
||||||
]);
|
223,
|
||||||
|
FisherF::new(0.7, 13.5).unwrap(),
|
||||||
|
&[
|
||||||
|
3.3196593155045124f64,
|
||||||
|
0.3409169916262829,
|
||||||
|
0.03377989856426519,
|
||||||
|
0.00004041672861036937,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
// StudentT has same special cases as ChiSquared
|
// StudentT has same special cases as ChiSquared
|
||||||
test_samples(223, StudentT::new(1.0).unwrap(), &[
|
test_samples(
|
||||||
0.54703987f32, -1.8545331, 3.093162, -0.14168274,
|
223,
|
||||||
]);
|
StudentT::new(1.0).unwrap(),
|
||||||
test_samples(223, StudentT::new(1.1).unwrap(), &[
|
&[0.54703987f32, -1.8545331, 3.093162, -0.14168274],
|
||||||
0.7729195887949754f64,
|
);
|
||||||
1.2606210611616204,
|
test_samples(
|
||||||
-1.7553606501113175,
|
223,
|
||||||
-2.377641221169782,
|
StudentT::new(1.1).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
0.7729195887949754f64,
|
||||||
|
1.2606210611616204,
|
||||||
|
-1.7553606501113175,
|
||||||
|
-2.377641221169782,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
// Beta has two special cases:
|
// Beta has two special cases:
|
||||||
//
|
//
|
||||||
// 1. min(alpha, beta) <= 1
|
// 1. min(alpha, beta) <= 1
|
||||||
// 2. min(alpha, beta) > 1
|
// 2. min(alpha, beta) > 1
|
||||||
test_samples(223, Beta::new(1.0, 0.8).unwrap(), &[
|
test_samples(
|
||||||
0.8300703726659456,
|
223,
|
||||||
0.8134131062097899,
|
Beta::new(1.0, 0.8).unwrap(),
|
||||||
0.47912589330631555,
|
&[
|
||||||
0.25323238071138526,
|
0.8300703726659456,
|
||||||
]);
|
0.8134131062097899,
|
||||||
test_samples(223, Beta::new(3.0, 1.2).unwrap(), &[
|
0.47912589330631555,
|
||||||
0.49563509121756827,
|
0.25323238071138526,
|
||||||
0.9551305482256759,
|
],
|
||||||
0.5151181353461637,
|
);
|
||||||
0.7551732971235077,
|
test_samples(
|
||||||
]);
|
223,
|
||||||
|
Beta::new(3.0, 1.2).unwrap(),
|
||||||
|
&[
|
||||||
|
0.49563509121756827,
|
||||||
|
0.9551305482256759,
|
||||||
|
0.5151181353461637,
|
||||||
|
0.7551732971235077,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn exponential_stability() {
|
fn exponential_stability() {
|
||||||
test_samples(223, Exp1, &[
|
test_samples(223, Exp1, &[1.079617f32, 1.8325565, 0.04601166, 0.34471703]);
|
||||||
1.079617f32, 1.8325565, 0.04601166, 0.34471703,
|
test_samples(
|
||||||
]);
|
223,
|
||||||
test_samples(223, Exp1, &[
|
Exp1,
|
||||||
1.0796170642388276f64,
|
&[
|
||||||
1.8325565304274,
|
1.0796170642388276f64,
|
||||||
0.04601166186842716,
|
1.8325565304274,
|
||||||
0.3447170217100157,
|
0.04601166186842716,
|
||||||
]);
|
0.3447170217100157,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(223, Exp::new(2.0).unwrap(), &[
|
test_samples(
|
||||||
0.5398085f32, 0.91627824, 0.02300583, 0.17235851,
|
223,
|
||||||
]);
|
Exp::new(2.0).unwrap(),
|
||||||
test_samples(223, Exp::new(1.0).unwrap(), &[
|
&[0.5398085f32, 0.91627824, 0.02300583, 0.17235851],
|
||||||
1.0796170642388276f64,
|
);
|
||||||
1.8325565304274,
|
test_samples(
|
||||||
0.04601166186842716,
|
223,
|
||||||
0.3447170217100157,
|
Exp::new(1.0).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
1.0796170642388276f64,
|
||||||
|
1.8325565304274,
|
||||||
|
0.04601166186842716,
|
||||||
|
0.3447170217100157,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn normal_stability() {
|
fn normal_stability() {
|
||||||
test_samples(213, StandardNormal, &[
|
test_samples(
|
||||||
-0.11844189f32, 0.781378, 0.06563994, -1.1932899,
|
213,
|
||||||
]);
|
StandardNormal,
|
||||||
test_samples(213, StandardNormal, &[
|
&[-0.11844189f32, 0.781378, 0.06563994, -1.1932899],
|
||||||
-0.11844188827977231f64,
|
);
|
||||||
0.7813779637772346,
|
test_samples(
|
||||||
0.06563993969580051,
|
213,
|
||||||
-1.1932899004186373,
|
StandardNormal,
|
||||||
]);
|
&[
|
||||||
|
-0.11844188827977231f64,
|
||||||
|
0.7813779637772346,
|
||||||
|
0.06563993969580051,
|
||||||
|
-1.1932899004186373,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(213, Normal::new(0.0, 1.0).unwrap(), &[
|
test_samples(
|
||||||
-0.11844189f32, 0.781378, 0.06563994, -1.1932899,
|
213,
|
||||||
]);
|
Normal::new(0.0, 1.0).unwrap(),
|
||||||
test_samples(213, Normal::new(2.0, 0.5).unwrap(), &[
|
&[-0.11844189f32, 0.781378, 0.06563994, -1.1932899],
|
||||||
1.940779055860114f64,
|
);
|
||||||
2.3906889818886174,
|
test_samples(
|
||||||
2.0328199698479,
|
213,
|
||||||
1.4033550497906813,
|
Normal::new(2.0, 0.5).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
1.940779055860114f64,
|
||||||
|
2.3906889818886174,
|
||||||
|
2.0328199698479,
|
||||||
|
1.4033550497906813,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(213, LogNormal::new(0.0, 1.0).unwrap(), &[
|
test_samples(
|
||||||
0.88830346f32, 2.1844804, 1.0678421, 0.30322206,
|
213,
|
||||||
]);
|
LogNormal::new(0.0, 1.0).unwrap(),
|
||||||
test_samples(213, LogNormal::new(2.0, 0.5).unwrap(), &[
|
&[0.88830346f32, 2.1844804, 1.0678421, 0.30322206],
|
||||||
6.964174338639032f64,
|
);
|
||||||
10.921015733601452,
|
test_samples(
|
||||||
7.6355881556915906,
|
213,
|
||||||
4.068828213584092,
|
LogNormal::new(2.0, 0.5).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
6.964174338639032f64,
|
||||||
|
10.921015733601452,
|
||||||
|
7.6355881556915906,
|
||||||
|
4.068828213584092,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn weibull_stability() {
|
fn weibull_stability() {
|
||||||
test_samples(213, Weibull::new(1.0, 1.0).unwrap(), &[
|
test_samples(
|
||||||
0.041495778f32, 0.7531094, 1.4189332, 0.38386202,
|
213,
|
||||||
]);
|
Weibull::new(1.0, 1.0).unwrap(),
|
||||||
test_samples(213, Weibull::new(2.0, 0.5).unwrap(), &[
|
&[0.041495778f32, 0.7531094, 1.4189332, 0.38386202],
|
||||||
1.1343478702739669f64,
|
);
|
||||||
0.29470010050655226,
|
test_samples(
|
||||||
0.7556151370284702,
|
213,
|
||||||
7.877212340241561,
|
Weibull::new(2.0, 0.5).unwrap(),
|
||||||
]);
|
&[
|
||||||
|
1.1343478702739669f64,
|
||||||
|
0.29470010050655226,
|
||||||
|
0.7556151370284702,
|
||||||
|
7.877212340241561,
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
@ -351,13 +505,16 @@ fn dirichlet_stability() {
|
|||||||
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
|
rng.sample(Dirichlet::new([1.0, 2.0, 3.0]).unwrap()),
|
||||||
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
|
[0.12941567177708177, 0.4702121891675036, 0.4003721390554146]
|
||||||
);
|
);
|
||||||
assert_eq!(rng.sample(Dirichlet::new([8.0; 5]).unwrap()), [
|
assert_eq!(
|
||||||
0.17684200044809556,
|
rng.sample(Dirichlet::new([8.0; 5]).unwrap()),
|
||||||
0.29915953935953055,
|
[
|
||||||
0.1832858056608014,
|
0.17684200044809556,
|
||||||
0.1425623503573967,
|
0.29915953935953055,
|
||||||
0.19815030417417595
|
0.1832858056608014,
|
||||||
]);
|
0.1425623503573967,
|
||||||
|
0.19815030417417595
|
||||||
|
]
|
||||||
|
);
|
||||||
// Test stability for the case where all alphas are less than 0.1.
|
// Test stability for the case where all alphas are less than 0.1.
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()),
|
rng.sample(Dirichlet::new([0.05, 0.025, 0.075, 0.05]).unwrap()),
|
||||||
@ -372,12 +529,16 @@ fn dirichlet_stability() {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn cauchy_stability() {
|
fn cauchy_stability() {
|
||||||
test_samples(353, Cauchy::new(100f64, 10.0).unwrap(), &[
|
test_samples(
|
||||||
77.93369152808678f64,
|
353,
|
||||||
90.1606912098641,
|
Cauchy::new(100f64, 10.0).unwrap(),
|
||||||
125.31516221323625,
|
&[
|
||||||
86.10217834773925,
|
77.93369152808678f64,
|
||||||
]);
|
90.1606912098641,
|
||||||
|
125.31516221323625,
|
||||||
|
86.10217834773925,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
// Unfortunately this test is not fully portable due to reliance on the
|
// Unfortunately this test is not fully portable due to reliance on the
|
||||||
// system's implementation of tanf (see doc on Cauchy struct).
|
// system's implementation of tanf (see doc on Cauchy struct).
|
||||||
@ -386,7 +547,7 @@ fn cauchy_stability() {
|
|||||||
let mut rng = get_rng(353);
|
let mut rng = get_rng(353);
|
||||||
let expected = [15.023088, -5.446413, 3.7092876, 3.112482];
|
let expected = [15.023088, -5.446413, 3.7092876, 3.112482];
|
||||||
for &a in expected.iter() {
|
for &a in expected.iter() {
|
||||||
let b = rng.sample(&distr);
|
let b = rng.sample(distr);
|
||||||
assert_almost_eq!(a, b, 1e-5);
|
assert_almost_eq!(a, b, 1e-5);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -15,7 +15,8 @@ const MULTIPLIER: u128 = 0x2360_ED05_1FC6_5DA4_4385_DF64_9FCC_F645;
|
|||||||
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use rand_core::{impls, le, RngCore, SeedableRng};
|
use rand_core::{impls, le, RngCore, SeedableRng};
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A PCG random number generator (XSL RR 128/64 (LCG) variant).
|
/// A PCG random number generator (XSL RR 128/64 (LCG) variant).
|
||||||
///
|
///
|
||||||
@ -153,7 +154,6 @@ impl RngCore for Lcg128Xsl64 {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// A PCG random number generator (XSL 128/64 (MCG) variant).
|
/// A PCG random number generator (XSL 128/64 (MCG) variant).
|
||||||
///
|
///
|
||||||
/// Permuted Congruential Generator with 128-bit state, internal Multiplicative
|
/// Permuted Congruential Generator with 128-bit state, internal Multiplicative
|
||||||
|
@ -15,7 +15,8 @@ const MULTIPLIER: u64 = 15750249268501108917;
|
|||||||
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use rand_core::{impls, le, RngCore, SeedableRng};
|
use rand_core::{impls, le, RngCore, SeedableRng};
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A PCG random number generator (CM DXSM 128/64 (LCG) variant).
|
/// A PCG random number generator (CM DXSM 128/64 (LCG) variant).
|
||||||
///
|
///
|
||||||
|
@ -12,7 +12,8 @@
|
|||||||
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use rand_core::{impls, le, RngCore, SeedableRng};
|
use rand_core::{impls, le, RngCore, SeedableRng};
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// This is the default multiplier used by PCG for 64-bit state.
|
// This is the default multiplier used by PCG for 64-bit state.
|
||||||
const MULTIPLIER: u64 = 6364136223846793005;
|
const MULTIPLIER: u64 = 6364136223846793005;
|
||||||
|
32
rustfmt.toml
32
rustfmt.toml
@ -1,32 +0,0 @@
|
|||||||
# This rustfmt file is added for configuration, but in practice much of our
|
|
||||||
# code is hand-formatted, frequently with more readable results.
|
|
||||||
|
|
||||||
# Comments:
|
|
||||||
normalize_comments = true
|
|
||||||
wrap_comments = false
|
|
||||||
comment_width = 90 # small excess is okay but prefer 80
|
|
||||||
|
|
||||||
# Arguments:
|
|
||||||
use_small_heuristics = "Default"
|
|
||||||
# TODO: single line functions only where short, please?
|
|
||||||
# https://github.com/rust-lang/rustfmt/issues/3358
|
|
||||||
fn_single_line = false
|
|
||||||
fn_args_layout = "Compressed"
|
|
||||||
overflow_delimited_expr = true
|
|
||||||
where_single_line = true
|
|
||||||
|
|
||||||
# enum_discrim_align_threshold = 20
|
|
||||||
# struct_field_align_threshold = 20
|
|
||||||
|
|
||||||
# Compatibility:
|
|
||||||
edition = "2021"
|
|
||||||
|
|
||||||
# Misc:
|
|
||||||
inline_attribute_width = 80
|
|
||||||
blank_lines_upper_bound = 2
|
|
||||||
reorder_impl_items = true
|
|
||||||
# report_todo = "Unnumbered"
|
|
||||||
# report_fixme = "Unnumbered"
|
|
||||||
|
|
||||||
# Ignored files:
|
|
||||||
ignore = []
|
|
@ -13,7 +13,7 @@ use crate::Rng;
|
|||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// The Bernoulli distribution.
|
/// The Bernoulli distribution.
|
||||||
///
|
///
|
||||||
@ -151,7 +151,8 @@ mod test {
|
|||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
fn test_serializing_deserializing_bernoulli() {
|
fn test_serializing_deserializing_bernoulli() {
|
||||||
let coin_flip = Bernoulli::new(0.5).unwrap();
|
let coin_flip = Bernoulli::new(0.5).unwrap();
|
||||||
let de_coin_flip: Bernoulli = bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
|
let de_coin_flip: Bernoulli =
|
||||||
|
bincode::deserialize(&bincode::serialize(&coin_flip).unwrap()).unwrap();
|
||||||
|
|
||||||
assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
|
assert_eq!(coin_flip.p_int, de_coin_flip.p_int);
|
||||||
}
|
}
|
||||||
@ -208,9 +209,10 @@ mod test {
|
|||||||
for x in &mut buf {
|
for x in &mut buf {
|
||||||
*x = rng.sample(distr);
|
*x = rng.sample(distr);
|
||||||
}
|
}
|
||||||
assert_eq!(buf, [
|
assert_eq!(
|
||||||
true, false, false, true, false, false, true, true, true, true
|
buf,
|
||||||
]);
|
[true, false, false, true, false, false, true, true, true, true]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
@ -10,7 +10,8 @@
|
|||||||
//! Distribution trait and associates
|
//! Distribution trait and associates
|
||||||
|
|
||||||
use crate::Rng;
|
use crate::Rng;
|
||||||
#[cfg(feature = "alloc")] use alloc::string::String;
|
#[cfg(feature = "alloc")]
|
||||||
|
use alloc::string::String;
|
||||||
use core::iter;
|
use core::iter;
|
||||||
|
|
||||||
/// Types (distributions) that can be used to create a random instance of `T`.
|
/// Types (distributions) that can be used to create a random instance of `T`.
|
||||||
|
@ -8,14 +8,15 @@
|
|||||||
|
|
||||||
//! Basic floating-point number distributions
|
//! Basic floating-point number distributions
|
||||||
|
|
||||||
use crate::distributions::utils::{IntAsSIMD, FloatAsSIMD, FloatSIMDUtils};
|
use crate::distributions::utils::{FloatAsSIMD, FloatSIMDUtils, IntAsSIMD};
|
||||||
use crate::distributions::{Distribution, Standard};
|
use crate::distributions::{Distribution, Standard};
|
||||||
use crate::Rng;
|
use crate::Rng;
|
||||||
use core::mem;
|
use core::mem;
|
||||||
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
|
#[cfg(feature = "simd_support")]
|
||||||
|
use core::simd::prelude::*;
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A distribution to sample floating point numbers uniformly in the half-open
|
/// A distribution to sample floating point numbers uniformly in the half-open
|
||||||
/// interval `(0, 1]`, i.e. including 1 but not 0.
|
/// interval `(0, 1]`, i.e. including 1 but not 0.
|
||||||
@ -72,7 +73,6 @@ pub struct OpenClosed01;
|
|||||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||||
pub struct Open01;
|
pub struct Open01;
|
||||||
|
|
||||||
|
|
||||||
// This trait is needed by both this lib and rand_distr hence is a hidden export
|
// This trait is needed by both this lib and rand_distr hence is a hidden export
|
||||||
#[doc(hidden)]
|
#[doc(hidden)]
|
||||||
pub trait IntoFloat {
|
pub trait IntoFloat {
|
||||||
@ -146,12 +146,11 @@ macro_rules! float_impls {
|
|||||||
// Transmute-based method; 23/52 random bits; (0, 1) interval.
|
// Transmute-based method; 23/52 random bits; (0, 1) interval.
|
||||||
// We use the most significant bits because for simple RNGs
|
// We use the most significant bits because for simple RNGs
|
||||||
// those are usually more random.
|
// those are usually more random.
|
||||||
use core::$f_scalar::EPSILON;
|
|
||||||
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
|
let float_size = mem::size_of::<$f_scalar>() as $u_scalar * 8;
|
||||||
|
|
||||||
let value: $uty = rng.random();
|
let value: $uty = rng.random();
|
||||||
let fraction = value >> $uty::splat(float_size - $fraction_bits);
|
let fraction = value >> $uty::splat(float_size - $fraction_bits);
|
||||||
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - EPSILON / 2.0)
|
fraction.into_float_with_exponent(0) - $ty::splat(1.0 - $f_scalar::EPSILON / 2.0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -210,9 +209,15 @@ mod tests {
|
|||||||
let mut zeros = StepRng::new(0, 0);
|
let mut zeros = StepRng::new(0, 0);
|
||||||
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
|
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
|
||||||
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
|
let mut one = StepRng::new(1 << 9 | 1 << (9 + 32), 0);
|
||||||
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
|
assert_eq!(
|
||||||
|
one.sample::<$ty, _>(Open01),
|
||||||
|
$EPSILON / two * $ty::splat(3.0)
|
||||||
|
);
|
||||||
let mut max = StepRng::new(!0, 0);
|
let mut max = StepRng::new(!0, 0);
|
||||||
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
|
assert_eq!(
|
||||||
|
max.sample::<$ty, _>(Open01),
|
||||||
|
$ty::splat(1.0) - $EPSILON / two
|
||||||
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -252,9 +257,15 @@ mod tests {
|
|||||||
let mut zeros = StepRng::new(0, 0);
|
let mut zeros = StepRng::new(0, 0);
|
||||||
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
|
assert_eq!(zeros.sample::<$ty, _>(Open01), $ZERO + $EPSILON / two);
|
||||||
let mut one = StepRng::new(1 << 12, 0);
|
let mut one = StepRng::new(1 << 12, 0);
|
||||||
assert_eq!(one.sample::<$ty, _>(Open01), $EPSILON / two * $ty::splat(3.0));
|
assert_eq!(
|
||||||
|
one.sample::<$ty, _>(Open01),
|
||||||
|
$EPSILON / two * $ty::splat(3.0)
|
||||||
|
);
|
||||||
let mut max = StepRng::new(!0, 0);
|
let mut max = StepRng::new(!0, 0);
|
||||||
assert_eq!(max.sample::<$ty, _>(Open01), $ty::splat(1.0) - $EPSILON / two);
|
assert_eq!(
|
||||||
|
max.sample::<$ty, _>(Open01),
|
||||||
|
$ty::splat(1.0) - $EPSILON / two
|
||||||
|
);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
@ -269,7 +280,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>(
|
fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>(
|
||||||
distr: &D, zero: T, expected: &[T],
|
distr: &D,
|
||||||
|
zero: T,
|
||||||
|
expected: &[T],
|
||||||
) {
|
) {
|
||||||
let mut rng = crate::test::rng(0x6f44f5646c2a7334);
|
let mut rng = crate::test::rng(0x6f44f5646c2a7334);
|
||||||
let mut buf = [zero; 3];
|
let mut buf = [zero; 3];
|
||||||
@ -280,25 +293,25 @@ mod tests {
|
|||||||
}
|
}
|
||||||
|
|
||||||
test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]);
|
test_samples(&Standard, 0f32, &[0.0035963655, 0.7346052, 0.09778172]);
|
||||||
test_samples(&Standard, 0f64, &[
|
test_samples(
|
||||||
0.7346051961657583,
|
&Standard,
|
||||||
0.20298547462974248,
|
0f64,
|
||||||
0.8166436635290655,
|
&[0.7346051961657583, 0.20298547462974248, 0.8166436635290655],
|
||||||
]);
|
);
|
||||||
|
|
||||||
test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]);
|
test_samples(&OpenClosed01, 0f32, &[0.003596425, 0.73460525, 0.09778178]);
|
||||||
test_samples(&OpenClosed01, 0f64, &[
|
test_samples(
|
||||||
0.7346051961657584,
|
&OpenClosed01,
|
||||||
0.2029854746297426,
|
0f64,
|
||||||
0.8166436635290656,
|
&[0.7346051961657584, 0.2029854746297426, 0.8166436635290656],
|
||||||
]);
|
);
|
||||||
|
|
||||||
test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]);
|
test_samples(&Open01, 0f32, &[0.0035963655, 0.73460525, 0.09778172]);
|
||||||
test_samples(&Open01, 0f64, &[
|
test_samples(
|
||||||
0.7346051961657584,
|
&Open01,
|
||||||
0.20298547462974248,
|
0f64,
|
||||||
0.8166436635290656,
|
&[0.7346051961657584, 0.20298547462974248, 0.8166436635290656],
|
||||||
]);
|
);
|
||||||
|
|
||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
{
|
{
|
||||||
@ -306,17 +319,25 @@ mod tests {
|
|||||||
// non-SIMD types; we assume this pattern continues across all
|
// non-SIMD types; we assume this pattern continues across all
|
||||||
// SIMD types.
|
// SIMD types.
|
||||||
|
|
||||||
test_samples(&Standard, f32x2::from([0.0, 0.0]), &[
|
test_samples(
|
||||||
f32x2::from([0.0035963655, 0.7346052]),
|
&Standard,
|
||||||
f32x2::from([0.09778172, 0.20298547]),
|
f32x2::from([0.0, 0.0]),
|
||||||
f32x2::from([0.34296435, 0.81664366]),
|
&[
|
||||||
]);
|
f32x2::from([0.0035963655, 0.7346052]),
|
||||||
|
f32x2::from([0.09778172, 0.20298547]),
|
||||||
|
f32x2::from([0.34296435, 0.81664366]),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(&Standard, f64x2::from([0.0, 0.0]), &[
|
test_samples(
|
||||||
f64x2::from([0.7346051961657583, 0.20298547462974248]),
|
&Standard,
|
||||||
f64x2::from([0.8166436635290655, 0.7423708925400552]),
|
f64x2::from([0.0, 0.0]),
|
||||||
f64x2::from([0.16387782224016323, 0.9087068770169618]),
|
&[
|
||||||
]);
|
f64x2::from([0.7346051961657583, 0.20298547462974248]),
|
||||||
|
f64x2::from([0.8166436635290655, 0.7423708925400552]),
|
||||||
|
f64x2::from([0.16387782224016323, 0.9087068770169618]),
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -19,10 +19,11 @@ use core::arch::x86_64::__m512i;
|
|||||||
#[cfg(target_arch = "x86_64")]
|
#[cfg(target_arch = "x86_64")]
|
||||||
use core::arch::x86_64::{__m128i, __m256i};
|
use core::arch::x86_64::{__m128i, __m256i};
|
||||||
use core::num::{
|
use core::num::{
|
||||||
NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize,NonZeroU128,
|
NonZeroI128, NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize, NonZeroU128,
|
||||||
NonZeroI16, NonZeroI32, NonZeroI64, NonZeroI8, NonZeroIsize,NonZeroI128
|
NonZeroU16, NonZeroU32, NonZeroU64, NonZeroU8, NonZeroUsize,
|
||||||
};
|
};
|
||||||
#[cfg(feature = "simd_support")] use core::simd::*;
|
#[cfg(feature = "simd_support")]
|
||||||
|
use core::simd::*;
|
||||||
|
|
||||||
impl Distribution<u8> for Standard {
|
impl Distribution<u8> for Standard {
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -211,7 +212,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<T: Copy + core::fmt::Debug + PartialEq>(zero: T, expected: &[T])
|
fn test_samples<T: Copy + core::fmt::Debug + PartialEq>(zero: T, expected: &[T])
|
||||||
where Standard: Distribution<T> {
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
|
{
|
||||||
let mut rng = crate::test::rng(807);
|
let mut rng = crate::test::rng(807);
|
||||||
let mut buf = [zero; 3];
|
let mut buf = [zero; 3];
|
||||||
for x in &mut buf {
|
for x in &mut buf {
|
||||||
@ -223,24 +226,33 @@ mod tests {
|
|||||||
test_samples(0u8, &[9, 247, 111]);
|
test_samples(0u8, &[9, 247, 111]);
|
||||||
test_samples(0u16, &[32265, 42999, 38255]);
|
test_samples(0u16, &[32265, 42999, 38255]);
|
||||||
test_samples(0u32, &[2220326409, 2575017975, 2018088303]);
|
test_samples(0u32, &[2220326409, 2575017975, 2018088303]);
|
||||||
test_samples(0u64, &[
|
test_samples(
|
||||||
11059617991457472009,
|
0u64,
|
||||||
16096616328739788143,
|
&[
|
||||||
1487364411147516184,
|
11059617991457472009,
|
||||||
]);
|
16096616328739788143,
|
||||||
test_samples(0u128, &[
|
1487364411147516184,
|
||||||
296930161868957086625409848350820761097,
|
],
|
||||||
145644820879247630242265036535529306392,
|
);
|
||||||
111087889832015897993126088499035356354,
|
test_samples(
|
||||||
]);
|
0u128,
|
||||||
|
&[
|
||||||
|
296930161868957086625409848350820761097,
|
||||||
|
145644820879247630242265036535529306392,
|
||||||
|
111087889832015897993126088499035356354,
|
||||||
|
],
|
||||||
|
);
|
||||||
#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))]
|
#[cfg(any(target_pointer_width = "32", target_pointer_width = "16"))]
|
||||||
test_samples(0usize, &[2220326409, 2575017975, 2018088303]);
|
test_samples(0usize, &[2220326409, 2575017975, 2018088303]);
|
||||||
#[cfg(target_pointer_width = "64")]
|
#[cfg(target_pointer_width = "64")]
|
||||||
test_samples(0usize, &[
|
test_samples(
|
||||||
11059617991457472009,
|
0usize,
|
||||||
16096616328739788143,
|
&[
|
||||||
1487364411147516184,
|
11059617991457472009,
|
||||||
]);
|
16096616328739788143,
|
||||||
|
1487364411147516184,
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(0i8, &[9, -9, 111]);
|
test_samples(0i8, &[9, -9, 111]);
|
||||||
// Skip further i* types: they are simple reinterpretation of u* samples
|
// Skip further i* types: they are simple reinterpretation of u* samples
|
||||||
@ -249,49 +261,58 @@ mod tests {
|
|||||||
{
|
{
|
||||||
// We only test a sub-set of types here and make assumptions about the rest.
|
// We only test a sub-set of types here and make assumptions about the rest.
|
||||||
|
|
||||||
test_samples(u8x4::default(), &[
|
test_samples(
|
||||||
u8x4::from([9, 126, 87, 132]),
|
u8x4::default(),
|
||||||
u8x4::from([247, 167, 123, 153]),
|
&[
|
||||||
u8x4::from([111, 149, 73, 120]),
|
u8x4::from([9, 126, 87, 132]),
|
||||||
]);
|
u8x4::from([247, 167, 123, 153]),
|
||||||
test_samples(u8x8::default(), &[
|
u8x4::from([111, 149, 73, 120]),
|
||||||
u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]),
|
],
|
||||||
u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]),
|
);
|
||||||
u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]),
|
test_samples(
|
||||||
]);
|
u8x8::default(),
|
||||||
|
&[
|
||||||
|
u8x8::from([9, 126, 87, 132, 247, 167, 123, 153]),
|
||||||
|
u8x8::from([111, 149, 73, 120, 68, 171, 98, 223]),
|
||||||
|
u8x8::from([24, 121, 1, 50, 13, 46, 164, 20]),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(i64x8::default(), &[
|
test_samples(
|
||||||
i64x8::from([
|
i64x8::default(),
|
||||||
-7387126082252079607,
|
&[
|
||||||
-2350127744969763473,
|
i64x8::from([
|
||||||
1487364411147516184,
|
-7387126082252079607,
|
||||||
7895421560427121838,
|
-2350127744969763473,
|
||||||
602190064936008898,
|
1487364411147516184,
|
||||||
6022086574635100741,
|
7895421560427121838,
|
||||||
-5080089175222015595,
|
602190064936008898,
|
||||||
-4066367846667249123,
|
6022086574635100741,
|
||||||
]),
|
-5080089175222015595,
|
||||||
i64x8::from([
|
-4066367846667249123,
|
||||||
9180885022207963908,
|
]),
|
||||||
3095981199532211089,
|
i64x8::from([
|
||||||
6586075293021332726,
|
9180885022207963908,
|
||||||
419343203796414657,
|
3095981199532211089,
|
||||||
3186951873057035255,
|
6586075293021332726,
|
||||||
5287129228749947252,
|
419343203796414657,
|
||||||
444726432079249540,
|
3186951873057035255,
|
||||||
-1587028029513790706,
|
5287129228749947252,
|
||||||
]),
|
444726432079249540,
|
||||||
i64x8::from([
|
-1587028029513790706,
|
||||||
6075236523189346388,
|
]),
|
||||||
1351763722368165432,
|
i64x8::from([
|
||||||
-6192309979959753740,
|
6075236523189346388,
|
||||||
-7697775502176768592,
|
1351763722368165432,
|
||||||
-4482022114172078123,
|
-6192309979959753740,
|
||||||
7522501477800909500,
|
-7697775502176768592,
|
||||||
-1837258847956201231,
|
-4482022114172078123,
|
||||||
-586926753024886735,
|
7522501477800909500,
|
||||||
]),
|
-1837258847956201231,
|
||||||
]);
|
-586926753024886735,
|
||||||
|
]),
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -110,10 +110,10 @@ pub mod hidden_export {
|
|||||||
pub mod uniform;
|
pub mod uniform;
|
||||||
|
|
||||||
pub use self::bernoulli::{Bernoulli, BernoulliError};
|
pub use self::bernoulli::{Bernoulli, BernoulliError};
|
||||||
pub use self::distribution::{Distribution, DistIter, DistMap};
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
||||||
pub use self::distribution::DistString;
|
pub use self::distribution::DistString;
|
||||||
|
pub use self::distribution::{DistIter, DistMap, Distribution};
|
||||||
pub use self::float::{Open01, OpenClosed01};
|
pub use self::float::{Open01, OpenClosed01};
|
||||||
pub use self::other::Alphanumeric;
|
pub use self::other::Alphanumeric;
|
||||||
pub use self::slice::Slice;
|
pub use self::slice::Slice;
|
||||||
|
@ -8,24 +8,23 @@
|
|||||||
|
|
||||||
//! The implementations of the `Standard` distribution for other built-in types.
|
//! The implementations of the `Standard` distribution for other built-in types.
|
||||||
|
|
||||||
use core::char;
|
|
||||||
use core::num::Wrapping;
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
use alloc::string::String;
|
use alloc::string::String;
|
||||||
|
use core::char;
|
||||||
|
use core::num::Wrapping;
|
||||||
|
|
||||||
use crate::distributions::{Distribution, Standard, Uniform};
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
use crate::distributions::DistString;
|
use crate::distributions::DistString;
|
||||||
|
use crate::distributions::{Distribution, Standard, Uniform};
|
||||||
use crate::Rng;
|
use crate::Rng;
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
|
||||||
use serde::{Serialize, Deserialize};
|
|
||||||
use core::mem::{self, MaybeUninit};
|
use core::mem::{self, MaybeUninit};
|
||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
use core::simd::prelude::*;
|
use core::simd::prelude::*;
|
||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
use core::simd::{LaneCount, MaskElement, SupportedLaneCount};
|
use core::simd::{LaneCount, MaskElement, SupportedLaneCount};
|
||||||
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
// ----- Sampling distributions -----
|
// ----- Sampling distributions -----
|
||||||
|
|
||||||
@ -71,7 +70,6 @@ use core::simd::{LaneCount, MaskElement, SupportedLaneCount};
|
|||||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||||
pub struct Alphanumeric;
|
pub struct Alphanumeric;
|
||||||
|
|
||||||
|
|
||||||
// ----- Implementations of distributions -----
|
// ----- Implementations of distributions -----
|
||||||
|
|
||||||
impl Distribution<char> for Standard {
|
impl Distribution<char> for Standard {
|
||||||
@ -240,7 +238,8 @@ macro_rules! tuple_impls {
|
|||||||
tuple_impls! {A B C D E F G H I J K L}
|
tuple_impls! {A B C D E F G H I J K L}
|
||||||
|
|
||||||
impl<T, const N: usize> Distribution<[T; N]> for Standard
|
impl<T, const N: usize> Distribution<[T; N]> for Standard
|
||||||
where Standard: Distribution<T>
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; N] {
|
fn sample<R: Rng + ?Sized>(&self, _rng: &mut R) -> [T; N] {
|
||||||
@ -255,7 +254,8 @@ where Standard: Distribution<T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Distribution<Option<T>> for Standard
|
impl<T> Distribution<Option<T>> for Standard
|
||||||
where Standard: Distribution<T>
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<T> {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Option<T> {
|
||||||
@ -269,7 +269,8 @@ where Standard: Distribution<T>
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<T> Distribution<Wrapping<T>> for Standard
|
impl<T> Distribution<Wrapping<T>> for Standard
|
||||||
where Standard: Distribution<T>
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
{
|
{
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Wrapping<T> {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Wrapping<T> {
|
||||||
@ -277,7 +278,6 @@ where Standard: Distribution<T>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
@ -315,9 +315,7 @@ mod tests {
|
|||||||
let mut incorrect = false;
|
let mut incorrect = false;
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
let c: char = rng.sample(Alphanumeric).into();
|
let c: char = rng.sample(Alphanumeric).into();
|
||||||
incorrect |= !(('0'..='9').contains(&c) ||
|
incorrect |= !c.is_ascii_alphanumeric();
|
||||||
('A'..='Z').contains(&c) ||
|
|
||||||
('a'..='z').contains(&c) );
|
|
||||||
}
|
}
|
||||||
assert!(!incorrect);
|
assert!(!incorrect);
|
||||||
}
|
}
|
||||||
@ -325,7 +323,9 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>(
|
fn test_samples<T: Copy + core::fmt::Debug + PartialEq, D: Distribution<T>>(
|
||||||
distr: &D, zero: T, expected: &[T],
|
distr: &D,
|
||||||
|
zero: T,
|
||||||
|
expected: &[T],
|
||||||
) {
|
) {
|
||||||
let mut rng = crate::test::rng(807);
|
let mut rng = crate::test::rng(807);
|
||||||
let mut buf = [zero; 5];
|
let mut buf = [zero; 5];
|
||||||
@ -335,54 +335,66 @@ mod tests {
|
|||||||
assert_eq!(&buf, expected);
|
assert_eq!(&buf, expected);
|
||||||
}
|
}
|
||||||
|
|
||||||
test_samples(&Standard, 'a', &[
|
test_samples(
|
||||||
'\u{8cdac}',
|
&Standard,
|
||||||
'\u{a346a}',
|
'a',
|
||||||
'\u{80120}',
|
&[
|
||||||
'\u{ed692}',
|
'\u{8cdac}',
|
||||||
'\u{35888}',
|
'\u{a346a}',
|
||||||
]);
|
'\u{80120}',
|
||||||
|
'\u{ed692}',
|
||||||
|
'\u{35888}',
|
||||||
|
],
|
||||||
|
);
|
||||||
test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]);
|
test_samples(&Alphanumeric, 0, &[104, 109, 101, 51, 77]);
|
||||||
test_samples(&Standard, false, &[true, true, false, true, false]);
|
test_samples(&Standard, false, &[true, true, false, true, false]);
|
||||||
test_samples(&Standard, None as Option<bool>, &[
|
test_samples(
|
||||||
Some(true),
|
&Standard,
|
||||||
None,
|
None as Option<bool>,
|
||||||
Some(false),
|
&[Some(true), None, Some(false), None, Some(false)],
|
||||||
None,
|
);
|
||||||
Some(false),
|
test_samples(
|
||||||
]);
|
&Standard,
|
||||||
test_samples(&Standard, Wrapping(0i32), &[
|
Wrapping(0i32),
|
||||||
Wrapping(-2074640887),
|
&[
|
||||||
Wrapping(-1719949321),
|
Wrapping(-2074640887),
|
||||||
Wrapping(2018088303),
|
Wrapping(-1719949321),
|
||||||
Wrapping(-547181756),
|
Wrapping(2018088303),
|
||||||
Wrapping(838957336),
|
Wrapping(-547181756),
|
||||||
]);
|
Wrapping(838957336),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
// We test only sub-sets of tuple and array impls
|
// We test only sub-sets of tuple and array impls
|
||||||
test_samples(&Standard, (), &[(), (), (), (), ()]);
|
test_samples(&Standard, (), &[(), (), (), (), ()]);
|
||||||
test_samples(&Standard, (false,), &[
|
test_samples(
|
||||||
(true,),
|
&Standard,
|
||||||
(true,),
|
|
||||||
(false,),
|
(false,),
|
||||||
(true,),
|
&[(true,), (true,), (false,), (true,), (false,)],
|
||||||
(false,),
|
);
|
||||||
]);
|
test_samples(
|
||||||
test_samples(&Standard, (false, false), &[
|
&Standard,
|
||||||
(true, true),
|
|
||||||
(false, true),
|
|
||||||
(false, false),
|
(false, false),
|
||||||
(true, false),
|
&[
|
||||||
(false, false),
|
(true, true),
|
||||||
]);
|
(false, true),
|
||||||
|
(false, false),
|
||||||
|
(true, false),
|
||||||
|
(false, false),
|
||||||
|
],
|
||||||
|
);
|
||||||
|
|
||||||
test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]);
|
test_samples(&Standard, [0u8; 0], &[[], [], [], [], []]);
|
||||||
test_samples(&Standard, [0u8; 3], &[
|
test_samples(
|
||||||
[9, 247, 111],
|
&Standard,
|
||||||
[68, 24, 13],
|
[0u8; 3],
|
||||||
[174, 19, 194],
|
&[
|
||||||
[172, 69, 213],
|
[9, 247, 111],
|
||||||
[149, 207, 29],
|
[68, 24, 13],
|
||||||
]);
|
[174, 19, 194],
|
||||||
|
[172, 69, 213],
|
||||||
|
[149, 207, 29],
|
||||||
|
],
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -148,7 +148,11 @@ impl<'a> super::DistString for Slice<'a, char> {
|
|||||||
|
|
||||||
// Split the extension of string to reuse the unused capacities.
|
// Split the extension of string to reuse the unused capacities.
|
||||||
// Skip the split for small length or only ascii slice.
|
// Skip the split for small length or only ascii slice.
|
||||||
let mut extend_len = if max_char_len == 1 || len < 100 { len } else { len / 4 };
|
let mut extend_len = if max_char_len == 1 || len < 100 {
|
||||||
|
len
|
||||||
|
} else {
|
||||||
|
len / 4
|
||||||
|
};
|
||||||
let mut remain_len = len;
|
let mut remain_len = len;
|
||||||
while extend_len > 0 {
|
while extend_len > 0 {
|
||||||
string.reserve(max_char_len * extend_len);
|
string.reserve(max_char_len * extend_len);
|
||||||
|
@ -104,18 +104,22 @@
|
|||||||
//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow
|
//! [`SampleBorrow::borrow`]: crate::distributions::uniform::SampleBorrow::borrow
|
||||||
|
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
use core::time::Duration;
|
|
||||||
use core::ops::{Range, RangeInclusive};
|
use core::ops::{Range, RangeInclusive};
|
||||||
|
use core::time::Duration;
|
||||||
|
|
||||||
use crate::distributions::float::IntoFloat;
|
use crate::distributions::float::IntoFloat;
|
||||||
use crate::distributions::utils::{BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply};
|
use crate::distributions::utils::{
|
||||||
|
BoolAsSIMD, FloatAsSIMD, FloatSIMDUtils, IntAsSIMD, WideningMultiply,
|
||||||
|
};
|
||||||
use crate::distributions::Distribution;
|
use crate::distributions::Distribution;
|
||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
use crate::distributions::Standard;
|
use crate::distributions::Standard;
|
||||||
use crate::{Rng, RngCore};
|
use crate::{Rng, RngCore};
|
||||||
|
|
||||||
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
|
#[cfg(feature = "simd_support")]
|
||||||
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SupportedLaneCount};
|
use core::simd::prelude::*;
|
||||||
|
#[cfg(feature = "simd_support")]
|
||||||
|
use core::simd::{LaneCount, SupportedLaneCount};
|
||||||
|
|
||||||
/// Error type returned from [`Uniform::new`] and `new_inclusive`.
|
/// Error type returned from [`Uniform::new`] and `new_inclusive`.
|
||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
@ -140,7 +144,7 @@ impl fmt::Display for Error {
|
|||||||
impl std::error::Error for Error {}
|
impl std::error::Error for Error {}
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// Sample values uniformly between two bounds.
|
/// Sample values uniformly between two bounds.
|
||||||
///
|
///
|
||||||
@ -194,7 +198,10 @@ use serde::{Serialize, Deserialize};
|
|||||||
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
|
||||||
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
|
||||||
#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))]
|
#[cfg_attr(feature = "serde1", serde(bound(serialize = "X::Sampler: Serialize")))]
|
||||||
#[cfg_attr(feature = "serde1", serde(bound(deserialize = "X::Sampler: Deserialize<'de>")))]
|
#[cfg_attr(
|
||||||
|
feature = "serde1",
|
||||||
|
serde(bound(deserialize = "X::Sampler: Deserialize<'de>"))
|
||||||
|
)]
|
||||||
pub struct Uniform<X: SampleUniform>(X::Sampler);
|
pub struct Uniform<X: SampleUniform>(X::Sampler);
|
||||||
|
|
||||||
impl<X: SampleUniform> Uniform<X> {
|
impl<X: SampleUniform> Uniform<X> {
|
||||||
@ -297,7 +304,11 @@ pub trait UniformSampler: Sized {
|
|||||||
/// <T as SampleUniform>::Sampler::sample_single(lb, ub, &mut rng).unwrap()
|
/// <T as SampleUniform>::Sampler::sample_single(lb, ub, &mut rng).unwrap()
|
||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
fn sample_single<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R) -> Result<Self::X, Error>
|
fn sample_single<R: Rng + ?Sized, B1, B2>(
|
||||||
|
low: B1,
|
||||||
|
high: B2,
|
||||||
|
rng: &mut R,
|
||||||
|
) -> Result<Self::X, Error>
|
||||||
where
|
where
|
||||||
B1: SampleBorrow<Self::X> + Sized,
|
B1: SampleBorrow<Self::X> + Sized,
|
||||||
B2: SampleBorrow<Self::X> + Sized,
|
B2: SampleBorrow<Self::X> + Sized,
|
||||||
@ -314,10 +325,14 @@ pub trait UniformSampler: Sized {
|
|||||||
/// some types more optimal implementations for single usage may be provided
|
/// some types more optimal implementations for single usage may be provided
|
||||||
/// via this method.
|
/// via this method.
|
||||||
/// Results may not be identical.
|
/// Results may not be identical.
|
||||||
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(low: B1, high: B2, rng: &mut R)
|
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
|
||||||
-> Result<Self::X, Error>
|
low: B1,
|
||||||
where B1: SampleBorrow<Self::X> + Sized,
|
high: B2,
|
||||||
B2: SampleBorrow<Self::X> + Sized
|
rng: &mut R,
|
||||||
|
) -> Result<Self::X, Error>
|
||||||
|
where
|
||||||
|
B1: SampleBorrow<Self::X> + Sized,
|
||||||
|
B2: SampleBorrow<Self::X> + Sized,
|
||||||
{
|
{
|
||||||
let uniform: Self = UniformSampler::new_inclusive(low, high)?;
|
let uniform: Self = UniformSampler::new_inclusive(low, high)?;
|
||||||
Ok(uniform.sample(rng))
|
Ok(uniform.sample(rng))
|
||||||
@ -340,7 +355,6 @@ impl<X: SampleUniform> TryFrom<RangeInclusive<X>> for Uniform<X> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Helper trait similar to [`Borrow`] but implemented
|
/// Helper trait similar to [`Borrow`] but implemented
|
||||||
/// only for SampleUniform and references to SampleUniform in
|
/// only for SampleUniform and references to SampleUniform in
|
||||||
/// order to resolve ambiguity issues.
|
/// order to resolve ambiguity issues.
|
||||||
@ -353,7 +367,8 @@ pub trait SampleBorrow<Borrowed> {
|
|||||||
fn borrow(&self) -> &Borrowed;
|
fn borrow(&self) -> &Borrowed;
|
||||||
}
|
}
|
||||||
impl<Borrowed> SampleBorrow<Borrowed> for Borrowed
|
impl<Borrowed> SampleBorrow<Borrowed> for Borrowed
|
||||||
where Borrowed: SampleUniform
|
where
|
||||||
|
Borrowed: SampleUniform,
|
||||||
{
|
{
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn borrow(&self) -> &Borrowed {
|
fn borrow(&self) -> &Borrowed {
|
||||||
@ -361,7 +376,8 @@ where Borrowed: SampleUniform
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed
|
impl<'a, Borrowed> SampleBorrow<Borrowed> for &'a Borrowed
|
||||||
where Borrowed: SampleUniform
|
where
|
||||||
|
Borrowed: SampleUniform,
|
||||||
{
|
{
|
||||||
#[inline(always)]
|
#[inline(always)]
|
||||||
fn borrow(&self) -> &Borrowed {
|
fn borrow(&self) -> &Borrowed {
|
||||||
@ -405,12 +421,10 @@ impl<T: SampleUniform + PartialOrd> SampleRange<T> for RangeInclusive<T> {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
////////////////////////////////////////////////////////////////////////////////
|
////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
||||||
// What follows are all back-ends.
|
// What follows are all back-ends.
|
||||||
|
|
||||||
|
|
||||||
/// The back-end implementing [`UniformSampler`] for integer types.
|
/// The back-end implementing [`UniformSampler`] for integer types.
|
||||||
///
|
///
|
||||||
/// Unless you are implementing [`UniformSampler`] for your own type, this type
|
/// Unless you are implementing [`UniformSampler`] for your own type, this type
|
||||||
@ -505,7 +519,7 @@ macro_rules! uniform_int_impl {
|
|||||||
|
|
||||||
Ok(UniformInt {
|
Ok(UniformInt {
|
||||||
low,
|
low,
|
||||||
range: range as $ty, // type: $uty
|
range: range as $ty, // type: $uty
|
||||||
thresh: thresh as $uty as $ty, // type: $sample_ty
|
thresh: thresh as $uty as $ty, // type: $sample_ty
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@ -529,7 +543,11 @@ macro_rules! uniform_int_impl {
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample_single<R: Rng + ?Sized, B1, B2>(low_b: B1, high_b: B2, rng: &mut R) -> Result<Self::X, Error>
|
fn sample_single<R: Rng + ?Sized, B1, B2>(
|
||||||
|
low_b: B1,
|
||||||
|
high_b: B2,
|
||||||
|
rng: &mut R,
|
||||||
|
) -> Result<Self::X, Error>
|
||||||
where
|
where
|
||||||
B1: SampleBorrow<Self::X> + Sized,
|
B1: SampleBorrow<Self::X> + Sized,
|
||||||
B2: SampleBorrow<Self::X> + Sized,
|
B2: SampleBorrow<Self::X> + Sized,
|
||||||
@ -549,7 +567,9 @@ macro_rules! uniform_int_impl {
|
|||||||
#[cfg(not(feature = "unbiased"))]
|
#[cfg(not(feature = "unbiased"))]
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
|
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
|
||||||
low_b: B1, high_b: B2, rng: &mut R,
|
low_b: B1,
|
||||||
|
high_b: B2,
|
||||||
|
rng: &mut R,
|
||||||
) -> Result<Self::X, Error>
|
) -> Result<Self::X, Error>
|
||||||
where
|
where
|
||||||
B1: SampleBorrow<Self::X> + Sized,
|
B1: SampleBorrow<Self::X> + Sized,
|
||||||
@ -585,7 +605,9 @@ macro_rules! uniform_int_impl {
|
|||||||
#[cfg(feature = "unbiased")]
|
#[cfg(feature = "unbiased")]
|
||||||
#[inline]
|
#[inline]
|
||||||
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
|
fn sample_single_inclusive<R: Rng + ?Sized, B1, B2>(
|
||||||
low_b: B1, high_b: B2, rng: &mut R,
|
low_b: B1,
|
||||||
|
high_b: B2,
|
||||||
|
rng: &mut R,
|
||||||
) -> Result<Self::X, Error>
|
) -> Result<Self::X, Error>
|
||||||
where
|
where
|
||||||
B1: SampleBorrow<$ty> + Sized,
|
B1: SampleBorrow<$ty> + Sized,
|
||||||
@ -599,7 +621,7 @@ macro_rules! uniform_int_impl {
|
|||||||
let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty;
|
let range = high.wrapping_sub(low).wrapping_add(1) as $uty as $sample_ty;
|
||||||
if range == 0 {
|
if range == 0 {
|
||||||
// Range is MAX+1 (unrepresentable), so we need a special case
|
// Range is MAX+1 (unrepresentable), so we need a special case
|
||||||
return Ok(rng.gen());
|
return Ok(rng.random());
|
||||||
}
|
}
|
||||||
|
|
||||||
let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range);
|
let (mut result, mut lo) = rng.random::<$sample_ty>().wmul(range);
|
||||||
@ -844,7 +866,12 @@ impl UniformSampler for UniformChar {
|
|||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
||||||
impl super::DistString for Uniform<char> {
|
impl super::DistString for Uniform<char> {
|
||||||
fn append_string<R: Rng + ?Sized>(&self, rng: &mut R, string: &mut alloc::string::String, len: usize) {
|
fn append_string<R: Rng + ?Sized>(
|
||||||
|
&self,
|
||||||
|
rng: &mut R,
|
||||||
|
string: &mut alloc::string::String,
|
||||||
|
len: usize,
|
||||||
|
) {
|
||||||
// Getting the hi value to assume the required length to reserve in string.
|
// Getting the hi value to assume the required length to reserve in string.
|
||||||
let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
|
let mut hi = self.0.sampler.low + self.0.sampler.range - 1;
|
||||||
if hi >= CHAR_SURROGATE_START {
|
if hi >= CHAR_SURROGATE_START {
|
||||||
@ -911,7 +938,7 @@ macro_rules! uniform_float_impl {
|
|||||||
return Err(Error::EmptyRange);
|
return Err(Error::EmptyRange);
|
||||||
}
|
}
|
||||||
let max_rand = <$ty>::splat(
|
let max_rand = <$ty>::splat(
|
||||||
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
|
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut scale = high - low;
|
let mut scale = high - low;
|
||||||
@ -947,7 +974,7 @@ macro_rules! uniform_float_impl {
|
|||||||
return Err(Error::EmptyRange);
|
return Err(Error::EmptyRange);
|
||||||
}
|
}
|
||||||
let max_rand = <$ty>::splat(
|
let max_rand = <$ty>::splat(
|
||||||
(::core::$u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
|
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut scale = (high - low) / max_rand;
|
let mut scale = (high - low) / max_rand;
|
||||||
@ -1111,7 +1138,6 @@ uniform_float_impl! { feature = "simd_support", f64x4, u64x4, f64, u64, 64 - 52
|
|||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 }
|
uniform_float_impl! { feature = "simd_support", f64x8, u64x8, f64, u64, 64 - 52 }
|
||||||
|
|
||||||
|
|
||||||
/// The back-end implementing [`UniformSampler`] for `Duration`.
|
/// The back-end implementing [`UniformSampler`] for `Duration`.
|
||||||
///
|
///
|
||||||
/// Unless you are implementing [`UniformSampler`] for your own types, this type
|
/// Unless you are implementing [`UniformSampler`] for your own types, this type
|
||||||
@ -1248,26 +1274,29 @@ impl UniformSampler for UniformDuration {
|
|||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use crate::rngs::mock::StepRng;
|
|
||||||
use crate::distributions::utils::FloatSIMDScalarUtils;
|
use crate::distributions::utils::FloatSIMDScalarUtils;
|
||||||
|
use crate::rngs::mock::StepRng;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
fn test_serialization_uniform_duration() {
|
fn test_serialization_uniform_duration() {
|
||||||
let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap();
|
let distr = UniformDuration::new(Duration::from_secs(10), Duration::from_secs(60)).unwrap();
|
||||||
let de_distr: UniformDuration = bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
|
let de_distr: UniformDuration =
|
||||||
|
bincode::deserialize(&bincode::serialize(&distr).unwrap()).unwrap();
|
||||||
assert_eq!(distr, de_distr);
|
assert_eq!(distr, de_distr);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
fn test_uniform_serialization() {
|
fn test_uniform_serialization() {
|
||||||
let unit_box: Uniform<i32> = Uniform::new(-1, 1).unwrap();
|
let unit_box: Uniform<i32> = Uniform::new(-1, 1).unwrap();
|
||||||
let de_unit_box: Uniform<i32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();
|
let de_unit_box: Uniform<i32> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();
|
||||||
assert_eq!(unit_box.0, de_unit_box.0);
|
assert_eq!(unit_box.0, de_unit_box.0);
|
||||||
|
|
||||||
let unit_box: Uniform<f32> = Uniform::new(-1., 1.).unwrap();
|
let unit_box: Uniform<f32> = Uniform::new(-1., 1.).unwrap();
|
||||||
let de_unit_box: Uniform<f32> = bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();
|
let de_unit_box: Uniform<f32> =
|
||||||
|
bincode::deserialize(&bincode::serialize(&unit_box).unwrap()).unwrap();
|
||||||
assert_eq!(unit_box.0, de_unit_box.0);
|
assert_eq!(unit_box.0, de_unit_box.0);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -1293,10 +1322,6 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(miri, ignore)] // Miri is too slow
|
#[cfg_attr(miri, ignore)] // Miri is too slow
|
||||||
fn test_integers() {
|
fn test_integers() {
|
||||||
use core::{i128, u128};
|
|
||||||
use core::{i16, i32, i64, i8, isize};
|
|
||||||
use core::{u16, u32, u64, u8, usize};
|
|
||||||
|
|
||||||
let mut rng = crate::test::rng(251);
|
let mut rng = crate::test::rng(251);
|
||||||
macro_rules! t {
|
macro_rules! t {
|
||||||
($ty:ident, $v:expr, $le:expr, $lt:expr) => {{
|
($ty:ident, $v:expr, $le:expr, $lt:expr) => {{
|
||||||
@ -1383,14 +1408,15 @@ mod tests {
|
|||||||
let mut max = core::char::from_u32(0).unwrap();
|
let mut max = core::char::from_u32(0).unwrap();
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
let c = rng.gen_range('A'..='Z');
|
let c = rng.gen_range('A'..='Z');
|
||||||
assert!(('A'..='Z').contains(&c));
|
assert!(c.is_ascii_uppercase());
|
||||||
max = max.max(c);
|
max = max.max(c);
|
||||||
}
|
}
|
||||||
assert_eq!(max, 'Z');
|
assert_eq!(max, 'Z');
|
||||||
let d = Uniform::new(
|
let d = Uniform::new(
|
||||||
core::char::from_u32(0xD7F0).unwrap(),
|
core::char::from_u32(0xD7F0).unwrap(),
|
||||||
core::char::from_u32(0xE010).unwrap(),
|
core::char::from_u32(0xE010).unwrap(),
|
||||||
).unwrap();
|
)
|
||||||
|
.unwrap();
|
||||||
for _ in 0..100 {
|
for _ in 0..100 {
|
||||||
let c = d.sample(&mut rng);
|
let c = d.sample(&mut rng);
|
||||||
assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
|
assert!((c as u32) < 0xD800 || (c as u32) > 0xDFFF);
|
||||||
@ -1403,12 +1429,16 @@ mod tests {
|
|||||||
let string2 = Uniform::new(
|
let string2 = Uniform::new(
|
||||||
core::char::from_u32(0x0000).unwrap(),
|
core::char::from_u32(0x0000).unwrap(),
|
||||||
core::char::from_u32(0x0080).unwrap(),
|
core::char::from_u32(0x0080).unwrap(),
|
||||||
).unwrap().sample_string(&mut rng, 100);
|
)
|
||||||
|
.unwrap()
|
||||||
|
.sample_string(&mut rng, 100);
|
||||||
assert_eq!(string2.capacity(), 100);
|
assert_eq!(string2.capacity(), 100);
|
||||||
let string3 = Uniform::new_inclusive(
|
let string3 = Uniform::new_inclusive(
|
||||||
core::char::from_u32(0x0000).unwrap(),
|
core::char::from_u32(0x0000).unwrap(),
|
||||||
core::char::from_u32(0x0080).unwrap(),
|
core::char::from_u32(0x0080).unwrap(),
|
||||||
).unwrap().sample_string(&mut rng, 100);
|
)
|
||||||
|
.unwrap()
|
||||||
|
.sample_string(&mut rng, 100);
|
||||||
assert_eq!(string3.capacity(), 200);
|
assert_eq!(string3.capacity(), 200);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -1430,8 +1460,8 @@ mod tests {
|
|||||||
(-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)),
|
(-<$f_scalar>::from_bits(10), -<$f_scalar>::from_bits(1)),
|
||||||
(-<$f_scalar>::from_bits(5), 0.0),
|
(-<$f_scalar>::from_bits(5), 0.0),
|
||||||
(-<$f_scalar>::from_bits(7), -0.0),
|
(-<$f_scalar>::from_bits(7), -0.0),
|
||||||
(0.1 * ::core::$f_scalar::MAX, ::core::$f_scalar::MAX),
|
(0.1 * $f_scalar::MAX, $f_scalar::MAX),
|
||||||
(-::core::$f_scalar::MAX * 0.2, ::core::$f_scalar::MAX * 0.7),
|
(-$f_scalar::MAX * 0.2, $f_scalar::MAX * 0.7),
|
||||||
];
|
];
|
||||||
for &(low_scalar, high_scalar) in v.iter() {
|
for &(low_scalar, high_scalar) in v.iter() {
|
||||||
for lane in 0..<$ty>::LEN {
|
for lane in 0..<$ty>::LEN {
|
||||||
@ -1444,27 +1474,47 @@ mod tests {
|
|||||||
assert!(low_scalar <= v && v < high_scalar);
|
assert!(low_scalar <= v && v < high_scalar);
|
||||||
let v = rng.sample(my_incl_uniform).extract(lane);
|
let v = rng.sample(my_incl_uniform).extract(lane);
|
||||||
assert!(low_scalar <= v && v <= high_scalar);
|
assert!(low_scalar <= v && v <= high_scalar);
|
||||||
let v = <$ty as SampleUniform>::Sampler
|
let v =
|
||||||
::sample_single(low, high, &mut rng).unwrap().extract(lane);
|
<$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
|
||||||
|
.unwrap()
|
||||||
|
.extract(lane);
|
||||||
assert!(low_scalar <= v && v < high_scalar);
|
assert!(low_scalar <= v && v < high_scalar);
|
||||||
let v = <$ty as SampleUniform>::Sampler
|
let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
|
||||||
::sample_single_inclusive(low, high, &mut rng).unwrap().extract(lane);
|
low, high, &mut rng,
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.extract(lane);
|
||||||
assert!(low_scalar <= v && v <= high_scalar);
|
assert!(low_scalar <= v && v <= high_scalar);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rng.sample(Uniform::new_inclusive(low, low).unwrap()).extract(lane),
|
rng.sample(Uniform::new_inclusive(low, low).unwrap())
|
||||||
|
.extract(lane),
|
||||||
low_scalar
|
low_scalar
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar);
|
assert_eq!(zero_rng.sample(my_uniform).extract(lane), low_scalar);
|
||||||
assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar);
|
assert_eq!(zero_rng.sample(my_incl_uniform).extract(lane), low_scalar);
|
||||||
assert_eq!(<$ty as SampleUniform>::Sampler
|
assert_eq!(
|
||||||
::sample_single(low, high, &mut zero_rng).unwrap()
|
<$ty as SampleUniform>::Sampler::sample_single(
|
||||||
.extract(lane), low_scalar);
|
low,
|
||||||
assert_eq!(<$ty as SampleUniform>::Sampler
|
high,
|
||||||
::sample_single_inclusive(low, high, &mut zero_rng).unwrap()
|
&mut zero_rng
|
||||||
.extract(lane), low_scalar);
|
)
|
||||||
|
.unwrap()
|
||||||
|
.extract(lane),
|
||||||
|
low_scalar
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
|
||||||
|
low,
|
||||||
|
high,
|
||||||
|
&mut zero_rng
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.extract(lane),
|
||||||
|
low_scalar
|
||||||
|
);
|
||||||
|
|
||||||
assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
|
assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
|
||||||
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
|
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
|
||||||
@ -1472,9 +1522,16 @@ mod tests {
|
|||||||
// assert!(<$ty as SampleUniform>::Sampler
|
// assert!(<$ty as SampleUniform>::Sampler
|
||||||
// ::sample_single(low, high, &mut max_rng).unwrap()
|
// ::sample_single(low, high, &mut max_rng).unwrap()
|
||||||
// .extract(lane) < high_scalar);
|
// .extract(lane) < high_scalar);
|
||||||
assert!(<$ty as SampleUniform>::Sampler
|
assert!(
|
||||||
::sample_single_inclusive(low, high, &mut max_rng).unwrap()
|
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
|
||||||
.extract(lane) <= high_scalar);
|
low,
|
||||||
|
high,
|
||||||
|
&mut max_rng
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.extract(lane)
|
||||||
|
<= high_scalar
|
||||||
|
);
|
||||||
|
|
||||||
// Don't run this test for really tiny differences between high and low
|
// Don't run this test for really tiny differences between high and low
|
||||||
// since for those rounding might result in selecting high for a very
|
// since for those rounding might result in selecting high for a very
|
||||||
@ -1485,27 +1542,26 @@ mod tests {
|
|||||||
(-1i64 << $bits_shifted) as u64,
|
(-1i64 << $bits_shifted) as u64,
|
||||||
);
|
);
|
||||||
assert!(
|
assert!(
|
||||||
<$ty as SampleUniform>::Sampler
|
<$ty as SampleUniform>::Sampler::sample_single(
|
||||||
::sample_single(low, high, &mut lowering_max_rng).unwrap()
|
low,
|
||||||
.extract(lane) < high_scalar
|
high,
|
||||||
|
&mut lowering_max_rng
|
||||||
|
)
|
||||||
|
.unwrap()
|
||||||
|
.extract(lane)
|
||||||
|
< high_scalar
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rng.sample(Uniform::new_inclusive(
|
rng.sample(Uniform::new_inclusive($f_scalar::MAX, $f_scalar::MAX).unwrap()),
|
||||||
::core::$f_scalar::MAX,
|
$f_scalar::MAX
|
||||||
::core::$f_scalar::MAX
|
|
||||||
).unwrap()),
|
|
||||||
::core::$f_scalar::MAX
|
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
rng.sample(Uniform::new_inclusive(
|
rng.sample(Uniform::new_inclusive(-$f_scalar::MAX, -$f_scalar::MAX).unwrap()),
|
||||||
-::core::$f_scalar::MAX,
|
-$f_scalar::MAX
|
||||||
-::core::$f_scalar::MAX
|
|
||||||
).unwrap()),
|
|
||||||
-::core::$f_scalar::MAX
|
|
||||||
);
|
);
|
||||||
}};
|
}};
|
||||||
}
|
}
|
||||||
@ -1549,21 +1605,18 @@ mod tests {
|
|||||||
macro_rules! t {
|
macro_rules! t {
|
||||||
($ty:ident, $f_scalar:ident) => {{
|
($ty:ident, $f_scalar:ident) => {{
|
||||||
let v: &[($f_scalar, $f_scalar)] = &[
|
let v: &[($f_scalar, $f_scalar)] = &[
|
||||||
(::std::$f_scalar::NAN, 0.0),
|
($f_scalar::NAN, 0.0),
|
||||||
(1.0, ::std::$f_scalar::NAN),
|
(1.0, $f_scalar::NAN),
|
||||||
(::std::$f_scalar::NAN, ::std::$f_scalar::NAN),
|
($f_scalar::NAN, $f_scalar::NAN),
|
||||||
(1.0, 0.5),
|
(1.0, 0.5),
|
||||||
(::std::$f_scalar::MAX, -::std::$f_scalar::MAX),
|
($f_scalar::MAX, -$f_scalar::MAX),
|
||||||
(::std::$f_scalar::INFINITY, ::std::$f_scalar::INFINITY),
|
($f_scalar::INFINITY, $f_scalar::INFINITY),
|
||||||
(
|
($f_scalar::NEG_INFINITY, $f_scalar::NEG_INFINITY),
|
||||||
::std::$f_scalar::NEG_INFINITY,
|
($f_scalar::NEG_INFINITY, 5.0),
|
||||||
::std::$f_scalar::NEG_INFINITY,
|
(5.0, $f_scalar::INFINITY),
|
||||||
),
|
($f_scalar::NAN, $f_scalar::INFINITY),
|
||||||
(::std::$f_scalar::NEG_INFINITY, 5.0),
|
($f_scalar::NEG_INFINITY, $f_scalar::NAN),
|
||||||
(5.0, ::std::$f_scalar::INFINITY),
|
($f_scalar::NEG_INFINITY, $f_scalar::INFINITY),
|
||||||
(::std::$f_scalar::NAN, ::std::$f_scalar::INFINITY),
|
|
||||||
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::NAN),
|
|
||||||
(::std::$f_scalar::NEG_INFINITY, ::std::$f_scalar::INFINITY),
|
|
||||||
];
|
];
|
||||||
for &(low_scalar, high_scalar) in v.iter() {
|
for &(low_scalar, high_scalar) in v.iter() {
|
||||||
for lane in 0..<$ty>::LEN {
|
for lane in 0..<$ty>::LEN {
|
||||||
@ -1593,7 +1646,6 @@ mod tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg_attr(miri, ignore)] // Miri is too slow
|
#[cfg_attr(miri, ignore)] // Miri is too slow
|
||||||
fn test_durations() {
|
fn test_durations() {
|
||||||
@ -1602,10 +1654,7 @@ mod tests {
|
|||||||
let v = &[
|
let v = &[
|
||||||
(Duration::new(10, 50000), Duration::new(100, 1234)),
|
(Duration::new(10, 50000), Duration::new(100, 1234)),
|
||||||
(Duration::new(0, 100), Duration::new(1, 50)),
|
(Duration::new(0, 100), Duration::new(1, 50)),
|
||||||
(
|
(Duration::new(0, 0), Duration::new(u64::MAX, 999_999_999)),
|
||||||
Duration::new(0, 0),
|
|
||||||
Duration::new(u64::MAX, 999_999_999),
|
|
||||||
),
|
|
||||||
];
|
];
|
||||||
for &(low, high) in v.iter() {
|
for &(low, high) in v.iter() {
|
||||||
let my_uniform = Uniform::new(low, high).unwrap();
|
let my_uniform = Uniform::new(low, high).unwrap();
|
||||||
@ -1707,8 +1756,13 @@ mod tests {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<T: SampleUniform + Copy + fmt::Debug + PartialEq>(
|
fn test_samples<T: SampleUniform + Copy + fmt::Debug + PartialEq>(
|
||||||
lb: T, ub: T, expected_single: &[T], expected_multiple: &[T],
|
lb: T,
|
||||||
) where Uniform<T>: Distribution<T> {
|
ub: T,
|
||||||
|
expected_single: &[T],
|
||||||
|
expected_multiple: &[T],
|
||||||
|
) where
|
||||||
|
Uniform<T>: Distribution<T>,
|
||||||
|
{
|
||||||
let mut rng = crate::test::rng(897);
|
let mut rng = crate::test::rng(897);
|
||||||
let mut buf = [lb; 3];
|
let mut buf = [lb; 3];
|
||||||
|
|
||||||
@ -1730,11 +1784,12 @@ mod tests {
|
|||||||
test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]);
|
test_samples(11u8, 219, &[17, 66, 214], &[181, 93, 165]);
|
||||||
test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]);
|
test_samples(11u32, 219, &[17, 66, 214], &[181, 93, 165]);
|
||||||
|
|
||||||
test_samples(0f32, 1e-2f32, &[0.0003070104, 0.0026630748, 0.00979833], &[
|
test_samples(
|
||||||
0.008194133,
|
0f32,
|
||||||
0.00398172,
|
1e-2f32,
|
||||||
0.007428536,
|
&[0.0003070104, 0.0026630748, 0.00979833],
|
||||||
]);
|
&[0.008194133, 0.00398172, 0.007428536],
|
||||||
|
);
|
||||||
test_samples(
|
test_samples(
|
||||||
-1e10f64,
|
-1e10f64,
|
||||||
1e10f64,
|
1e10f64,
|
||||||
@ -1760,9 +1815,15 @@ mod tests {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn uniform_distributions_can_be_compared() {
|
fn uniform_distributions_can_be_compared() {
|
||||||
assert_eq!(Uniform::new(1.0, 2.0).unwrap(), Uniform::new(1.0, 2.0).unwrap());
|
assert_eq!(
|
||||||
|
Uniform::new(1.0, 2.0).unwrap(),
|
||||||
|
Uniform::new(1.0, 2.0).unwrap()
|
||||||
|
);
|
||||||
|
|
||||||
// To cover UniformInt
|
// To cover UniformInt
|
||||||
assert_eq!(Uniform::new(1_u32, 2_u32).unwrap(), Uniform::new(1_u32, 2_u32).unwrap());
|
assert_eq!(
|
||||||
|
Uniform::new(1_u32, 2_u32).unwrap(),
|
||||||
|
Uniform::new(1_u32, 2_u32).unwrap()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -8,9 +8,10 @@
|
|||||||
|
|
||||||
//! Math helper functions
|
//! Math helper functions
|
||||||
|
|
||||||
#[cfg(feature = "simd_support")] use core::simd::prelude::*;
|
#[cfg(feature = "simd_support")]
|
||||||
#[cfg(feature = "simd_support")] use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
|
use core::simd::prelude::*;
|
||||||
|
#[cfg(feature = "simd_support")]
|
||||||
|
use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
|
||||||
|
|
||||||
pub(crate) trait WideningMultiply<RHS = Self> {
|
pub(crate) trait WideningMultiply<RHS = Self> {
|
||||||
type Output;
|
type Output;
|
||||||
@ -146,8 +147,10 @@ wmul_impl_usize! { u64 }
|
|||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
mod simd_wmul {
|
mod simd_wmul {
|
||||||
use super::*;
|
use super::*;
|
||||||
#[cfg(target_arch = "x86")] use core::arch::x86::*;
|
#[cfg(target_arch = "x86")]
|
||||||
#[cfg(target_arch = "x86_64")] use core::arch::x86_64::*;
|
use core::arch::x86::*;
|
||||||
|
#[cfg(target_arch = "x86_64")]
|
||||||
|
use core::arch::x86_64::*;
|
||||||
|
|
||||||
wmul_impl! {
|
wmul_impl! {
|
||||||
(u8x4, u16x4),
|
(u8x4, u16x4),
|
||||||
@ -340,12 +343,12 @@ macro_rules! scalar_float_impl {
|
|||||||
scalar_float_impl!(f32, u32);
|
scalar_float_impl!(f32, u32);
|
||||||
scalar_float_impl!(f64, u64);
|
scalar_float_impl!(f64, u64);
|
||||||
|
|
||||||
|
|
||||||
#[cfg(feature = "simd_support")]
|
#[cfg(feature = "simd_support")]
|
||||||
macro_rules! simd_impl {
|
macro_rules! simd_impl {
|
||||||
($fty:ident, $uty:ident) => {
|
($fty:ident, $uty:ident) => {
|
||||||
impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES>
|
impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES>
|
||||||
where LaneCount<LANES>: SupportedLaneCount
|
where
|
||||||
|
LaneCount<LANES>: SupportedLaneCount,
|
||||||
{
|
{
|
||||||
type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
|
type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
|
||||||
type UInt = Simd<$uty, LANES>;
|
type UInt = Simd<$uty, LANES>;
|
||||||
|
@ -108,7 +108,11 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
|
|||||||
X: Weight,
|
X: Weight,
|
||||||
{
|
{
|
||||||
let mut iter = weights.into_iter();
|
let mut iter = weights.into_iter();
|
||||||
let mut total_weight: X = iter.next().ok_or(WeightError::InvalidInput)?.borrow().clone();
|
let mut total_weight: X = iter
|
||||||
|
.next()
|
||||||
|
.ok_or(WeightError::InvalidInput)?
|
||||||
|
.borrow()
|
||||||
|
.clone();
|
||||||
|
|
||||||
let zero = X::ZERO;
|
let zero = X::ZERO;
|
||||||
if !(total_weight >= zero) {
|
if !(total_weight >= zero) {
|
||||||
@ -252,9 +256,9 @@ pub struct WeightedIndexIter<'a, X: SampleUniform + PartialOrd> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, X> Debug for WeightedIndexIter<'a, X>
|
impl<'a, X> Debug for WeightedIndexIter<'a, X>
|
||||||
where
|
where
|
||||||
X: SampleUniform + PartialOrd + Debug,
|
X: SampleUniform + PartialOrd + Debug,
|
||||||
X::Sampler: Debug,
|
X::Sampler: Debug,
|
||||||
{
|
{
|
||||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
f.debug_struct("WeightedIndexIter")
|
f.debug_struct("WeightedIndexIter")
|
||||||
@ -278,10 +282,7 @@ where
|
|||||||
|
|
||||||
impl<'a, X> Iterator for WeightedIndexIter<'a, X>
|
impl<'a, X> Iterator for WeightedIndexIter<'a, X>
|
||||||
where
|
where
|
||||||
X: for<'b> core::ops::SubAssign<&'b X>
|
X: for<'b> core::ops::SubAssign<&'b X> + SampleUniform + PartialOrd + Clone,
|
||||||
+ SampleUniform
|
|
||||||
+ PartialOrd
|
|
||||||
+ Clone,
|
|
||||||
{
|
{
|
||||||
type Item = X;
|
type Item = X;
|
||||||
|
|
||||||
@ -315,15 +316,16 @@ impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn weight(&self, index: usize) -> Option<X>
|
pub fn weight(&self, index: usize) -> Option<X>
|
||||||
where
|
where
|
||||||
X: for<'a> core::ops::SubAssign<&'a X>
|
X: for<'a> core::ops::SubAssign<&'a X>,
|
||||||
{
|
{
|
||||||
let mut weight = if index < self.cumulative_weights.len() {
|
use core::cmp::Ordering::*;
|
||||||
self.cumulative_weights[index].clone()
|
|
||||||
} else if index == self.cumulative_weights.len() {
|
let mut weight = match index.cmp(&self.cumulative_weights.len()) {
|
||||||
self.total_weight.clone()
|
Less => self.cumulative_weights[index].clone(),
|
||||||
} else {
|
Equal => self.total_weight.clone(),
|
||||||
return None;
|
Greater => return None,
|
||||||
};
|
};
|
||||||
|
|
||||||
if index > 0 {
|
if index > 0 {
|
||||||
weight -= &self.cumulative_weights[index - 1];
|
weight -= &self.cumulative_weights[index - 1];
|
||||||
}
|
}
|
||||||
@ -348,7 +350,7 @@ impl<X: SampleUniform + PartialOrd + Clone> WeightedIndex<X> {
|
|||||||
/// ```
|
/// ```
|
||||||
pub fn weights(&self) -> WeightedIndexIter<'_, X>
|
pub fn weights(&self) -> WeightedIndexIter<'_, X>
|
||||||
where
|
where
|
||||||
X: for<'a> core::ops::SubAssign<&'a X>
|
X: for<'a> core::ops::SubAssign<&'a X>,
|
||||||
{
|
{
|
||||||
WeightedIndexIter {
|
WeightedIndexIter {
|
||||||
weighted_index: self,
|
weighted_index: self,
|
||||||
@ -387,6 +389,7 @@ pub trait Weight: Clone {
|
|||||||
/// - `Result::Err`: Returns an error when `Self` cannot represent the
|
/// - `Result::Err`: Returns an error when `Self` cannot represent the
|
||||||
/// result of `self + v` (i.e. overflow). The value of `self` should be
|
/// result of `self + v` (i.e. overflow). The value of `self` should be
|
||||||
/// discarded.
|
/// discarded.
|
||||||
|
#[allow(clippy::result_unit_err)]
|
||||||
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
|
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()>;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -417,6 +420,7 @@ macro_rules! impl_weight_float {
|
|||||||
($t:ty) => {
|
($t:ty) => {
|
||||||
impl Weight for $t {
|
impl Weight for $t {
|
||||||
const ZERO: Self = 0.0;
|
const ZERO: Self = 0.0;
|
||||||
|
|
||||||
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
|
fn checked_add_assign(&mut self, v: &Self) -> Result<(), ()> {
|
||||||
// Floats have an explicit representation for overflow
|
// Floats have an explicit representation for overflow
|
||||||
*self += *v;
|
*self += *v;
|
||||||
@ -435,7 +439,7 @@ mod test {
|
|||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_weightedindex_serde1() {
|
fn test_weightedindex_serde1() {
|
||||||
let weighted_index = WeightedIndex::new(&[1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
|
let weighted_index = WeightedIndex::new([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]).unwrap();
|
||||||
|
|
||||||
let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
|
let ser_weighted_index = bincode::serialize(&weighted_index).unwrap();
|
||||||
let de_weighted_index: WeightedIndex<i32> =
|
let de_weighted_index: WeightedIndex<i32> =
|
||||||
@ -451,20 +455,20 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn test_accepting_nan() {
|
fn test_accepting_nan() {
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[f32::NAN, 0.5]).unwrap_err(),
|
WeightedIndex::new([f32::NAN, 0.5]).unwrap_err(),
|
||||||
WeightError::InvalidWeight,
|
WeightError::InvalidWeight,
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[f32::NAN]).unwrap_err(),
|
WeightedIndex::new([f32::NAN]).unwrap_err(),
|
||||||
WeightError::InvalidWeight,
|
WeightError::InvalidWeight,
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[0.5, f32::NAN]).unwrap_err(),
|
WeightedIndex::new([0.5, f32::NAN]).unwrap_err(),
|
||||||
WeightError::InvalidWeight,
|
WeightError::InvalidWeight,
|
||||||
);
|
);
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[0.5, 7.0])
|
WeightedIndex::new([0.5, 7.0])
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.update_weights(&[(0, &f32::NAN)])
|
.update_weights(&[(0, &f32::NAN)])
|
||||||
.unwrap_err(),
|
.unwrap_err(),
|
||||||
@ -516,10 +520,10 @@ mod test {
|
|||||||
verify(chosen);
|
verify(chosen);
|
||||||
|
|
||||||
for _ in 0..5 {
|
for _ in 0..5 {
|
||||||
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
|
assert_eq!(WeightedIndex::new([0, 1]).unwrap().sample(&mut r), 1);
|
||||||
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
|
assert_eq!(WeightedIndex::new([1, 0]).unwrap().sample(&mut r), 0);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
|
WeightedIndex::new([0, 0, 0, 0, 10, 0])
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.sample(&mut r),
|
.sample(&mut r),
|
||||||
4
|
4
|
||||||
@ -531,19 +535,19 @@ mod test {
|
|||||||
WeightError::InvalidInput
|
WeightError::InvalidInput
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[0]).unwrap_err(),
|
WeightedIndex::new([0]).unwrap_err(),
|
||||||
WeightError::InsufficientNonZero
|
WeightError::InsufficientNonZero
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
|
WeightedIndex::new([10, 20, -1, 30]).unwrap_err(),
|
||||||
WeightError::InvalidWeight
|
WeightError::InvalidWeight
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
|
WeightedIndex::new([-10, 20, 1, 30]).unwrap_err(),
|
||||||
WeightError::InvalidWeight
|
WeightError::InvalidWeight
|
||||||
);
|
);
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
WeightedIndex::new(&[-10]).unwrap_err(),
|
WeightedIndex::new([-10]).unwrap_err(),
|
||||||
WeightError::InvalidWeight
|
WeightError::InvalidWeight
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
@ -649,7 +653,9 @@ mod test {
|
|||||||
#[test]
|
#[test]
|
||||||
fn value_stability() {
|
fn value_stability() {
|
||||||
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
|
fn test_samples<X: Weight + SampleUniform + PartialOrd, I>(
|
||||||
weights: I, buf: &mut [usize], expected: &[usize],
|
weights: I,
|
||||||
|
buf: &mut [usize],
|
||||||
|
expected: &[usize],
|
||||||
) where
|
) where
|
||||||
I: IntoIterator,
|
I: IntoIterator,
|
||||||
I::Item: SampleBorrow<X>,
|
I::Item: SampleBorrow<X>,
|
||||||
@ -665,17 +671,17 @@ mod test {
|
|||||||
|
|
||||||
let mut buf = [0; 10];
|
let mut buf = [0; 10];
|
||||||
test_samples(
|
test_samples(
|
||||||
&[1i32, 1, 1, 1, 1, 1, 1, 1, 1],
|
[1i32, 1, 1, 1, 1, 1, 1, 1, 1],
|
||||||
&mut buf,
|
&mut buf,
|
||||||
&[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
|
&[0, 6, 2, 6, 3, 4, 7, 8, 2, 5],
|
||||||
);
|
);
|
||||||
test_samples(
|
test_samples(
|
||||||
&[0.7f32, 0.1, 0.1, 0.1],
|
[0.7f32, 0.1, 0.1, 0.1],
|
||||||
&mut buf,
|
&mut buf,
|
||||||
&[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
|
&[0, 0, 0, 1, 0, 0, 2, 3, 0, 0],
|
||||||
);
|
);
|
||||||
test_samples(
|
test_samples(
|
||||||
&[1.0f64, 0.999, 0.998, 0.997],
|
[1.0f64, 0.999, 0.998, 0.997],
|
||||||
&mut buf,
|
&mut buf,
|
||||||
&[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
|
&[2, 2, 1, 3, 2, 1, 3, 3, 2, 1],
|
||||||
);
|
);
|
||||||
@ -683,7 +689,7 @@ mod test {
|
|||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn weighted_index_distributions_can_be_compared() {
|
fn weighted_index_distributions_can_be_compared() {
|
||||||
assert_eq!(WeightedIndex::new(&[1, 2]), WeightedIndex::new(&[1, 2]));
|
assert_eq!(WeightedIndex::new([1, 2]), WeightedIndex::new([1, 2]));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
11
src/lib.rs
11
src/lib.rs
@ -50,6 +50,7 @@
|
|||||||
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
#![doc(test(attr(allow(unused_variables), deny(warnings))))]
|
||||||
#![no_std]
|
#![no_std]
|
||||||
#![cfg_attr(feature = "simd_support", feature(portable_simd))]
|
#![cfg_attr(feature = "simd_support", feature(portable_simd))]
|
||||||
|
#![allow(unexpected_cfgs)]
|
||||||
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
#![cfg_attr(doc_cfg, feature(doc_cfg))]
|
||||||
#![allow(
|
#![allow(
|
||||||
clippy::float_cmp,
|
clippy::float_cmp,
|
||||||
@ -57,8 +58,10 @@
|
|||||||
clippy::nonminimal_bool
|
clippy::nonminimal_bool
|
||||||
)]
|
)]
|
||||||
|
|
||||||
#[cfg(feature = "alloc")] extern crate alloc;
|
#[cfg(feature = "alloc")]
|
||||||
#[cfg(feature = "std")] extern crate std;
|
extern crate alloc;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
extern crate std;
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
macro_rules! trace { ($($x:tt)*) => (
|
macro_rules! trace { ($($x:tt)*) => (
|
||||||
@ -160,7 +163,9 @@ use crate::distributions::{Distribution, Standard};
|
|||||||
)]
|
)]
|
||||||
#[inline]
|
#[inline]
|
||||||
pub fn random<T>() -> T
|
pub fn random<T>() -> T
|
||||||
where Standard: Distribution<T> {
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
|
{
|
||||||
thread_rng().random()
|
thread_rng().random()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +18,8 @@
|
|||||||
//! # let _: f32 = r.random();
|
//! # let _: f32 = r.random();
|
||||||
//! ```
|
//! ```
|
||||||
|
|
||||||
#[doc(no_inline)] pub use crate::distributions::Distribution;
|
#[doc(no_inline)]
|
||||||
|
pub use crate::distributions::Distribution;
|
||||||
#[cfg(feature = "small_rng")]
|
#[cfg(feature = "small_rng")]
|
||||||
#[doc(no_inline)]
|
#[doc(no_inline)]
|
||||||
pub use crate::rngs::SmallRng;
|
pub use crate::rngs::SmallRng;
|
||||||
@ -33,4 +34,5 @@ pub use crate::seq::{IndexedMutRandom, IndexedRandom, IteratorRandom, SliceRando
|
|||||||
#[doc(no_inline)]
|
#[doc(no_inline)]
|
||||||
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
|
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
|
||||||
pub use crate::{random, thread_rng};
|
pub use crate::{random, thread_rng};
|
||||||
#[doc(no_inline)] pub use crate::{CryptoRng, Rng, RngCore, SeedableRng};
|
#[doc(no_inline)]
|
||||||
|
pub use crate::{CryptoRng, Rng, RngCore, SeedableRng};
|
||||||
|
14
src/rng.rs
14
src/rng.rs
@ -89,7 +89,9 @@ pub trait Rng: RngCore {
|
|||||||
/// [`Standard`]: distributions::Standard
|
/// [`Standard`]: distributions::Standard
|
||||||
#[inline]
|
#[inline]
|
||||||
fn random<T>(&mut self) -> T
|
fn random<T>(&mut self) -> T
|
||||||
where Standard: Distribution<T> {
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
|
{
|
||||||
Standard.sample(self)
|
Standard.sample(self)
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -309,7 +311,9 @@ pub trait Rng: RngCore {
|
|||||||
note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024."
|
note = "Renamed to `random` to avoid conflict with the new `gen` keyword in Rust 2024."
|
||||||
)]
|
)]
|
||||||
fn gen<T>(&mut self) -> T
|
fn gen<T>(&mut self) -> T
|
||||||
where Standard: Distribution<T> {
|
where
|
||||||
|
Standard: Distribution<T>,
|
||||||
|
{
|
||||||
self.random()
|
self.random()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -402,7 +406,8 @@ impl_fill!(u16, u32, u64, usize, u128,);
|
|||||||
impl_fill!(i8, i16, i32, i64, isize, i128,);
|
impl_fill!(i8, i16, i32, i64, isize, i128,);
|
||||||
|
|
||||||
impl<T, const N: usize> Fill for [T; N]
|
impl<T, const N: usize> Fill for [T; N]
|
||||||
where [T]: Fill
|
where
|
||||||
|
[T]: Fill,
|
||||||
{
|
{
|
||||||
fn fill<R: Rng + ?Sized>(&mut self, rng: &mut R) {
|
fn fill<R: Rng + ?Sized>(&mut self, rng: &mut R) {
|
||||||
<[T] as Fill>::fill(self, rng)
|
<[T] as Fill>::fill(self, rng)
|
||||||
@ -414,7 +419,8 @@ mod test {
|
|||||||
use super::*;
|
use super::*;
|
||||||
use crate::rngs::mock::StepRng;
|
use crate::rngs::mock::StepRng;
|
||||||
use crate::test::rng;
|
use crate::test::rng;
|
||||||
#[cfg(feature = "alloc")] use alloc::boxed::Box;
|
#[cfg(feature = "alloc")]
|
||||||
|
use alloc::boxed::Box;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_fill_bytes_default() {
|
fn test_fill_bytes_default() {
|
||||||
|
@ -10,7 +10,8 @@
|
|||||||
|
|
||||||
use rand_core::{impls, RngCore};
|
use rand_core::{impls, RngCore};
|
||||||
|
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A mock generator yielding very predictable output
|
/// A mock generator yielding very predictable output
|
||||||
///
|
///
|
||||||
@ -78,7 +79,8 @@ rand_core::impl_try_rng_from_rng_core!(StepRng);
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
#[cfg(any(feature = "alloc", feature = "serde1"))] use super::StepRng;
|
#[cfg(any(feature = "alloc", feature = "serde1"))]
|
||||||
|
use super::StepRng;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
|
@ -102,18 +102,22 @@ pub use reseeding::ReseedingRng;
|
|||||||
pub mod mock; // Public so we don't export `StepRng` directly, making it a bit
|
pub mod mock; // Public so we don't export `StepRng` directly, making it a bit
|
||||||
// more clear it is intended for testing.
|
// more clear it is intended for testing.
|
||||||
|
|
||||||
#[cfg(feature = "small_rng")] mod small;
|
#[cfg(feature = "small_rng")]
|
||||||
|
mod small;
|
||||||
#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))]
|
#[cfg(all(feature = "small_rng", not(target_pointer_width = "64")))]
|
||||||
mod xoshiro128plusplus;
|
mod xoshiro128plusplus;
|
||||||
#[cfg(all(feature = "small_rng", target_pointer_width = "64"))]
|
#[cfg(all(feature = "small_rng", target_pointer_width = "64"))]
|
||||||
mod xoshiro256plusplus;
|
mod xoshiro256plusplus;
|
||||||
|
|
||||||
#[cfg(feature = "std_rng")] mod std;
|
#[cfg(feature = "std_rng")]
|
||||||
|
mod std;
|
||||||
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
|
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
|
||||||
pub(crate) mod thread;
|
pub(crate) mod thread;
|
||||||
|
|
||||||
#[cfg(feature = "small_rng")] pub use self::small::SmallRng;
|
#[cfg(feature = "small_rng")]
|
||||||
#[cfg(feature = "std_rng")] pub use self::std::StdRng;
|
pub use self::small::SmallRng;
|
||||||
|
#[cfg(feature = "std_rng")]
|
||||||
|
pub use self::std::StdRng;
|
||||||
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
|
#[cfg(all(feature = "std", feature = "std_rng", feature = "getrandom"))]
|
||||||
pub use self::thread::ThreadRng;
|
pub use self::thread::ThreadRng;
|
||||||
|
|
||||||
|
@ -33,7 +33,6 @@ use crate::rngs::ReseedingRng;
|
|||||||
// `ThreadRng` internally, which is nonsensical anyway. We should also never run
|
// `ThreadRng` internally, which is nonsensical anyway. We should also never run
|
||||||
// `ThreadRng` in destructors of its implementation, which is also nonsensical.
|
// `ThreadRng` in destructors of its implementation, which is also nonsensical.
|
||||||
|
|
||||||
|
|
||||||
// Number of generated bytes after which to reseed `ThreadRng`.
|
// Number of generated bytes after which to reseed `ThreadRng`.
|
||||||
// According to benchmarks, reseeding has a noticeable impact with thresholds
|
// According to benchmarks, reseeding has a noticeable impact with thresholds
|
||||||
// of 32 kB and less. We choose 64 kB to avoid significant overhead.
|
// of 32 kB and less. We choose 64 kB to avoid significant overhead.
|
||||||
|
@ -9,7 +9,8 @@
|
|||||||
use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32};
|
use rand_core::impls::{fill_bytes_via_next, next_u64_via_u32};
|
||||||
use rand_core::le::read_u32_into;
|
use rand_core::le::read_u32_into;
|
||||||
use rand_core::{RngCore, SeedableRng};
|
use rand_core::{RngCore, SeedableRng};
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A xoshiro128++ random number generator.
|
/// A xoshiro128++ random number generator.
|
||||||
///
|
///
|
||||||
|
@ -9,7 +9,8 @@
|
|||||||
use rand_core::impls::fill_bytes_via_next;
|
use rand_core::impls::fill_bytes_via_next;
|
||||||
use rand_core::le::read_u64_into;
|
use rand_core::le::read_u64_into;
|
||||||
use rand_core::{RngCore, SeedableRng};
|
use rand_core::{RngCore, SeedableRng};
|
||||||
#[cfg(feature = "serde1")] use serde::{Deserialize, Serialize};
|
#[cfg(feature = "serde1")]
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A xoshiro256++ random number generator.
|
/// A xoshiro256++ random number generator.
|
||||||
///
|
///
|
||||||
|
@ -10,7 +10,7 @@ use crate::RngCore;
|
|||||||
|
|
||||||
pub(crate) struct CoinFlipper<R: RngCore> {
|
pub(crate) struct CoinFlipper<R: RngCore> {
|
||||||
pub rng: R,
|
pub rng: R,
|
||||||
chunk: u32, //TODO(opt): this should depend on RNG word size
|
chunk: u32, // TODO(opt): this should depend on RNG word size
|
||||||
chunk_remaining: u32,
|
chunk_remaining: u32,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -92,7 +92,7 @@ impl<R: RngCore> CoinFlipper<R> {
|
|||||||
// If n * 2^c > `usize::MAX` we always return `true` anyway
|
// If n * 2^c > `usize::MAX` we always return `true` anyway
|
||||||
n = n.saturating_mul(2_usize.pow(c));
|
n = n.saturating_mul(2_usize.pow(c));
|
||||||
} else {
|
} else {
|
||||||
//At least one tail
|
// At least one tail
|
||||||
if c == 1 {
|
if c == 1 {
|
||||||
// Calculate 2n - d.
|
// Calculate 2n - d.
|
||||||
// We need to use wrapping as 2n might be greater than `usize::MAX`
|
// We need to use wrapping as 2n might be greater than `usize::MAX`
|
||||||
|
@ -7,23 +7,30 @@
|
|||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
|
|
||||||
//! Low-level API for sampling indices
|
//! Low-level API for sampling indices
|
||||||
|
use core::{cmp::Ordering, hash::Hash, ops::AddAssign};
|
||||||
|
|
||||||
#[cfg(feature = "alloc")] use core::slice;
|
#[cfg(feature = "alloc")]
|
||||||
|
use core::slice;
|
||||||
|
|
||||||
#[cfg(feature = "alloc")] use alloc::vec::{self, Vec};
|
#[cfg(feature = "alloc")]
|
||||||
|
use alloc::vec::{self, Vec};
|
||||||
// BTreeMap is not as fast in tests, but better than nothing.
|
// BTreeMap is not as fast in tests, but better than nothing.
|
||||||
#[cfg(all(feature = "alloc", not(feature = "std")))]
|
#[cfg(all(feature = "alloc", not(feature = "std")))]
|
||||||
use alloc::collections::BTreeSet;
|
use alloc::collections::BTreeSet;
|
||||||
#[cfg(feature = "std")] use std::collections::HashSet;
|
#[cfg(feature = "std")]
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
use super::WeightError;
|
use super::WeightError;
|
||||||
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
use crate::{Rng, distributions::{uniform::SampleUniform, Distribution, Uniform}};
|
use crate::{
|
||||||
|
distributions::{uniform::SampleUniform, Distribution, Uniform},
|
||||||
|
Rng,
|
||||||
|
};
|
||||||
|
|
||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
use serde::{Serialize, Deserialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
/// A vector of indices.
|
/// A vector of indices.
|
||||||
///
|
///
|
||||||
@ -88,8 +95,8 @@ impl IndexVec {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl IntoIterator for IndexVec {
|
impl IntoIterator for IndexVec {
|
||||||
type Item = usize;
|
|
||||||
type IntoIter = IndexVecIntoIter;
|
type IntoIter = IndexVecIntoIter;
|
||||||
|
type Item = usize;
|
||||||
|
|
||||||
/// Convert into an iterator over the indices as a sequence of `usize` values
|
/// Convert into an iterator over the indices as a sequence of `usize` values
|
||||||
#[inline]
|
#[inline]
|
||||||
@ -196,7 +203,6 @@ impl Iterator for IndexVecIntoIter {
|
|||||||
|
|
||||||
impl ExactSizeIterator for IndexVecIntoIter {}
|
impl ExactSizeIterator for IndexVecIntoIter {}
|
||||||
|
|
||||||
|
|
||||||
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
|
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
|
||||||
/// return them in random order (fully shuffled).
|
/// return them in random order (fully shuffled).
|
||||||
///
|
///
|
||||||
@ -221,7 +227,9 @@ impl ExactSizeIterator for IndexVecIntoIter {}
|
|||||||
/// Panics if `amount > length`.
|
/// Panics if `amount > length`.
|
||||||
#[track_caller]
|
#[track_caller]
|
||||||
pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
|
pub fn sample<R>(rng: &mut R, length: usize, amount: usize) -> IndexVec
|
||||||
where R: Rng + ?Sized {
|
where
|
||||||
|
R: Rng + ?Sized,
|
||||||
|
{
|
||||||
if amount > length {
|
if amount > length {
|
||||||
panic!("`amount` of samples must be less than or equal to `length`");
|
panic!("`amount` of samples must be less than or equal to `length`");
|
||||||
}
|
}
|
||||||
@ -276,7 +284,10 @@ where R: Rng + ?Sized {
|
|||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
|
||||||
pub fn sample_weighted<R, F, X>(
|
pub fn sample_weighted<R, F, X>(
|
||||||
rng: &mut R, length: usize, weight: F, amount: usize,
|
rng: &mut R,
|
||||||
|
length: usize,
|
||||||
|
weight: F,
|
||||||
|
amount: usize,
|
||||||
) -> Result<IndexVec, WeightError>
|
) -> Result<IndexVec, WeightError>
|
||||||
where
|
where
|
||||||
R: Rng + ?Sized,
|
R: Rng + ?Sized,
|
||||||
@ -293,7 +304,6 @@ where
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
|
/// Randomly sample exactly `amount` distinct indices from `0..length`, and
|
||||||
/// return them in an arbitrary order (there is no guarantee of shuffling or
|
/// return them in an arbitrary order (there is no guarantee of shuffling or
|
||||||
/// ordering). The weights are to be provided by the input function `weights`,
|
/// ordering). The weights are to be provided by the input function `weights`,
|
||||||
@ -308,7 +318,10 @@ where
|
|||||||
/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive.
|
/// - [`WeightError::InsufficientNonZero`] when fewer than `amount` weights are positive.
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
fn sample_efraimidis_spirakis<R, F, X, N>(
|
fn sample_efraimidis_spirakis<R, F, X, N>(
|
||||||
rng: &mut R, length: N, weight: F, amount: N,
|
rng: &mut R,
|
||||||
|
length: N,
|
||||||
|
weight: F,
|
||||||
|
amount: N,
|
||||||
) -> Result<IndexVec, WeightError>
|
) -> Result<IndexVec, WeightError>
|
||||||
where
|
where
|
||||||
R: Rng + ?Sized,
|
R: Rng + ?Sized,
|
||||||
@ -325,23 +338,27 @@ where
|
|||||||
index: N,
|
index: N,
|
||||||
key: f64,
|
key: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N> PartialOrd for Element<N> {
|
impl<N> PartialOrd for Element<N> {
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||||
self.key.partial_cmp(&other.key)
|
Some(self.cmp(other))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N> Ord for Element<N> {
|
impl<N> Ord for Element<N> {
|
||||||
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
|
fn cmp(&self, other: &Self) -> Ordering {
|
||||||
// partial_cmp will always produce a value,
|
// partial_cmp will always produce a value,
|
||||||
// because we check that the weights are not nan
|
// because we check that the weights are not nan
|
||||||
self.partial_cmp(other).unwrap()
|
self.key.partial_cmp(&other.key).unwrap()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N> PartialEq for Element<N> {
|
impl<N> PartialEq for Element<N> {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool {
|
||||||
self.key == other.key
|
self.key == other.key
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<N> Eq for Element<N> {}
|
impl<N> Eq for Element<N> {}
|
||||||
|
|
||||||
let mut candidates = Vec::with_capacity(length.as_usize());
|
let mut candidates = Vec::with_capacity(length.as_usize());
|
||||||
@ -367,8 +384,7 @@ where
|
|||||||
// keys. Do this by using `select_nth_unstable` to put the elements with
|
// keys. Do this by using `select_nth_unstable` to put the elements with
|
||||||
// the *smallest* keys at the beginning of the list in `O(n)` time, which
|
// the *smallest* keys at the beginning of the list in `O(n)` time, which
|
||||||
// provides equivalent information about the elements with the *greatest* keys.
|
// provides equivalent information about the elements with the *greatest* keys.
|
||||||
let (_, mid, greater)
|
let (_, mid, greater) = candidates.select_nth_unstable(avail - amount.as_usize());
|
||||||
= candidates.select_nth_unstable(avail - amount.as_usize());
|
|
||||||
|
|
||||||
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
|
let mut result: Vec<N> = Vec::with_capacity(amount.as_usize());
|
||||||
result.push(mid.index);
|
result.push(mid.index);
|
||||||
@ -385,7 +401,9 @@ where
|
|||||||
///
|
///
|
||||||
/// This implementation uses `O(amount)` memory and `O(amount^2)` time.
|
/// This implementation uses `O(amount)` memory and `O(amount^2)` time.
|
||||||
fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
|
fn sample_floyd<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
|
||||||
where R: Rng + ?Sized {
|
where
|
||||||
|
R: Rng + ?Sized,
|
||||||
|
{
|
||||||
// Note that the values returned by `rng.gen_range()` can be
|
// Note that the values returned by `rng.gen_range()` can be
|
||||||
// inferred from the returned vector by working backwards from
|
// inferred from the returned vector by working backwards from
|
||||||
// the last entry. This bijection proves the algorithm fair.
|
// the last entry. This bijection proves the algorithm fair.
|
||||||
@ -414,7 +432,9 @@ where R: Rng + ?Sized {
|
|||||||
///
|
///
|
||||||
/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
|
/// Set-up is `O(length)` time and memory and shuffling is `O(amount)` time.
|
||||||
fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
|
fn sample_inplace<R>(rng: &mut R, length: u32, amount: u32) -> IndexVec
|
||||||
where R: Rng + ?Sized {
|
where
|
||||||
|
R: Rng + ?Sized,
|
||||||
|
{
|
||||||
debug_assert!(amount <= length);
|
debug_assert!(amount <= length);
|
||||||
let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
|
let mut indices: Vec<u32> = Vec::with_capacity(length as usize);
|
||||||
indices.extend(0..length);
|
indices.extend(0..length);
|
||||||
@ -427,12 +447,12 @@ where R: Rng + ?Sized {
|
|||||||
IndexVec::from(indices)
|
IndexVec::from(indices)
|
||||||
}
|
}
|
||||||
|
|
||||||
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform
|
trait UInt: Copy + PartialOrd + Ord + PartialEq + Eq + SampleUniform + Hash + AddAssign {
|
||||||
+ core::hash::Hash + core::ops::AddAssign {
|
|
||||||
fn zero() -> Self;
|
fn zero() -> Self;
|
||||||
fn one() -> Self;
|
fn one() -> Self;
|
||||||
fn as_usize(self) -> usize;
|
fn as_usize(self) -> usize;
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UInt for u32 {
|
impl UInt for u32 {
|
||||||
#[inline]
|
#[inline]
|
||||||
fn zero() -> Self {
|
fn zero() -> Self {
|
||||||
@ -449,6 +469,7 @@ impl UInt for u32 {
|
|||||||
self as usize
|
self as usize
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl UInt for usize {
|
impl UInt for usize {
|
||||||
#[inline]
|
#[inline]
|
||||||
fn zero() -> Self {
|
fn zero() -> Self {
|
||||||
@ -507,19 +528,23 @@ mod test {
|
|||||||
#[cfg(feature = "serde1")]
|
#[cfg(feature = "serde1")]
|
||||||
fn test_serialization_index_vec() {
|
fn test_serialization_index_vec() {
|
||||||
let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]);
|
let some_index_vec = IndexVec::from(vec![254_usize, 234, 2, 1]);
|
||||||
let de_some_index_vec: IndexVec = bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
|
let de_some_index_vec: IndexVec =
|
||||||
|
bincode::deserialize(&bincode::serialize(&some_index_vec).unwrap()).unwrap();
|
||||||
match (some_index_vec, de_some_index_vec) {
|
match (some_index_vec, de_some_index_vec) {
|
||||||
(IndexVec::U32(a), IndexVec::U32(b)) => {
|
(IndexVec::U32(a), IndexVec::U32(b)) => {
|
||||||
assert_eq!(a, b);
|
assert_eq!(a, b);
|
||||||
},
|
}
|
||||||
(IndexVec::USize(a), IndexVec::USize(b)) => {
|
(IndexVec::USize(a), IndexVec::USize(b)) => {
|
||||||
assert_eq!(a, b);
|
assert_eq!(a, b);
|
||||||
},
|
}
|
||||||
_ => {panic!("failed to seralize/deserialize `IndexVec`")}
|
_ => {
|
||||||
|
panic!("failed to seralize/deserialize `IndexVec`")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature = "alloc")] use alloc::vec;
|
#[cfg(feature = "alloc")]
|
||||||
|
use alloc::vec;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn test_sample_boundaries() {
|
fn test_sample_boundaries() {
|
||||||
@ -593,7 +618,7 @@ mod test {
|
|||||||
for &i in &indices {
|
for &i in &indices {
|
||||||
assert!((i as usize) < len);
|
assert!((i as usize) < len);
|
||||||
}
|
}
|
||||||
},
|
}
|
||||||
IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
|
IndexVec::USize(_) => panic!("expected `IndexVec::U32`"),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -628,11 +653,15 @@ mod test {
|
|||||||
do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace
|
do_test(300, 80, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace
|
||||||
do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace
|
do_test(300, 180, &[31, 289, 248, 154, 221, 243, 7, 192]); // inplace
|
||||||
|
|
||||||
do_test(1_000_000, 8, &[
|
do_test(
|
||||||
103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573,
|
1_000_000,
|
||||||
]); // floyd
|
8,
|
||||||
do_test(1_000_000, 180, &[
|
&[103717, 963485, 826422, 509101, 736394, 807035, 5327, 632573],
|
||||||
103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573,
|
); // floyd
|
||||||
]); // rejection
|
do_test(
|
||||||
|
1_000_000,
|
||||||
|
180,
|
||||||
|
&[103718, 963490, 826426, 509103, 736396, 807036, 5327, 632573],
|
||||||
|
); // rejection
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -44,7 +44,8 @@ use alloc::vec::Vec;
|
|||||||
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
use crate::distributions::uniform::{SampleBorrow, SampleUniform};
|
use crate::distributions::uniform::{SampleBorrow, SampleUniform};
|
||||||
#[cfg(feature = "alloc")] use crate::distributions::Weight;
|
#[cfg(feature = "alloc")]
|
||||||
|
use crate::distributions::Weight;
|
||||||
use crate::Rng;
|
use crate::Rng;
|
||||||
|
|
||||||
use self::coin_flipper::CoinFlipper;
|
use self::coin_flipper::CoinFlipper;
|
||||||
@ -167,7 +168,9 @@ pub trait IndexedRandom: Index<usize> {
|
|||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
||||||
fn choose_weighted<R, F, B, X>(
|
fn choose_weighted<R, F, B, X>(
|
||||||
&self, rng: &mut R, weight: F,
|
&self,
|
||||||
|
rng: &mut R,
|
||||||
|
weight: F,
|
||||||
) -> Result<&Self::Output, WeightError>
|
) -> Result<&Self::Output, WeightError>
|
||||||
where
|
where
|
||||||
R: Rng + ?Sized,
|
R: Rng + ?Sized,
|
||||||
@ -212,13 +215,15 @@ pub trait IndexedRandom: Index<usize> {
|
|||||||
/// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
|
/// println!("{:?}", choices.choose_multiple_weighted(&mut rng, 2, |item| item.1).unwrap().collect::<Vec<_>>());
|
||||||
/// ```
|
/// ```
|
||||||
/// [`choose_multiple`]: IndexedRandom::choose_multiple
|
/// [`choose_multiple`]: IndexedRandom::choose_multiple
|
||||||
//
|
|
||||||
// Note: this is feature-gated on std due to usage of f64::powf.
|
// Note: this is feature-gated on std due to usage of f64::powf.
|
||||||
// If necessary, we may use alloc+libm as an alternative (see PR #1089).
|
// If necessary, we may use alloc+libm as an alternative (see PR #1089).
|
||||||
#[cfg(feature = "std")]
|
#[cfg(feature = "std")]
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "std")))]
|
||||||
fn choose_multiple_weighted<R, F, X>(
|
fn choose_multiple_weighted<R, F, X>(
|
||||||
&self, rng: &mut R, amount: usize, weight: F,
|
&self,
|
||||||
|
rng: &mut R,
|
||||||
|
amount: usize,
|
||||||
|
weight: F,
|
||||||
) -> Result<SliceChooseIter<Self, Self::Output>, WeightError>
|
) -> Result<SliceChooseIter<Self, Self::Output>, WeightError>
|
||||||
where
|
where
|
||||||
Self::Output: Sized,
|
Self::Output: Sized,
|
||||||
@ -285,7 +290,9 @@ pub trait IndexedMutRandom: IndexedRandom + IndexMut<usize> {
|
|||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
#[cfg_attr(doc_cfg, doc(cfg(feature = "alloc")))]
|
||||||
fn choose_weighted_mut<R, F, B, X>(
|
fn choose_weighted_mut<R, F, B, X>(
|
||||||
&mut self, rng: &mut R, weight: F,
|
&mut self,
|
||||||
|
rng: &mut R,
|
||||||
|
weight: F,
|
||||||
) -> Result<&mut Self::Output, WeightError>
|
) -> Result<&mut Self::Output, WeightError>
|
||||||
where
|
where
|
||||||
R: Rng + ?Sized,
|
R: Rng + ?Sized,
|
||||||
@ -358,7 +365,9 @@ pub trait SliceRandom: IndexedMutRandom {
|
|||||||
///
|
///
|
||||||
/// For slices, complexity is `O(m)` where `m = amount`.
|
/// For slices, complexity is `O(m)` where `m = amount`.
|
||||||
fn partial_shuffle<R>(
|
fn partial_shuffle<R>(
|
||||||
&mut self, rng: &mut R, amount: usize,
|
&mut self,
|
||||||
|
rng: &mut R,
|
||||||
|
amount: usize,
|
||||||
) -> (&mut [Self::Output], &mut [Self::Output])
|
) -> (&mut [Self::Output], &mut [Self::Output])
|
||||||
where
|
where
|
||||||
Self::Output: Sized,
|
Self::Output: Sized,
|
||||||
@ -624,9 +633,7 @@ impl<T> SliceRandom for [T] {
|
|||||||
self.partial_shuffle(rng, self.len());
|
self.partial_shuffle(rng, self.len());
|
||||||
}
|
}
|
||||||
|
|
||||||
fn partial_shuffle<R>(
|
fn partial_shuffle<R>(&mut self, rng: &mut R, amount: usize) -> (&mut [T], &mut [T])
|
||||||
&mut self, rng: &mut R, amount: usize,
|
|
||||||
) -> (&mut [T], &mut [T])
|
|
||||||
where
|
where
|
||||||
R: Rng + ?Sized,
|
R: Rng + ?Sized,
|
||||||
{
|
{
|
||||||
@ -1294,7 +1301,10 @@ mod test {
|
|||||||
fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
|
fn do_test<I: Clone + Iterator<Item = u32>>(iter: I, v: &[u32]) {
|
||||||
let mut rng = crate::test::rng(412);
|
let mut rng = crate::test::rng(412);
|
||||||
let mut buf = [0u32; 8];
|
let mut buf = [0u32; 8];
|
||||||
assert_eq!(iter.clone().choose_multiple_fill(&mut rng, &mut buf), v.len());
|
assert_eq!(
|
||||||
|
iter.clone().choose_multiple_fill(&mut rng, &mut buf),
|
||||||
|
v.len()
|
||||||
|
);
|
||||||
assert_eq!(&buf[0..v.len()], v);
|
assert_eq!(&buf[0..v.len()], v);
|
||||||
|
|
||||||
#[cfg(feature = "alloc")]
|
#[cfg(feature = "alloc")]
|
||||||
|
Loading…
x
Reference in New Issue
Block a user