alnyan/yggdrasil: patch getrandom url
This commit is contained in:
parent
4336232dda
commit
b16da2bdc9
@ -26,3 +26,4 @@ serde1 = ["serde", "serde_derive"] # enables serde for BlockRng wrapper
|
||||
[dependencies]
|
||||
serde = { version = "1", optional = true }
|
||||
serde_derive = { version = "^1.0.38", optional = true }
|
||||
getrandom = { version = "0.2", git = "https://git.alnyan.me/yggdrasil/getrandom.git", branch = "alnyan/yggdrasil", optional = true }
|
||||
|
@ -14,23 +14,23 @@ use rand_core::{Error, ErrorKind};
|
||||
use std::fs::File;
|
||||
use std::io;
|
||||
use std::io::Read;
|
||||
use std::sync::{Once, Mutex, ONCE_INIT};
|
||||
use std::sync::{Mutex, Once};
|
||||
|
||||
// TODO: remove outer Option when `Mutex::new(None)` is a constant expression
|
||||
static mut READ_RNG_FILE: Option<Mutex<Option<File>>> = None;
|
||||
static READ_RNG_ONCE: Once = ONCE_INIT;
|
||||
static READ_RNG_ONCE: Once = Once::new();
|
||||
|
||||
#[allow(unused)]
|
||||
pub fn open<F>(path: &'static str, open_fn: F) -> Result<(), Error>
|
||||
where F: Fn(&'static str) -> Result<File, io::Error>
|
||||
where
|
||||
F: Fn(&'static str) -> Result<File, io::Error>,
|
||||
{
|
||||
READ_RNG_ONCE.call_once(|| {
|
||||
unsafe { READ_RNG_FILE = Some(Mutex::new(None)) }
|
||||
});
|
||||
READ_RNG_ONCE.call_once(|| unsafe { READ_RNG_FILE = Some(Mutex::new(None)) });
|
||||
|
||||
// We try opening the file outside the `call_once` fn because we cannot
|
||||
// clone the error, thus we must retry on failure.
|
||||
|
||||
#[allow(static_mut_refs)]
|
||||
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
|
||||
let mut guard = mutex.lock().unwrap();
|
||||
if (*guard).is_none() {
|
||||
@ -45,26 +45,27 @@ pub fn read(dest: &mut [u8]) -> Result<(), Error> {
|
||||
// We expect this function only to be used after `random_device::open`
|
||||
// was succesful. Therefore we can assume that our memory was set with a
|
||||
// valid object.
|
||||
#[allow(static_mut_refs)]
|
||||
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
|
||||
let mut guard = mutex.lock().unwrap();
|
||||
let file = (*guard).as_mut().unwrap();
|
||||
|
||||
// Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`.
|
||||
file.read_exact(dest).map_err(|err| {
|
||||
Error::with_cause(ErrorKind::Unavailable,
|
||||
"error reading random device", err)
|
||||
Error::with_cause(ErrorKind::Unavailable, "error reading random device", err)
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
pub fn map_err(err: io::Error) -> Error {
|
||||
match err.kind() {
|
||||
io::ErrorKind::Interrupted =>
|
||||
Error::new(ErrorKind::Transient, "interrupted"),
|
||||
io::ErrorKind::WouldBlock =>
|
||||
Error::with_cause(ErrorKind::NotReady,
|
||||
"OS RNG not yet seeded", err),
|
||||
_ => Error::with_cause(ErrorKind::Unavailable,
|
||||
"error while opening random device", err)
|
||||
io::ErrorKind::Interrupted => Error::new(ErrorKind::Transient, "interrupted"),
|
||||
io::ErrorKind::WouldBlock => {
|
||||
Error::with_cause(ErrorKind::NotReady, "OS RNG not yet seeded", err)
|
||||
}
|
||||
_ => Error::with_cause(
|
||||
ErrorKind::Unavailable,
|
||||
"error while opening random device",
|
||||
err,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
@ -182,50 +182,77 @@
|
||||
//! [`Weibull`]: struct.Weibull.html
|
||||
//! [`WeightedIndex`]: struct.WeightedIndex.html
|
||||
|
||||
#[cfg(any(rustc_1_26, features="nightly"))]
|
||||
#[cfg(any(rustc_1_26, features = "nightly"))]
|
||||
use core::iter;
|
||||
use Rng;
|
||||
|
||||
pub use self::other::Alphanumeric;
|
||||
#[doc(inline)] pub use self::uniform::Uniform;
|
||||
pub use self::float::{OpenClosed01, Open01};
|
||||
pub use self::bernoulli::Bernoulli;
|
||||
#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError};
|
||||
#[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface;
|
||||
#[cfg(feature="std")] pub use self::unit_circle::UnitCircle;
|
||||
#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF,
|
||||
StudentT, Beta};
|
||||
#[cfg(feature="std")] pub use self::normal::{Normal, LogNormal, StandardNormal};
|
||||
#[cfg(feature="std")] pub use self::exponential::{Exp, Exp1};
|
||||
#[cfg(feature="std")] pub use self::pareto::Pareto;
|
||||
#[cfg(feature="std")] pub use self::poisson::Poisson;
|
||||
#[cfg(feature="std")] pub use self::binomial::Binomial;
|
||||
#[cfg(feature="std")] pub use self::cauchy::Cauchy;
|
||||
#[cfg(feature="std")] pub use self::dirichlet::Dirichlet;
|
||||
#[cfg(feature="std")] pub use self::triangular::Triangular;
|
||||
#[cfg(feature="std")] pub use self::weibull::Weibull;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::binomial::Binomial;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::cauchy::Cauchy;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::dirichlet::Dirichlet;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::exponential::{Exp, Exp1};
|
||||
pub use self::float::{Open01, OpenClosed01};
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::gamma::{Beta, ChiSquared, FisherF, Gamma, StudentT};
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::normal::{LogNormal, Normal, StandardNormal};
|
||||
pub use self::other::Alphanumeric;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::pareto::Pareto;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::poisson::Poisson;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::triangular::Triangular;
|
||||
#[doc(inline)]
|
||||
pub use self::uniform::Uniform;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::unit_circle::UnitCircle;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::unit_sphere::UnitSphereSurface;
|
||||
#[cfg(feature = "std")]
|
||||
pub use self::weibull::Weibull;
|
||||
#[cfg(feature = "alloc")]
|
||||
pub use self::weighted::{WeightedError, WeightedIndex};
|
||||
|
||||
pub mod uniform;
|
||||
mod bernoulli;
|
||||
#[cfg(feature="alloc")] mod weighted;
|
||||
#[cfg(feature="std")] mod unit_sphere;
|
||||
#[cfg(feature="std")] mod unit_circle;
|
||||
#[cfg(feature="std")] mod gamma;
|
||||
#[cfg(feature="std")] mod normal;
|
||||
#[cfg(feature="std")] mod exponential;
|
||||
#[cfg(feature="std")] mod pareto;
|
||||
#[cfg(feature="std")] mod poisson;
|
||||
#[cfg(feature="std")] mod binomial;
|
||||
#[cfg(feature="std")] mod cauchy;
|
||||
#[cfg(feature="std")] mod dirichlet;
|
||||
#[cfg(feature="std")] mod triangular;
|
||||
#[cfg(feature="std")] mod weibull;
|
||||
#[cfg(feature = "std")]
|
||||
mod binomial;
|
||||
#[cfg(feature = "std")]
|
||||
mod cauchy;
|
||||
#[cfg(feature = "std")]
|
||||
mod dirichlet;
|
||||
#[cfg(feature = "std")]
|
||||
mod exponential;
|
||||
#[cfg(feature = "std")]
|
||||
mod gamma;
|
||||
#[cfg(feature = "std")]
|
||||
mod normal;
|
||||
#[cfg(feature = "std")]
|
||||
mod pareto;
|
||||
#[cfg(feature = "std")]
|
||||
mod poisson;
|
||||
#[cfg(feature = "std")]
|
||||
mod triangular;
|
||||
pub mod uniform;
|
||||
#[cfg(feature = "std")]
|
||||
mod unit_circle;
|
||||
#[cfg(feature = "std")]
|
||||
mod unit_sphere;
|
||||
#[cfg(feature = "std")]
|
||||
mod weibull;
|
||||
#[cfg(feature = "alloc")]
|
||||
mod weighted;
|
||||
|
||||
mod float;
|
||||
mod integer;
|
||||
mod other;
|
||||
mod utils;
|
||||
#[cfg(feature="std")] mod ziggurat_tables;
|
||||
#[cfg(feature = "std")]
|
||||
mod ziggurat_tables;
|
||||
|
||||
/// Types (distributions) that can be used to create a random instance of `T`.
|
||||
///
|
||||
@ -269,7 +296,9 @@ pub trait Distribution<T> {
|
||||
/// }
|
||||
/// ```
|
||||
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
|
||||
where Self: Sized, R: Rng
|
||||
where
|
||||
Self: Sized,
|
||||
R: Rng,
|
||||
{
|
||||
DistIter {
|
||||
distr: self,
|
||||
@ -285,7 +314,6 @@ impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
/// An iterator that generates random values of `T` with distribution `D`,
|
||||
/// using `R` as the source of randomness.
|
||||
///
|
||||
@ -302,7 +330,9 @@ pub struct DistIter<'a, D: 'a, R: 'a, T> {
|
||||
}
|
||||
|
||||
impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
|
||||
where D: Distribution<T>, R: Rng + 'a
|
||||
where
|
||||
D: Distribution<T>,
|
||||
R: Rng + 'a,
|
||||
{
|
||||
type Item = T;
|
||||
|
||||
@ -318,12 +348,19 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
|
||||
|
||||
#[cfg(rustc_1_26)]
|
||||
impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T>
|
||||
where D: Distribution<T>, R: Rng + 'a {}
|
||||
where
|
||||
D: Distribution<T>,
|
||||
R: Rng + 'a,
|
||||
{
|
||||
}
|
||||
|
||||
#[cfg(features = "nightly")]
|
||||
impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
|
||||
where D: Distribution<T>, R: Rng + 'a {}
|
||||
|
||||
where
|
||||
D: Distribution<T>,
|
||||
R: Rng + 'a,
|
||||
{
|
||||
}
|
||||
|
||||
/// A generic random value distribution, implemented for many primitive types.
|
||||
/// Usually generates values with a numerically uniform distribution, and with a
|
||||
@ -385,9 +422,8 @@ impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
|
||||
#[derive(Clone, Copy, Debug)]
|
||||
pub struct Standard;
|
||||
|
||||
|
||||
/// A value with a particular weight for use with `WeightedChoice`.
|
||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
||||
#[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
|
||||
#[allow(deprecated)]
|
||||
#[derive(Copy, Clone, Debug)]
|
||||
pub struct Weighted<T> {
|
||||
@ -402,15 +438,15 @@ pub struct Weighted<T> {
|
||||
/// Deprecated: use [`WeightedIndex`] instead.
|
||||
///
|
||||
/// [`WeightedIndex`]: struct.WeightedIndex.html
|
||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
||||
#[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
|
||||
#[allow(deprecated)]
|
||||
#[derive(Debug)]
|
||||
pub struct WeightedChoice<'a, T:'a> {
|
||||
pub struct WeightedChoice<'a, T: 'a> {
|
||||
items: &'a mut [Weighted<T>],
|
||||
weight_range: Uniform<u32>,
|
||||
}
|
||||
|
||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
||||
#[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
|
||||
#[allow(deprecated)]
|
||||
impl<'a, T: Clone> WeightedChoice<'a, T> {
|
||||
/// Create a new `WeightedChoice`.
|
||||
@ -422,7 +458,10 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
|
||||
/// - the total weight is larger than a `u32` can contain.
|
||||
pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
|
||||
// strictly speaking, this is subsumed by the total weight == 0 case
|
||||
assert!(!items.is_empty(), "WeightedChoice::new called with no items");
|
||||
assert!(
|
||||
!items.is_empty(),
|
||||
"WeightedChoice::new called with no items"
|
||||
);
|
||||
|
||||
let mut running_total: u32 = 0;
|
||||
|
||||
@ -432,24 +471,29 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
|
||||
for item in items.iter_mut() {
|
||||
running_total = match running_total.checked_add(item.weight) {
|
||||
Some(n) => n,
|
||||
None => panic!("WeightedChoice::new called with a total weight \
|
||||
larger than a u32 can contain")
|
||||
None => panic!(
|
||||
"WeightedChoice::new called with a total weight \
|
||||
larger than a u32 can contain"
|
||||
),
|
||||
};
|
||||
|
||||
item.weight = running_total;
|
||||
}
|
||||
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
|
||||
assert!(
|
||||
running_total != 0,
|
||||
"WeightedChoice::new called with a total weight of 0"
|
||||
);
|
||||
|
||||
WeightedChoice {
|
||||
items,
|
||||
// we're likely to be generating numbers in this range
|
||||
// relatively often, so might as well cache it
|
||||
weight_range: Uniform::new(0, running_total)
|
||||
weight_range: Uniform::new(0, running_total),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
||||
// #[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
||||
#[allow(deprecated)]
|
||||
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
|
||||
@ -496,9 +540,9 @@ impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use rngs::mock::StepRng;
|
||||
#[allow(deprecated)]
|
||||
use super::{WeightedChoice, Weighted, Distribution};
|
||||
use super::{Distribution, Weighted, WeightedChoice};
|
||||
use rngs::mock::StepRng;
|
||||
|
||||
#[test]
|
||||
#[allow(deprecated)]
|
||||
@ -511,7 +555,9 @@ mod tests {
|
||||
($items:expr, $expected:expr) => {{
|
||||
let mut items = $items;
|
||||
let mut total_weight = 0;
|
||||
for item in &items { total_weight += item.weight; }
|
||||
for item in &items {
|
||||
total_weight += item.weight;
|
||||
}
|
||||
|
||||
let wc = WeightedChoice::new(&mut items);
|
||||
let expected = $expected;
|
||||
@ -524,92 +570,176 @@ mod tests {
|
||||
for &val in expected.iter() {
|
||||
assert_eq!(wc.sample(&mut rng), val)
|
||||
}
|
||||
}}
|
||||
}};
|
||||
}
|
||||
|
||||
t!([Weighted { weight: 1, item: 10}], [10]);
|
||||
t!(
|
||||
[Weighted {
|
||||
weight: 1,
|
||||
item: 10
|
||||
}],
|
||||
[10]
|
||||
);
|
||||
|
||||
// skip some
|
||||
t!([Weighted { weight: 0, item: 20},
|
||||
Weighted { weight: 2, item: 21},
|
||||
Weighted { weight: 0, item: 22},
|
||||
Weighted { weight: 1, item: 23}],
|
||||
[21, 21, 23]);
|
||||
t!(
|
||||
[
|
||||
Weighted {
|
||||
weight: 0,
|
||||
item: 20
|
||||
},
|
||||
Weighted {
|
||||
weight: 2,
|
||||
item: 21
|
||||
},
|
||||
Weighted {
|
||||
weight: 0,
|
||||
item: 22
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 23
|
||||
}
|
||||
],
|
||||
[21, 21, 23]
|
||||
);
|
||||
|
||||
// different weights
|
||||
t!([Weighted { weight: 4, item: 30},
|
||||
Weighted { weight: 3, item: 31}],
|
||||
[30, 31, 30, 31, 30, 31, 30]);
|
||||
t!(
|
||||
[
|
||||
Weighted {
|
||||
weight: 4,
|
||||
item: 30
|
||||
},
|
||||
Weighted {
|
||||
weight: 3,
|
||||
item: 31
|
||||
}
|
||||
],
|
||||
[30, 31, 30, 31, 30, 31, 30]
|
||||
);
|
||||
|
||||
// check that we're binary searching
|
||||
// correctly with some vectors of odd
|
||||
// length.
|
||||
t!([Weighted { weight: 1, item: 40},
|
||||
Weighted { weight: 1, item: 41},
|
||||
Weighted { weight: 1, item: 42},
|
||||
Weighted { weight: 1, item: 43},
|
||||
Weighted { weight: 1, item: 44}],
|
||||
[40, 41, 42, 43, 44]);
|
||||
t!([Weighted { weight: 1, item: 50},
|
||||
Weighted { weight: 1, item: 51},
|
||||
Weighted { weight: 1, item: 52},
|
||||
Weighted { weight: 1, item: 53},
|
||||
Weighted { weight: 1, item: 54},
|
||||
Weighted { weight: 1, item: 55},
|
||||
Weighted { weight: 1, item: 56}],
|
||||
[50, 54, 51, 55, 52, 56, 53]);
|
||||
t!(
|
||||
[
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 40
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 41
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 42
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 43
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 44
|
||||
}
|
||||
],
|
||||
[40, 41, 42, 43, 44]
|
||||
);
|
||||
t!(
|
||||
[
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 50
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 51
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 52
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 53
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 54
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 55
|
||||
},
|
||||
Weighted {
|
||||
weight: 1,
|
||||
item: 56
|
||||
}
|
||||
],
|
||||
[50, 54, 51, 55, 52, 56, 53]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[allow(deprecated)]
|
||||
fn test_weighted_clone_initialization() {
|
||||
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
|
||||
let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
|
||||
let clone = initial.clone();
|
||||
assert_eq!(initial.weight, clone.weight);
|
||||
assert_eq!(initial.item, clone.item);
|
||||
}
|
||||
|
||||
#[test] #[should_panic]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
#[allow(deprecated)]
|
||||
fn test_weighted_clone_change_weight() {
|
||||
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
|
||||
let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
|
||||
let mut clone = initial.clone();
|
||||
clone.weight = 5;
|
||||
assert_eq!(initial.weight, clone.weight);
|
||||
}
|
||||
|
||||
#[test] #[should_panic]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
#[allow(deprecated)]
|
||||
fn test_weighted_clone_change_item() {
|
||||
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
|
||||
let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
|
||||
let mut clone = initial.clone();
|
||||
clone.item = 5;
|
||||
assert_eq!(initial.item, clone.item);
|
||||
|
||||
}
|
||||
|
||||
#[test] #[should_panic]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
#[allow(deprecated)]
|
||||
fn test_weighted_choice_no_items() {
|
||||
WeightedChoice::<isize>::new(&mut []);
|
||||
}
|
||||
#[test] #[should_panic]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
#[allow(deprecated)]
|
||||
fn test_weighted_choice_zero_weight() {
|
||||
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
|
||||
Weighted { weight: 0, item: 1}]);
|
||||
WeightedChoice::new(&mut [
|
||||
Weighted { weight: 0, item: 0 },
|
||||
Weighted { weight: 0, item: 1 },
|
||||
]);
|
||||
}
|
||||
#[test] #[should_panic]
|
||||
#[test]
|
||||
#[should_panic]
|
||||
#[allow(deprecated)]
|
||||
fn test_weighted_choice_weight_overflows() {
|
||||
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
|
||||
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
|
||||
Weighted { weight: 1, item: 1 },
|
||||
Weighted { weight: x, item: 2 },
|
||||
Weighted { weight: 1, item: 3 }]);
|
||||
WeightedChoice::new(&mut [
|
||||
Weighted { weight: x, item: 0 },
|
||||
Weighted { weight: 1, item: 1 },
|
||||
Weighted { weight: x, item: 2 },
|
||||
Weighted { weight: 1, item: 3 },
|
||||
]);
|
||||
}
|
||||
|
||||
#[cfg(feature="std")]
|
||||
#[cfg(feature = "std")]
|
||||
#[test]
|
||||
fn test_distributions_iter() {
|
||||
use distributions::Normal;
|
||||
|
@ -6,14 +6,15 @@
|
||||
// option. This file may not be copied, modified, or distributed
|
||||
// except according to those terms.
|
||||
|
||||
use Rng;
|
||||
use distributions::Distribution;
|
||||
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
|
||||
use ::core::cmp::PartialOrd;
|
||||
use core::fmt;
|
||||
use distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
|
||||
use distributions::Distribution;
|
||||
use Rng;
|
||||
|
||||
// Note that this whole module is only imported if feature="alloc" is enabled.
|
||||
#[cfg(not(feature="std"))] use alloc::vec::Vec;
|
||||
#[cfg(not(feature = "std"))]
|
||||
use alloc::vec::Vec;
|
||||
|
||||
/// A distribution using weighted sampling to pick a discretely selected
|
||||
/// item.
|
||||
@ -87,16 +88,13 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
|
||||
/// [`Distribution`]: trait.Distribution.html
|
||||
/// [`Uniform<X>`]: struct.Uniform.html
|
||||
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
|
||||
where I: IntoIterator,
|
||||
I::Item: SampleBorrow<X>,
|
||||
X: for<'a> ::core::ops::AddAssign<&'a X> +
|
||||
Clone +
|
||||
Default {
|
||||
where
|
||||
I: IntoIterator,
|
||||
I::Item: SampleBorrow<X>,
|
||||
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
|
||||
{
|
||||
let mut iter = weights.into_iter();
|
||||
let mut total_weight: X = iter.next()
|
||||
.ok_or(WeightedError::NoItem)?
|
||||
.borrow()
|
||||
.clone();
|
||||
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
|
||||
|
||||
let zero = <X as Default>::default();
|
||||
if total_weight < zero {
|
||||
@ -117,18 +115,30 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
|
||||
}
|
||||
let distr = X::Sampler::new(zero, total_weight);
|
||||
|
||||
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
|
||||
Ok(WeightedIndex {
|
||||
cumulative_weights: weights,
|
||||
weight_distribution: distr,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl<X> Distribution<usize> for WeightedIndex<X> where
|
||||
X: SampleUniform + PartialOrd {
|
||||
impl<X> Distribution<usize> for WeightedIndex<X>
|
||||
where
|
||||
X: SampleUniform + PartialOrd,
|
||||
{
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
|
||||
use ::core::cmp::Ordering;
|
||||
let chosen_weight = self.weight_distribution.sample(rng);
|
||||
// Find the first item which has a weight *higher* than the chosen weight.
|
||||
self.cumulative_weights.binary_search_by(
|
||||
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
|
||||
self.cumulative_weights
|
||||
.binary_search_by(|w| {
|
||||
if *w <= chosen_weight {
|
||||
Ordering::Less
|
||||
} else {
|
||||
Ordering::Greater
|
||||
}
|
||||
})
|
||||
.unwrap_err()
|
||||
}
|
||||
}
|
||||
|
||||
@ -181,14 +191,34 @@ mod test {
|
||||
for _ in 0..5 {
|
||||
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
|
||||
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
|
||||
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
|
||||
assert_eq!(
|
||||
WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
|
||||
.unwrap()
|
||||
.sample(&mut r),
|
||||
4
|
||||
);
|
||||
}
|
||||
|
||||
assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
|
||||
assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
|
||||
assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight);
|
||||
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
|
||||
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
|
||||
assert_eq!(
|
||||
WeightedIndex::new(&[10][0..0]).unwrap_err(),
|
||||
WeightedError::NoItem
|
||||
);
|
||||
assert_eq!(
|
||||
WeightedIndex::new(&[0]).unwrap_err(),
|
||||
WeightedError::AllWeightsZero
|
||||
);
|
||||
assert_eq!(
|
||||
WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
|
||||
WeightedError::NegativeWeight
|
||||
);
|
||||
assert_eq!(
|
||||
WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
|
||||
WeightedError::NegativeWeight
|
||||
);
|
||||
assert_eq!(
|
||||
WeightedIndex::new(&[-10]).unwrap_err(),
|
||||
WeightedError::NegativeWeight
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@ -215,7 +245,7 @@ impl WeightedError {
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature="std")]
|
||||
#[cfg(feature = "std")]
|
||||
impl ::std::error::Error for WeightedError {
|
||||
fn description(&self) -> &str {
|
||||
self.msg()
|
||||
|
Loading…
x
Reference in New Issue
Block a user