1#[cfg(feature = "simd_support")]
12use core::simd::prelude::*;
13#[cfg(feature = "simd_support")]
14use core::simd::{LaneCount, SimdElement, SupportedLaneCount};
15
16pub(crate) trait WideningMultiply<RHS = Self> {
17    type Output;
18
19    fn wmul(self, x: RHS) -> Self::Output;
20}
21
22macro_rules! wmul_impl {
23    ($ty:ty, $wide:ty, $shift:expr) => {
24        impl WideningMultiply for $ty {
25            type Output = ($ty, $ty);
26
27            #[inline(always)]
28            fn wmul(self, x: $ty) -> Self::Output {
29                let tmp = (self as $wide) * (x as $wide);
30                ((tmp >> $shift) as $ty, tmp as $ty)
31            }
32        }
33    };
34
35    ($(($ty:ident, $wide:ty),)+, $shift:expr) => {
37        $(
38            impl WideningMultiply for $ty {
39                type Output = ($ty, $ty);
40
41                #[inline(always)]
42                fn wmul(self, x: $ty) -> Self::Output {
43                    let y: $wide = self.cast();
48                    let x: $wide = x.cast();
49                    let tmp = y * x;
50                    let hi: $ty = (tmp >> Simd::splat($shift)).cast();
51                    let lo: $ty = tmp.cast();
52                    (hi, lo)
53                }
54            }
55        )+
56    };
57}
58wmul_impl! { u8, u16, 8 }
59wmul_impl! { u16, u32, 16 }
60wmul_impl! { u32, u64, 32 }
61wmul_impl! { u64, u128, 64 }
62
63macro_rules! wmul_impl_large {
70    ($ty:ty, $half:expr) => {
71        impl WideningMultiply for $ty {
72            type Output = ($ty, $ty);
73
74            #[inline(always)]
75            fn wmul(self, b: $ty) -> Self::Output {
76                const LOWER_MASK: $ty = !0 >> $half;
77                let mut low = (self & LOWER_MASK).wrapping_mul(b & LOWER_MASK);
78                let mut t = low >> $half;
79                low &= LOWER_MASK;
80                t += (self >> $half).wrapping_mul(b & LOWER_MASK);
81                low += (t & LOWER_MASK) << $half;
82                let mut high = t >> $half;
83                t = low >> $half;
84                low &= LOWER_MASK;
85                t += (b >> $half).wrapping_mul(self & LOWER_MASK);
86                low += (t & LOWER_MASK) << $half;
87                high += t >> $half;
88                high += (self >> $half).wrapping_mul(b >> $half);
89
90                (high, low)
91            }
92        }
93    };
94
95    (($($ty:ty,)+) $scalar:ty, $half:expr) => {
97        $(
98            impl WideningMultiply for $ty {
99                type Output = ($ty, $ty);
100
101                #[inline(always)]
102                fn wmul(self, b: $ty) -> Self::Output {
103                    let lower_mask = <$ty>::splat(!0 >> $half);
105                    let half = <$ty>::splat($half);
106                    let mut low = (self & lower_mask) * (b & lower_mask);
107                    let mut t = low >> half;
108                    low &= lower_mask;
109                    t += (self >> half) * (b & lower_mask);
110                    low += (t & lower_mask) << half;
111                    let mut high = t >> half;
112                    t = low >> half;
113                    low &= lower_mask;
114                    t += (b >> half) * (self & lower_mask);
115                    low += (t & lower_mask) << half;
116                    high += t >> half;
117                    high += (self >> half) * (b >> half);
118
119                    (high, low)
120                }
121            }
122        )+
123    };
124}
125wmul_impl_large! { u128, 64 }
126
127macro_rules! wmul_impl_usize {
128    ($ty:ty) => {
129        impl WideningMultiply for usize {
130            type Output = (usize, usize);
131
132            #[inline(always)]
133            fn wmul(self, x: usize) -> Self::Output {
134                let (high, low) = (self as $ty).wmul(x as $ty);
135                (high as usize, low as usize)
136            }
137        }
138    };
139}
140#[cfg(target_pointer_width = "16")]
141wmul_impl_usize! { u16 }
142#[cfg(target_pointer_width = "32")]
143wmul_impl_usize! { u32 }
144#[cfg(target_pointer_width = "64")]
145wmul_impl_usize! { u64 }
146
147#[cfg(feature = "simd_support")]
148mod simd_wmul {
149    use super::*;
150    #[cfg(target_arch = "x86")]
151    use core::arch::x86::*;
152    #[cfg(target_arch = "x86_64")]
153    use core::arch::x86_64::*;
154
155    wmul_impl! {
156        (u8x4, u16x4),
157        (u8x8, u16x8),
158        (u8x16, u16x16),
159        (u8x32, u16x32),
160        (u8x64, Simd<u16, 64>),,
161        8
162    }
163
164    wmul_impl! { (u16x2, u32x2),, 16 }
165    wmul_impl! { (u16x4, u32x4),, 16 }
166    #[cfg(not(target_feature = "sse2"))]
167    wmul_impl! { (u16x8, u32x8),, 16 }
168    #[cfg(not(target_feature = "avx2"))]
169    wmul_impl! { (u16x16, u32x16),, 16 }
170    #[cfg(not(target_feature = "avx512bw"))]
171    wmul_impl! { (u16x32, Simd<u32, 32>),, 16 }
172
173    #[allow(unused_macros)]
176    macro_rules! wmul_impl_16 {
177        ($ty:ident, $mulhi:ident, $mullo:ident) => {
178            impl WideningMultiply for $ty {
179                type Output = ($ty, $ty);
180
181                #[inline(always)]
182                fn wmul(self, x: $ty) -> Self::Output {
183                    let hi = unsafe { $mulhi(self.into(), x.into()) }.into();
184                    let lo = unsafe { $mullo(self.into(), x.into()) }.into();
185                    (hi, lo)
186                }
187            }
188        };
189    }
190
191    #[cfg(target_feature = "sse2")]
192    wmul_impl_16! { u16x8, _mm_mulhi_epu16, _mm_mullo_epi16 }
193    #[cfg(target_feature = "avx2")]
194    wmul_impl_16! { u16x16, _mm256_mulhi_epu16, _mm256_mullo_epi16 }
195    #[cfg(target_feature = "avx512bw")]
196    wmul_impl_16! { u16x32, _mm512_mulhi_epu16, _mm512_mullo_epi16 }
197
198    wmul_impl! {
199        (u32x2, u64x2),
200        (u32x4, u64x4),
201        (u32x8, u64x8),
202        (u32x16, Simd<u64, 16>),,
203        32
204    }
205
206    wmul_impl_large! { (u64x2, u64x4, u64x8,) u64, 32 }
207}
208
209pub(crate) trait FloatSIMDUtils {
211    fn all_lt(self, other: Self) -> bool;
217    fn all_le(self, other: Self) -> bool;
218    fn all_finite(self) -> bool;
219
220    type Mask;
221    fn gt_mask(self, other: Self) -> Self::Mask;
222
223    fn decrease_masked(self, mask: Self::Mask) -> Self;
227
228    type UInt;
231    fn cast_from_int(i: Self::UInt) -> Self;
232}
233
234#[cfg(test)]
235pub(crate) trait FloatSIMDScalarUtils: FloatSIMDUtils {
236    type Scalar;
237
238    fn replace(self, index: usize, new_value: Self::Scalar) -> Self;
239    fn extract_lane(self, index: usize) -> Self::Scalar;
240}
241
242pub(crate) trait FloatAsSIMD: Sized {
244    #[cfg(test)]
245    const LEN: usize = 1;
246
247    #[inline(always)]
248    fn splat(scalar: Self) -> Self {
249        scalar
250    }
251}
252
253pub(crate) trait IntAsSIMD: Sized {
254    #[inline(always)]
255    fn splat(scalar: Self) -> Self {
256        scalar
257    }
258}
259
260impl IntAsSIMD for u32 {}
261impl IntAsSIMD for u64 {}
262
263pub(crate) trait BoolAsSIMD: Sized {
264    fn any(self) -> bool;
265}
266
267impl BoolAsSIMD for bool {
268    #[inline(always)]
269    fn any(self) -> bool {
270        self
271    }
272}
273
274macro_rules! scalar_float_impl {
275    ($ty:ident, $uty:ident) => {
276        impl FloatSIMDUtils for $ty {
277            type Mask = bool;
278            type UInt = $uty;
279
280            #[inline(always)]
281            fn all_lt(self, other: Self) -> bool {
282                self < other
283            }
284
285            #[inline(always)]
286            fn all_le(self, other: Self) -> bool {
287                self <= other
288            }
289
290            #[inline(always)]
291            fn all_finite(self) -> bool {
292                self.is_finite()
293            }
294
295            #[inline(always)]
296            fn gt_mask(self, other: Self) -> Self::Mask {
297                self > other
298            }
299
300            #[inline(always)]
301            fn decrease_masked(self, mask: Self::Mask) -> Self {
302                debug_assert!(mask, "At least one lane must be set");
303                <$ty>::from_bits(self.to_bits() - 1)
304            }
305
306            #[inline]
307            fn cast_from_int(i: Self::UInt) -> Self {
308                i as $ty
309            }
310        }
311
312        #[cfg(test)]
313        impl FloatSIMDScalarUtils for $ty {
314            type Scalar = $ty;
315
316            #[inline]
317            fn replace(self, index: usize, new_value: Self::Scalar) -> Self {
318                debug_assert_eq!(index, 0);
319                new_value
320            }
321
322            #[inline]
323            fn extract_lane(self, index: usize) -> Self::Scalar {
324                debug_assert_eq!(index, 0);
325                self
326            }
327        }
328
329        impl FloatAsSIMD for $ty {}
330    };
331}
332
333scalar_float_impl!(f32, u32);
334scalar_float_impl!(f64, u64);
335
336#[cfg(feature = "simd_support")]
337macro_rules! simd_impl {
338    ($fty:ident, $uty:ident) => {
339        impl<const LANES: usize> FloatSIMDUtils for Simd<$fty, LANES>
340        where
341            LaneCount<LANES>: SupportedLaneCount,
342        {
343            type Mask = Mask<<$fty as SimdElement>::Mask, LANES>;
344            type UInt = Simd<$uty, LANES>;
345
346            #[inline(always)]
347            fn all_lt(self, other: Self) -> bool {
348                self.simd_lt(other).all()
349            }
350
351            #[inline(always)]
352            fn all_le(self, other: Self) -> bool {
353                self.simd_le(other).all()
354            }
355
356            #[inline(always)]
357            fn all_finite(self) -> bool {
358                self.is_finite().all()
359            }
360
361            #[inline(always)]
362            fn gt_mask(self, other: Self) -> Self::Mask {
363                self.simd_gt(other)
364            }
365
366            #[inline(always)]
367            fn decrease_masked(self, mask: Self::Mask) -> Self {
368                debug_assert!(mask.any(), "At least one lane must be set");
375                Self::from_bits(self.to_bits() + mask.to_int().cast())
376            }
377
378            #[inline]
379            fn cast_from_int(i: Self::UInt) -> Self {
380                i.cast()
381            }
382        }
383
384        #[cfg(test)]
385        impl<const LANES: usize> FloatSIMDScalarUtils for Simd<$fty, LANES>
386        where
387            LaneCount<LANES>: SupportedLaneCount,
388        {
389            type Scalar = $fty;
390
391            #[inline]
392            fn replace(mut self, index: usize, new_value: Self::Scalar) -> Self {
393                self.as_mut_array()[index] = new_value;
394                self
395            }
396
397            #[inline]
398            fn extract_lane(self, index: usize) -> Self::Scalar {
399                self.as_array()[index]
400            }
401        }
402    };
403}
404
405#[cfg(feature = "simd_support")]
406simd_impl!(f32, u32);
407#[cfg(feature = "simd_support")]
408simd_impl!(f64, u64);