fix stdsimd, add mask opt notes
This commit is contained in:
parent
f89f15fc1f
commit
2fab15dcd7
@ -114,7 +114,7 @@ impl_nzint!(NonZeroU64, NonZeroU64::new);
|
||||
impl_nzint!(NonZeroU128, NonZeroU128::new);
|
||||
impl_nzint!(NonZeroUsize, NonZeroUsize::new);
|
||||
|
||||
macro_rules! intrinsic_impl {
|
||||
macro_rules! x86_intrinsic_impl {
|
||||
($($intrinsic:ident),+) => {$(
|
||||
/// Available only on x86/64 platforms
|
||||
impl Distribution<$intrinsic> for Standard {
|
||||
@ -156,12 +156,12 @@ macro_rules! simd_impl {
|
||||
simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64, usize, isize);
|
||||
|
||||
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
|
||||
intrinsic_impl!(__m128i, __m256i);
|
||||
x86_intrinsic_impl!(__m128i, __m256i);
|
||||
#[cfg(all(
|
||||
any(target_arch = "x86", target_arch = "x86_64"),
|
||||
feature = "simd_support"
|
||||
))]
|
||||
intrinsic_impl!(__m512i);
|
||||
x86_intrinsic_impl!(__m512i);
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
@ -23,7 +23,7 @@ use serde::{Serialize, Deserialize};
|
||||
#[cfg(feature = "min_const_gen")]
|
||||
use core::mem::{self, MaybeUninit};
|
||||
#[cfg(feature = "simd_support")]
|
||||
use core::simd::{Mask, Simd, LaneCount, SupportedLaneCount, MaskElement, SimdElement};
|
||||
use core::simd::*;
|
||||
|
||||
|
||||
// ----- Sampling distributions -----
|
||||
@ -161,22 +161,29 @@ impl Distribution<bool> for Standard {
|
||||
/// let x = rng.gen::<mask8x16>().select(b, a);
|
||||
/// ```
|
||||
///
|
||||
/// Since most bits are unused you could also generate only as many bits as you need.
|
||||
/// Since most bits are unused you could also generate only as many bits as you need, i.e.:
|
||||
/// ```
|
||||
/// let x = u16x8::splat(rng.gen::<u8> as u16);
|
||||
/// let mask = u16x8::splat(1) << u16x8::from([0, 1, 2, 3, 4, 5, 6, 7]);
|
||||
/// let rand_mask = (x & mask).simd_eq(mask);
|
||||
/// ```
|
||||
///
|
||||
/// [`_mm_blendv_epi8`]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_epi8&ig_expand=514/
|
||||
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
|
||||
#[cfg(feature = "simd_support")]
|
||||
impl<T, const LANES: usize> Distribution<Mask<T, LANES>> for Standard
|
||||
where
|
||||
T: MaskElement + PartialOrd + SimdElement<Mask = T> + Default,
|
||||
T: MaskElement + Default,
|
||||
LaneCount<LANES>: SupportedLaneCount,
|
||||
Standard: Distribution<Simd<T, LANES>>,
|
||||
Simd<T, LANES>: SimdPartialOrd<Mask = Mask<T, LANES>>,
|
||||
{
|
||||
#[inline]
|
||||
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mask<T, LANES> {
|
||||
// `MaskElement` must be a signed integer, so this is equivalent
|
||||
// to the scalar `i32 < 0` method
|
||||
rng.gen().lanes_lt(Simd::default())
|
||||
let var = rng.gen::<Simd<T, LANES>>();
|
||||
var.simd_lt(Simd::default())
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -606,7 +606,7 @@ macro_rules! uniform_simd_int_impl {
|
||||
{
|
||||
let low = *low_b.borrow();
|
||||
let high = *high_b.borrow();
|
||||
assert!(low.lanes_lt(high).all(), "Uniform::new called with `low >= high`");
|
||||
assert!(low.simd_lt(high).all(), "Uniform::new called with `low >= high`");
|
||||
UniformSampler::new_inclusive(low, high - Simd::splat(1))
|
||||
}
|
||||
|
||||
@ -618,7 +618,7 @@ macro_rules! uniform_simd_int_impl {
|
||||
{
|
||||
let low = *low_b.borrow();
|
||||
let high = *high_b.borrow();
|
||||
assert!(low.lanes_le(high).all(),
|
||||
assert!(low.simd_le(high).all(),
|
||||
"Uniform::new_inclusive called with `low > high`");
|
||||
let unsigned_max = Simd::splat(::core::$unsigned::MAX);
|
||||
|
||||
@ -626,7 +626,7 @@ macro_rules! uniform_simd_int_impl {
|
||||
// see https://doc.rust-lang.org/std/simd/struct.Simd.html
|
||||
let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast();
|
||||
// `% 0` will panic at runtime.
|
||||
let not_full_range = range.lanes_gt(Simd::splat(0));
|
||||
let not_full_range = range.simd_gt(Simd::splat(0));
|
||||
// replacing 0 with `unsigned_max` allows a faster `select`
|
||||
// with bitwise OR
|
||||
let modulo = not_full_range.select(range, unsigned_max);
|
||||
@ -660,7 +660,7 @@ macro_rules! uniform_simd_int_impl {
|
||||
let mut v: Simd<$unsigned, LANES> = rng.gen();
|
||||
loop {
|
||||
let (hi, lo) = v.wmul(range);
|
||||
let mask = lo.lanes_le(zone);
|
||||
let mask = lo.simd_le(zone);
|
||||
if mask.all() {
|
||||
let hi: Simd<$ty, LANES> = hi.cast();
|
||||
// wrapping addition
|
||||
@ -669,7 +669,7 @@ macro_rules! uniform_simd_int_impl {
|
||||
// When `range.eq(0).none()` the compare and blend
|
||||
// operations are avoided.
|
||||
let v: Simd<$ty, LANES> = v.cast();
|
||||
return range.lanes_gt(Simd::splat(0)).select(result, v);
|
||||
return range.simd_gt(Simd::splat(0)).select(result, v);
|
||||
}
|
||||
// Replace only the failing lanes
|
||||
v = mask.select(v, rng.gen());
|
||||
@ -1265,8 +1265,8 @@ mod tests {
|
||||
($ty::splat(10), $ty::splat(127)),
|
||||
($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)),
|
||||
],
|
||||
|x: $ty, y| x.lanes_le(y).all(),
|
||||
|x: $ty, y| x.lanes_lt(y).all()
|
||||
|x: $ty, y| x.simd_le(y).all(),
|
||||
|x: $ty, y| x.simd_lt(y).all()
|
||||
);)*
|
||||
}};
|
||||
}
|
||||
|
@ -99,20 +99,20 @@ macro_rules! wmul_impl_large {
|
||||
#[inline(always)]
|
||||
fn wmul(self, b: $ty) -> Self::Output {
|
||||
// needs wrapping multiplication
|
||||
const LOWER_MASK: $ty = <$ty>::splat(!0 >> $half);
|
||||
const HALF: $ty = <$ty>::splat($half);
|
||||
let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
|
||||
let mut t = low >> HALF;
|
||||
low &= LOWER_MASK;
|
||||
t += (self >> HALF) * (b & LOWER_MASK);
|
||||
low += (t & LOWER_MASK) << HALF;
|
||||
let mut high = t >> HALF;
|
||||
t = low >> HALF;
|
||||
low &= LOWER_MASK;
|
||||
t += (b >> HALF) * (self & LOWER_MASK);
|
||||
low += (t & LOWER_MASK) << HALF;
|
||||
high += t >> HALF;
|
||||
high += (self >> HALF) * (b >> HALF);
|
||||
let lower_mask = <$ty>::splat(!0 >> $half);
|
||||
let half = <$ty>::splat($half);
|
||||
let mut low = (self & lower_mask) * (b & lower_mask);
|
||||
let mut t = low >> half;
|
||||
low &= lower_mask;
|
||||
t += (self >> half) * (b & lower_mask);
|
||||
low += (t & lower_mask) << half;
|
||||
let mut high = t >> half;
|
||||
t = low >> half;
|
||||
low &= lower_mask;
|
||||
t += (b >> half) * (self & lower_mask);
|
||||
low += (t & lower_mask) << half;
|
||||
high += t >> half;
|
||||
high += (self >> half) * (b >> half);
|
||||
|
||||
(high, low)
|
||||
}
|
||||
@ -385,12 +385,12 @@ macro_rules! simd_impl {
|
||||
|
||||
#[inline(always)]
|
||||
fn all_lt(self, other: Self) -> bool {
|
||||
self.lanes_lt(other).all()
|
||||
self.simd_lt(other).all()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn all_le(self, other: Self) -> bool {
|
||||
self.lanes_le(other).all()
|
||||
self.simd_le(other).all()
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
@ -405,12 +405,12 @@ macro_rules! simd_impl {
|
||||
|
||||
#[inline(always)]
|
||||
fn gt_mask(self, other: Self) -> Self::Mask {
|
||||
self.lanes_gt(other)
|
||||
self.simd_gt(other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
fn ge_mask(self, other: Self) -> Self::Mask {
|
||||
self.lanes_ge(other)
|
||||
self.simd_ge(other)
|
||||
}
|
||||
|
||||
#[inline(always)]
|
||||
|
Loading…
x
Reference in New Issue
Block a user