fix stdsimd, add mask opt notes

This commit is contained in:
TheIronBorn 2022-08-07 17:22:19 -07:00
parent f89f15fc1f
commit 2fab15dcd7
4 changed files with 39 additions and 32 deletions

View File

@ -114,7 +114,7 @@ impl_nzint!(NonZeroU64, NonZeroU64::new);
impl_nzint!(NonZeroU128, NonZeroU128::new);
impl_nzint!(NonZeroUsize, NonZeroUsize::new);
macro_rules! intrinsic_impl {
macro_rules! x86_intrinsic_impl {
($($intrinsic:ident),+) => {$(
/// Available only on x86/64 platforms
impl Distribution<$intrinsic> for Standard {
@ -156,12 +156,12 @@ macro_rules! simd_impl {
simd_impl!(u8, i8, u16, i16, u32, i32, u64, i64, usize, isize);
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
intrinsic_impl!(__m128i, __m256i);
x86_intrinsic_impl!(__m128i, __m256i);
#[cfg(all(
any(target_arch = "x86", target_arch = "x86_64"),
feature = "simd_support"
))]
intrinsic_impl!(__m512i);
x86_intrinsic_impl!(__m512i);
#[cfg(test)]
mod tests {

View File

@ -23,7 +23,7 @@ use serde::{Serialize, Deserialize};
#[cfg(feature = "min_const_gen")]
use core::mem::{self, MaybeUninit};
#[cfg(feature = "simd_support")]
use core::simd::{Mask, Simd, LaneCount, SupportedLaneCount, MaskElement, SimdElement};
use core::simd::*;
// ----- Sampling distributions -----
@ -161,22 +161,29 @@ impl Distribution<bool> for Standard {
/// let x = rng.gen::<mask8x16>().select(b, a);
/// ```
///
/// Since most bits are unused you could also generate only as many bits as you need.
/// Since most bits are unused you could also generate only as many bits as you need, i.e.:
/// ```
/// let x = u16x8::splat(rng.gen::<u8> as u16);
/// let mask = u16x8::splat(1) << u16x8::from([0, 1, 2, 3, 4, 5, 6, 7]);
/// let rand_mask = (x & mask).simd_eq(mask);
/// ```
///
/// [`_mm_blendv_epi8`]: https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_blendv_epi8&ig_expand=514/
/// [`simd_support`]: https://github.com/rust-random/rand#crate-features
#[cfg(feature = "simd_support")]
impl<T, const LANES: usize> Distribution<Mask<T, LANES>> for Standard
where
T: MaskElement + PartialOrd + SimdElement<Mask = T> + Default,
T: MaskElement + Default,
LaneCount<LANES>: SupportedLaneCount,
Standard: Distribution<Simd<T, LANES>>,
Simd<T, LANES>: SimdPartialOrd<Mask = Mask<T, LANES>>,
{
#[inline]
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Mask<T, LANES> {
// `MaskElement` must be a signed integer, so this is equivalent
// to the scalar `i32 < 0` method
rng.gen().lanes_lt(Simd::default())
let var = rng.gen::<Simd<T, LANES>>();
var.simd_lt(Simd::default())
}
}

View File

@ -606,7 +606,7 @@ macro_rules! uniform_simd_int_impl {
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low.lanes_lt(high).all(), "Uniform::new called with `low >= high`");
assert!(low.simd_lt(high).all(), "Uniform::new called with `low >= high`");
UniformSampler::new_inclusive(low, high - Simd::splat(1))
}
@ -618,7 +618,7 @@ macro_rules! uniform_simd_int_impl {
{
let low = *low_b.borrow();
let high = *high_b.borrow();
assert!(low.lanes_le(high).all(),
assert!(low.simd_le(high).all(),
"Uniform::new_inclusive called with `low > high`");
let unsigned_max = Simd::splat(::core::$unsigned::MAX);
@ -626,7 +626,7 @@ macro_rules! uniform_simd_int_impl {
// see https://doc.rust-lang.org/std/simd/struct.Simd.html
let range: Simd<$unsigned, LANES> = ((high - low) + Simd::splat(1)).cast();
// `% 0` will panic at runtime.
let not_full_range = range.lanes_gt(Simd::splat(0));
let not_full_range = range.simd_gt(Simd::splat(0));
// replacing 0 with `unsigned_max` allows a faster `select`
// with bitwise OR
let modulo = not_full_range.select(range, unsigned_max);
@ -660,7 +660,7 @@ macro_rules! uniform_simd_int_impl {
let mut v: Simd<$unsigned, LANES> = rng.gen();
loop {
let (hi, lo) = v.wmul(range);
let mask = lo.lanes_le(zone);
let mask = lo.simd_le(zone);
if mask.all() {
let hi: Simd<$ty, LANES> = hi.cast();
// wrapping addition
@ -669,7 +669,7 @@ macro_rules! uniform_simd_int_impl {
// When `range.eq(0).none()` the compare and blend
// operations are avoided.
let v: Simd<$ty, LANES> = v.cast();
return range.lanes_gt(Simd::splat(0)).select(result, v);
return range.simd_gt(Simd::splat(0)).select(result, v);
}
// Replace only the failing lanes
v = mask.select(v, rng.gen());
@ -1265,8 +1265,8 @@ mod tests {
($ty::splat(10), $ty::splat(127)),
($ty::splat($scalar::MIN), $ty::splat($scalar::MAX)),
],
|x: $ty, y| x.lanes_le(y).all(),
|x: $ty, y| x.lanes_lt(y).all()
|x: $ty, y| x.simd_le(y).all(),
|x: $ty, y| x.simd_lt(y).all()
);)*
}};
}

View File

@ -99,20 +99,20 @@ macro_rules! wmul_impl_large {
#[inline(always)]
fn wmul(self, b: $ty) -> Self::Output {
// needs wrapping multiplication
const LOWER_MASK: $ty = <$ty>::splat(!0 >> $half);
const HALF: $ty = <$ty>::splat($half);
let mut low = (self & LOWER_MASK) * (b & LOWER_MASK);
let mut t = low >> HALF;
low &= LOWER_MASK;
t += (self >> HALF) * (b & LOWER_MASK);
low += (t & LOWER_MASK) << HALF;
let mut high = t >> HALF;
t = low >> HALF;
low &= LOWER_MASK;
t += (b >> HALF) * (self & LOWER_MASK);
low += (t & LOWER_MASK) << HALF;
high += t >> HALF;
high += (self >> HALF) * (b >> HALF);
let lower_mask = <$ty>::splat(!0 >> $half);
let half = <$ty>::splat($half);
let mut low = (self & lower_mask) * (b & lower_mask);
let mut t = low >> half;
low &= lower_mask;
t += (self >> half) * (b & lower_mask);
low += (t & lower_mask) << half;
let mut high = t >> half;
t = low >> half;
low &= lower_mask;
t += (b >> half) * (self & lower_mask);
low += (t & lower_mask) << half;
high += t >> half;
high += (self >> half) * (b >> half);
(high, low)
}
@ -385,12 +385,12 @@ macro_rules! simd_impl {
#[inline(always)]
fn all_lt(self, other: Self) -> bool {
self.lanes_lt(other).all()
self.simd_lt(other).all()
}
#[inline(always)]
fn all_le(self, other: Self) -> bool {
self.lanes_le(other).all()
self.simd_le(other).all()
}
#[inline(always)]
@ -405,12 +405,12 @@ macro_rules! simd_impl {
#[inline(always)]
fn gt_mask(self, other: Self) -> Self::Mask {
self.lanes_gt(other)
self.simd_gt(other)
}
#[inline(always)]
fn ge_mask(self, other: Self) -> Self::Mask {
self.lanes_ge(other)
self.simd_ge(other)
}
#[inline(always)]