miri/shims/x86/avx.rs
1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::{Double, Single};
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8 FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
9 mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked, unary_op_ps,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15 fn emulate_x86_avx_intrinsic(
16 &mut self,
17 link_name: Symbol,
18 abi: &FnAbi<'tcx, Ty<'tcx>>,
19 args: &[OpTy<'tcx>],
20 dest: &MPlaceTy<'tcx>,
21 ) -> InterpResult<'tcx, EmulateItemResult> {
22 let this = self.eval_context_mut();
23 this.expect_target_feature_for_intrinsic(link_name, "avx")?;
24 // Prefix should have already been checked.
25 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
26
27 match unprefixed_name {
28 // Used to implement _mm256_min_ps and _mm256_max_ps functions.
29 // Note that the semantics are a bit different from Rust simd_min
30 // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
31 // matches the IEEE min/max operations, while x86 has different
32 // semantics.
33 "min.ps.256" | "max.ps.256" => {
34 let [left, right] =
35 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
36
37 let which = match unprefixed_name {
38 "min.ps.256" => FloatBinOp::Min,
39 "max.ps.256" => FloatBinOp::Max,
40 _ => unreachable!(),
41 };
42
43 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
44 }
45 // Used to implement _mm256_min_pd and _mm256_max_pd functions.
46 "min.pd.256" | "max.pd.256" => {
47 let [left, right] =
48 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
49
50 let which = match unprefixed_name {
51 "min.pd.256" => FloatBinOp::Min,
52 "max.pd.256" => FloatBinOp::Max,
53 _ => unreachable!(),
54 };
55
56 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
57 }
58 // Used to implement the _mm256_round_ps function.
59 // Rounds the elements of `op` according to `rounding`.
60 "round.ps.256" => {
61 let [op, rounding] =
62 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
63
64 round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
65 }
66 // Used to implement the _mm256_round_pd function.
67 // Rounds the elements of `op` according to `rounding`.
68 "round.pd.256" => {
69 let [op, rounding] =
70 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
71
72 round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
73 }
74 // Used to implement _mm256_{rcp,rsqrt}_ps functions.
75 // Performs the operations on all components of `op`.
76 "rcp.ps.256" | "rsqrt.ps.256" => {
77 let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
78
79 let which = match unprefixed_name {
80 "rcp.ps.256" => FloatUnaryOp::Rcp,
81 "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
82 _ => unreachable!(),
83 };
84
85 unary_op_ps(this, which, op, dest)?;
86 }
87 // Used to implement the _mm256_dp_ps function.
88 "dp.ps.256" => {
89 let [left, right, imm] =
90 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
91
92 conditional_dot_product(this, left, right, imm, dest)?;
93 }
94 // Used to implement the _mm256_cmp_ps function.
95 // Performs a comparison operation on each component of `left`
96 // and `right`. For each component, returns 0 if false or u32::MAX
97 // if true.
98 "cmp.ps.256" => {
99 let [left, right, imm] =
100 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
101
102 let which =
103 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
104
105 bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
106 }
107 // Used to implement the _mm256_cmp_pd function.
108 // Performs a comparison operation on each component of `left`
109 // and `right`. For each component, returns 0 if false or u64::MAX
110 // if true.
111 "cmp.pd.256" => {
112 let [left, right, imm] =
113 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
114
115 let which =
116 FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
117
118 bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
119 }
120 // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
121 // and _mm256_cvttpd_epi32 functions.
122 // Converts packed f32/f64 to packed i32.
123 "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
124 let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
125
126 let rnd = match unprefixed_name {
127 // "current SSE rounding mode", assume nearest
128 "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
129 // always truncate
130 "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
131 _ => unreachable!(),
132 };
133
134 convert_float_to_int(this, op, rnd, dest)?;
135 }
136 // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
137 // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
138 // chunk is shuffled independently: this means that we view the vector as a
139 // sequence of 4-element arrays, and we shuffle each of these arrays, where
140 // `control` determines which element of the current `data` array is written.
141 "vpermilvar.ps" | "vpermilvar.ps.256" => {
142 let [data, control] =
143 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
144
145 let (data, data_len) = this.project_to_simd(data)?;
146 let (control, control_len) = this.project_to_simd(control)?;
147 let (dest, dest_len) = this.project_to_simd(dest)?;
148
149 assert_eq!(dest_len, data_len);
150 assert_eq!(dest_len, control_len);
151
152 for i in 0..dest_len {
153 let control = this.project_index(&control, i)?;
154
155 // Each 128-bit chunk is shuffled independently. Since each chunk contains
156 // four 32-bit elements, only two bits from `control` are used. To read the
157 // value from the current chunk, add the destination index truncated to a multiple
158 // of 4.
159 let chunk_base = i & !0b11;
160 let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
161 .strict_add(chunk_base);
162
163 this.copy_op(
164 &this.project_index(&data, src_i)?,
165 &this.project_index(&dest, i)?,
166 )?;
167 }
168 }
169 // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
170 // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
171 // chunk is shuffled independently: this means that we view the vector as
172 // a sequence of 2-element arrays, and we shuffle each of these arrays,
173 // where `right` determines which element of the current `left` array is
174 // written.
175 "vpermilvar.pd" | "vpermilvar.pd.256" => {
176 let [data, control] =
177 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
178
179 let (data, data_len) = this.project_to_simd(data)?;
180 let (control, control_len) = this.project_to_simd(control)?;
181 let (dest, dest_len) = this.project_to_simd(dest)?;
182
183 assert_eq!(dest_len, data_len);
184 assert_eq!(dest_len, control_len);
185
186 for i in 0..dest_len {
187 let control = this.project_index(&control, i)?;
188
189 // Each 128-bit chunk is shuffled independently. Since each chunk contains
190 // two 64-bit elements, only the second bit from `control` is used (yes, the
191 // second instead of the first, ask Intel). To read the value from the current
192 // chunk, add the destination index truncated to a multiple of 2.
193 let chunk_base = i & !1;
194 let src_i =
195 ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
196
197 this.copy_op(
198 &this.project_index(&data, src_i)?,
199 &this.project_index(&dest, i)?,
200 )?;
201 }
202 }
203 // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
204 // and _mm256_maskload_pd functions.
205 // For the element `i`, if the high bit of the `i`-th element of `mask`
206 // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
207 // loaded.
208 "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
209 let [ptr, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
210
211 mask_load(this, ptr, mask, dest)?;
212 }
213 // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
214 // and _mm256_maskstore_pd functions.
215 // For the element `i`, if the high bit of the element `i`-th of `mask`
216 // is one, it is stored into `ptr.wapping_add(i)`.
217 // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
218 "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
219 let [ptr, mask, value] =
220 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
221
222 mask_store(this, ptr, mask, value)?;
223 }
224 // Used to implement the _mm256_lddqu_si256 function.
225 // Reads a 256-bit vector from an unaligned pointer. This intrinsic
226 // is expected to perform better than a regular unaligned read when
227 // the data crosses a cache line, but for Miri this is just a regular
228 // unaligned read.
229 "ldu.dq.256" => {
230 let [src_ptr] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
231 let src_ptr = this.read_pointer(src_ptr)?;
232 let dest = dest.force_mplace(this)?;
233
234 // Unaligned copy, which is what we want.
235 this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
236 }
237 // Used to implement the _mm256_testnzc_si256 function.
238 // Tests `op & mask != 0 && op & mask != mask`
239 "ptestnzc.256" => {
240 let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
241
242 let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
243 let res = !all_zero && !masked_set;
244
245 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
246 }
247 // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
248 // _mm_testnzc_pd, _mm256_testz_ps, _mm256_testc_ps, _mm256_testnzc_ps and
249 // _mm_testnzc_ps functions.
250 // Calculates two booleans:
251 // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
252 // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
253 // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
254 "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestnzc.pd"
255 | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256" | "vtestnzc.ps" => {
256 let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
257
258 let (direct, negated) = test_high_bits_masked(this, op, mask)?;
259 let res = match unprefixed_name {
260 "vtestz.pd.256" | "vtestz.ps.256" => direct,
261 "vtestc.pd.256" | "vtestc.ps.256" => negated,
262 "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
263 !direct && !negated,
264 _ => unreachable!(),
265 };
266
267 this.write_scalar(Scalar::from_i32(res.into()), dest)?;
268 }
269 // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
270 // These function clear out the upper 128 bits of all avx registers or
271 // zero out all avx registers respectively.
272 "vzeroupper" | "vzeroall" => {
273 // These functions are purely a performance hint for the CPU.
274 // Any registers currently in use will be saved beforehand by the
275 // compiler, making these functions no-ops.
276
277 // The only thing that needs to be ensured is the correct calling convention.
278 let [] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
279 }
280 _ => return interp_ok(EmulateItemResult::NotSupported),
281 }
282 interp_ok(EmulateItemResult::NeedsReturn)
283 }
284}