Skip to main content

rand/distr/
utils.rs

1// Copyright 2018 Developers of the Rand project.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// https://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or https://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8
9//! Math helper functions
10
11#[cfg(feature = "simd_support")]
12use core::simd::SimdElement;
13#[cfg(feature = "simd_support")]
14use core::simd::prelude::*;
15
16pub(crate) trait WideningMultiply<RHS = Self> {
17    type Output;
18
19    fn wmul(self, x: RHS) -> Self::Output;
20}
21
22macro_rules! wmul_impl {
23    ($ty:ty, $wide:ty, $shift:expr) => {
24        impl WideningMultiply for $ty {
25            type Output = ($ty, $ty);
26
27            #[inline(always)]
28            fn wmul(self, x: $ty) -> Self::Output {
29                let tmp = (self as $wide) * (x as $wide);
30                ((tmp >> $shift) as $ty, tmp as $ty)
31            }
32        }
33    };
34
35    // simd bulk implementation
36    ($(($ty:ident, $wide:ty),)+, $shift:expr) => {
37        $(
38            impl WideningMultiply for $ty {
39                type Output = ($ty, $ty);
40
41                #[inline(always)]
42                fn wmul(self, x: $ty) -> Self::Output {
43                    // For supported vectors, this should compile to a couple
44                    // supported multiply & swizzle instructions (no actual
45                    // casting).
46                    // TODO: optimize
47                    let y: $wide = self.cast();
48                    let x: $wide = x.cast();
49                    let tmp = y * x;
50                    let hi: $ty = (tmp >> Simd::splat($shift)).cast();
51                    let lo: $ty = tmp.cast();
52                    (hi, lo)
53                }
54            }
55        )+
56    };
57}
58impl WideningMultiply for u8 {
    type Output = (u8, u8);
    #[inline(always)]
    fn wmul(self, x: u8) -> Self::Output {
        let tmp = (self as u16) * (x as u16);
        ((tmp >> 8) as u8, tmp as u8)
    }
}wmul_impl! { u8, u16, 8 }
59impl WideningMultiply for u16 {
    type Output = (u16, u16);
    #[inline(always)]
    fn wmul(self, x: u16) -> Self::Output {
        let tmp = (self as u32) * (x as u32);
        ((tmp >> 16) as u16, tmp as u16)
    }
}wmul_impl! { u16, u32, 16 }
60impl WideningMultiply for u32 {
    type Output = (u32, u32);
    #[inline(always)]
    fn wmul(self, x: u32) -> Self::Output {
        let tmp = (self as u64) * (x as u64);
        ((tmp >> 32) as u32, tmp as u32)
    }
}wmul_impl! { u32, u64, 32 }
61impl WideningMultiply for u64 {
    type Output = (u64, u64);
    #[inline(always)]
    fn wmul(self, x: u64) -> Self::Output {
        let tmp = (self as u128) * (x as u128);
        ((tmp >> 64) as u64, tmp as u64)
    }
}wmul_impl! { u64, u128, 64 }
62
63// This code is a translation of the __mulddi3 function in LLVM's
64// compiler-rt. It is an optimised variant of the common method
65// `(a + b) * (c + d) = ac + ad + bc + bd`.
66//
67// For some reason LLVM can optimise the C version very well, but
68// keeps shuffling registers in this Rust translation.
69macro_rules! wmul_impl_large {
70    ($ty:ty, $half:expr) => {
71        impl WideningMultiply for $ty {
72            type Output = ($ty, $ty);
73
74            #[inline(always)]
75            fn wmul(self, b: $ty) -> Self::Output {
76                const LOWER_MASK: $ty = !0 >> $half;
77                let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
78                let mut t = low >> $half;
79                low &= LOWER_MASK;
80                t += (self >> $half).wrapping_mul(b & LOWER_MASK);
81                low += (t & LOWER_MASK) << $half;
82                let mut high = t >> $half;
83                t = low >> $half;
84                low &= LOWER_MASK;
85                t += (b >> $half).wrapping_mul(self & LOWER_MASK);
86                low += (t & LOWER_MASK) << $half;
87                high += t >> $half;
88                high += (self >> $half).wrapping_mul(b >> $half);
89
90                (high, low)
91            }
92        }
93    };
94
95    // simd bulk implementation
96    (($($ty:ty,)+) $scalar:ty, $half:expr) => {
97        $(
98            impl WideningMultiply for $ty {
99                type Output = ($ty, $ty);
100
101                #[inline(always)]
102                fn wmul(self, b: $ty) -> Self::Output {
103                    // needs wrapping multiplication
104                    let lower_mask = <$ty>::splat(!0 >> $half);
105                    let half = <$ty>::splat($half);
106                    let mut low = (self & lower_mask) * (b & lower_mask);
107                    let mut t = low >> half;
108                    low &= lower_mask;
109                    t += (self >> half) * (b & lower_mask);
110                    low += (t & lower_mask) << half;
111                    let mut high = t >> half;
112                    t = low >> half;
113                    low &= lower_mask;
114                    t += (b >> half) * (self & lower_mask);
115                    low += (t & lower_mask) << half;
116                    high += t >> half;
117                    high += (self >> half) * (b >> half);
118
119                    (high, low)
120                }
121            }
122        )+
123    };
124}
125impl WideningMultiply for u128 {
    type Output = (u128, u128);
    #[inline(always)]
    fn wmul(self, b: u128) -> Self::Output {
        const LOWER_MASK: u128 = !0 >> 64;
        let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
        let mut t = low >> 64;
        low &= LOWER_MASK;
        t += (self >> 64).wrapping_mul(b & LOWER_MASK);
        low += (t & LOWER_MASK) << 64;
        let mut high = t >> 64;
        t = low >> 64;
        low &= LOWER_MASK;
        t += (b >> 64).wrapping_mul(self & LOWER_MASK);
        low += (t & LOWER_MASK) << 64;
        high += t >> 64;
        high += (self >> 64).wrapping_mul(b >> 64);
        (high, low)
    }
}wmul_impl_large! { u128, 64 }
126
127macro_rules! wmul_impl_usize {
128    ($ty:ty) => {
129        impl WideningMultiply for usize {
130            type Output = (usize, usize);
131
132            #[inline(always)]
133            fn wmul(self, x: usize) -> Self::Output {
134                let (high, low) = (self as $ty).wmul(x as $ty);
135                (high as usize, low as usize)
136            }
137        }
138    };
139}
140#[cfg(target_pointer_width = "16")]
141wmul_impl_usize! { u16 }
142#[cfg(target_pointer_width = "32")]
143wmul_impl_usize! { u32 }
144#[cfg(target_pointer_width = "64")]
145impl WideningMultiply for usize {
    type Output = (usize, usize);
    #[inline(always)]
    fn wmul(self, x: usize) -> Self::Output {
        let (high, low) = (self as u64).wmul(x as u64);
        (high as usize, low as usize)
    }
}wmul_impl_usize! { u64 }
146
147#[cfg(feature = "simd_support")]
148mod simd_wmul {
149    use super::*;
150    #[cfg(target_arch = "x86")]
151    use core::arch::x86::*;
152    #[cfg(target_arch = "x86_64")]
153    use core::arch::x86_64::*;
154
155    wmul_impl! {
156        (u8x4, u16x4),
157        (u8x8, u16x8),
158        (u8x16, u16x16),
159        (u8x32, u16x32),
160        (u8x64, Simd<u16, 64>),,
161        8
162    }
163
164    wmul_impl! { (u16x2, u32x2),, 16 }
165    wmul_impl! { (u16x4, u32x4),, 16 }
166    #[cfg(not(target_feature = "sse2"))]
167    wmul_impl! { (u16x8, u32x8),, 16 }
168    #[cfg(not(target_feature = "avx2"))]
169    wmul_impl! { (u16x16, u32x16),, 16 }
170    #[cfg(not(target_feature = "avx512bw"))]
171    wmul_impl! { (u16x32, Simd<u32, 32>),, 16 }
172
173    // 16-bit lane widths allow use of the x86 `mulhi` instructions, which
174    // means `wmul` can be implemented with only two instructions.
175    #[allow(unused_macros)]
176    macro_rules! wmul_impl_16 {
177        ($ty:ident, $mulhi:ident, $mullo:ident) => {
178            impl WideningMultiply for $ty {
179                type Output = ($ty, $ty);
180
181                #[inline(always)]
182                #[allow(clippy::undocumented_unsafe_blocks)]
183                fn wmul(self, x: $ty) -> Self::Output {
184                    let hi = unsafe { $mulhi(self.into(), x.into()) }.into();
185                    let lo = unsafe { $mullo(self.into(), x.into()) }.into();
186                    (hi, lo)
187                }
188            }
189        };
190    }
191
192    #[cfg(target_feature = "sse2")]
193    wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 }
194    #[cfg(target_feature = "avx2")]
195    wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
196    #[cfg(target_feature = "avx512bw")]
197    wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 }
198
199    wmul_impl! {
200        (u32x2, u64x2),
201        (u32x4, u64x4),
202        (u32x8, u64x8),
203        (u32x16, Simd<u64, 16>),,
204        32
205    }
206
207    wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
208}
209
210/// Helper trait when dealing with scalar and SIMD floating point types.
211pub(crate) trait FloatSIMDUtils {
212    // `PartialOrd` for vectors compares lexicographically. We want to compare all
213    // the individual SIMD lanes instead, and get the combined result over all
214    // lanes. This is possible using something like `a.lt(b).all()`, but we
215    // implement it as a trait so we can write the same code for `f32` and `f64`.
216    // Only the comparison functions we need are implemented.
217    fn all_lt(self, other: Self) -> bool;
218    fn all_le(self, other: Self) -> bool;
219    fn all_finite(self) -> bool;
220
221    type Mask;
222    fn gt_mask(self, other: Self) -> Self::Mask;
223
224    // Decrease all lanes where the mask is `true` to the next lower value
225    // representable by the floating-point type. At least one of the lanes
226    // must be set.
227    fn decrease_masked(self, mask: Self::Mask) -> Self;
228
229    // Convert from int value. Conversion is done while retaining the numerical
230    // value, not by retaining the binary representation.
231    type UInt;
232    fn cast_from_int(i: Self::UInt) -> Self;
233}
234
235#[cfg(test)]
236pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils {
237    type Scalar;
238
239    fn replace(self, index: usize, new_value: Self::Scalar) -> Self;
240    fn extract_lane(self, index: usize) -> Self::Scalar;
241}
242
243/// Implement functions on f32/f64 to give them APIs similar to SIMD types
244pub(crate) trait FloatAsSIMD: Sized {
245    #[cfg(test)]
246    const LEN: usize = 1;
247
248    #[inline(always)]
249    fn splat(scalar: Self) -> Self {
250        scalar
251    }
252}
253
254pub(crate) trait IntAsSIMD: Sized {
255    #[inline(always)]
256    fn splat(scalar: Self) -> Self {
257        scalar
258    }
259}
260
261impl IntAsSIMD for u32 {}
262impl IntAsSIMD for u64 {}
263
264pub(crate) trait BoolAsSIMD: Sized {
265    fn any(self) -> bool;
266}
267
268impl BoolAsSIMD for bool {
269    #[inline(always)]
270    fn any(self) -> bool {
271        self
272    }
273}
274
275macro_rules! scalar_float_impl {
276    ($ty:ident, $uty:ident) => {
277        impl FloatSIMDUtils for $ty {
278            type Mask = bool;
279            type UInt = $uty;
280
281            #[inline(always)]
282            fn all_lt(self, other: Self) -> bool {
283                self < other
284            }
285
286            #[inline(always)]
287            fn all_le(self, other: Self) -> bool {
288                self <= other
289            }
290
291            #[inline(always)]
292            fn all_finite(self) -> bool {
293                self.is_finite()
294            }
295
296            #[inline(always)]
297            fn gt_mask(self, other: Self) -> Self::Mask {
298                self > other
299            }
300
301            #[inline(always)]
302            fn decrease_masked(self, mask: Self::Mask) -> Self {
303                debug_assert!(mask, "At least one lane must be set");
304                <$ty>::from_bits(self.to_bits() - 1)
305            }
306
307            #[inline]
308            fn cast_from_int(i: Self::UInt) -> Self {
309                i as $ty
310            }
311        }
312
313        #[cfg(test)]
314        impl FloatSIMDScalarUtils for $ty {
315            type Scalar = $ty;
316
317            #[inline]
318            fn replace(self, index: usize, new_value: Self::Scalar) -> Self {
319                debug_assert_eq!(index, 0);
320                new_value
321            }
322
323            #[inline]
324            fn extract_lane(self, index: usize) -> Self::Scalar {
325                debug_assert_eq!(index, 0);
326                self
327            }
328        }
329
330        impl FloatAsSIMD for $ty {}
331    };
332}
333
334impl FloatSIMDUtils for f32 {
    type Mask = bool;
    type UInt = u32;
    #[inline(always)]
    fn all_lt(self, other: Self) -> bool { self < other }
    #[inline(always)]
    fn all_le(self, other: Self) -> bool { self <= other }
    #[inline(always)]
    fn all_finite(self) -> bool { self.is_finite() }
    #[inline(always)]
    fn gt_mask(self, other: Self) -> Self::Mask { self > other }
    #[inline(always)]
    fn decrease_masked(self, mask: Self::Mask) -> Self {
        if true {
            if !mask {
                {
                    ::core::panicking::panic_fmt(format_args!("At least one lane must be set"));
                }
            };
        };
        <f32>::from_bits(self.to_bits() - 1)
    }
    #[inline]
    fn cast_from_int(i: Self::UInt) -> Self { i as f32 }
}
impl FloatAsSIMD for f32 {}scalar_float_impl!(f32, u32);
335impl FloatSIMDUtils for f64 {
    type Mask = bool;
    type UInt = u64;
    #[inline(always)]
    fn all_lt(self, other: Self) -> bool { self < other }
    #[inline(always)]
    fn all_le(self, other: Self) -> bool { self <= other }
    #[inline(always)]
    fn all_finite(self) -> bool { self.is_finite() }
    #[inline(always)]
    fn gt_mask(self, other: Self) -> Self::Mask { self > other }
    #[inline(always)]
    fn decrease_masked(self, mask: Self::Mask) -> Self {
        if true {
            if !mask {
                {
                    ::core::panicking::panic_fmt(format_args!("At least one lane must be set"));
                }
            };
        };
        <f64>::from_bits(self.to_bits() - 1)
    }
    #[inline]
    fn cast_from_int(i: Self::UInt) -> Self { i as f64 }
}
impl FloatAsSIMD for f64 {}scalar_float_impl!(f64, u64);
336
337#[cfg(feature = "simd_support")]
338macro_rules! simd_impl {
339    ($fty:ident, $uty:ident) => {
340        impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES> {
341            type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
342            type UInt = Simd<$uty, LANES>;
343
344            #[inline(always)]
345            fn all_lt(self, other: Self) -> bool {
346                self.simd_lt(other).all()
347            }
348
349            #[inline(always)]
350            fn all_le(self, other: Self) -> bool {
351                self.simd_le(other).all()
352            }
353
354            #[inline(always)]
355            fn all_finite(self) -> bool {
356                self.is_finite().all()
357            }
358
359            #[inline(always)]
360            fn gt_mask(self, other: Self) -> Self::Mask {
361                self.simd_gt(other)
362            }
363
364            #[inline(always)]
365            fn decrease_masked(self, mask: Self::Mask) -> Self {
366                // Casting a mask into ints will produce all bits set for
367                // true, and 0 for false. Adding that to the binary
368                // representation of a float means subtracting one from
369                // the binary representation, resulting in the next lower
370                // value representable by $fty. This works even when the
371                // current value is infinity.
372                debug_assert!(mask.any(), "At least one lane must be set");
373                Self::from_bits(self.to_bits() + mask.to_simd().cast())
374            }
375
376            #[inline]
377            fn cast_from_int(i: Self::UInt) -> Self {
378                i.cast()
379            }
380        }
381
382        #[cfg(test)]
383        impl<const LANES: usize> FloatSIMDScalarUtils for Simd<$fty, LANES> {
384            type Scalar = $fty;
385
386            #[inline]
387            fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self {
388                self.as_mut_array()[index] = new_value;
389                self
390            }
391
392            #[inline]
393            fn extract_lane(self, index: usize) -> Self::Scalar {
394                self.as_array()[index]
395            }
396        }
397    };
398}
399
400#[cfg(feature = "simd_support")]
401simd_impl!(f32, u32);
402#[cfg(feature = "simd_support")]
403simd_impl!(f64, u64);