miri/shims/x86/avx2.rs
1use rustc_abi::CanonAbi;
2use rustc_middle::mir;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8 ShiftOp, horizontal_bin_op, mpsadbw, packssdw, packsswb, packusdw, packuswb, permute, pmaddbw,
9 pmulhrsw, psadbw, psign, shift_simd_by_scalar,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15 fn emulate_x86_avx2_intrinsic(
16 &mut self,
17 link_name: Symbol,
18 abi: &FnAbi<'tcx, Ty<'tcx>>,
19 args: &[OpTy<'tcx>],
20 dest: &MPlaceTy<'tcx>,
21 ) -> InterpResult<'tcx, EmulateItemResult> {
22 let this = self.eval_context_mut();
23 this.expect_target_feature_for_intrinsic(link_name, "avx2")?;
24 // Prefix should have already been checked.
25 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx2.").unwrap();
26
27 match unprefixed_name {
28 // Used to implement the _mm256_h{adds,subs}_epi16 functions.
29 // Horizontally add / subtract with saturation adjacent 16-bit
30 // integer values in `left` and `right`.
31 "phadd.sw" | "phsub.sw" => {
32 let [left, right] =
33 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
34
35 let which = match unprefixed_name {
36 "phadd.sw" => mir::BinOp::Add,
37 "phsub.sw" => mir::BinOp::Sub,
38 _ => unreachable!(),
39 };
40
41 horizontal_bin_op(this, which, /*saturating*/ true, left, right, dest)?;
42 }
43 // Used to implement `_mm{,_mask}_{i32,i64}gather_{epi32,epi64,pd,ps}` functions
44 // Gathers elements from `slice` using `offsets * scale` as indices.
45 // When the highest bit of the corresponding element of `mask` is 0,
46 // the value is copied from `src` instead.
47 "gather.d.d" | "gather.d.d.256" | "gather.d.q" | "gather.d.q.256" | "gather.q.d"
48 | "gather.q.d.256" | "gather.q.q" | "gather.q.q.256" | "gather.d.pd"
49 | "gather.d.pd.256" | "gather.q.pd" | "gather.q.pd.256" | "gather.d.ps"
50 | "gather.d.ps.256" | "gather.q.ps" | "gather.q.ps.256" => {
51 let [src, slice, offsets, mask, scale] =
52 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
53
54 assert_eq!(dest.layout, src.layout);
55
56 let (src, _) = this.project_to_simd(src)?;
57 let (offsets, offsets_len) = this.project_to_simd(offsets)?;
58 let (mask, mask_len) = this.project_to_simd(mask)?;
59 let (dest, dest_len) = this.project_to_simd(dest)?;
60
61 // There are cases like dest: i32x4, offsets: i64x2
62 // If dest has more elements than offset, extra dest elements are filled with zero.
63 // If offsets has more elements than dest, extra offsets are ignored.
64 let actual_len = dest_len.min(offsets_len);
65
66 assert_eq!(dest_len, mask_len);
67
68 let mask_item_size = mask.layout.field(this, 0).size;
69 let high_bit_offset = mask_item_size.bits().strict_sub(1);
70
71 let scale = this.read_scalar(scale)?.to_i8()?;
72 if !matches!(scale, 1 | 2 | 4 | 8) {
73 panic!("invalid gather scale {scale}");
74 }
75 let scale = i64::from(scale);
76
77 let slice = this.read_pointer(slice)?;
78 for i in 0..actual_len {
79 let mask = this.project_index(&mask, i)?;
80 let dest = this.project_index(&dest, i)?;
81
82 if this.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
83 let offset = this.project_index(&offsets, i)?;
84 let offset =
85 i64::try_from(this.read_scalar(&offset)?.to_int(offset.layout.size)?)
86 .unwrap();
87 let ptr = slice.wrapping_signed_offset(offset.strict_mul(scale), &this.tcx);
88 // Unaligned copy, which is what we want.
89 this.mem_copy(
90 ptr,
91 dest.ptr(),
92 dest.layout.size,
93 /*nonoverlapping*/ true,
94 )?;
95 } else {
96 this.copy_op(&this.project_index(&src, i)?, &dest)?;
97 }
98 }
99 for i in actual_len..dest_len {
100 let dest = this.project_index(&dest, i)?;
101 this.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
102 }
103 }
104 // Used to implement the _mm256_maddubs_epi16 function.
105 "pmadd.ub.sw" => {
106 let [left, right] =
107 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
108
109 pmaddbw(this, left, right, dest)?;
110 }
111 // Used to implement the _mm256_mpsadbw_epu8 function.
112 // Compute the sum of absolute differences of quadruplets of unsigned
113 // 8-bit integers in `left` and `right`, and store the 16-bit results
114 // in `right`. Quadruplets are selected from `left` and `right` with
115 // offsets specified in `imm`.
116 // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8
117 "mpsadbw" => {
118 let [left, right, imm] =
119 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
120
121 mpsadbw(this, left, right, imm, dest)?;
122 }
123 // Used to implement the _mm256_mulhrs_epi16 function.
124 // Multiplies packed 16-bit signed integer values, truncates the 32-bit
125 // product to the 18 most significant bits by right-shifting, and then
126 // divides the 18-bit value by 2 (rounding to nearest) by first adding
127 // 1 and then taking the bits `1..=16`.
128 // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16
129 "pmul.hr.sw" => {
130 let [left, right] =
131 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
132
133 pmulhrsw(this, left, right, dest)?;
134 }
135 // Used to implement the _mm256_packs_epi16 function.
136 // Converts two 16-bit integer vectors to a single 8-bit integer
137 // vector with signed saturation.
138 "packsswb" => {
139 let [left, right] =
140 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
141
142 packsswb(this, left, right, dest)?;
143 }
144 // Used to implement the _mm256_packs_epi32 function.
145 // Converts two 32-bit integer vectors to a single 16-bit integer
146 // vector with signed saturation.
147 "packssdw" => {
148 let [left, right] =
149 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
150
151 packssdw(this, left, right, dest)?;
152 }
153 // Used to implement the _mm256_packus_epi16 function.
154 // Converts two 16-bit signed integer vectors to a single 8-bit
155 // unsigned integer vector with saturation.
156 "packuswb" => {
157 let [left, right] =
158 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
159
160 packuswb(this, left, right, dest)?;
161 }
162 // Used to implement the _mm256_packus_epi32 function.
163 // Concatenates two 32-bit signed integer vectors and converts
164 // the result to a 16-bit unsigned integer vector with saturation.
165 "packusdw" => {
166 let [left, right] =
167 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
168
169 packusdw(this, left, right, dest)?;
170 }
171 // Used to implement _mm256_permutevar8x32_epi32 and _mm256_permutevar8x32_ps.
172 "permd" | "permps" => {
173 let [left, right] =
174 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
175
176 permute(this, left, right, dest)?;
177 }
178 // Used to implement the _mm256_sad_epu8 function.
179 "psad.bw" => {
180 let [left, right] =
181 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
182
183 psadbw(this, left, right, dest)?
184 }
185 // Used to implement the _mm256_shuffle_epi8 intrinsic.
186 // Shuffles bytes from `left` using `right` as pattern.
187 // Each 128-bit block is shuffled independently.
188 "pshuf.b" => {
189 let [left, right] =
190 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
191
192 let (left, left_len) = this.project_to_simd(left)?;
193 let (right, right_len) = this.project_to_simd(right)?;
194 let (dest, dest_len) = this.project_to_simd(dest)?;
195
196 assert_eq!(dest_len, left_len);
197 assert_eq!(dest_len, right_len);
198
199 for i in 0..dest_len {
200 let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
201 let dest = this.project_index(&dest, i)?;
202
203 let res = if right & 0x80 == 0 {
204 // Shuffle each 128-bit (16-byte) block independently.
205 let j = u64::from(right % 16).strict_add(i & !15);
206 this.read_scalar(&this.project_index(&left, j)?)?
207 } else {
208 // If the highest bit in `right` is 1, write zero.
209 Scalar::from_u8(0)
210 };
211
212 this.write_scalar(res, &dest)?;
213 }
214 }
215 // Used to implement the _mm256_sign_epi{8,16,32} functions.
216 // Negates elements from `left` when the corresponding element in
217 // `right` is negative. If an element from `right` is zero, zero
218 // is writen to the corresponding output element.
219 // Basically, we multiply `left` with `right.signum()`.
220 "psign.b" | "psign.w" | "psign.d" => {
221 let [left, right] =
222 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
223
224 psign(this, left, right, dest)?;
225 }
226 // Used to implement the _mm256_{sll,srl,sra}_epi{16,32,64} functions
227 // (except _mm256_sra_epi64, which is not available in AVX2).
228 // Shifts N-bit packed integers in left by the amount in right.
229 // `right` is as 128-bit vector. but it is interpreted as a single
230 // 64-bit integer (remaining bits are ignored).
231 // For logic shifts, when right is larger than N - 1, zero is produced.
232 // For arithmetic shifts, when right is larger than N - 1, the sign bit
233 // is copied to remaining bits.
234 "psll.w" | "psrl.w" | "psra.w" | "psll.d" | "psrl.d" | "psra.d" | "psll.q"
235 | "psrl.q" => {
236 let [left, right] =
237 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
238
239 let which = match unprefixed_name {
240 "psll.w" | "psll.d" | "psll.q" => ShiftOp::Left,
241 "psrl.w" | "psrl.d" | "psrl.q" => ShiftOp::RightLogic,
242 "psra.w" | "psra.d" => ShiftOp::RightArith,
243 _ => unreachable!(),
244 };
245
246 shift_simd_by_scalar(this, left, right, which, dest)?;
247 }
248 _ => return interp_ok(EmulateItemResult::NotSupported),
249 }
250 interp_ok(EmulateItemResult::NeedsReturn)
251 }
252}