core/portable-simd/crates/core_simd/src/
swizzle_dyn.rs

1use crate::simd::{LaneCount, Simd, SupportedLaneCount};
2use core::mem;
3
4impl<const N: usize> Simd<u8, N>
5where
6    LaneCount<N>: SupportedLaneCount,
7{
8    /// Swizzle a vector of bytes according to the index vector.
9    /// Indices within range select the appropriate byte.
10    /// Indices "out of bounds" instead select 0.
11    ///
12    /// Note that the current implementation is selected during build-time
13    /// of the standard library, so `cargo build -Zbuild-std` may be necessary
14    /// to unlock better performance, especially for larger vectors.
15    /// A planned compiler improvement will enable using `#[target_feature]` instead.
16    #[inline]
17    pub fn swizzle_dyn(self, idxs: Simd<u8, N>) -> Self {
18        #![allow(unused_imports, unused_unsafe)]
19        #[cfg(all(
20            any(target_arch = "aarch64", target_arch = "arm64ec"),
21            target_endian = "little"
22        ))]
23        use core::arch::aarch64::{uint8x8_t, vqtbl1q_u8, vtbl1_u8};
24        #[cfg(all(
25            target_arch = "arm",
26            target_feature = "v7",
27            target_feature = "neon",
28            target_endian = "little"
29        ))]
30        use core::arch::arm::{uint8x8_t, vtbl1_u8};
31        #[cfg(target_arch = "wasm32")]
32        use core::arch::wasm32 as wasm;
33        #[cfg(target_arch = "wasm64")]
34        use core::arch::wasm64 as wasm;
35        #[cfg(target_arch = "x86")]
36        use core::arch::x86;
37        #[cfg(target_arch = "x86_64")]
38        use core::arch::x86_64 as x86;
39        // SAFETY: Intrinsics covered by cfg
40        unsafe {
41            match N {
42                #[cfg(all(
43                    any(
44                        target_arch = "aarch64",
45                        target_arch = "arm64ec",
46                        all(target_arch = "arm", target_feature = "v7")
47                    ),
48                    target_feature = "neon",
49                    target_endian = "little"
50                ))]
51                8 => transize(vtbl1_u8, self, idxs),
52                #[cfg(target_feature = "ssse3")]
53                16 => transize(x86::_mm_shuffle_epi8, self, zeroing_idxs(idxs)),
54                #[cfg(target_feature = "simd128")]
55                16 => transize(wasm::i8x16_swizzle, self, idxs),
56                #[cfg(all(
57                    any(target_arch = "aarch64", target_arch = "arm64ec"),
58                    target_feature = "neon",
59                    target_endian = "little"
60                ))]
61                16 => transize(vqtbl1q_u8, self, idxs),
62                #[cfg(all(
63                    target_arch = "arm",
64                    target_feature = "v7",
65                    target_feature = "neon",
66                    target_endian = "little"
67                ))]
68                16 => transize(armv7_neon_swizzle_u8x16, self, idxs),
69                #[cfg(all(target_feature = "avx2", not(target_feature = "avx512vbmi")))]
70                32 => transize(avx2_pshufb, self, idxs),
71                #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
72                32 => {
73                    // Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
74                    let swizzler = |bytes, idxs| {
75                        let mask = x86::_mm256_cmp_epu8_mask::<{ x86::_MM_CMPINT_LT }>(
76                            idxs,
77                            Simd::<u8, 32>::splat(N as u8).into(),
78                        );
79                        x86::_mm256_maskz_permutexvar_epi8(mask, idxs, bytes)
80                    };
81                    transize(swizzler, self, idxs)
82                }
83                // Notable absence: avx512bw pshufb shuffle
84                #[cfg(all(target_feature = "avx512vl", target_feature = "avx512vbmi"))]
85                64 => {
86                    // Unlike vpshufb, vpermb doesn't zero out values in the result based on the index high bit
87                    let swizzler = |bytes, idxs| {
88                        let mask = x86::_mm512_cmp_epu8_mask::<{ x86::_MM_CMPINT_LT }>(
89                            idxs,
90                            Simd::<u8, 64>::splat(N as u8).into(),
91                        );
92                        x86::_mm512_maskz_permutexvar_epi8(mask, idxs, bytes)
93                    };
94                    transize(swizzler, self, idxs)
95                }
96                _ => {
97                    let mut array = [0; N];
98                    for (i, k) in idxs.to_array().into_iter().enumerate() {
99                        if (k as usize) < N {
100                            array[i] = self[k as usize];
101                        };
102                    }
103                    array.into()
104                }
105            }
106        }
107    }
108}
109
110/// armv7 neon supports swizzling `u8x16` by swizzling two u8x8 blocks
111/// with a u8x8x2 lookup table.
112///
113/// # Safety
114/// This requires armv7 neon to work
115#[cfg(all(
116    target_arch = "arm",
117    target_feature = "v7",
118    target_feature = "neon",
119    target_endian = "little"
120))]
121unsafe fn armv7_neon_swizzle_u8x16(bytes: Simd<u8, 16>, idxs: Simd<u8, 16>) -> Simd<u8, 16> {
122    use core::arch::arm::{uint8x8x2_t, vcombine_u8, vget_high_u8, vget_low_u8, vtbl2_u8};
123    // SAFETY: Caller promised arm neon support
124    unsafe {
125        let bytes = uint8x8x2_t(vget_low_u8(bytes.into()), vget_high_u8(bytes.into()));
126        let lo = vtbl2_u8(bytes, vget_low_u8(idxs.into()));
127        let hi = vtbl2_u8(bytes, vget_high_u8(idxs.into()));
128        vcombine_u8(lo, hi).into()
129    }
130}
131
132/// "vpshufb like it was meant to be" on AVX2
133///
134/// # Safety
135/// This requires AVX2 to work
136#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
137#[target_feature(enable = "avx2")]
138#[allow(unused)]
139#[inline]
140#[allow(clippy::let_and_return)]
141unsafe fn avx2_pshufb(bytes: Simd<u8, 32>, idxs: Simd<u8, 32>) -> Simd<u8, 32> {
142    use crate::simd::cmp::SimdPartialOrd;
143    #[cfg(target_arch = "x86")]
144    use core::arch::x86;
145    #[cfg(target_arch = "x86_64")]
146    use core::arch::x86_64 as x86;
147    use x86::_mm256_permute2x128_si256 as avx2_cross_shuffle;
148    use x86::_mm256_shuffle_epi8 as avx2_half_pshufb;
149    let mid = Simd::splat(16u8);
150    let high = mid + mid;
151    // SAFETY: Caller promised AVX2
152    unsafe {
153        // This is ordering sensitive, and LLVM will order these how you put them.
154        // Most AVX2 impls use ~5 "ports", and only 1 or 2 are capable of permutes.
155        // But the "compose" step will lower to ops that can also use at least 1 other port.
156        // So this tries to break up permutes so composition flows through "open" ports.
157        // Comparative benches should be done on multiple AVX2 CPUs before reordering this
158
159        let hihi = avx2_cross_shuffle::<0x11>(bytes.into(), bytes.into());
160        let hi_shuf = Simd::from(avx2_half_pshufb(
161            hihi,        // duplicate the vector's top half
162            idxs.into(), // so that using only 4 bits of an index still picks bytes 16-31
163        ));
164        // A zero-fill during the compose step gives the "all-Neon-like" OOB-is-0 semantics
165        let compose = idxs.simd_lt(high).select(hi_shuf, Simd::splat(0));
166        let lolo = avx2_cross_shuffle::<0x00>(bytes.into(), bytes.into());
167        let lo_shuf = Simd::from(avx2_half_pshufb(lolo, idxs.into()));
168        // Repeat, then pick indices < 16, overwriting indices 0-15 from previous compose step
169        let compose = idxs.simd_lt(mid).select(lo_shuf, compose);
170        compose
171    }
172}
173
174/// This sets up a call to an architecture-specific function, and in doing so
175/// it persuades rustc that everything is the correct size. Which it is.
176/// This would not be needed if one could convince Rust that, by matching on N,
177/// N is that value, and thus it would be valid to substitute e.g. 16.
178///
179/// # Safety
180/// The correctness of this function hinges on the sizes agreeing in actuality.
181#[allow(dead_code)]
182#[inline(always)]
183unsafe fn transize<T, const N: usize>(
184    f: unsafe fn(T, T) -> T,
185    a: Simd<u8, N>,
186    b: Simd<u8, N>,
187) -> Simd<u8, N>
188where
189    LaneCount<N>: SupportedLaneCount,
190{
191    // SAFETY: Same obligation to use this function as to use mem::transmute_copy.
192    unsafe { mem::transmute_copy(&f(mem::transmute_copy(&a), mem::transmute_copy(&b))) }
193}
194
195/// Make indices that yield 0 for x86
196#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
197#[allow(unused)]
198#[inline(always)]
199fn zeroing_idxs<const N: usize>(idxs: Simd<u8, N>) -> Simd<u8, N>
200where
201    LaneCount<N>: SupportedLaneCount,
202{
203    use crate::simd::cmp::SimdPartialOrd;
204    idxs.simd_lt(Simd::splat(N as u8))
205        .select(idxs, Simd::splat(u8::MAX))
206}