miri/shims/x86/
avx.rs

1use rustc_apfloat::ieee::{Double, Single};
2use rustc_middle::mir;
3use rustc_middle::ty::Ty;
4use rustc_middle::ty::layout::LayoutOf as _;
5use rustc_span::Symbol;
6use rustc_target::callconv::{Conv, FnAbi};
7
8use super::{
9    FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
10    horizontal_bin_op, mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked,
11    unary_op_ps,
12};
13use crate::*;
14
15impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
16pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
17    fn emulate_x86_avx_intrinsic(
18        &mut self,
19        link_name: Symbol,
20        abi: &FnAbi<'tcx, Ty<'tcx>>,
21        args: &[OpTy<'tcx>],
22        dest: &MPlaceTy<'tcx>,
23    ) -> InterpResult<'tcx, EmulateItemResult> {
24        let this = self.eval_context_mut();
25        this.expect_target_feature_for_intrinsic(link_name, "avx")?;
26        // Prefix should have already been checked.
27        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
28
29        match unprefixed_name {
30            // Used to implement _mm256_min_ps and _mm256_max_ps functions.
31            // Note that the semantics are a bit different from Rust simd_min
32            // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
33            // matches the IEEE min/max operations, while x86 has different
34            // semantics.
35            "min.ps.256" | "max.ps.256" => {
36                let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
37
38                let which = match unprefixed_name {
39                    "min.ps.256" => FloatBinOp::Min,
40                    "max.ps.256" => FloatBinOp::Max,
41                    _ => unreachable!(),
42                };
43
44                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
45            }
46            // Used to implement _mm256_min_pd and _mm256_max_pd functions.
47            "min.pd.256" | "max.pd.256" => {
48                let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
49
50                let which = match unprefixed_name {
51                    "min.pd.256" => FloatBinOp::Min,
52                    "max.pd.256" => FloatBinOp::Max,
53                    _ => unreachable!(),
54                };
55
56                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
57            }
58            // Used to implement the _mm256_round_ps function.
59            // Rounds the elements of `op` according to `rounding`.
60            "round.ps.256" => {
61                let [op, rounding] = this.check_shim(abi, Conv::C, link_name, args)?;
62
63                round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
64            }
65            // Used to implement the _mm256_round_pd function.
66            // Rounds the elements of `op` according to `rounding`.
67            "round.pd.256" => {
68                let [op, rounding] = this.check_shim(abi, Conv::C, link_name, args)?;
69
70                round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
71            }
72            // Used to implement _mm256_{rcp,rsqrt}_ps functions.
73            // Performs the operations on all components of `op`.
74            "rcp.ps.256" | "rsqrt.ps.256" => {
75                let [op] = this.check_shim(abi, Conv::C, link_name, args)?;
76
77                let which = match unprefixed_name {
78                    "rcp.ps.256" => FloatUnaryOp::Rcp,
79                    "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
80                    _ => unreachable!(),
81                };
82
83                unary_op_ps(this, which, op, dest)?;
84            }
85            // Used to implement the _mm256_dp_ps function.
86            "dp.ps.256" => {
87                let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
88
89                conditional_dot_product(this, left, right, imm, dest)?;
90            }
91            // Used to implement the _mm256_h{add,sub}_p{s,d} functions.
92            // Horizontally add/subtract adjacent floating point values
93            // in `left` and `right`.
94            "hadd.ps.256" | "hadd.pd.256" | "hsub.ps.256" | "hsub.pd.256" => {
95                let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
96
97                let which = match unprefixed_name {
98                    "hadd.ps.256" | "hadd.pd.256" => mir::BinOp::Add,
99                    "hsub.ps.256" | "hsub.pd.256" => mir::BinOp::Sub,
100                    _ => unreachable!(),
101                };
102
103                horizontal_bin_op(this, which, /*saturating*/ false, left, right, dest)?;
104            }
105            // Used to implement the _mm256_cmp_ps function.
106            // Performs a comparison operation on each component of `left`
107            // and `right`. For each component, returns 0 if false or u32::MAX
108            // if true.
109            "cmp.ps.256" => {
110                let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
111
112                let which =
113                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
114
115                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
116            }
117            // Used to implement the _mm256_cmp_pd function.
118            // Performs a comparison operation on each component of `left`
119            // and `right`. For each component, returns 0 if false or u64::MAX
120            // if true.
121            "cmp.pd.256" => {
122                let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
123
124                let which =
125                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
126
127                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
128            }
129            // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
130            // and _mm256_cvttpd_epi32 functions.
131            // Converts packed f32/f64 to packed i32.
132            "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
133                let [op] = this.check_shim(abi, Conv::C, link_name, args)?;
134
135                let rnd = match unprefixed_name {
136                    // "current SSE rounding mode", assume nearest
137                    "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
138                    // always truncate
139                    "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
140                    _ => unreachable!(),
141                };
142
143                convert_float_to_int(this, op, rnd, dest)?;
144            }
145            // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
146            // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
147            // chunk is shuffled independently: this means that we view the vector as a
148            // sequence of 4-element arrays, and we shuffle each of these arrays, where
149            // `control` determines which element of the current `data` array is written.
150            "vpermilvar.ps" | "vpermilvar.ps.256" => {
151                let [data, control] = this.check_shim(abi, Conv::C, link_name, args)?;
152
153                let (data, data_len) = this.project_to_simd(data)?;
154                let (control, control_len) = this.project_to_simd(control)?;
155                let (dest, dest_len) = this.project_to_simd(dest)?;
156
157                assert_eq!(dest_len, data_len);
158                assert_eq!(dest_len, control_len);
159
160                for i in 0..dest_len {
161                    let control = this.project_index(&control, i)?;
162
163                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
164                    // four 32-bit elements, only two bits from `control` are used. To read the
165                    // value from the current chunk, add the destination index truncated to a multiple
166                    // of 4.
167                    let chunk_base = i & !0b11;
168                    let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
169                        .strict_add(chunk_base);
170
171                    this.copy_op(
172                        &this.project_index(&data, src_i)?,
173                        &this.project_index(&dest, i)?,
174                    )?;
175                }
176            }
177            // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
178            // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
179            // chunk is shuffled independently: this means that we view the vector as
180            // a sequence of 2-element arrays, and we shuffle each of these arrays,
181            // where `right` determines which element of the current `left` array is
182            // written.
183            "vpermilvar.pd" | "vpermilvar.pd.256" => {
184                let [data, control] = this.check_shim(abi, Conv::C, link_name, args)?;
185
186                let (data, data_len) = this.project_to_simd(data)?;
187                let (control, control_len) = this.project_to_simd(control)?;
188                let (dest, dest_len) = this.project_to_simd(dest)?;
189
190                assert_eq!(dest_len, data_len);
191                assert_eq!(dest_len, control_len);
192
193                for i in 0..dest_len {
194                    let control = this.project_index(&control, i)?;
195
196                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
197                    // two 64-bit elements, only the second bit from `control` is used (yes, the
198                    // second instead of the first, ask Intel). To read the value from the current
199                    // chunk, add the destination index truncated to a multiple of 2.
200                    let chunk_base = i & !1;
201                    let src_i =
202                        ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
203
204                    this.copy_op(
205                        &this.project_index(&data, src_i)?,
206                        &this.project_index(&dest, i)?,
207                    )?;
208                }
209            }
210            // Used to implement the _mm256_permute2f128_ps, _mm256_permute2f128_pd and
211            // _mm256_permute2f128_si256 functions. Regardless of the suffix in the name
212            // thay all can be considered to operate on vectors of 128-bit elements.
213            // For each 128-bit element of `dest`, copies one from `left`, `right` or
214            // zero, according to `imm`.
215            "vperm2f128.ps.256" | "vperm2f128.pd.256" | "vperm2f128.si.256" => {
216                let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
217
218                assert_eq!(dest.layout, left.layout);
219                assert_eq!(dest.layout, right.layout);
220                assert_eq!(dest.layout.size.bits(), 256);
221
222                // Transmute to `[u128; 2]` to process each 128-bit chunk independently.
223                let u128x2_layout =
224                    this.layout_of(Ty::new_array(this.tcx.tcx, this.tcx.types.u128, 2))?;
225                let left = left.transmute(u128x2_layout, this)?;
226                let right = right.transmute(u128x2_layout, this)?;
227                let dest = dest.transmute(u128x2_layout, this)?;
228
229                let imm = this.read_scalar(imm)?.to_u8()?;
230
231                for i in 0..2 {
232                    let dest = this.project_index(&dest, i)?;
233
234                    let imm = match i {
235                        0 => imm & 0xF,
236                        1 => imm >> 4,
237                        _ => unreachable!(),
238                    };
239                    if imm & 0b100 != 0 {
240                        this.write_scalar(Scalar::from_u128(0), &dest)?;
241                    } else {
242                        let src = match imm {
243                            0b00 => this.project_index(&left, 0)?,
244                            0b01 => this.project_index(&left, 1)?,
245                            0b10 => this.project_index(&right, 0)?,
246                            0b11 => this.project_index(&right, 1)?,
247                            _ => unreachable!(),
248                        };
249                        this.copy_op(&src, &dest)?;
250                    }
251                }
252            }
253            // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
254            // and _mm256_maskload_pd functions.
255            // For the element `i`, if the high bit of the `i`-th element of `mask`
256            // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
257            // loaded.
258            "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
259                let [ptr, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
260
261                mask_load(this, ptr, mask, dest)?;
262            }
263            // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
264            // and _mm256_maskstore_pd functions.
265            // For the element `i`, if the high bit of the element `i`-th of `mask`
266            // is one, it is stored into `ptr.wapping_add(i)`.
267            // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
268            "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
269                let [ptr, mask, value] = this.check_shim(abi, Conv::C, link_name, args)?;
270
271                mask_store(this, ptr, mask, value)?;
272            }
273            // Used to implement the _mm256_lddqu_si256 function.
274            // Reads a 256-bit vector from an unaligned pointer. This intrinsic
275            // is expected to perform better than a regular unaligned read when
276            // the data crosses a cache line, but for Miri this is just a regular
277            // unaligned read.
278            "ldu.dq.256" => {
279                let [src_ptr] = this.check_shim(abi, Conv::C, link_name, args)?;
280                let src_ptr = this.read_pointer(src_ptr)?;
281                let dest = dest.force_mplace(this)?;
282
283                // Unaligned copy, which is what we want.
284                this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
285            }
286            // Used to implement the _mm256_testz_si256, _mm256_testc_si256 and
287            // _mm256_testnzc_si256 functions.
288            // Tests `op & mask == 0`, `op & mask == mask` or
289            // `op & mask != 0 && op & mask != mask`
290            "ptestz.256" | "ptestc.256" | "ptestnzc.256" => {
291                let [op, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
292
293                let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
294                let res = match unprefixed_name {
295                    "ptestz.256" => all_zero,
296                    "ptestc.256" => masked_set,
297                    "ptestnzc.256" => !all_zero && !masked_set,
298                    _ => unreachable!(),
299                };
300
301                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
302            }
303            // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
304            // _mm_testz_pd, _mm_testc_pd, _mm_testnzc_pd, _mm256_testz_ps,
305            // _mm256_testc_ps, _mm256_testnzc_ps, _mm_testz_ps, _mm_testc_ps and
306            // _mm_testnzc_ps functions.
307            // Calculates two booleans:
308            // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
309            // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
310            // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
311            "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestz.pd" | "vtestc.pd"
312            | "vtestnzc.pd" | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256"
313            | "vtestz.ps" | "vtestc.ps" | "vtestnzc.ps" => {
314                let [op, mask] = this.check_shim(abi, Conv::C, link_name, args)?;
315
316                let (direct, negated) = test_high_bits_masked(this, op, mask)?;
317                let res = match unprefixed_name {
318                    "vtestz.pd.256" | "vtestz.pd" | "vtestz.ps.256" | "vtestz.ps" => direct,
319                    "vtestc.pd.256" | "vtestc.pd" | "vtestc.ps.256" | "vtestc.ps" => negated,
320                    "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
321                        !direct && !negated,
322                    _ => unreachable!(),
323                };
324
325                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
326            }
327            // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
328            // These function clear out the upper 128 bits of all avx registers or
329            // zero out all avx registers respectively.
330            "vzeroupper" | "vzeroall" => {
331                // These functions are purely a performance hint for the CPU.
332                // Any registers currently in use will be saved beforehand by the
333                // compiler, making these functions no-ops.
334
335                // The only thing that needs to be ensured is the correct calling convention.
336                let [] = this.check_shim(abi, Conv::C, link_name, args)?;
337            }
338            _ => return interp_ok(EmulateItemResult::NotSupported),
339        }
340        interp_ok(EmulateItemResult::NeedsReturn)
341    }
342}