core/portable-simd/crates/core_simd/src/masks/
full_masks.rs

1//! Masks that take up full SIMD vector registers.
2
3use crate::simd::{LaneCount, MaskElement, Simd, SupportedLaneCount};
4
5#[repr(transparent)]
6pub(crate) struct Mask<T, const N: usize>(Simd<T, N>)
7where
8    T: MaskElement,
9    LaneCount<N>: SupportedLaneCount;
10
11impl<T, const N: usize> Copy for Mask<T, N>
12where
13    T: MaskElement,
14    LaneCount<N>: SupportedLaneCount,
15{
16}
17
18impl<T, const N: usize> Clone for Mask<T, N>
19where
20    T: MaskElement,
21    LaneCount<N>: SupportedLaneCount,
22{
23    #[inline]
24    fn clone(&self) -> Self {
25        *self
26    }
27}
28
29impl<T, const N: usize> PartialEq for Mask<T, N>
30where
31    T: MaskElement + PartialEq,
32    LaneCount<N>: SupportedLaneCount,
33{
34    #[inline]
35    fn eq(&self, other: &Self) -> bool {
36        self.0.eq(&other.0)
37    }
38}
39
40impl<T, const N: usize> PartialOrd for Mask<T, N>
41where
42    T: MaskElement + PartialOrd,
43    LaneCount<N>: SupportedLaneCount,
44{
45    #[inline]
46    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
47        self.0.partial_cmp(&other.0)
48    }
49}
50
51impl<T, const N: usize> Eq for Mask<T, N>
52where
53    T: MaskElement + Eq,
54    LaneCount<N>: SupportedLaneCount,
55{
56}
57
58impl<T, const N: usize> Ord for Mask<T, N>
59where
60    T: MaskElement + Ord,
61    LaneCount<N>: SupportedLaneCount,
62{
63    #[inline]
64    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
65        self.0.cmp(&other.0)
66    }
67}
68
69// Used for bitmask bit order workaround
70pub(crate) trait ReverseBits {
71    // Reverse the least significant `n` bits of `self`.
72    // (Remaining bits must be 0.)
73    fn reverse_bits(self, n: usize) -> Self;
74}
75
76macro_rules! impl_reverse_bits {
77    { $($int:ty),* } => {
78        $(
79        impl ReverseBits for $int {
80            #[inline(always)]
81            fn reverse_bits(self, n: usize) -> Self {
82                let rev = <$int>::reverse_bits(self);
83                let bitsize = size_of::<$int>() * 8;
84                if n < bitsize {
85                    // Shift things back to the right
86                    rev >> (bitsize - n)
87                } else {
88                    rev
89                }
90            }
91        }
92        )*
93    }
94}
95
96impl_reverse_bits! { u8, u16, u32, u64 }
97
98impl<T, const N: usize> Mask<T, N>
99where
100    T: MaskElement,
101    LaneCount<N>: SupportedLaneCount,
102{
103    #[inline]
104    #[must_use = "method returns a new mask and does not mutate the original value"]
105    pub(crate) fn splat(value: bool) -> Self {
106        Self(Simd::splat(if value { T::TRUE } else { T::FALSE }))
107    }
108
109    #[inline]
110    #[must_use = "method returns a new bool and does not mutate the original value"]
111    pub(crate) unsafe fn test_unchecked(&self, lane: usize) -> bool {
112        T::eq(self.0[lane], T::TRUE)
113    }
114
115    #[inline]
116    pub(crate) unsafe fn set_unchecked(&mut self, lane: usize, value: bool) {
117        self.0[lane] = if value { T::TRUE } else { T::FALSE }
118    }
119
120    #[inline]
121    #[must_use = "method returns a new vector and does not mutate the original value"]
122    pub(crate) fn to_int(self) -> Simd<T, N> {
123        self.0
124    }
125
126    #[inline]
127    #[must_use = "method returns a new mask and does not mutate the original value"]
128    pub(crate) unsafe fn from_int_unchecked(value: Simd<T, N>) -> Self {
129        Self(value)
130    }
131
132    #[inline]
133    #[must_use = "method returns a new mask and does not mutate the original value"]
134    pub(crate) fn convert<U>(self) -> Mask<U, N>
135    where
136        U: MaskElement,
137    {
138        // Safety: masks are simply integer vectors of 0 and -1, and we can cast the element type.
139        unsafe { Mask(core::intrinsics::simd::simd_cast(self.0)) }
140    }
141
142    #[inline]
143    unsafe fn to_bitmask_impl<U: ReverseBits, const M: usize>(self) -> U
144    where
145        LaneCount<M>: SupportedLaneCount,
146    {
147        let resized = self.to_int().resize::<M>(T::FALSE);
148
149        // Safety: `resized` is an integer vector with length M, which must match T
150        let bitmask: U = unsafe { core::intrinsics::simd::simd_bitmask(resized) };
151
152        // LLVM assumes bit order should match endianness
153        if cfg!(target_endian = "big") {
154            bitmask.reverse_bits(M)
155        } else {
156            bitmask
157        }
158    }
159
160    #[inline]
161    unsafe fn from_bitmask_impl<U: ReverseBits, const M: usize>(bitmask: U) -> Self
162    where
163        LaneCount<M>: SupportedLaneCount,
164    {
165        // LLVM assumes bit order should match endianness
166        let bitmask = if cfg!(target_endian = "big") {
167            bitmask.reverse_bits(M)
168        } else {
169            bitmask
170        };
171
172        // SAFETY: `mask` is the correct bitmask type for a u64 bitmask
173        let mask: Simd<T, M> = unsafe {
174            core::intrinsics::simd::simd_select_bitmask(
175                bitmask,
176                Simd::<T, M>::splat(T::TRUE),
177                Simd::<T, M>::splat(T::FALSE),
178            )
179        };
180
181        // SAFETY: `mask` only contains `T::TRUE` or `T::FALSE`
182        unsafe { Self::from_int_unchecked(mask.resize::<N>(T::FALSE)) }
183    }
184
185    #[inline]
186    pub(crate) fn to_bitmask_integer(self) -> u64 {
187        // TODO modify simd_bitmask to zero-extend output, making this unnecessary
188        if N <= 8 {
189            // Safety: bitmask matches length
190            unsafe { self.to_bitmask_impl::<u8, 8>() as u64 }
191        } else if N <= 16 {
192            // Safety: bitmask matches length
193            unsafe { self.to_bitmask_impl::<u16, 16>() as u64 }
194        } else if N <= 32 {
195            // Safety: bitmask matches length
196            unsafe { self.to_bitmask_impl::<u32, 32>() as u64 }
197        } else {
198            // Safety: bitmask matches length
199            unsafe { self.to_bitmask_impl::<u64, 64>() }
200        }
201    }
202
203    #[inline]
204    pub(crate) fn from_bitmask_integer(bitmask: u64) -> Self {
205        // TODO modify simd_bitmask_select to truncate input, making this unnecessary
206        if N <= 8 {
207            // Safety: bitmask matches length
208            unsafe { Self::from_bitmask_impl::<u8, 8>(bitmask as u8) }
209        } else if N <= 16 {
210            // Safety: bitmask matches length
211            unsafe { Self::from_bitmask_impl::<u16, 16>(bitmask as u16) }
212        } else if N <= 32 {
213            // Safety: bitmask matches length
214            unsafe { Self::from_bitmask_impl::<u32, 32>(bitmask as u32) }
215        } else {
216            // Safety: bitmask matches length
217            unsafe { Self::from_bitmask_impl::<u64, 64>(bitmask) }
218        }
219    }
220
221    #[inline]
222    #[must_use = "method returns a new bool and does not mutate the original value"]
223    pub(crate) fn any(self) -> bool {
224        // Safety: use `self` as an integer vector
225        unsafe { core::intrinsics::simd::simd_reduce_any(self.to_int()) }
226    }
227
228    #[inline]
229    #[must_use = "method returns a new bool and does not mutate the original value"]
230    pub(crate) fn all(self) -> bool {
231        // Safety: use `self` as an integer vector
232        unsafe { core::intrinsics::simd::simd_reduce_all(self.to_int()) }
233    }
234}
235
236impl<T, const N: usize> From<Mask<T, N>> for Simd<T, N>
237where
238    T: MaskElement,
239    LaneCount<N>: SupportedLaneCount,
240{
241    #[inline]
242    fn from(value: Mask<T, N>) -> Self {
243        value.0
244    }
245}
246
247impl<T, const N: usize> core::ops::BitAnd for Mask<T, N>
248where
249    T: MaskElement,
250    LaneCount<N>: SupportedLaneCount,
251{
252    type Output = Self;
253    #[inline]
254    fn bitand(self, rhs: Self) -> Self {
255        // Safety: `self` is an integer vector
256        unsafe { Self(core::intrinsics::simd::simd_and(self.0, rhs.0)) }
257    }
258}
259
260impl<T, const N: usize> core::ops::BitOr for Mask<T, N>
261where
262    T: MaskElement,
263    LaneCount<N>: SupportedLaneCount,
264{
265    type Output = Self;
266    #[inline]
267    fn bitor(self, rhs: Self) -> Self {
268        // Safety: `self` is an integer vector
269        unsafe { Self(core::intrinsics::simd::simd_or(self.0, rhs.0)) }
270    }
271}
272
273impl<T, const N: usize> core::ops::BitXor for Mask<T, N>
274where
275    T: MaskElement,
276    LaneCount<N>: SupportedLaneCount,
277{
278    type Output = Self;
279    #[inline]
280    fn bitxor(self, rhs: Self) -> Self {
281        // Safety: `self` is an integer vector
282        unsafe { Self(core::intrinsics::simd::simd_xor(self.0, rhs.0)) }
283    }
284}
285
286impl<T, const N: usize> core::ops::Not for Mask<T, N>
287where
288    T: MaskElement,
289    LaneCount<N>: SupportedLaneCount,
290{
291    type Output = Self;
292    #[inline]
293    fn not(self) -> Self::Output {
294        Self::splat(true) ^ self
295    }
296}