miri/shims/x86/
avx2.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::mir;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8    ShiftOp, horizontal_bin_op, mask_load, mask_store, mpsadbw, packssdw, packsswb, packusdw,
9    packuswb, pmulhrsw, psadbw, psign, shift_simd_by_scalar, shift_simd_by_simd,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15    fn emulate_x86_avx2_intrinsic(
16        &mut self,
17        link_name: Symbol,
18        abi: &FnAbi<'tcx, Ty<'tcx>>,
19        args: &[OpTy<'tcx>],
20        dest: &MPlaceTy<'tcx>,
21    ) -> InterpResult<'tcx, EmulateItemResult> {
22        let this = self.eval_context_mut();
23        this.expect_target_feature_for_intrinsic(link_name, "avx2")?;
24        // Prefix should have already been checked.
25        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx2.").unwrap();
26
27        match unprefixed_name {
28            // Used to implement the _mm256_h{adds,subs}_epi16 functions.
29            // Horizontally add / subtract with saturation adjacent 16-bit
30            // integer values in `left` and `right`.
31            "phadd.sw" | "phsub.sw" => {
32                let [left, right] =
33                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
34
35                let which = match unprefixed_name {
36                    "phadd.sw" => mir::BinOp::Add,
37                    "phsub.sw" => mir::BinOp::Sub,
38                    _ => unreachable!(),
39                };
40
41                horizontal_bin_op(this, which, /*saturating*/ true, left, right, dest)?;
42            }
43            // Used to implement `_mm{,_mask}_{i32,i64}gather_{epi32,epi64,pd,ps}` functions
44            // Gathers elements from `slice` using `offsets * scale` as indices.
45            // When the highest bit of the corresponding element of `mask` is 0,
46            // the value is copied from `src` instead.
47            "gather.d.d" | "gather.d.d.256" | "gather.d.q" | "gather.d.q.256" | "gather.q.d"
48            | "gather.q.d.256" | "gather.q.q" | "gather.q.q.256" | "gather.d.pd"
49            | "gather.d.pd.256" | "gather.q.pd" | "gather.q.pd.256" | "gather.d.ps"
50            | "gather.d.ps.256" | "gather.q.ps" | "gather.q.ps.256" => {
51                let [src, slice, offsets, mask, scale] =
52                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
53
54                assert_eq!(dest.layout, src.layout);
55
56                let (src, _) = this.project_to_simd(src)?;
57                let (offsets, offsets_len) = this.project_to_simd(offsets)?;
58                let (mask, mask_len) = this.project_to_simd(mask)?;
59                let (dest, dest_len) = this.project_to_simd(dest)?;
60
61                // There are cases like dest: i32x4, offsets: i64x2
62                // If dest has more elements than offset, extra dest elements are filled with zero.
63                // If offsets has more elements than dest, extra offsets are ignored.
64                let actual_len = dest_len.min(offsets_len);
65
66                assert_eq!(dest_len, mask_len);
67
68                let mask_item_size = mask.layout.field(this, 0).size;
69                let high_bit_offset = mask_item_size.bits().strict_sub(1);
70
71                let scale = this.read_scalar(scale)?.to_i8()?;
72                if !matches!(scale, 1 | 2 | 4 | 8) {
73                    panic!("invalid gather scale {scale}");
74                }
75                let scale = i64::from(scale);
76
77                let slice = this.read_pointer(slice)?;
78                for i in 0..actual_len {
79                    let mask = this.project_index(&mask, i)?;
80                    let dest = this.project_index(&dest, i)?;
81
82                    if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
83                        let offset = this.project_index(&offsets, i)?;
84                        let offset =
85                            i64::try_from(this.read_scalar(&offset)?.to_int(offset.layout.size)?)
86                                .unwrap();
87                        let ptr = slice.wrapping_signed_offset(offset.strict_mul(scale), &this.tcx);
88                        // Unaligned copy, which is what we want.
89                        this.mem_copy(
90                            ptr,
91                            dest.ptr(),
92                            dest.layout.size,
93                            /*nonoverlapping*/ true,
94                        )?;
95                    } else {
96                        this.copy_op(&this.project_index(&src, i)?, &dest)?;
97                    }
98                }
99                for i in actual_len..dest_len {
100                    let dest = this.project_index(&dest, i)?;
101                    this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
102                }
103            }
104            // Used to implement the _mm256_maddubs_epi16 function.
105            // Multiplies packed 8-bit unsigned integers from `left` and packed
106            // signed 8-bit integers from `right` into 16-bit signed integers. Then,
107            // the saturating sum of the products with indices `2*i` and `2*i+1`
108            // produces the output at index `i`.
109            "pmadd.ub.sw" => {
110                let [left, right] =
111                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
112
113                let (left, left_len) = this.project_to_simd(left)?;
114                let (right, right_len) = this.project_to_simd(right)?;
115                let (dest, dest_len) = this.project_to_simd(dest)?;
116
117                assert_eq!(left_len, right_len);
118                assert_eq!(dest_len.strict_mul(2), left_len);
119
120                for i in 0..dest_len {
121                    let j1 = i.strict_mul(2);
122                    let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
123                    let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;
124
125                    let j2 = j1.strict_add(1);
126                    let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
127                    let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;
128
129                    let dest = this.project_index(&dest, i)?;
130
131                    // Multiplication of a u8 and an i8 into an i16 cannot overflow.
132                    let mul1 = i16::from(left1).strict_mul(right1.into());
133                    let mul2 = i16::from(left2).strict_mul(right2.into());
134                    let res = mul1.saturating_add(mul2);
135
136                    this.write_scalar(Scalar::from_i16(res), &dest)?;
137                }
138            }
139            // Used to implement the _mm_maskload_epi32, _mm_maskload_epi64,
140            // _mm256_maskload_epi32 and _mm256_maskload_epi64 functions.
141            // For the element `i`, if the high bit of the `i`-th element of `mask`
142            // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
143            // loaded.
144            "maskload.d" | "maskload.q" | "maskload.d.256" | "maskload.q.256" => {
145                let [ptr, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
146
147                mask_load(this, ptr, mask, dest)?;
148            }
149            // Used to implement the _mm_maskstore_epi32, _mm_maskstore_epi64,
150            // _mm256_maskstore_epi32 and _mm256_maskstore_epi64 functions.
151            // For the element `i`, if the high bit of the element `i`-th of `mask`
152            // is one, it is stored into `ptr.wapping_add(i)`.
153            // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
154            "maskstore.d" | "maskstore.q" | "maskstore.d.256" | "maskstore.q.256" => {
155                let [ptr, mask, value] =
156                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
157
158                mask_store(this, ptr, mask, value)?;
159            }
160            // Used to implement the _mm256_mpsadbw_epu8 function.
161            // Compute the sum of absolute differences of quadruplets of unsigned
162            // 8-bit integers in `left` and `right`, and store the 16-bit results
163            // in `right`. Quadruplets are selected from `left` and `right` with
164            // offsets specified in `imm`.
165            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8
166            "mpsadbw" => {
167                let [left, right, imm] =
168                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
169
170                mpsadbw(this, left, right, imm, dest)?;
171            }
172            // Used to implement the _mm256_mulhrs_epi16 function.
173            // Multiplies packed 16-bit signed integer values, truncates the 32-bit
174            // product to the 18 most significant bits by right-shifting, and then
175            // divides the 18-bit value by 2 (rounding to nearest) by first adding
176            // 1 and then taking the bits `1..=16`.
177            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16
178            "pmul.hr.sw" => {
179                let [left, right] =
180                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
181
182                pmulhrsw(this, left, right, dest)?;
183            }
184            // Used to implement the _mm256_packs_epi16 function.
185            // Converts two 16-bit integer vectors to a single 8-bit integer
186            // vector with signed saturation.
187            "packsswb" => {
188                let [left, right] =
189                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
190
191                packsswb(this, left, right, dest)?;
192            }
193            // Used to implement the _mm256_packs_epi32 function.
194            // Converts two 32-bit integer vectors to a single 16-bit integer
195            // vector with signed saturation.
196            "packssdw" => {
197                let [left, right] =
198                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
199
200                packssdw(this, left, right, dest)?;
201            }
202            // Used to implement the _mm256_packus_epi16 function.
203            // Converts two 16-bit signed integer vectors to a single 8-bit
204            // unsigned integer vector with saturation.
205            "packuswb" => {
206                let [left, right] =
207                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
208
209                packuswb(this, left, right, dest)?;
210            }
211            // Used to implement the _mm256_packus_epi32 function.
212            // Concatenates two 32-bit signed integer vectors and converts
213            // the result to a 16-bit unsigned integer vector with saturation.
214            "packusdw" => {
215                let [left, right] =
216                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
217
218                packusdw(this, left, right, dest)?;
219            }
220            // Used to implement the _mm256_permutevar8x32_epi32 and
221            // _mm256_permutevar8x32_ps function.
222            // Shuffles `left` using the three low bits of each element of `right`
223            // as indices.
224            "permd" | "permps" => {
225                let [left, right] =
226                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
227
228                let (left, left_len) = this.project_to_simd(left)?;
229                let (right, right_len) = this.project_to_simd(right)?;
230                let (dest, dest_len) = this.project_to_simd(dest)?;
231
232                assert_eq!(dest_len, left_len);
233                assert_eq!(dest_len, right_len);
234
235                for i in 0..dest_len {
236                    let dest = this.project_index(&dest, i)?;
237                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
238                    let left = this.project_index(&left, (right & 0b111).into())?;
239
240                    this.copy_op(&left, &dest)?;
241                }
242            }
243            // Used to implement the _mm256_sad_epu8 function.
244            "psad.bw" => {
245                let [left, right] =
246                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
247
248                psadbw(this, left, right, dest)?
249            }
250            // Used to implement the _mm256_shuffle_epi8 intrinsic.
251            // Shuffles bytes from `left` using `right` as pattern.
252            // Each 128-bit block is shuffled independently.
253            "pshuf.b" => {
254                let [left, right] =
255                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
256
257                let (left, left_len) = this.project_to_simd(left)?;
258                let (right, right_len) = this.project_to_simd(right)?;
259                let (dest, dest_len) = this.project_to_simd(dest)?;
260
261                assert_eq!(dest_len, left_len);
262                assert_eq!(dest_len, right_len);
263
264                for i in 0..dest_len {
265                    let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
266                    let dest = this.project_index(&dest, i)?;
267
268                    let res = if right & 0x80 == 0 {
269                        // Shuffle each 128-bit (16-byte) block independently.
270                        let j = u64::from(right % 16).strict_add(i & !15);
271                        this.read_scalar(&this.project_index(&left, j)?)?
272                    } else {
273                        // If the highest bit in `right` is 1, write zero.
274                        Scalar::from_u8(0)
275                    };
276
277                    this.write_scalar(res, &dest)?;
278                }
279            }
280            // Used to implement the _mm256_sign_epi{8,16,32} functions.
281            // Negates elements from `left` when the corresponding element in
282            // `right` is negative. If an element from `right` is zero, zero
283            // is writen to the corresponding output element.
284            // Basically, we multiply `left` with `right.signum()`.
285            "psign.b" | "psign.w" | "psign.d" => {
286                let [left, right] =
287                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
288
289                psign(this, left, right, dest)?;
290            }
291            // Used to implement the _mm256_{sll,srl,sra}_epi{16,32,64} functions
292            // (except _mm256_sra_epi64, which is not available in AVX2).
293            // Shifts N-bit packed integers in left by the amount in right.
294            // `right` is as 128-bit vector. but it is interpreted as a single
295            // 64-bit integer (remaining bits are ignored).
296            // For logic shifts, when right is larger than N - 1, zero is produced.
297            // For arithmetic shifts, when right is larger than N - 1, the sign bit
298            // is copied to remaining bits.
299            "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
300            | "psrl.q" => {
301                let [left, right] =
302                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
303
304                let which = match unprefixed_name {
305                    "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
306                    "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
307                    "psra.w" | "psra.d" => ShiftOp::RightArith,
308                    _ => unreachable!(),
309                };
310
311                shift_simd_by_scalar(this, left, right, which, dest)?;
312            }
313            // Used to implement the _mm{,256}_{sllv,srlv,srav}_epi{32,64} functions
314            // (except _mm{,256}_srav_epi64, which are not available in AVX2).
315            "psllv.d" | "psllv.d.256" | "psllv.q" | "psllv.q.256" | "psrlv.d" | "psrlv.d.256"
316            | "psrlv.q" | "psrlv.q.256" | "psrav.d" | "psrav.d.256" => {
317                let [left, right] =
318                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
319
320                let which = match unprefixed_name {
321                    "psllv.d" | "psllv.d.256" | "psllv.q" | "psllv.q.256" => ShiftOp::Left,
322                    "psrlv.d" | "psrlv.d.256" | "psrlv.q" | "psrlv.q.256" => ShiftOp::RightLogic,
323                    "psrav.d" | "psrav.d.256" => ShiftOp::RightArith,
324                    _ => unreachable!(),
325                };
326
327                shift_simd_by_simd(this, left, right, which, dest)?;
328            }
329            _ => return interp_ok(EmulateItemResult::NotSupported),
330        }
331        interp_ok(EmulateItemResult::NeedsReturn)
332    }
333}