UniformFloat: allow inclusion of high in all cases (#1462)
Fix #1299 by removing logic specific to ensuring that we emulate a closed range by excluding `high` from the result.
This commit is contained in:
parent
2584f48ace
commit
1e381d13ee
@ -16,6 +16,7 @@ You may also find the [Upgrade Guide](https://rust-random.github.io/book/update.
|
||||
- Move all benchmarks to new `benches` crate (#1439)
|
||||
- Annotate panicking methods with `#[track_caller]` (#1442, #1447)
|
||||
- Enable feature `small_rng` by default (#1455)
|
||||
- Allow `UniformFloat::new` samples and `UniformFloat::sample_single` to yield `high` (#1462)
|
||||
|
||||
## [0.9.0-alpha.1] - 2024-03-18
|
||||
- Add the `Slice::num_choices` method to the Slice distribution (#1402)
|
||||
|
@ -51,7 +51,8 @@
|
||||
//! Those methods should include an assertion to check the range is valid (i.e.
|
||||
//! `low < high`). The example below merely wraps another back-end.
|
||||
//!
|
||||
//! The `new`, `new_inclusive` and `sample_single` functions use arguments of
|
||||
//! The `new`, `new_inclusive`, `sample_single` and `sample_single_inclusive`
|
||||
//! functions use arguments of
|
||||
//! type `SampleBorrow<X>` to support passing in values by reference or
|
||||
//! by value. In the implementation of these functions, you can choose to
|
||||
//! simply use the reference returned by [`SampleBorrow::borrow`], or you can choose
|
||||
@ -207,6 +208,11 @@ impl<X: SampleUniform> Uniform<X> {
|
||||
/// Create a new `Uniform` instance, which samples uniformly from the half
|
||||
/// open range `[low, high)` (excluding `high`).
|
||||
///
|
||||
/// For discrete types (e.g. integers), samples will always be strictly less
|
||||
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
|
||||
/// samples may equal `high` due to loss of precision but may not be
|
||||
/// greater than `high`.
|
||||
///
|
||||
/// Fails if `low >= high`, or if `low`, `high` or the range `high - low` is
|
||||
/// non-finite. In release mode, only the range is checked.
|
||||
pub fn new<B1, B2>(low: B1, high: B2) -> Result<Uniform<X>, Error>
|
||||
@ -265,6 +271,11 @@ pub trait UniformSampler: Sized {
|
||||
|
||||
/// Construct self, with inclusive lower bound and exclusive upper bound `[low, high)`.
|
||||
///
|
||||
/// For discrete types (e.g. integers), samples will always be strictly less
|
||||
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
|
||||
/// samples may equal `high` due to loss of precision but may not be
|
||||
/// greater than `high`.
|
||||
///
|
||||
/// Usually users should not call this directly but prefer to use
|
||||
/// [`Uniform::new`].
|
||||
fn new<B1, B2>(low: B1, high: B2) -> Result<Self, Error>
|
||||
@ -287,6 +298,11 @@ pub trait UniformSampler: Sized {
|
||||
/// Sample a single value uniformly from a range with inclusive lower bound
|
||||
/// and exclusive upper bound `[low, high)`.
|
||||
///
|
||||
/// For discrete types (e.g. integers), samples will always be strictly less
|
||||
/// than `high`. For (approximations of) continuous types (e.g. `f32`, `f64`),
|
||||
/// samples may equal `high` due to loss of precision but may not be
|
||||
/// greater than `high`.
|
||||
///
|
||||
/// By default this is implemented using
|
||||
/// `UniformSampler::new(low, high).sample(rng)`. However, for some types
|
||||
/// more optimal implementations for single usage may be provided via this
|
||||
@ -908,6 +924,33 @@ pub struct UniformFloat<X> {
|
||||
|
||||
macro_rules! uniform_float_impl {
|
||||
($($meta:meta)?, $ty:ty, $uty:ident, $f_scalar:ident, $u_scalar:ident, $bits_to_discard:expr) => {
|
||||
$(#[cfg($meta)])?
|
||||
impl UniformFloat<$ty> {
|
||||
/// Construct, reducing `scale` as required to ensure that rounding
|
||||
/// can never yield values greater than `high`.
|
||||
///
|
||||
/// Note: though it may be tempting to use a variant of this method
|
||||
/// to ensure that samples from `[low, high)` are always strictly
|
||||
/// less than `high`, this approach may be very slow where
|
||||
/// `scale.abs()` is much smaller than `high.abs()`
|
||||
/// (example: `low=0.99999999997819644, high=1.`).
|
||||
fn new_bounded(low: $ty, high: $ty, mut scale: $ty) -> Self {
|
||||
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
|
||||
|
||||
loop {
|
||||
let mask = (scale * max_rand + low).gt_mask(high);
|
||||
if !mask.any() {
|
||||
break;
|
||||
}
|
||||
scale = scale.decrease_masked(mask);
|
||||
}
|
||||
|
||||
debug_assert!(<$ty>::splat(0.0).all_le(scale));
|
||||
|
||||
UniformFloat { low, scale }
|
||||
}
|
||||
}
|
||||
|
||||
$(#[cfg($meta)])?
|
||||
impl SampleUniform for $ty {
|
||||
type Sampler = UniformFloat<$ty>;
|
||||
@ -931,26 +974,13 @@ macro_rules! uniform_float_impl {
|
||||
if !(low.all_lt(high)) {
|
||||
return Err(Error::EmptyRange);
|
||||
}
|
||||
let max_rand = <$ty>::splat(
|
||||
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
|
||||
);
|
||||
|
||||
let mut scale = high - low;
|
||||
let scale = high - low;
|
||||
if !(scale.all_finite()) {
|
||||
return Err(Error::NonFinite);
|
||||
}
|
||||
|
||||
loop {
|
||||
let mask = (scale * max_rand + low).ge_mask(high);
|
||||
if !mask.any() {
|
||||
break;
|
||||
}
|
||||
scale = scale.decrease_masked(mask);
|
||||
}
|
||||
|
||||
debug_assert!(<$ty>::splat(0.0).all_le(scale));
|
||||
|
||||
Ok(UniformFloat { low, scale })
|
||||
Ok(Self::new_bounded(low, high, scale))
|
||||
}
|
||||
|
||||
fn new_inclusive<B1, B2>(low_b: B1, high_b: B2) -> Result<Self, Error>
|
||||
@ -967,26 +997,14 @@ macro_rules! uniform_float_impl {
|
||||
if !low.all_le(high) {
|
||||
return Err(Error::EmptyRange);
|
||||
}
|
||||
let max_rand = <$ty>::splat(
|
||||
($u_scalar::MAX >> $bits_to_discard).into_float_with_exponent(0) - 1.0,
|
||||
);
|
||||
|
||||
let mut scale = (high - low) / max_rand;
|
||||
let max_rand = <$ty>::splat(1.0 as $f_scalar - $f_scalar::EPSILON);
|
||||
let scale = (high - low) / max_rand;
|
||||
if !scale.all_finite() {
|
||||
return Err(Error::NonFinite);
|
||||
}
|
||||
|
||||
loop {
|
||||
let mask = (scale * max_rand + low).gt_mask(high);
|
||||
if !mask.any() {
|
||||
break;
|
||||
}
|
||||
scale = scale.decrease_masked(mask);
|
||||
}
|
||||
|
||||
debug_assert!(<$ty>::splat(0.0).all_le(scale));
|
||||
|
||||
Ok(UniformFloat { low, scale })
|
||||
Ok(Self::new_bounded(low, high, scale))
|
||||
}
|
||||
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Self::X {
|
||||
@ -1010,72 +1028,7 @@ macro_rules! uniform_float_impl {
|
||||
B1: SampleBorrow<Self::X> + Sized,
|
||||
B2: SampleBorrow<Self::X> + Sized,
|
||||
{
|
||||
let low = *low_b.borrow();
|
||||
let high = *high_b.borrow();
|
||||
#[cfg(debug_assertions)]
|
||||
if !low.all_finite() || !high.all_finite() {
|
||||
return Err(Error::NonFinite);
|
||||
}
|
||||
if !low.all_lt(high) {
|
||||
return Err(Error::EmptyRange);
|
||||
}
|
||||
let mut scale = high - low;
|
||||
if !scale.all_finite() {
|
||||
return Err(Error::NonFinite);
|
||||
}
|
||||
|
||||
loop {
|
||||
// Generate a value in the range [1, 2)
|
||||
let value1_2 =
|
||||
(rng.random::<$uty>() >> $uty::splat($bits_to_discard)).into_float_with_exponent(0);
|
||||
|
||||
// Get a value in the range [0, 1) to avoid overflow when multiplying by scale
|
||||
let value0_1 = value1_2 - <$ty>::splat(1.0);
|
||||
|
||||
// Doing multiply before addition allows some architectures
|
||||
// to use a single instruction.
|
||||
let res = value0_1 * scale + low;
|
||||
|
||||
debug_assert!(low.all_le(res) || !scale.all_finite());
|
||||
if res.all_lt(high) {
|
||||
return Ok(res);
|
||||
}
|
||||
|
||||
// This handles a number of edge cases.
|
||||
// * `low` or `high` is NaN. In this case `scale` and
|
||||
// `res` are going to end up as NaN.
|
||||
// * `low` is negative infinity and `high` is finite.
|
||||
// `scale` is going to be infinite and `res` will be
|
||||
// NaN.
|
||||
// * `high` is positive infinity and `low` is finite.
|
||||
// `scale` is going to be infinite and `res` will
|
||||
// be infinite or NaN (if value0_1 is 0).
|
||||
// * `low` is negative infinity and `high` is positive
|
||||
// infinity. `scale` will be infinite and `res` will
|
||||
// be NaN.
|
||||
// * `low` and `high` are finite, but `high - low`
|
||||
// overflows to infinite. `scale` will be infinite
|
||||
// and `res` will be infinite or NaN (if value0_1 is 0).
|
||||
// So if `high` or `low` are non-finite, we are guaranteed
|
||||
// to fail the `res < high` check above and end up here.
|
||||
//
|
||||
// While we technically should check for non-finite `low`
|
||||
// and `high` before entering the loop, by doing the checks
|
||||
// here instead, we allow the common case to avoid these
|
||||
// checks. But we are still guaranteed that if `low` or
|
||||
// `high` are non-finite we'll end up here and can do the
|
||||
// appropriate checks.
|
||||
//
|
||||
// Likewise, `high - low` overflowing to infinity is also
|
||||
// rare, so handle it here after the common case.
|
||||
let mask = !scale.finite_mask();
|
||||
if mask.any() {
|
||||
if !(low.all_finite() && high.all_finite()) {
|
||||
return Err(Error::NonFinite);
|
||||
}
|
||||
scale = scale.decrease_masked(mask);
|
||||
}
|
||||
}
|
||||
Self::sample_single_inclusive(low_b, high_b, rng)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
@ -1465,14 +1418,14 @@ mod tests {
|
||||
let my_incl_uniform = Uniform::new_inclusive(low, high).unwrap();
|
||||
for _ in 0..100 {
|
||||
let v = rng.sample(my_uniform).extract(lane);
|
||||
assert!(low_scalar <= v && v < high_scalar);
|
||||
assert!(low_scalar <= v && v <= high_scalar);
|
||||
let v = rng.sample(my_incl_uniform).extract(lane);
|
||||
assert!(low_scalar <= v && v <= high_scalar);
|
||||
let v =
|
||||
<$ty as SampleUniform>::Sampler::sample_single(low, high, &mut rng)
|
||||
.unwrap()
|
||||
.extract(lane);
|
||||
assert!(low_scalar <= v && v < high_scalar);
|
||||
assert!(low_scalar <= v && v <= high_scalar);
|
||||
let v = <$ty as SampleUniform>::Sampler::sample_single_inclusive(
|
||||
low, high, &mut rng,
|
||||
)
|
||||
@ -1510,12 +1463,12 @@ mod tests {
|
||||
low_scalar
|
||||
);
|
||||
|
||||
assert!(max_rng.sample(my_uniform).extract(lane) < high_scalar);
|
||||
assert!(max_rng.sample(my_uniform).extract(lane) <= high_scalar);
|
||||
assert!(max_rng.sample(my_incl_uniform).extract(lane) <= high_scalar);
|
||||
// sample_single cannot cope with max_rng:
|
||||
// assert!(<$ty as SampleUniform>::Sampler
|
||||
// ::sample_single(low, high, &mut max_rng).unwrap()
|
||||
// .extract(lane) < high_scalar);
|
||||
// .extract(lane) <= high_scalar);
|
||||
assert!(
|
||||
<$ty as SampleUniform>::Sampler::sample_single_inclusive(
|
||||
low,
|
||||
@ -1543,7 +1496,7 @@ mod tests {
|
||||
)
|
||||
.unwrap()
|
||||
.extract(lane)
|
||||
< high_scalar
|
||||
<= high_scalar
|
||||
);
|
||||
}
|
||||
}
|
||||
@ -1590,10 +1543,9 @@ mod tests {
|
||||
#[cfg(all(feature = "std", panic = "unwind"))]
|
||||
fn test_float_assertions() {
|
||||
use super::SampleUniform;
|
||||
use std::panic::catch_unwind;
|
||||
fn range<T: SampleUniform>(low: T, high: T) {
|
||||
fn range<T: SampleUniform>(low: T, high: T) -> Result<T, Error> {
|
||||
let mut rng = crate::test::rng(253);
|
||||
T::Sampler::sample_single(low, high, &mut rng).unwrap();
|
||||
T::Sampler::sample_single(low, high, &mut rng)
|
||||
}
|
||||
|
||||
macro_rules! t {
|
||||
@ -1616,10 +1568,9 @@ mod tests {
|
||||
for lane in 0..<$ty>::LEN {
|
||||
let low = <$ty>::splat(0.0 as $f_scalar).replace(lane, low_scalar);
|
||||
let high = <$ty>::splat(1.0 as $f_scalar).replace(lane, high_scalar);
|
||||
assert!(catch_unwind(|| range(low, high)).is_err());
|
||||
assert!(range(low, high).is_err());
|
||||
assert!(Uniform::new(low, high).is_err());
|
||||
assert!(Uniform::new_inclusive(low, high).is_err());
|
||||
assert!(catch_unwind(|| range(low, low)).is_err());
|
||||
assert!(Uniform::new(low, low).is_err());
|
||||
}
|
||||
}
|
||||
|
@ -218,9 +218,7 @@ pub(crate) trait FloatSIMDUtils {
|
||||
fn all_finite(self) -> bool;
|
||||
|
||||
type Mask;
|
||||
fn finite_mask(self) -> Self::Mask;
|
||||
fn gt_mask(self, other: Self) -> Self::Mask;
|
||||
fn ge_mask(self, other: Self) -> Self::Mask;
|
||||
|
||||
// Decrease all lanes where the mask is `true` to the next lower value
|
||||
// representable by the floating-point type. At least one of the lanes
|
||||
@ -292,21 +290,11 @@ macro_rules! scalar_float_impl {
|
||||
self.is_finite()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn finite_mask(self) -> Self::Mask {
|
||||
self.is_finite()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn gt_mask(self, other: Self) -> Self::Mask {
|
||||
self > other
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ge_mask(self, other: Self) -> Self::Mask {
|
||||
self >= other
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn decrease_masked(self, mask: Self::Mask) -> Self {
|
||||
debug_assert!(mask, "At least one lane must be set");
|
||||
@ -368,21 +356,11 @@ macro_rules! simd_impl {
|
||||
self.is_finite().all()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn finite_mask(self) -> Self::Mask {
|
||||
self.is_finite()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn gt_mask(self, other: Self) -> Self::Mask {
|
||||
self.simd_gt(other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ge_mask(self, other: Self) -> Self::Mask {
|
||||
self.simd_ge(other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn decrease_masked(self, mask: Self::Mask) -> Self {
|
||||
// Casting a mask into ints will produce all bits set for
|
||||
|
Loading…
x
Reference in New Issue
Block a user