alnyan/yggdrasil: patch getrandom url

This commit is contained in:
2023-12-22 15:26:55 +02:00
parent 4336232dda
commit b16da2bdc9
4 changed files with 296 additions and 134 deletions
+1
View File
@@ -26,3 +26,4 @@ serde1 = ["serde", "serde_derive"] # enables serde for BlockRng wrapper
[dependencies] [dependencies]
serde = { version = "1", optional = true } serde = { version = "1", optional = true }
serde_derive = { version = "^1.0.38", optional = true } serde_derive = { version = "^1.0.38", optional = true }
getrandom = { version = "0.2", git = "https://git.alnyan.me/yggdrasil/getrandom.git", branch = "alnyan/yggdrasil", optional = true }
+17 -16
View File
@@ -14,23 +14,23 @@ use rand_core::{Error, ErrorKind};
use std::fs::File; use std::fs::File;
use std::io; use std::io;
use std::io::Read; use std::io::Read;
use std::sync::{Once, Mutex, ONCE_INIT}; use std::sync::{Mutex, Once};
// TODO: remove outer Option when `Mutex::new(None)` is a constant expression // TODO: remove outer Option when `Mutex::new(None)` is a constant expression
static mut READ_RNG_FILE: Option<Mutex<Option<File>>> = None; static mut READ_RNG_FILE: Option<Mutex<Option<File>>> = None;
static READ_RNG_ONCE: Once = ONCE_INIT; static READ_RNG_ONCE: Once = Once::new();
#[allow(unused)] #[allow(unused)]
pub fn open<F>(path: &'static str, open_fn: F) -> Result<(), Error> pub fn open<F>(path: &'static str, open_fn: F) -> Result<(), Error>
where F: Fn(&'static str) -> Result<File, io::Error> where
F: Fn(&'static str) -> Result<File, io::Error>,
{ {
READ_RNG_ONCE.call_once(|| { READ_RNG_ONCE.call_once(|| unsafe { READ_RNG_FILE = Some(Mutex::new(None)) });
unsafe { READ_RNG_FILE = Some(Mutex::new(None)) }
});
// We try opening the file outside the `call_once` fn because we cannot // We try opening the file outside the `call_once` fn because we cannot
// clone the error, thus we must retry on failure. // clone the error, thus we must retry on failure.
#[allow(static_mut_refs)]
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() }; let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
let mut guard = mutex.lock().unwrap(); let mut guard = mutex.lock().unwrap();
if (*guard).is_none() { if (*guard).is_none() {
@@ -45,26 +45,27 @@ pub fn read(dest: &mut [u8]) -> Result<(), Error> {
// We expect this function only to be used after `random_device::open` // We expect this function only to be used after `random_device::open`
// was succesful. Therefore we can assume that our memory was set with a // was succesful. Therefore we can assume that our memory was set with a
// valid object. // valid object.
#[allow(static_mut_refs)]
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() }; let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
let mut guard = mutex.lock().unwrap(); let mut guard = mutex.lock().unwrap();
let file = (*guard).as_mut().unwrap(); let file = (*guard).as_mut().unwrap();
// Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`. // Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`.
file.read_exact(dest).map_err(|err| { file.read_exact(dest).map_err(|err| {
Error::with_cause(ErrorKind::Unavailable, Error::with_cause(ErrorKind::Unavailable, "error reading random device", err)
"error reading random device", err)
}) })
} }
pub fn map_err(err: io::Error) -> Error { pub fn map_err(err: io::Error) -> Error {
match err.kind() { match err.kind() {
io::ErrorKind::Interrupted => io::ErrorKind::Interrupted => Error::new(ErrorKind::Transient, "interrupted"),
Error::new(ErrorKind::Transient, "interrupted"), io::ErrorKind::WouldBlock => {
io::ErrorKind::WouldBlock => Error::with_cause(ErrorKind::NotReady, "OS RNG not yet seeded", err)
Error::with_cause(ErrorKind::NotReady, }
"OS RNG not yet seeded", err), _ => Error::with_cause(
_ => Error::with_cause(ErrorKind::Unavailable, ErrorKind::Unavailable,
"error while opening random device", err) "error while opening random device",
err,
),
} }
} }
+223 -93
View File
@@ -182,50 +182,77 @@
//! [`Weibull`]: struct.Weibull.html //! [`Weibull`]: struct.Weibull.html
//! [`WeightedIndex`]: struct.WeightedIndex.html //! [`WeightedIndex`]: struct.WeightedIndex.html
#[cfg(any(rustc_1_26, features="nightly"))] #[cfg(any(rustc_1_26, features = "nightly"))]
use core::iter; use core::iter;
use Rng; use Rng;
pub use self::other::Alphanumeric;
#[doc(inline)] pub use self::uniform::Uniform;
pub use self::float::{OpenClosed01, Open01};
pub use self::bernoulli::Bernoulli; pub use self::bernoulli::Bernoulli;
#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError}; #[cfg(feature = "std")]
#[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface; pub use self::binomial::Binomial;
#[cfg(feature="std")] pub use self::unit_circle::UnitCircle; #[cfg(feature = "std")]
#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF, pub use self::cauchy::Cauchy;
StudentT, Beta}; #[cfg(feature = "std")]
#[cfg(feature="std")] pub use self::normal::{Normal, LogNormal, StandardNormal}; pub use self::dirichlet::Dirichlet;
#[cfg(feature="std")] pub use self::exponential::{Exp, Exp1}; #[cfg(feature = "std")]
#[cfg(feature="std")] pub use self::pareto::Pareto; pub use self::exponential::{Exp, Exp1};
#[cfg(feature="std")] pub use self::poisson::Poisson; pub use self::float::{Open01, OpenClosed01};
#[cfg(feature="std")] pub use self::binomial::Binomial; #[cfg(feature = "std")]
#[cfg(feature="std")] pub use self::cauchy::Cauchy; pub use self::gamma::{Beta, ChiSquared, FisherF, Gamma, StudentT};
#[cfg(feature="std")] pub use self::dirichlet::Dirichlet; #[cfg(feature = "std")]
#[cfg(feature="std")] pub use self::triangular::Triangular; pub use self::normal::{LogNormal, Normal, StandardNormal};
#[cfg(feature="std")] pub use self::weibull::Weibull; pub use self::other::Alphanumeric;
#[cfg(feature = "std")]
pub use self::pareto::Pareto;
#[cfg(feature = "std")]
pub use self::poisson::Poisson;
#[cfg(feature = "std")]
pub use self::triangular::Triangular;
#[doc(inline)]
pub use self::uniform::Uniform;
#[cfg(feature = "std")]
pub use self::unit_circle::UnitCircle;
#[cfg(feature = "std")]
pub use self::unit_sphere::UnitSphereSurface;
#[cfg(feature = "std")]
pub use self::weibull::Weibull;
#[cfg(feature = "alloc")]
pub use self::weighted::{WeightedError, WeightedIndex};
pub mod uniform;
mod bernoulli; mod bernoulli;
#[cfg(feature="alloc")] mod weighted; #[cfg(feature = "std")]
#[cfg(feature="std")] mod unit_sphere; mod binomial;
#[cfg(feature="std")] mod unit_circle; #[cfg(feature = "std")]
#[cfg(feature="std")] mod gamma; mod cauchy;
#[cfg(feature="std")] mod normal; #[cfg(feature = "std")]
#[cfg(feature="std")] mod exponential; mod dirichlet;
#[cfg(feature="std")] mod pareto; #[cfg(feature = "std")]
#[cfg(feature="std")] mod poisson; mod exponential;
#[cfg(feature="std")] mod binomial; #[cfg(feature = "std")]
#[cfg(feature="std")] mod cauchy; mod gamma;
#[cfg(feature="std")] mod dirichlet; #[cfg(feature = "std")]
#[cfg(feature="std")] mod triangular; mod normal;
#[cfg(feature="std")] mod weibull; #[cfg(feature = "std")]
mod pareto;
#[cfg(feature = "std")]
mod poisson;
#[cfg(feature = "std")]
mod triangular;
pub mod uniform;
#[cfg(feature = "std")]
mod unit_circle;
#[cfg(feature = "std")]
mod unit_sphere;
#[cfg(feature = "std")]
mod weibull;
#[cfg(feature = "alloc")]
mod weighted;
mod float; mod float;
mod integer; mod integer;
mod other; mod other;
mod utils; mod utils;
#[cfg(feature="std")] mod ziggurat_tables; #[cfg(feature = "std")]
mod ziggurat_tables;
/// Types (distributions) that can be used to create a random instance of `T`. /// Types (distributions) that can be used to create a random instance of `T`.
/// ///
@@ -269,7 +296,9 @@ pub trait Distribution<T> {
/// } /// }
/// ``` /// ```
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T> fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
where Self: Sized, R: Rng where
Self: Sized,
R: Rng,
{ {
DistIter { DistIter {
distr: self, distr: self,
@@ -285,7 +314,6 @@ impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
} }
} }
/// An iterator that generates random values of `T` with distribution `D`, /// An iterator that generates random values of `T` with distribution `D`,
/// using `R` as the source of randomness. /// using `R` as the source of randomness.
/// ///
@@ -302,7 +330,9 @@ pub struct DistIter<'a, D: 'a, R: 'a, T> {
} }
impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T> impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a where
D: Distribution<T>,
R: Rng + 'a,
{ {
type Item = T; type Item = T;
@@ -318,12 +348,19 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
#[cfg(rustc_1_26)] #[cfg(rustc_1_26)]
impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T> impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a {} where
D: Distribution<T>,
R: Rng + 'a,
{
}
#[cfg(features = "nightly")] #[cfg(features = "nightly")]
impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T> impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
where D: Distribution<T>, R: Rng + 'a {} where
D: Distribution<T>,
R: Rng + 'a,
{
}
/// A generic random value distribution, implemented for many primitive types. /// A generic random value distribution, implemented for many primitive types.
/// Usually generates values with a numerically uniform distribution, and with a /// Usually generates values with a numerically uniform distribution, and with a
@@ -385,9 +422,8 @@ impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
#[derive(Clone, Copy, Debug)] #[derive(Clone, Copy, Debug)]
pub struct Standard; pub struct Standard;
/// A value with a particular weight for use with `WeightedChoice`. /// A value with a particular weight for use with `WeightedChoice`.
#[deprecated(since="0.6.0", note="use WeightedIndex instead")] #[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
#[allow(deprecated)] #[allow(deprecated)]
#[derive(Copy, Clone, Debug)] #[derive(Copy, Clone, Debug)]
pub struct Weighted<T> { pub struct Weighted<T> {
@@ -402,15 +438,15 @@ pub struct Weighted<T> {
/// Deprecated: use [`WeightedIndex`] instead. /// Deprecated: use [`WeightedIndex`] instead.
/// ///
/// [`WeightedIndex`]: struct.WeightedIndex.html /// [`WeightedIndex`]: struct.WeightedIndex.html
#[deprecated(since="0.6.0", note="use WeightedIndex instead")] #[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
#[allow(deprecated)] #[allow(deprecated)]
#[derive(Debug)] #[derive(Debug)]
pub struct WeightedChoice<'a, T:'a> { pub struct WeightedChoice<'a, T: 'a> {
items: &'a mut [Weighted<T>], items: &'a mut [Weighted<T>],
weight_range: Uniform<u32>, weight_range: Uniform<u32>,
} }
#[deprecated(since="0.6.0", note="use WeightedIndex instead")] #[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
#[allow(deprecated)] #[allow(deprecated)]
impl<'a, T: Clone> WeightedChoice<'a, T> { impl<'a, T: Clone> WeightedChoice<'a, T> {
/// Create a new `WeightedChoice`. /// Create a new `WeightedChoice`.
@@ -422,7 +458,10 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
/// - the total weight is larger than a `u32` can contain. /// - the total weight is larger than a `u32` can contain.
pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> { pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
// strictly speaking, this is subsumed by the total weight == 0 case // strictly speaking, this is subsumed by the total weight == 0 case
assert!(!items.is_empty(), "WeightedChoice::new called with no items"); assert!(
!items.is_empty(),
"WeightedChoice::new called with no items"
);
let mut running_total: u32 = 0; let mut running_total: u32 = 0;
@@ -432,24 +471,29 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
for item in items.iter_mut() { for item in items.iter_mut() {
running_total = match running_total.checked_add(item.weight) { running_total = match running_total.checked_add(item.weight) {
Some(n) => n, Some(n) => n,
None => panic!("WeightedChoice::new called with a total weight \ None => panic!(
larger than a u32 can contain") "WeightedChoice::new called with a total weight \
larger than a u32 can contain"
),
}; };
item.weight = running_total; item.weight = running_total;
} }
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0"); assert!(
running_total != 0,
"WeightedChoice::new called with a total weight of 0"
);
WeightedChoice { WeightedChoice {
items, items,
// we're likely to be generating numbers in this range // we're likely to be generating numbers in this range
// relatively often, so might as well cache it // relatively often, so might as well cache it
weight_range: Uniform::new(0, running_total) weight_range: Uniform::new(0, running_total),
} }
} }
} }
#[deprecated(since="0.6.0", note="use WeightedIndex instead")] // #[deprecated(since="0.6.0", note="use WeightedIndex instead")]
#[allow(deprecated)] #[allow(deprecated)]
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> { impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
@@ -496,9 +540,9 @@ impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use rngs::mock::StepRng;
#[allow(deprecated)] #[allow(deprecated)]
use super::{WeightedChoice, Weighted, Distribution}; use super::{Distribution, Weighted, WeightedChoice};
use rngs::mock::StepRng;
#[test] #[test]
#[allow(deprecated)] #[allow(deprecated)]
@@ -511,7 +555,9 @@ mod tests {
($items:expr, $expected:expr) => {{ ($items:expr, $expected:expr) => {{
let mut items = $items; let mut items = $items;
let mut total_weight = 0; let mut total_weight = 0;
for item in &items { total_weight += item.weight; } for item in &items {
total_weight += item.weight;
}
let wc = WeightedChoice::new(&mut items); let wc = WeightedChoice::new(&mut items);
let expected = $expected; let expected = $expected;
@@ -524,92 +570,176 @@ mod tests {
for &val in expected.iter() { for &val in expected.iter() {
assert_eq!(wc.sample(&mut rng), val) assert_eq!(wc.sample(&mut rng), val)
} }
}} }};
} }
t!([Weighted { weight: 1, item: 10}], [10]); t!(
[Weighted {
weight: 1,
item: 10
}],
[10]
);
// skip some // skip some
t!([Weighted { weight: 0, item: 20}, t!(
Weighted { weight: 2, item: 21}, [
Weighted { weight: 0, item: 22}, Weighted {
Weighted { weight: 1, item: 23}], weight: 0,
[21, 21, 23]); item: 20
},
Weighted {
weight: 2,
item: 21
},
Weighted {
weight: 0,
item: 22
},
Weighted {
weight: 1,
item: 23
}
],
[21, 21, 23]
);
// different weights // different weights
t!([Weighted { weight: 4, item: 30}, t!(
Weighted { weight: 3, item: 31}], [
[30, 31, 30, 31, 30, 31, 30]); Weighted {
weight: 4,
item: 30
},
Weighted {
weight: 3,
item: 31
}
],
[30, 31, 30, 31, 30, 31, 30]
);
// check that we're binary searching // check that we're binary searching
// correctly with some vectors of odd // correctly with some vectors of odd
// length. // length.
t!([Weighted { weight: 1, item: 40}, t!(
Weighted { weight: 1, item: 41}, [
Weighted { weight: 1, item: 42}, Weighted {
Weighted { weight: 1, item: 43}, weight: 1,
Weighted { weight: 1, item: 44}], item: 40
[40, 41, 42, 43, 44]); },
t!([Weighted { weight: 1, item: 50}, Weighted {
Weighted { weight: 1, item: 51}, weight: 1,
Weighted { weight: 1, item: 52}, item: 41
Weighted { weight: 1, item: 53}, },
Weighted { weight: 1, item: 54}, Weighted {
Weighted { weight: 1, item: 55}, weight: 1,
Weighted { weight: 1, item: 56}], item: 42
[50, 54, 51, 55, 52, 56, 53]); },
Weighted {
weight: 1,
item: 43
},
Weighted {
weight: 1,
item: 44
}
],
[40, 41, 42, 43, 44]
);
t!(
[
Weighted {
weight: 1,
item: 50
},
Weighted {
weight: 1,
item: 51
},
Weighted {
weight: 1,
item: 52
},
Weighted {
weight: 1,
item: 53
},
Weighted {
weight: 1,
item: 54
},
Weighted {
weight: 1,
item: 55
},
Weighted {
weight: 1,
item: 56
}
],
[50, 54, 51, 55, 52, 56, 53]
);
} }
#[test] #[test]
#[allow(deprecated)] #[allow(deprecated)]
fn test_weighted_clone_initialization() { fn test_weighted_clone_initialization() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1}; let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
let clone = initial.clone(); let clone = initial.clone();
assert_eq!(initial.weight, clone.weight); assert_eq!(initial.weight, clone.weight);
assert_eq!(initial.item, clone.item); assert_eq!(initial.item, clone.item);
} }
#[test] #[should_panic] #[test]
#[should_panic]
#[allow(deprecated)] #[allow(deprecated)]
fn test_weighted_clone_change_weight() { fn test_weighted_clone_change_weight() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1}; let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
let mut clone = initial.clone(); let mut clone = initial.clone();
clone.weight = 5; clone.weight = 5;
assert_eq!(initial.weight, clone.weight); assert_eq!(initial.weight, clone.weight);
} }
#[test] #[should_panic] #[test]
#[should_panic]
#[allow(deprecated)] #[allow(deprecated)]
fn test_weighted_clone_change_item() { fn test_weighted_clone_change_item() {
let initial : Weighted<u32> = Weighted {weight: 1, item: 1}; let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
let mut clone = initial.clone(); let mut clone = initial.clone();
clone.item = 5; clone.item = 5;
assert_eq!(initial.item, clone.item); assert_eq!(initial.item, clone.item);
} }
#[test] #[should_panic] #[test]
#[should_panic]
#[allow(deprecated)] #[allow(deprecated)]
fn test_weighted_choice_no_items() { fn test_weighted_choice_no_items() {
WeightedChoice::<isize>::new(&mut []); WeightedChoice::<isize>::new(&mut []);
} }
#[test] #[should_panic] #[test]
#[should_panic]
#[allow(deprecated)] #[allow(deprecated)]
fn test_weighted_choice_zero_weight() { fn test_weighted_choice_zero_weight() {
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0}, WeightedChoice::new(&mut [
Weighted { weight: 0, item: 1}]); Weighted { weight: 0, item: 0 },
Weighted { weight: 0, item: 1 },
]);
} }
#[test] #[should_panic] #[test]
#[should_panic]
#[allow(deprecated)] #[allow(deprecated)]
fn test_weighted_choice_weight_overflows() { fn test_weighted_choice_weight_overflows() {
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 }, WeightedChoice::new(&mut [
Weighted { weight: 1, item: 1 }, Weighted { weight: x, item: 0 },
Weighted { weight: x, item: 2 }, Weighted { weight: 1, item: 1 },
Weighted { weight: 1, item: 3 }]); Weighted { weight: x, item: 2 },
Weighted { weight: 1, item: 3 },
]);
} }
#[cfg(feature="std")] #[cfg(feature = "std")]
#[test] #[test]
fn test_distributions_iter() { fn test_distributions_iter() {
use distributions::Normal; use distributions::Normal;
+55 -25
View File
@@ -6,14 +6,15 @@
// option. This file may not be copied, modified, or distributed // option. This file may not be copied, modified, or distributed
// except according to those terms. // except according to those terms.
use Rng;
use distributions::Distribution;
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
use ::core::cmp::PartialOrd; use ::core::cmp::PartialOrd;
use core::fmt; use core::fmt;
use distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
use distributions::Distribution;
use Rng;
// Note that this whole module is only imported if feature="alloc" is enabled. // Note that this whole module is only imported if feature="alloc" is enabled.
#[cfg(not(feature="std"))] use alloc::vec::Vec; #[cfg(not(feature = "std"))]
use alloc::vec::Vec;
/// A distribution using weighted sampling to pick a discretely selected /// A distribution using weighted sampling to pick a discretely selected
/// item. /// item.
@@ -87,16 +88,13 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
/// [`Distribution`]: trait.Distribution.html /// [`Distribution`]: trait.Distribution.html
/// [`Uniform<X>`]: struct.Uniform.html /// [`Uniform<X>`]: struct.Uniform.html
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError> pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
where I: IntoIterator, where
I::Item: SampleBorrow<X>, I: IntoIterator,
X: for<'a> ::core::ops::AddAssign<&'a X> + I::Item: SampleBorrow<X>,
Clone + X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
Default { {
let mut iter = weights.into_iter(); let mut iter = weights.into_iter();
let mut total_weight: X = iter.next() let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
.ok_or(WeightedError::NoItem)?
.borrow()
.clone();
let zero = <X as Default>::default(); let zero = <X as Default>::default();
if total_weight < zero { if total_weight < zero {
@@ -117,18 +115,30 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
} }
let distr = X::Sampler::new(zero, total_weight); let distr = X::Sampler::new(zero, total_weight);
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr }) Ok(WeightedIndex {
cumulative_weights: weights,
weight_distribution: distr,
})
} }
} }
impl<X> Distribution<usize> for WeightedIndex<X> where impl<X> Distribution<usize> for WeightedIndex<X>
X: SampleUniform + PartialOrd { where
X: SampleUniform + PartialOrd,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize { fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
use ::core::cmp::Ordering; use ::core::cmp::Ordering;
let chosen_weight = self.weight_distribution.sample(rng); let chosen_weight = self.weight_distribution.sample(rng);
// Find the first item which has a weight *higher* than the chosen weight. // Find the first item which has a weight *higher* than the chosen weight.
self.cumulative_weights.binary_search_by( self.cumulative_weights
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err() .binary_search_by(|w| {
if *w <= chosen_weight {
Ordering::Less
} else {
Ordering::Greater
}
})
.unwrap_err()
} }
} }
@@ -181,14 +191,34 @@ mod test {
for _ in 0..5 { for _ in 0..5 {
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1); assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0); assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4); assert_eq!(
WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
.unwrap()
.sample(&mut r),
4
);
} }
assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem); assert_eq!(
assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero); WeightedIndex::new(&[10][0..0]).unwrap_err(),
assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight); WeightedError::NoItem
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight); );
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight); assert_eq!(
WeightedIndex::new(&[0]).unwrap_err(),
WeightedError::AllWeightsZero
);
assert_eq!(
WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
WeightedError::NegativeWeight
);
assert_eq!(
WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
WeightedError::NegativeWeight
);
assert_eq!(
WeightedIndex::new(&[-10]).unwrap_err(),
WeightedError::NegativeWeight
);
} }
} }
@@ -215,7 +245,7 @@ impl WeightedError {
} }
} }
#[cfg(feature="std")] #[cfg(feature = "std")]
impl ::std::error::Error for WeightedError { impl ::std::error::Error for WeightedError {
fn description(&self) -> &str { fn description(&self) -> &str {
self.msg() self.msg()