core/portable-simd/crates/core_simd/src/simd/cmp/
ord.rs

1use crate::simd::{
2    LaneCount, Mask, Simd, SupportedLaneCount,
3    cmp::SimdPartialEq,
4    ptr::{SimdConstPtr, SimdMutPtr},
5};
6
7/// Parallel `PartialOrd`.
8pub trait SimdPartialOrd: SimdPartialEq {
9    /// Test if each element is less than the corresponding element in `other`.
10    #[must_use = "method returns a new mask and does not mutate the original value"]
11    fn simd_lt(self, other: Self) -> Self::Mask;
12
13    /// Test if each element is less than or equal to the corresponding element in `other`.
14    #[must_use = "method returns a new mask and does not mutate the original value"]
15    fn simd_le(self, other: Self) -> Self::Mask;
16
17    /// Test if each element is greater than the corresponding element in `other`.
18    #[must_use = "method returns a new mask and does not mutate the original value"]
19    fn simd_gt(self, other: Self) -> Self::Mask;
20
21    /// Test if each element is greater than or equal to the corresponding element in `other`.
22    #[must_use = "method returns a new mask and does not mutate the original value"]
23    fn simd_ge(self, other: Self) -> Self::Mask;
24}
25
26/// Parallel `Ord`.
27pub trait SimdOrd: SimdPartialOrd {
28    /// Returns the element-wise maximum with `other`.
29    #[must_use = "method returns a new vector and does not mutate the original value"]
30    fn simd_max(self, other: Self) -> Self;
31
32    /// Returns the element-wise minimum with `other`.
33    #[must_use = "method returns a new vector and does not mutate the original value"]
34    fn simd_min(self, other: Self) -> Self;
35
36    /// Restrict each element to a certain interval.
37    ///
38    /// For each element, returns `max` if `self` is greater than `max`, and `min` if `self` is
39    /// less than `min`. Otherwise returns `self`.
40    ///
41    /// # Panics
42    ///
43    /// Panics if `min > max` on any element.
44    #[must_use = "method returns a new vector and does not mutate the original value"]
45    fn simd_clamp(self, min: Self, max: Self) -> Self;
46}
47
48macro_rules! impl_integer {
49    { $($integer:ty),* } => {
50        $(
51        impl<const N: usize> SimdPartialOrd for Simd<$integer, N>
52        where
53            LaneCount<N>: SupportedLaneCount,
54        {
55            #[inline]
56            fn simd_lt(self, other: Self) -> Self::Mask {
57                // Safety: `self` is a vector, and the result of the comparison
58                // is always a valid mask.
59                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_lt(self, other)) }
60            }
61
62            #[inline]
63            fn simd_le(self, other: Self) -> Self::Mask {
64                // Safety: `self` is a vector, and the result of the comparison
65                // is always a valid mask.
66                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_le(self, other)) }
67            }
68
69            #[inline]
70            fn simd_gt(self, other: Self) -> Self::Mask {
71                // Safety: `self` is a vector, and the result of the comparison
72                // is always a valid mask.
73                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_gt(self, other)) }
74            }
75
76            #[inline]
77            fn simd_ge(self, other: Self) -> Self::Mask {
78                // Safety: `self` is a vector, and the result of the comparison
79                // is always a valid mask.
80                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_ge(self, other)) }
81            }
82        }
83
84        impl<const N: usize> SimdOrd for Simd<$integer, N>
85        where
86            LaneCount<N>: SupportedLaneCount,
87        {
88            #[inline]
89            fn simd_max(self, other: Self) -> Self {
90                self.simd_lt(other).select(other, self)
91            }
92
93            #[inline]
94            fn simd_min(self, other: Self) -> Self {
95                self.simd_gt(other).select(other, self)
96            }
97
98            #[inline]
99            #[track_caller]
100            fn simd_clamp(self, min: Self, max: Self) -> Self {
101                assert!(
102                    min.simd_le(max).all(),
103                    "each element in `min` must be less than or equal to the corresponding element in `max`",
104                );
105                self.simd_max(min).simd_min(max)
106            }
107        }
108        )*
109    }
110}
111
112impl_integer! { u8, u16, u32, u64, usize, i8, i16, i32, i64, isize }
113
114macro_rules! impl_float {
115    { $($float:ty),* } => {
116        $(
117        impl<const N: usize> SimdPartialOrd for Simd<$float, N>
118        where
119            LaneCount<N>: SupportedLaneCount,
120        {
121            #[inline]
122            fn simd_lt(self, other: Self) -> Self::Mask {
123                // Safety: `self` is a vector, and the result of the comparison
124                // is always a valid mask.
125                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_lt(self, other)) }
126            }
127
128            #[inline]
129            fn simd_le(self, other: Self) -> Self::Mask {
130                // Safety: `self` is a vector, and the result of the comparison
131                // is always a valid mask.
132                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_le(self, other)) }
133            }
134
135            #[inline]
136            fn simd_gt(self, other: Self) -> Self::Mask {
137                // Safety: `self` is a vector, and the result of the comparison
138                // is always a valid mask.
139                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_gt(self, other)) }
140            }
141
142            #[inline]
143            fn simd_ge(self, other: Self) -> Self::Mask {
144                // Safety: `self` is a vector, and the result of the comparison
145                // is always a valid mask.
146                unsafe { Mask::from_int_unchecked(core::intrinsics::simd::simd_ge(self, other)) }
147            }
148        }
149        )*
150    }
151}
152
153impl_float! { f32, f64 }
154
155macro_rules! impl_mask {
156    { $($integer:ty),* } => {
157        $(
158        impl<const N: usize> SimdPartialOrd for Mask<$integer, N>
159        where
160            LaneCount<N>: SupportedLaneCount,
161        {
162            #[inline]
163            fn simd_lt(self, other: Self) -> Self::Mask {
164                // Safety: `self` is a vector, and the result of the comparison
165                // is always a valid mask.
166                unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_lt(self.to_int(), other.to_int())) }
167            }
168
169            #[inline]
170            fn simd_le(self, other: Self) -> Self::Mask {
171                // Safety: `self` is a vector, and the result of the comparison
172                // is always a valid mask.
173                unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_le(self.to_int(), other.to_int())) }
174            }
175
176            #[inline]
177            fn simd_gt(self, other: Self) -> Self::Mask {
178                // Safety: `self` is a vector, and the result of the comparison
179                // is always a valid mask.
180                unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_gt(self.to_int(), other.to_int())) }
181            }
182
183            #[inline]
184            fn simd_ge(self, other: Self) -> Self::Mask {
185                // Safety: `self` is a vector, and the result of the comparison
186                // is always a valid mask.
187                unsafe { Self::from_int_unchecked(core::intrinsics::simd::simd_ge(self.to_int(), other.to_int())) }
188            }
189        }
190
191        impl<const N: usize> SimdOrd for Mask<$integer, N>
192        where
193            LaneCount<N>: SupportedLaneCount,
194        {
195            #[inline]
196            fn simd_max(self, other: Self) -> Self {
197                self.simd_gt(other).select_mask(other, self)
198            }
199
200            #[inline]
201            fn simd_min(self, other: Self) -> Self {
202                self.simd_lt(other).select_mask(other, self)
203            }
204
205            #[inline]
206            #[track_caller]
207            fn simd_clamp(self, min: Self, max: Self) -> Self {
208                assert!(
209                    min.simd_le(max).all(),
210                    "each element in `min` must be less than or equal to the corresponding element in `max`",
211                );
212                self.simd_max(min).simd_min(max)
213            }
214        }
215        )*
216    }
217}
218
219impl_mask! { i8, i16, i32, i64, isize }
220
221impl<T, const N: usize> SimdPartialOrd for Simd<*const T, N>
222where
223    LaneCount<N>: SupportedLaneCount,
224{
225    #[inline]
226    fn simd_lt(self, other: Self) -> Self::Mask {
227        self.addr().simd_lt(other.addr())
228    }
229
230    #[inline]
231    fn simd_le(self, other: Self) -> Self::Mask {
232        self.addr().simd_le(other.addr())
233    }
234
235    #[inline]
236    fn simd_gt(self, other: Self) -> Self::Mask {
237        self.addr().simd_gt(other.addr())
238    }
239
240    #[inline]
241    fn simd_ge(self, other: Self) -> Self::Mask {
242        self.addr().simd_ge(other.addr())
243    }
244}
245
246impl<T, const N: usize> SimdOrd for Simd<*const T, N>
247where
248    LaneCount<N>: SupportedLaneCount,
249{
250    #[inline]
251    fn simd_max(self, other: Self) -> Self {
252        self.simd_lt(other).select(other, self)
253    }
254
255    #[inline]
256    fn simd_min(self, other: Self) -> Self {
257        self.simd_gt(other).select(other, self)
258    }
259
260    #[inline]
261    #[track_caller]
262    fn simd_clamp(self, min: Self, max: Self) -> Self {
263        assert!(
264            min.simd_le(max).all(),
265            "each element in `min` must be less than or equal to the corresponding element in `max`",
266        );
267        self.simd_max(min).simd_min(max)
268    }
269}
270
271impl<T, const N: usize> SimdPartialOrd for Simd<*mut T, N>
272where
273    LaneCount<N>: SupportedLaneCount,
274{
275    #[inline]
276    fn simd_lt(self, other: Self) -> Self::Mask {
277        self.addr().simd_lt(other.addr())
278    }
279
280    #[inline]
281    fn simd_le(self, other: Self) -> Self::Mask {
282        self.addr().simd_le(other.addr())
283    }
284
285    #[inline]
286    fn simd_gt(self, other: Self) -> Self::Mask {
287        self.addr().simd_gt(other.addr())
288    }
289
290    #[inline]
291    fn simd_ge(self, other: Self) -> Self::Mask {
292        self.addr().simd_ge(other.addr())
293    }
294}
295
296impl<T, const N: usize> SimdOrd for Simd<*mut T, N>
297where
298    LaneCount<N>: SupportedLaneCount,
299{
300    #[inline]
301    fn simd_max(self, other: Self) -> Self {
302        self.simd_lt(other).select(other, self)
303    }
304
305    #[inline]
306    fn simd_min(self, other: Self) -> Self {
307        self.simd_gt(other).select(other, self)
308    }
309
310    #[inline]
311    #[track_caller]
312    fn simd_clamp(self, min: Self, max: Self) -> Self {
313        assert!(
314            min.simd_le(max).all(),
315            "each element in `min` must be less than or equal to the corresponding element in `max`",
316        );
317        self.simd_max(min).simd_min(max)
318    }
319}