alnyan/yggdrasil: patch getrandom url
This commit is contained in:
@@ -26,3 +26,4 @@ serde1 = ["serde", "serde_derive"] # enables serde for BlockRng wrapper
|
|||||||
[dependencies]
|
[dependencies]
|
||||||
serde = { version = "1", optional = true }
|
serde = { version = "1", optional = true }
|
||||||
serde_derive = { version = "^1.0.38", optional = true }
|
serde_derive = { version = "^1.0.38", optional = true }
|
||||||
|
getrandom = { version = "0.2", git = "https://git.alnyan.me/yggdrasil/getrandom.git", branch = "alnyan/yggdrasil", optional = true }
|
||||||
|
|||||||
@@ -14,23 +14,23 @@ use rand_core::{Error, ErrorKind};
|
|||||||
use std::fs::File;
|
use std::fs::File;
|
||||||
use std::io;
|
use std::io;
|
||||||
use std::io::Read;
|
use std::io::Read;
|
||||||
use std::sync::{Once, Mutex, ONCE_INIT};
|
use std::sync::{Mutex, Once};
|
||||||
|
|
||||||
// TODO: remove outer Option when `Mutex::new(None)` is a constant expression
|
// TODO: remove outer Option when `Mutex::new(None)` is a constant expression
|
||||||
static mut READ_RNG_FILE: Option<Mutex<Option<File>>> = None;
|
static mut READ_RNG_FILE: Option<Mutex<Option<File>>> = None;
|
||||||
static READ_RNG_ONCE: Once = ONCE_INIT;
|
static READ_RNG_ONCE: Once = Once::new();
|
||||||
|
|
||||||
#[allow(unused)]
|
#[allow(unused)]
|
||||||
pub fn open<F>(path: &'static str, open_fn: F) -> Result<(), Error>
|
pub fn open<F>(path: &'static str, open_fn: F) -> Result<(), Error>
|
||||||
where F: Fn(&'static str) -> Result<File, io::Error>
|
where
|
||||||
|
F: Fn(&'static str) -> Result<File, io::Error>,
|
||||||
{
|
{
|
||||||
READ_RNG_ONCE.call_once(|| {
|
READ_RNG_ONCE.call_once(|| unsafe { READ_RNG_FILE = Some(Mutex::new(None)) });
|
||||||
unsafe { READ_RNG_FILE = Some(Mutex::new(None)) }
|
|
||||||
});
|
|
||||||
|
|
||||||
// We try opening the file outside the `call_once` fn because we cannot
|
// We try opening the file outside the `call_once` fn because we cannot
|
||||||
// clone the error, thus we must retry on failure.
|
// clone the error, thus we must retry on failure.
|
||||||
|
|
||||||
|
#[allow(static_mut_refs)]
|
||||||
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
|
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
|
||||||
let mut guard = mutex.lock().unwrap();
|
let mut guard = mutex.lock().unwrap();
|
||||||
if (*guard).is_none() {
|
if (*guard).is_none() {
|
||||||
@@ -45,26 +45,27 @@ pub fn read(dest: &mut [u8]) -> Result<(), Error> {
|
|||||||
// We expect this function only to be used after `random_device::open`
|
// We expect this function only to be used after `random_device::open`
|
||||||
// was succesful. Therefore we can assume that our memory was set with a
|
// was succesful. Therefore we can assume that our memory was set with a
|
||||||
// valid object.
|
// valid object.
|
||||||
|
#[allow(static_mut_refs)]
|
||||||
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
|
let mutex = unsafe { READ_RNG_FILE.as_ref().unwrap() };
|
||||||
let mut guard = mutex.lock().unwrap();
|
let mut guard = mutex.lock().unwrap();
|
||||||
let file = (*guard).as_mut().unwrap();
|
let file = (*guard).as_mut().unwrap();
|
||||||
|
|
||||||
// Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`.
|
// Use `std::io::read_exact`, which retries on `ErrorKind::Interrupted`.
|
||||||
file.read_exact(dest).map_err(|err| {
|
file.read_exact(dest).map_err(|err| {
|
||||||
Error::with_cause(ErrorKind::Unavailable,
|
Error::with_cause(ErrorKind::Unavailable, "error reading random device", err)
|
||||||
"error reading random device", err)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn map_err(err: io::Error) -> Error {
|
pub fn map_err(err: io::Error) -> Error {
|
||||||
match err.kind() {
|
match err.kind() {
|
||||||
io::ErrorKind::Interrupted =>
|
io::ErrorKind::Interrupted => Error::new(ErrorKind::Transient, "interrupted"),
|
||||||
Error::new(ErrorKind::Transient, "interrupted"),
|
io::ErrorKind::WouldBlock => {
|
||||||
io::ErrorKind::WouldBlock =>
|
Error::with_cause(ErrorKind::NotReady, "OS RNG not yet seeded", err)
|
||||||
Error::with_cause(ErrorKind::NotReady,
|
}
|
||||||
"OS RNG not yet seeded", err),
|
_ => Error::with_cause(
|
||||||
_ => Error::with_cause(ErrorKind::Unavailable,
|
ErrorKind::Unavailable,
|
||||||
"error while opening random device", err)
|
"error while opening random device",
|
||||||
|
err,
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
+223
-93
@@ -182,50 +182,77 @@
|
|||||||
//! [`Weibull`]: struct.Weibull.html
|
//! [`Weibull`]: struct.Weibull.html
|
||||||
//! [`WeightedIndex`]: struct.WeightedIndex.html
|
//! [`WeightedIndex`]: struct.WeightedIndex.html
|
||||||
|
|
||||||
#[cfg(any(rustc_1_26, features="nightly"))]
|
#[cfg(any(rustc_1_26, features = "nightly"))]
|
||||||
use core::iter;
|
use core::iter;
|
||||||
use Rng;
|
use Rng;
|
||||||
|
|
||||||
pub use self::other::Alphanumeric;
|
|
||||||
#[doc(inline)] pub use self::uniform::Uniform;
|
|
||||||
pub use self::float::{OpenClosed01, Open01};
|
|
||||||
pub use self::bernoulli::Bernoulli;
|
pub use self::bernoulli::Bernoulli;
|
||||||
#[cfg(feature="alloc")] pub use self::weighted::{WeightedIndex, WeightedError};
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] pub use self::unit_sphere::UnitSphereSurface;
|
pub use self::binomial::Binomial;
|
||||||
#[cfg(feature="std")] pub use self::unit_circle::UnitCircle;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] pub use self::gamma::{Gamma, ChiSquared, FisherF,
|
pub use self::cauchy::Cauchy;
|
||||||
StudentT, Beta};
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] pub use self::normal::{Normal, LogNormal, StandardNormal};
|
pub use self::dirichlet::Dirichlet;
|
||||||
#[cfg(feature="std")] pub use self::exponential::{Exp, Exp1};
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] pub use self::pareto::Pareto;
|
pub use self::exponential::{Exp, Exp1};
|
||||||
#[cfg(feature="std")] pub use self::poisson::Poisson;
|
pub use self::float::{Open01, OpenClosed01};
|
||||||
#[cfg(feature="std")] pub use self::binomial::Binomial;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] pub use self::cauchy::Cauchy;
|
pub use self::gamma::{Beta, ChiSquared, FisherF, Gamma, StudentT};
|
||||||
#[cfg(feature="std")] pub use self::dirichlet::Dirichlet;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] pub use self::triangular::Triangular;
|
pub use self::normal::{LogNormal, Normal, StandardNormal};
|
||||||
#[cfg(feature="std")] pub use self::weibull::Weibull;
|
pub use self::other::Alphanumeric;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use self::pareto::Pareto;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use self::poisson::Poisson;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use self::triangular::Triangular;
|
||||||
|
#[doc(inline)]
|
||||||
|
pub use self::uniform::Uniform;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use self::unit_circle::UnitCircle;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use self::unit_sphere::UnitSphereSurface;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
pub use self::weibull::Weibull;
|
||||||
|
#[cfg(feature = "alloc")]
|
||||||
|
pub use self::weighted::{WeightedError, WeightedIndex};
|
||||||
|
|
||||||
pub mod uniform;
|
|
||||||
mod bernoulli;
|
mod bernoulli;
|
||||||
#[cfg(feature="alloc")] mod weighted;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] mod unit_sphere;
|
mod binomial;
|
||||||
#[cfg(feature="std")] mod unit_circle;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] mod gamma;
|
mod cauchy;
|
||||||
#[cfg(feature="std")] mod normal;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] mod exponential;
|
mod dirichlet;
|
||||||
#[cfg(feature="std")] mod pareto;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] mod poisson;
|
mod exponential;
|
||||||
#[cfg(feature="std")] mod binomial;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] mod cauchy;
|
mod gamma;
|
||||||
#[cfg(feature="std")] mod dirichlet;
|
#[cfg(feature = "std")]
|
||||||
#[cfg(feature="std")] mod triangular;
|
mod normal;
|
||||||
#[cfg(feature="std")] mod weibull;
|
#[cfg(feature = "std")]
|
||||||
|
mod pareto;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod poisson;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod triangular;
|
||||||
|
pub mod uniform;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod unit_circle;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod unit_sphere;
|
||||||
|
#[cfg(feature = "std")]
|
||||||
|
mod weibull;
|
||||||
|
#[cfg(feature = "alloc")]
|
||||||
|
mod weighted;
|
||||||
|
|
||||||
mod float;
|
mod float;
|
||||||
mod integer;
|
mod integer;
|
||||||
mod other;
|
mod other;
|
||||||
mod utils;
|
mod utils;
|
||||||
#[cfg(feature="std")] mod ziggurat_tables;
|
#[cfg(feature = "std")]
|
||||||
|
mod ziggurat_tables;
|
||||||
|
|
||||||
/// Types (distributions) that can be used to create a random instance of `T`.
|
/// Types (distributions) that can be used to create a random instance of `T`.
|
||||||
///
|
///
|
||||||
@@ -269,7 +296,9 @@ pub trait Distribution<T> {
|
|||||||
/// }
|
/// }
|
||||||
/// ```
|
/// ```
|
||||||
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
|
fn sample_iter<'a, R>(&'a self, rng: &'a mut R) -> DistIter<'a, Self, R, T>
|
||||||
where Self: Sized, R: Rng
|
where
|
||||||
|
Self: Sized,
|
||||||
|
R: Rng,
|
||||||
{
|
{
|
||||||
DistIter {
|
DistIter {
|
||||||
distr: self,
|
distr: self,
|
||||||
@@ -285,7 +314,6 @@ impl<'a, T, D: Distribution<T>> Distribution<T> for &'a D {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
/// An iterator that generates random values of `T` with distribution `D`,
|
/// An iterator that generates random values of `T` with distribution `D`,
|
||||||
/// using `R` as the source of randomness.
|
/// using `R` as the source of randomness.
|
||||||
///
|
///
|
||||||
@@ -302,7 +330,9 @@ pub struct DistIter<'a, D: 'a, R: 'a, T> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
|
impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
|
||||||
where D: Distribution<T>, R: Rng + 'a
|
where
|
||||||
|
D: Distribution<T>,
|
||||||
|
R: Rng + 'a,
|
||||||
{
|
{
|
||||||
type Item = T;
|
type Item = T;
|
||||||
|
|
||||||
@@ -318,12 +348,19 @@ impl<'a, D, R, T> Iterator for DistIter<'a, D, R, T>
|
|||||||
|
|
||||||
#[cfg(rustc_1_26)]
|
#[cfg(rustc_1_26)]
|
||||||
impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T>
|
impl<'a, D, R, T> iter::FusedIterator for DistIter<'a, D, R, T>
|
||||||
where D: Distribution<T>, R: Rng + 'a {}
|
where
|
||||||
|
D: Distribution<T>,
|
||||||
|
R: Rng + 'a,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(features = "nightly")]
|
#[cfg(features = "nightly")]
|
||||||
impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
|
impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
|
||||||
where D: Distribution<T>, R: Rng + 'a {}
|
where
|
||||||
|
D: Distribution<T>,
|
||||||
|
R: Rng + 'a,
|
||||||
|
{
|
||||||
|
}
|
||||||
|
|
||||||
/// A generic random value distribution, implemented for many primitive types.
|
/// A generic random value distribution, implemented for many primitive types.
|
||||||
/// Usually generates values with a numerically uniform distribution, and with a
|
/// Usually generates values with a numerically uniform distribution, and with a
|
||||||
@@ -385,9 +422,8 @@ impl<'a, D, R, T> iter::TrustedLen for DistIter<'a, D, R, T>
|
|||||||
#[derive(Clone, Copy, Debug)]
|
#[derive(Clone, Copy, Debug)]
|
||||||
pub struct Standard;
|
pub struct Standard;
|
||||||
|
|
||||||
|
|
||||||
/// A value with a particular weight for use with `WeightedChoice`.
|
/// A value with a particular weight for use with `WeightedChoice`.
|
||||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
#[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
#[derive(Copy, Clone, Debug)]
|
#[derive(Copy, Clone, Debug)]
|
||||||
pub struct Weighted<T> {
|
pub struct Weighted<T> {
|
||||||
@@ -402,15 +438,15 @@ pub struct Weighted<T> {
|
|||||||
/// Deprecated: use [`WeightedIndex`] instead.
|
/// Deprecated: use [`WeightedIndex`] instead.
|
||||||
///
|
///
|
||||||
/// [`WeightedIndex`]: struct.WeightedIndex.html
|
/// [`WeightedIndex`]: struct.WeightedIndex.html
|
||||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
#[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct WeightedChoice<'a, T:'a> {
|
pub struct WeightedChoice<'a, T: 'a> {
|
||||||
items: &'a mut [Weighted<T>],
|
items: &'a mut [Weighted<T>],
|
||||||
weight_range: Uniform<u32>,
|
weight_range: Uniform<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
#[deprecated(since = "0.6.0", note = "use WeightedIndex instead")]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
impl<'a, T: Clone> WeightedChoice<'a, T> {
|
impl<'a, T: Clone> WeightedChoice<'a, T> {
|
||||||
/// Create a new `WeightedChoice`.
|
/// Create a new `WeightedChoice`.
|
||||||
@@ -422,7 +458,10 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
|
|||||||
/// - the total weight is larger than a `u32` can contain.
|
/// - the total weight is larger than a `u32` can contain.
|
||||||
pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
|
pub fn new(items: &'a mut [Weighted<T>]) -> WeightedChoice<'a, T> {
|
||||||
// strictly speaking, this is subsumed by the total weight == 0 case
|
// strictly speaking, this is subsumed by the total weight == 0 case
|
||||||
assert!(!items.is_empty(), "WeightedChoice::new called with no items");
|
assert!(
|
||||||
|
!items.is_empty(),
|
||||||
|
"WeightedChoice::new called with no items"
|
||||||
|
);
|
||||||
|
|
||||||
let mut running_total: u32 = 0;
|
let mut running_total: u32 = 0;
|
||||||
|
|
||||||
@@ -432,24 +471,29 @@ impl<'a, T: Clone> WeightedChoice<'a, T> {
|
|||||||
for item in items.iter_mut() {
|
for item in items.iter_mut() {
|
||||||
running_total = match running_total.checked_add(item.weight) {
|
running_total = match running_total.checked_add(item.weight) {
|
||||||
Some(n) => n,
|
Some(n) => n,
|
||||||
None => panic!("WeightedChoice::new called with a total weight \
|
None => panic!(
|
||||||
larger than a u32 can contain")
|
"WeightedChoice::new called with a total weight \
|
||||||
|
larger than a u32 can contain"
|
||||||
|
),
|
||||||
};
|
};
|
||||||
|
|
||||||
item.weight = running_total;
|
item.weight = running_total;
|
||||||
}
|
}
|
||||||
assert!(running_total != 0, "WeightedChoice::new called with a total weight of 0");
|
assert!(
|
||||||
|
running_total != 0,
|
||||||
|
"WeightedChoice::new called with a total weight of 0"
|
||||||
|
);
|
||||||
|
|
||||||
WeightedChoice {
|
WeightedChoice {
|
||||||
items,
|
items,
|
||||||
// we're likely to be generating numbers in this range
|
// we're likely to be generating numbers in this range
|
||||||
// relatively often, so might as well cache it
|
// relatively often, so might as well cache it
|
||||||
weight_range: Uniform::new(0, running_total)
|
weight_range: Uniform::new(0, running_total),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
// #[deprecated(since="0.6.0", note="use WeightedIndex instead")]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
|
impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> T {
|
||||||
@@ -496,9 +540,9 @@ impl<'a, T: Clone> Distribution<T> for WeightedChoice<'a, T> {
|
|||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use rngs::mock::StepRng;
|
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
use super::{WeightedChoice, Weighted, Distribution};
|
use super::{Distribution, Weighted, WeightedChoice};
|
||||||
|
use rngs::mock::StepRng;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
@@ -511,7 +555,9 @@ mod tests {
|
|||||||
($items:expr, $expected:expr) => {{
|
($items:expr, $expected:expr) => {{
|
||||||
let mut items = $items;
|
let mut items = $items;
|
||||||
let mut total_weight = 0;
|
let mut total_weight = 0;
|
||||||
for item in &items { total_weight += item.weight; }
|
for item in &items {
|
||||||
|
total_weight += item.weight;
|
||||||
|
}
|
||||||
|
|
||||||
let wc = WeightedChoice::new(&mut items);
|
let wc = WeightedChoice::new(&mut items);
|
||||||
let expected = $expected;
|
let expected = $expected;
|
||||||
@@ -524,92 +570,176 @@ mod tests {
|
|||||||
for &val in expected.iter() {
|
for &val in expected.iter() {
|
||||||
assert_eq!(wc.sample(&mut rng), val)
|
assert_eq!(wc.sample(&mut rng), val)
|
||||||
}
|
}
|
||||||
}}
|
}};
|
||||||
}
|
}
|
||||||
|
|
||||||
t!([Weighted { weight: 1, item: 10}], [10]);
|
t!(
|
||||||
|
[Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 10
|
||||||
|
}],
|
||||||
|
[10]
|
||||||
|
);
|
||||||
|
|
||||||
// skip some
|
// skip some
|
||||||
t!([Weighted { weight: 0, item: 20},
|
t!(
|
||||||
Weighted { weight: 2, item: 21},
|
[
|
||||||
Weighted { weight: 0, item: 22},
|
Weighted {
|
||||||
Weighted { weight: 1, item: 23}],
|
weight: 0,
|
||||||
[21, 21, 23]);
|
item: 20
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 2,
|
||||||
|
item: 21
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 0,
|
||||||
|
item: 22
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 23
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[21, 21, 23]
|
||||||
|
);
|
||||||
|
|
||||||
// different weights
|
// different weights
|
||||||
t!([Weighted { weight: 4, item: 30},
|
t!(
|
||||||
Weighted { weight: 3, item: 31}],
|
[
|
||||||
[30, 31, 30, 31, 30, 31, 30]);
|
Weighted {
|
||||||
|
weight: 4,
|
||||||
|
item: 30
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 3,
|
||||||
|
item: 31
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[30, 31, 30, 31, 30, 31, 30]
|
||||||
|
);
|
||||||
|
|
||||||
// check that we're binary searching
|
// check that we're binary searching
|
||||||
// correctly with some vectors of odd
|
// correctly with some vectors of odd
|
||||||
// length.
|
// length.
|
||||||
t!([Weighted { weight: 1, item: 40},
|
t!(
|
||||||
Weighted { weight: 1, item: 41},
|
[
|
||||||
Weighted { weight: 1, item: 42},
|
Weighted {
|
||||||
Weighted { weight: 1, item: 43},
|
weight: 1,
|
||||||
Weighted { weight: 1, item: 44}],
|
item: 40
|
||||||
[40, 41, 42, 43, 44]);
|
},
|
||||||
t!([Weighted { weight: 1, item: 50},
|
Weighted {
|
||||||
Weighted { weight: 1, item: 51},
|
weight: 1,
|
||||||
Weighted { weight: 1, item: 52},
|
item: 41
|
||||||
Weighted { weight: 1, item: 53},
|
},
|
||||||
Weighted { weight: 1, item: 54},
|
Weighted {
|
||||||
Weighted { weight: 1, item: 55},
|
weight: 1,
|
||||||
Weighted { weight: 1, item: 56}],
|
item: 42
|
||||||
[50, 54, 51, 55, 52, 56, 53]);
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 43
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 44
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[40, 41, 42, 43, 44]
|
||||||
|
);
|
||||||
|
t!(
|
||||||
|
[
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 50
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 51
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 52
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 53
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 54
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 55
|
||||||
|
},
|
||||||
|
Weighted {
|
||||||
|
weight: 1,
|
||||||
|
item: 56
|
||||||
|
}
|
||||||
|
],
|
||||||
|
[50, 54, 51, 55, 52, 56, 53]
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
fn test_weighted_clone_initialization() {
|
fn test_weighted_clone_initialization() {
|
||||||
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
|
let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
|
||||||
let clone = initial.clone();
|
let clone = initial.clone();
|
||||||
assert_eq!(initial.weight, clone.weight);
|
assert_eq!(initial.weight, clone.weight);
|
||||||
assert_eq!(initial.item, clone.item);
|
assert_eq!(initial.item, clone.item);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test] #[should_panic]
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
fn test_weighted_clone_change_weight() {
|
fn test_weighted_clone_change_weight() {
|
||||||
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
|
let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
|
||||||
let mut clone = initial.clone();
|
let mut clone = initial.clone();
|
||||||
clone.weight = 5;
|
clone.weight = 5;
|
||||||
assert_eq!(initial.weight, clone.weight);
|
assert_eq!(initial.weight, clone.weight);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test] #[should_panic]
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
fn test_weighted_clone_change_item() {
|
fn test_weighted_clone_change_item() {
|
||||||
let initial : Weighted<u32> = Weighted {weight: 1, item: 1};
|
let initial: Weighted<u32> = Weighted { weight: 1, item: 1 };
|
||||||
let mut clone = initial.clone();
|
let mut clone = initial.clone();
|
||||||
clone.item = 5;
|
clone.item = 5;
|
||||||
assert_eq!(initial.item, clone.item);
|
assert_eq!(initial.item, clone.item);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test] #[should_panic]
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
fn test_weighted_choice_no_items() {
|
fn test_weighted_choice_no_items() {
|
||||||
WeightedChoice::<isize>::new(&mut []);
|
WeightedChoice::<isize>::new(&mut []);
|
||||||
}
|
}
|
||||||
#[test] #[should_panic]
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
fn test_weighted_choice_zero_weight() {
|
fn test_weighted_choice_zero_weight() {
|
||||||
WeightedChoice::new(&mut [Weighted { weight: 0, item: 0},
|
WeightedChoice::new(&mut [
|
||||||
Weighted { weight: 0, item: 1}]);
|
Weighted { weight: 0, item: 0 },
|
||||||
|
Weighted { weight: 0, item: 1 },
|
||||||
|
]);
|
||||||
}
|
}
|
||||||
#[test] #[should_panic]
|
#[test]
|
||||||
|
#[should_panic]
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
fn test_weighted_choice_weight_overflows() {
|
fn test_weighted_choice_weight_overflows() {
|
||||||
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
|
let x = ::core::u32::MAX / 2; // x + x + 2 is the overflow
|
||||||
WeightedChoice::new(&mut [Weighted { weight: x, item: 0 },
|
WeightedChoice::new(&mut [
|
||||||
Weighted { weight: 1, item: 1 },
|
Weighted { weight: x, item: 0 },
|
||||||
Weighted { weight: x, item: 2 },
|
Weighted { weight: 1, item: 1 },
|
||||||
Weighted { weight: 1, item: 3 }]);
|
Weighted { weight: x, item: 2 },
|
||||||
|
Weighted { weight: 1, item: 3 },
|
||||||
|
]);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature="std")]
|
#[cfg(feature = "std")]
|
||||||
#[test]
|
#[test]
|
||||||
fn test_distributions_iter() {
|
fn test_distributions_iter() {
|
||||||
use distributions::Normal;
|
use distributions::Normal;
|
||||||
|
|||||||
@@ -6,14 +6,15 @@
|
|||||||
// option. This file may not be copied, modified, or distributed
|
// option. This file may not be copied, modified, or distributed
|
||||||
// except according to those terms.
|
// except according to those terms.
|
||||||
|
|
||||||
use Rng;
|
|
||||||
use distributions::Distribution;
|
|
||||||
use distributions::uniform::{UniformSampler, SampleUniform, SampleBorrow};
|
|
||||||
use ::core::cmp::PartialOrd;
|
use ::core::cmp::PartialOrd;
|
||||||
use core::fmt;
|
use core::fmt;
|
||||||
|
use distributions::uniform::{SampleBorrow, SampleUniform, UniformSampler};
|
||||||
|
use distributions::Distribution;
|
||||||
|
use Rng;
|
||||||
|
|
||||||
// Note that this whole module is only imported if feature="alloc" is enabled.
|
// Note that this whole module is only imported if feature="alloc" is enabled.
|
||||||
#[cfg(not(feature="std"))] use alloc::vec::Vec;
|
#[cfg(not(feature = "std"))]
|
||||||
|
use alloc::vec::Vec;
|
||||||
|
|
||||||
/// A distribution using weighted sampling to pick a discretely selected
|
/// A distribution using weighted sampling to pick a discretely selected
|
||||||
/// item.
|
/// item.
|
||||||
@@ -87,16 +88,13 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
|
|||||||
/// [`Distribution`]: trait.Distribution.html
|
/// [`Distribution`]: trait.Distribution.html
|
||||||
/// [`Uniform<X>`]: struct.Uniform.html
|
/// [`Uniform<X>`]: struct.Uniform.html
|
||||||
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
|
pub fn new<I>(weights: I) -> Result<WeightedIndex<X>, WeightedError>
|
||||||
where I: IntoIterator,
|
where
|
||||||
I::Item: SampleBorrow<X>,
|
I: IntoIterator,
|
||||||
X: for<'a> ::core::ops::AddAssign<&'a X> +
|
I::Item: SampleBorrow<X>,
|
||||||
Clone +
|
X: for<'a> ::core::ops::AddAssign<&'a X> + Clone + Default,
|
||||||
Default {
|
{
|
||||||
let mut iter = weights.into_iter();
|
let mut iter = weights.into_iter();
|
||||||
let mut total_weight: X = iter.next()
|
let mut total_weight: X = iter.next().ok_or(WeightedError::NoItem)?.borrow().clone();
|
||||||
.ok_or(WeightedError::NoItem)?
|
|
||||||
.borrow()
|
|
||||||
.clone();
|
|
||||||
|
|
||||||
let zero = <X as Default>::default();
|
let zero = <X as Default>::default();
|
||||||
if total_weight < zero {
|
if total_weight < zero {
|
||||||
@@ -117,18 +115,30 @@ impl<X: SampleUniform + PartialOrd> WeightedIndex<X> {
|
|||||||
}
|
}
|
||||||
let distr = X::Sampler::new(zero, total_weight);
|
let distr = X::Sampler::new(zero, total_weight);
|
||||||
|
|
||||||
Ok(WeightedIndex { cumulative_weights: weights, weight_distribution: distr })
|
Ok(WeightedIndex {
|
||||||
|
cumulative_weights: weights,
|
||||||
|
weight_distribution: distr,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<X> Distribution<usize> for WeightedIndex<X> where
|
impl<X> Distribution<usize> for WeightedIndex<X>
|
||||||
X: SampleUniform + PartialOrd {
|
where
|
||||||
|
X: SampleUniform + PartialOrd,
|
||||||
|
{
|
||||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
|
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> usize {
|
||||||
use ::core::cmp::Ordering;
|
use ::core::cmp::Ordering;
|
||||||
let chosen_weight = self.weight_distribution.sample(rng);
|
let chosen_weight = self.weight_distribution.sample(rng);
|
||||||
// Find the first item which has a weight *higher* than the chosen weight.
|
// Find the first item which has a weight *higher* than the chosen weight.
|
||||||
self.cumulative_weights.binary_search_by(
|
self.cumulative_weights
|
||||||
|w| if *w <= chosen_weight { Ordering::Less } else { Ordering::Greater }).unwrap_err()
|
.binary_search_by(|w| {
|
||||||
|
if *w <= chosen_weight {
|
||||||
|
Ordering::Less
|
||||||
|
} else {
|
||||||
|
Ordering::Greater
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.unwrap_err()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,14 +191,34 @@ mod test {
|
|||||||
for _ in 0..5 {
|
for _ in 0..5 {
|
||||||
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
|
assert_eq!(WeightedIndex::new(&[0, 1]).unwrap().sample(&mut r), 1);
|
||||||
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
|
assert_eq!(WeightedIndex::new(&[1, 0]).unwrap().sample(&mut r), 0);
|
||||||
assert_eq!(WeightedIndex::new(&[0, 0, 0, 0, 10, 0]).unwrap().sample(&mut r), 4);
|
assert_eq!(
|
||||||
|
WeightedIndex::new(&[0, 0, 0, 0, 10, 0])
|
||||||
|
.unwrap()
|
||||||
|
.sample(&mut r),
|
||||||
|
4
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
assert_eq!(WeightedIndex::new(&[10][0..0]).unwrap_err(), WeightedError::NoItem);
|
assert_eq!(
|
||||||
assert_eq!(WeightedIndex::new(&[0]).unwrap_err(), WeightedError::AllWeightsZero);
|
WeightedIndex::new(&[10][0..0]).unwrap_err(),
|
||||||
assert_eq!(WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(), WeightedError::NegativeWeight);
|
WeightedError::NoItem
|
||||||
assert_eq!(WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(), WeightedError::NegativeWeight);
|
);
|
||||||
assert_eq!(WeightedIndex::new(&[-10]).unwrap_err(), WeightedError::NegativeWeight);
|
assert_eq!(
|
||||||
|
WeightedIndex::new(&[0]).unwrap_err(),
|
||||||
|
WeightedError::AllWeightsZero
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
WeightedIndex::new(&[10, 20, -1, 30]).unwrap_err(),
|
||||||
|
WeightedError::NegativeWeight
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
WeightedIndex::new(&[-10, 20, 1, 30]).unwrap_err(),
|
||||||
|
WeightedError::NegativeWeight
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
WeightedIndex::new(&[-10]).unwrap_err(),
|
||||||
|
WeightedError::NegativeWeight
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -215,7 +245,7 @@ impl WeightedError {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(feature="std")]
|
#[cfg(feature = "std")]
|
||||||
impl ::std::error::Error for WeightedError {
|
impl ::std::error::Error for WeightedError {
|
||||||
fn description(&self) -> &str {
|
fn description(&self) -> &str {
|
||||||
self.msg()
|
self.msg()
|
||||||
|
|||||||
Reference in New Issue
Block a user