miri/shims/x86/avx.rs
1use rustc_apfloat::ieee::{Double, Single};
2use rustc_middle::mir;
3use rustc_middle::ty::Ty;
4use rustc_middle::ty::layout::LayoutOf as _;
5use rustc_span::Symbol;
6use rustc_target::callconv::{Conv, FnAbi};
7
8use super::{
9 FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
10 horizontal_bin_op, mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked,
11 unary_op_ps,
12};
13use crate::*;
14
15impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
16pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
17 fn emulate_x86_avx_intrinsic(
18 &mut self,
19 link_name: Symbol,
20 abi: &FnAbi<'tcx, Ty<'tcx>>,
21 args: &[OpTy<'tcx>],
22 dest: &MPlaceTy<'tcx>,
23 ) -> InterpResult<'tcx, EmulateItemResult> {
24 let this = self.eval_context_mut();
25 this.expect_target_feature_for_intrinsic(link_name, "avx")?;
26 // Prefix should have already been checked.
27 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
28
29 match unprefixed_name {
30 // Used to implement _mm256_min_ps and _mm256_max_ps functions.
31 // Note that the semantics are a bit different from Rust simd_min
32 // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
33 // matches the IEEE min/max operations, while x86 has different
34 // semantics.
35 "min.ps.256" | "max.ps.256" => {
36 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
37
38 let which = match unprefixed_name {
39 "min.ps.256" => FloatBinOp::Min,
40 "max.ps.256" => FloatBinOp::Max,
41 _ => unreachable!(),
42 };
43
44 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
45 }
46 // Used to implement _mm256_min_pd and _mm256_max_pd functions.
47 "min.pd.256" | "max.pd.256" => {
48 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
49
50 let which = match unprefixed_name {
51 "min.pd.256" => FloatBinOp::Min,
52 "max.pd.256" => FloatBinOp::Max,
53 _ => unreachable!(),
54 };
55
56 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
57 }
58 // Used to implement the _mm256_round_ps function.
59 // Rounds the elements of `op` according to `rounding`.
60 "round.ps.256" => {
61 let [op, rounding] = this.check_shim(abi, Conv::C, link_name, args)?;
62
63 round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
64 }
65 // Used to implement the _mm256_round_pd function.
66 // Rounds the elements of `op` according to `rounding`.
67 "round.pd.256" => {
68 let [op, rounding] = this.check_shim(abi, Conv::C, link_name, args)?;
69
70 round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
71 }
72 // Used to implement _mm256_{rcp,rsqrt}_ps functions.
73 // Performs the operations on all components of `op`.
74 "rcp.ps.256" | "rsqrt.ps.256" => {
75 let [op] = this.check_shim(abi, Conv::C, link_name, args)?;
76
77 let which = match unprefixed_name {
78 "rcp.ps.256" => FloatUnaryOp::Rcp,
79 "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
80 _ => unreachable!(),
81 };
82
83 unary_op_ps(this, which, op, dest)?;
84 }
85 // Used to implement the _mm256_dp_ps function.
86 "dp.ps.256" => {
87 let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
88
89 conditional_dot_product(this, left, right, imm, dest)?;
90 }
91 // Used to implement the _mm256_h{add,sub}_p{s,d} functions.
92 // Horizontally add/subtract adjacent floating point values
93 // in `left` and `right`.
94 "hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => {
95 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
96
97 let which = match unprefixed_name {
98 "hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add,
99 "hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub,
100 _ => unreachable!(),
101 };
102
103 horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?;
104 }
105 // Used to implement the _mm256_cmp_ps function.
106 // Performs a comparison operation on each component of `left`
107 // and `right`. For each component, returns 0 if false or u32::MAX
108 // if true.
109 "cmp.ps.256" => {
110 let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
111
112 let which =
113 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
114
115 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
116 }
117 // Used to implement the _mm256_cmp_pd function.
118 // Performs a comparison operation on each component of `left`
119 // and `right`. For each component, returns 0 if false or u64::MAX
120 // if true.
121 "cmp.pd.256" => {
122 let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
123
124 let which =
125 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
126
127 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
128 }
129 // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
130 // and _mm256_cvttpd_epi32 functions.
131 // Converts packed f32/f64 to packed i32.
132 "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
133 let [op] = this.check_shim(abi, Conv::C, link_name, args)?;
134
135 let rnd = match unprefixed_name {
136 // "current SSE rounding mode", assume nearest
137 "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
138 // always truncate
139 "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
140 _ => unreachable!(),
141 };
142
143 convert_float_to_int(this, op, rnd, dest)?;
144 }
145 // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
146 // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
147 // chunk is shuffled independently: this means that we view the vector as a
148 // sequence of 4-element arrays, and we shuffle each of these arrays, where
149 // `control` determines which element of the current `data` array is written.
150 "vpermilvar.ps" | "vpermilvar.ps.256" => {
151 let [data, control] = this.check_shim(abi, Conv::C, link_name, args)?;
152
153 let (data, data_len) = this.project_to_simd(data)?;
154 let (control, control_len) = this.project_to_simd(control)?;
155 let (dest, dest_len) = this.project_to_simd(dest)?;
156
157 assert_eq!(dest_len, data_len);
158 assert_eq!(dest_len, control_len);
159
160 for i in 0..dest_len {
161 let control = this.project_index(&control, i)?;
162
163 // Each 128-bit chunk is shuffled independently. Since each chunk contains
164 // four 32-bit elements, only two bits from `control` are used. To read the
165 // value from the current chunk, add the destination index truncated to a multiple
166 // of 4.
167 let chunk_base = i & !0b11;
168 let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
169 .strict_add(chunk_base);
170
171 this.copy_op(
172 &this.project_index(&data, src_i)?,
173 &this.project_index(&dest, i)?,
174 )?;
175 }
176 }
177 // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
178 // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
179 // chunk is shuffled independently: this means that we view the vector as
180 // a sequence of 2-element arrays, and we shuffle each of these arrays,
181 // where `right` determines which element of the current `left` array is
182 // written.
183 "vpermilvar.pd" | "vpermilvar.pd.256" => {
184 let [data, control] = this.check_shim(abi, Conv::C, link_name, args)?;
185
186 let (data, data_len) = this.project_to_simd(data)?;
187 let (control, control_len) = this.project_to_simd(control)?;
188 let (dest, dest_len) = this.project_to_simd(dest)?;
189
190 assert_eq!(dest_len, data_len);
191 assert_eq!(dest_len, control_len);
192
193 for i in 0..dest_len {
194 let control = this.project_index(&control, i)?;
195
196 // Each 128-bit chunk is shuffled independently. Since each chunk contains
197 // two 64-bit elements, only the second bit from `control` is used (yes, the
198 // second instead of the first, ask Intel). To read the value from the current
199 // chunk, add the destination index truncated to a multiple of 2.
200 let chunk_base = i & !1;
201 let src_i =
202 ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
203
204 this.copy_op(
205 &this.project_index(&data, src_i)?,
206 &this.project_index(&dest, i)?,
207 )?;
208 }
209 }
210 // Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and
211 // _mm256_permute2f128_si256 functions. Regardless of the suffix in the name
212 // thay all can be considered to operate on vectors of 128-bit elements.
213 // For each 128-bit element of `dest`, copies one from `left`, `right` or
214 // zero, according to `imm`.
215 "vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => {
216 let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
217
218 assert_eq!(dest.layout, left.layout);
219 assert_eq!(dest.layout, right.layout);
220 assert_eq!(dest.layout.size.bits(), 256);
221
222 // Transmute to `[u128; 2]` to process each 128-bit chunk independently.
223 let u128x2_layout =
224 this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?;
225 let left = left.transmute(u128x2_layout, this)?;
226 let right = right.transmute(u128x2_layout, this)?;
227 let dest = dest.transmute(u128x2_layout, this)?;
228
229 let imm = this.read_scalar(imm)?.to_u8()?;
230
231 for i in 0..2 {
232 let dest = this.project_index(&dest, i)?;
233
234 let imm = match i {
235 0 => imm & 0xF,
236 1 => imm >> 4,
237 _ => unreachable!(),
238 };
239 if imm & 0b100 != 0 {
240 this.write_scalar(Scalar::from_u128(0), &dest)?;
241 } else {
242 let src = match imm {
243 0b00 => this.project_index(&left, 0)?,
244 0b01 => this.project_index(&left, 1)?,
245 0b10 => this.project_index(&right, 0)?,
246 0b11 => this.project_index(&right, 1)?,
247 _ => unreachable!(),
248 };
249 this.copy_op(&src, &dest)?;
250 }
251 }
252 }
253 // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
254 // and _mm256_maskload_pd functions.
255 // For the element `i`, if the high bit of the `i`-th element of `mask`
256 // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
257 // loaded.
258 "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
259 let [ptr, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
260
261 mask_load(this, ptr, mask, dest)?;
262 }
263 // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
264 // and _mm256_maskstore_pd functions.
265 // For the element `i`, if the high bit of the element `i`-th of `mask`
266 // is one, it is stored into `ptr.wapping_add(i)`.
267 // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
268 "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
269 let [ptr, mask, value] = this.check_shim(abi, Conv::C, link_name, args)?;
270
271 mask_store(this, ptr, mask, value)?;
272 }
273 // Used to implement the _mm256_lddqu_si256 function.
274 // Reads a 256-bit vector from an unaligned pointer. This intrinsic
275 // is expected to perform better than a regular unaligned read when
276 // the data crosses a cache line, but for Miri this is just a regular
277 // unaligned read.
278 "ldu.dq.256" => {
279 let [src_ptr] = this.check_shim(abi, Conv::C, link_name, args)?;
280 let src_ptr = this.read_pointer(src_ptr)?;
281 let dest = dest.force_mplace(this)?;
282
283 // Unaligned copy, which is what we want.
284 this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
285 }
286 // Used to implement the _mm256_testz_si256, _mm256_testc_si256 and
287 // _mm256_testnzc_si256 functions.
288 // Tests `op & mask == 0`, `op & mask == mask` or
289 // `op & mask != 0 && op & mask != mask`
290 "ptestz.256" | "ptestc.256" | "ptestnzc.256" => {
291 let [op, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
292
293 let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
294 let res = match unprefixed_name {
295 "ptestz.256" => all_zero,
296 "ptestc.256" => masked_set,
297 "ptestnzc.256" => !all_zero && !masked_set,
298 _ => unreachable!(),
299 };
300
301 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
302 }
303 // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
304 // _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps,
305 // _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and
306 // _mm_testnzc_ps functions.
307 // Calculates two booleans:
308 // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
309 // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
310 // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
311 "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd"
312 | "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256"
313 | "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => {
314 let [op, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
315
316 let (direct, negated) = test_high_bits_masked(this, op, mask)?;
317 let res = match unprefixed_name {
318 "vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct,
319 "vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated,
320 "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
321 !direct && !negated,
322 _ => unreachable!(),
323 };
324
325 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
326 }
327 // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
328 // These function clear out the upper 128 bits of all avx registers or
329 // zero out all avx registers respectively.
330 "vzeroupper" | "vzeroall" => {
331 // These functions are purely a performance hint for the CPU.
332 // Any registers currently in use will be saved beforehand by the
333 // compiler, making these functions no-ops.
334
335 // The only thing that needs to be ensured is the correct calling convention.
336 let [] = this.check_shim(abi, Conv::C, link_name, args)?;
337 }
338 _ => return interp_ok(EmulateItemResult::NotSupported),
339 }
340 interp_ok(EmulateItemResult::NeedsReturn)
341 }
342}