1use rustc_middle::mir;
2use rustc_middle::ty::Ty;
3use rustc_middle::ty::layout::LayoutOf as _;
4use rustc_span::Symbol;
5use rustc_target::callconv::{Conv, FnAbi};
6
7use super::{
8 ShiftOp, horizontal_bin_op, int_abs, mask_load, mask_store, mpsadbw, packssdw, packsswb,
9 packusdw, packuswb, pmulhrsw, psign, shift_simd_by_scalar, shift_simd_by_simd,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15 fn emulate_x86_avx2_intrinsic(
16 &mut self,
17 link_name: Symbol,
18 abi: &FnAbi<'tcx, Ty<'tcx>>,
19 args: &[OpTy<'tcx>],
20 dest: &MPlaceTy<'tcx>,
21 ) -> InterpResult<'tcx, EmulateItemResult> {
22 let this = self.eval_context_mut();
23 this.expect_target_feature_for_intrinsic(link_name, "avx2")?;
24 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx2.").unwrap();
26
27 match unprefixed_name {
28 "pabs.b" | "pabs.w" | "pabs.d" => {
31 let [op] = this.check_shim(abi, Conv::C, link_name, args)?;
32
33 int_abs(this, op, dest)?;
34 }
35 "phadd.w" | "phadd.sw" | "phadd.d" | "phsub.w" | "phsub.sw" | "phsub.d" => {
39 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
40
41 let (which, saturating) = match unprefixed_name {
42 "phadd.w" | "phadd.d" => (mir::BinOp::Add, false),
43 "phadd.sw" => (mir::BinOp::Add, true),
44 "phsub.w" | "phsub.d" => (mir::BinOp::Sub, false),
45 "phsub.sw" => (mir::BinOp::Sub, true),
46 _ => unreachable!(),
47 };
48
49 horizontal_bin_op(this, which, saturating, left, right, dest)?;
50 }
51 "gather.d.d" | "gather.d.d.256" | "gather.d.q" | "gather.d.q.256" | "gather.q.d"
56 | "gather.q.d.256" | "gather.q.q" | "gather.q.q.256" | "gather.d.pd"
57 | "gather.d.pd.256" | "gather.q.pd" | "gather.q.pd.256" | "gather.d.ps"
58 | "gather.d.ps.256" | "gather.q.ps" | "gather.q.ps.256" => {
59 let [src, slice, offsets, mask, scale] =
60 this.check_shim(abi, Conv::C, link_name, args)?;
61
62 assert_eq!(dest.layout, src.layout);
63
64 let (src, _) = this.project_to_simd(src)?;
65 let (offsets, offsets_len) = this.project_to_simd(offsets)?;
66 let (mask, mask_len) = this.project_to_simd(mask)?;
67 let (dest, dest_len) = this.project_to_simd(dest)?;
68
69 let actual_len = dest_len.min(offsets_len);
73
74 assert_eq!(dest_len, mask_len);
75
76 let mask_item_size = mask.layout.field(this, 0).size;
77 let high_bit_offset = mask_item_size.bits().strict_sub(1);
78
79 let scale = this.read_scalar(scale)?.to_i8()?;
80 if !matches!(scale, 1 | 2 | 4 | 8) {
81 panic!("invalid gather scale {scale}");
82 }
83 let scale = i64::from(scale);
84
85 let slice = this.read_pointer(slice)?;
86 for i in 0..actual_len {
87 let mask = this.project_index(&mask, i)?;
88 let dest = this.project_index(&dest, i)?;
89
90 if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
91 let offset = this.project_index(&offsets, i)?;
92 let offset =
93 i64::try_from(this.read_scalar(&offset)?.to_int(offset.layout.size)?)
94 .unwrap();
95 let ptr = slice.wrapping_signed_offset(offset.strict_mul(scale), &this.tcx);
96 this.mem_copy(
98 ptr,
99 dest.ptr(),
100 dest.layout.size,
101 true,
102 )?;
103 } else {
104 this.copy_op(&this.project_index(&src, i)?, &dest)?;
105 }
106 }
107 for i in actual_len..dest_len {
108 let dest = this.project_index(&dest, i)?;
109 this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
110 }
111 }
112 "pmadd.wd" => {
117 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
118
119 let (left, left_len) = this.project_to_simd(left)?;
120 let (right, right_len) = this.project_to_simd(right)?;
121 let (dest, dest_len) = this.project_to_simd(dest)?;
122
123 assert_eq!(left_len, right_len);
124 assert_eq!(dest_len.strict_mul(2), left_len);
125
126 for i in 0..dest_len {
127 let j1 = i.strict_mul(2);
128 let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_i16()?;
129 let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i16()?;
130
131 let j2 = j1.strict_add(1);
132 let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_i16()?;
133 let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i16()?;
134
135 let dest = this.project_index(&dest, i)?;
136
137 let mul1 = i32::from(left1).strict_mul(right1.into());
139 let mul2 = i32::from(left2).strict_mul(right2.into());
140 let res = mul1.wrapping_add(mul2);
143
144 this.write_scalar(Scalar::from_i32(res), &dest)?;
145 }
146 }
147 "pmadd.ub.sw" => {
153 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
154
155 let (left, left_len) = this.project_to_simd(left)?;
156 let (right, right_len) = this.project_to_simd(right)?;
157 let (dest, dest_len) = this.project_to_simd(dest)?;
158
159 assert_eq!(left_len, right_len);
160 assert_eq!(dest_len.strict_mul(2), left_len);
161
162 for i in 0..dest_len {
163 let j1 = i.strict_mul(2);
164 let left1 = this.read_scalar(&this.project_index(&left, j1)?)?.to_u8()?;
165 let right1 = this.read_scalar(&this.project_index(&right, j1)?)?.to_i8()?;
166
167 let j2 = j1.strict_add(1);
168 let left2 = this.read_scalar(&this.project_index(&left, j2)?)?.to_u8()?;
169 let right2 = this.read_scalar(&this.project_index(&right, j2)?)?.to_i8()?;
170
171 let dest = this.project_index(&dest, i)?;
172
173 let mul1 = i16::from(left1).strict_mul(right1.into());
175 let mul2 = i16::from(left2).strict_mul(right2.into());
176 let res = mul1.saturating_add(mul2);
177
178 this.write_scalar(Scalar::from_i16(res), &dest)?;
179 }
180 }
181 "maskload.d" | "maskload.q" | "maskload.d.256" | "maskload.q.256" => {
187 let [ptr, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
188
189 mask_load(this, ptr, mask, dest)?;
190 }
191 "maskstore.d" | "maskstore.q" | "maskstore.d.256" | "maskstore.q.256" => {
197 let [ptr, mask, value] = this.check_shim(abi, Conv::C, link_name, args)?;
198
199 mask_store(this, ptr, mask, value)?;
200 }
201 "mpsadbw" => {
208 let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
209
210 mpsadbw(this, left, right, imm, dest)?;
211 }
212 "pmul.hr.sw" => {
219 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
220
221 pmulhrsw(this, left, right, dest)?;
222 }
223 "packsswb" => {
227 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
228
229 packsswb(this, left, right, dest)?;
230 }
231 "packssdw" => {
235 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
236
237 packssdw(this, left, right, dest)?;
238 }
239 "packuswb" => {
243 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
244
245 packuswb(this, left, right, dest)?;
246 }
247 "packusdw" => {
251 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
252
253 packusdw(this, left, right, dest)?;
254 }
255 "permd" | "permps" => {
260 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
261
262 let (left, left_len) = this.project_to_simd(left)?;
263 let (right, right_len) = this.project_to_simd(right)?;
264 let (dest, dest_len) = this.project_to_simd(dest)?;
265
266 assert_eq!(dest_len, left_len);
267 assert_eq!(dest_len, right_len);
268
269 for i in 0..dest_len {
270 let dest = this.project_index(&dest, i)?;
271 let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u32()?;
272 let left = this.project_index(&left, (right & 0b111).into())?;
273
274 this.copy_op(&left, &dest)?;
275 }
276 }
277 "vperm2i128" => {
280 let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
281
282 assert_eq!(left.layout.size.bits(), 256);
283 assert_eq!(right.layout.size.bits(), 256);
284 assert_eq!(dest.layout.size.bits(), 256);
285
286 let array_layout =
289 this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.i128, 2))?;
290 let left = left.transmute(array_layout, this)?;
291 let right = right.transmute(array_layout, this)?;
292 let dest = dest.transmute(array_layout, this)?;
293
294 let imm = this.read_scalar(imm)?.to_u8()?;
295
296 for i in 0..2 {
297 let dest = this.project_index(&dest, i)?;
298 let src = match (imm >> i.strict_mul(4)) & 0b11 {
299 0 => this.project_index(&left, 0)?,
300 1 => this.project_index(&left, 1)?,
301 2 => this.project_index(&right, 0)?,
302 3 => this.project_index(&right, 1)?,
303 _ => unreachable!(),
304 };
305
306 this.copy_op(&src, &dest)?;
307 }
308 }
309 "psad.bw" => {
317 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
318
319 let (left, left_len) = this.project_to_simd(left)?;
320 let (right, right_len) = this.project_to_simd(right)?;
321 let (dest, dest_len) = this.project_to_simd(dest)?;
322
323 assert_eq!(left_len, right_len);
324 assert_eq!(left_len, dest_len.strict_mul(8));
325
326 for i in 0..dest_len {
327 let dest = this.project_index(&dest, i)?;
328
329 let mut acc: u16 = 0;
330 for j in 0..8 {
331 let src_index = i.strict_mul(8).strict_add(j);
332
333 let left = this.project_index(&left, src_index)?;
334 let left = this.read_scalar(&left)?.to_u8()?;
335
336 let right = this.project_index(&right, src_index)?;
337 let right = this.read_scalar(&right)?.to_u8()?;
338
339 acc = acc.strict_add(left.abs_diff(right).into());
340 }
341
342 this.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
343 }
344 }
345 "pshuf.b" => {
349 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
350
351 let (left, left_len) = this.project_to_simd(left)?;
352 let (right, right_len) = this.project_to_simd(right)?;
353 let (dest, dest_len) = this.project_to_simd(dest)?;
354
355 assert_eq!(dest_len, left_len);
356 assert_eq!(dest_len, right_len);
357
358 for i in 0..dest_len {
359 let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
360 let dest = this.project_index(&dest, i)?;
361
362 let res = if right & 0x80 == 0 {
363 let j = u64::from(right % 16).strict_add(i & !15);
365 this.read_scalar(&this.project_index(&left, j)?)?
366 } else {
367 Scalar::from_u8(0)
369 };
370
371 this.write_scalar(res, &dest)?;
372 }
373 }
374 "psign.b" | "psign.w" | "psign.d" => {
380 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
381
382 psign(this, left, right, dest)?;
383 }
384 "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
393 | "psrl.q" => {
394 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
395
396 let which = match unprefixed_name {
397 "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
398 "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
399 "psra.w" | "psra.d" => ShiftOp::RightArith,
400 _ => unreachable!(),
401 };
402
403 shift_simd_by_scalar(this, left, right, which, dest)?;
404 }
405 "psllv.d" | "psllv.d.256" | "psllv.q" | "psllv.q.256" | "psrlv.d" | "psrlv.d.256"
408 | "psrlv.q" | "psrlv.q.256" | "psrav.d" | "psrav.d.256" => {
409 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
410
411 let which = match unprefixed_name {
412 "psllv.d" | "psllv.d.256" | "psllv.q" | "psllv.q.256" => ShiftOp::Left,
413 "psrlv.d" | "psrlv.d.256" | "psrlv.q" | "psrlv.q.256" => ShiftOp::RightLogic,
414 "psrav.d" | "psrav.d.256" => ShiftOp::RightArith,
415 _ => unreachable!(),
416 };
417
418 shift_simd_by_simd(this, left, right, which, dest)?;
419 }
420 _ => return interp_ok(EmulateItemResult::NotSupported),
421 }
422 interp_ok(EmulateItemResult::NeedsReturn)
423 }
424}