miri/shims/x86/
avx.rs

1use rustc_abi::CanonAbi;
2use rustc_apfloat::ieee::{Double, Single};
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use super::{
8    FloatBinOp, FloatUnaryOp, bin_op_simd_float_all, conditional_dot_product, convert_float_to_int,
9    mask_load, mask_store, round_all, test_bits_masked, test_high_bits_masked, unary_op_ps,
10};
11use crate::*;
12
13impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
14pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
15    fn emulate_x86_avx_intrinsic(
16        &mut self,
17        link_name: Symbol,
18        abi: &FnAbi<'tcx, Ty<'tcx>>,
19        args: &[OpTy<'tcx>],
20        dest: &MPlaceTy<'tcx>,
21    ) -> InterpResult<'tcx, EmulateItemResult> {
22        let this = self.eval_context_mut();
23        this.expect_target_feature_for_intrinsic(link_name, "avx")?;
24        // Prefix should have already been checked.
25        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx.").unwrap();
26
27        match unprefixed_name {
28            // Used to implement _mm256_min_ps and _mm256_max_ps functions.
29            // Note that the semantics are a bit different from Rust simd_min
30            // and simd_max intrinsics regarding handling of NaN and -0.0: Rust
31            // matches the IEEE min/max operations, while x86 has different
32            // semantics.
33            "min.ps.256" | "max.ps.256" => {
34                let [left, right] =
35                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
36
37                let which = match unprefixed_name {
38                    "min.ps.256" => FloatBinOp::Min,
39                    "max.ps.256" => FloatBinOp::Max,
40                    _ => unreachable!(),
41                };
42
43                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
44            }
45            // Used to implement _mm256_min_pd and _mm256_max_pd functions.
46            "min.pd.256" | "max.pd.256" => {
47                let [left, right] =
48                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
49
50                let which = match unprefixed_name {
51                    "min.pd.256" => FloatBinOp::Min,
52                    "max.pd.256" => FloatBinOp::Max,
53                    _ => unreachable!(),
54                };
55
56                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
57            }
58            // Used to implement the _mm256_round_ps function.
59            // Rounds the elements of `op` according to `rounding`.
60            "round.ps.256" => {
61                let [op, rounding] =
62                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
63
64                round_all::<rustc_apfloat::ieee::Single>(this, op, rounding, dest)?;
65            }
66            // Used to implement the _mm256_round_pd function.
67            // Rounds the elements of `op` according to `rounding`.
68            "round.pd.256" => {
69                let [op, rounding] =
70                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
71
72                round_all::<rustc_apfloat::ieee::Double>(this, op, rounding, dest)?;
73            }
74            // Used to implement _mm256_{rcp,rsqrt}_ps functions.
75            // Performs the operations on all components of `op`.
76            "rcp.ps.256" | "rsqrt.ps.256" => {
77                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
78
79                let which = match unprefixed_name {
80                    "rcp.ps.256" => FloatUnaryOp::Rcp,
81                    "rsqrt.ps.256" => FloatUnaryOp::Rsqrt,
82                    _ => unreachable!(),
83                };
84
85                unary_op_ps(this, which, op, dest)?;
86            }
87            // Used to implement the _mm256_dp_ps function.
88            "dp.ps.256" => {
89                let [left, right, imm] =
90                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
91
92                conditional_dot_product(this, left, right, imm, dest)?;
93            }
94            // Used to implement the _mm256_cmp_ps function.
95            // Performs a comparison operation on each component of `left`
96            // and `right`. For each component, returns 0 if false or u32::MAX
97            // if true.
98            "cmp.ps.256" => {
99                let [left, right, imm] =
100                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
101
102                let which =
103                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
104
105                bin_op_simd_float_all::<Single>(this, which, left, right, dest)?;
106            }
107            // Used to implement the _mm256_cmp_pd function.
108            // Performs a comparison operation on each component of `left`
109            // and `right`. For each component, returns 0 if false or u64::MAX
110            // if true.
111            "cmp.pd.256" => {
112                let [left, right, imm] =
113                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
114
115                let which =
116                    FloatBinOp::cmp_from_imm(this, this.read_scalar(imm)?.to_i8()?, link_name)?;
117
118                bin_op_simd_float_all::<Double>(this, which, left, right, dest)?;
119            }
120            // Used to implement the _mm256_cvtps_epi32, _mm256_cvttps_epi32, _mm256_cvtpd_epi32
121            // and _mm256_cvttpd_epi32 functions.
122            // Converts packed f32/f64 to packed i32.
123            "cvt.ps2dq.256" | "cvtt.ps2dq.256" | "cvt.pd2dq.256" | "cvtt.pd2dq.256" => {
124                let [op] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
125
126                let rnd = match unprefixed_name {
127                    // "current SSE rounding mode", assume nearest
128                    "cvt.ps2dq.256" | "cvt.pd2dq.256" => rustc_apfloat::Round::NearestTiesToEven,
129                    // always truncate
130                    "cvtt.ps2dq.256" | "cvtt.pd2dq.256" => rustc_apfloat::Round::TowardZero,
131                    _ => unreachable!(),
132                };
133
134                convert_float_to_int(this, op, rnd, dest)?;
135            }
136            // Used to implement the _mm_permutevar_ps and _mm256_permutevar_ps functions.
137            // Shuffles 32-bit floats from `data` using `control` as control. Each 128-bit
138            // chunk is shuffled independently: this means that we view the vector as a
139            // sequence of 4-element arrays, and we shuffle each of these arrays, where
140            // `control` determines which element of the current `data` array is written.
141            "vpermilvar.ps" | "vpermilvar.ps.256" => {
142                let [data, control] =
143                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
144
145                let (data, data_len) = this.project_to_simd(data)?;
146                let (control, control_len) = this.project_to_simd(control)?;
147                let (dest, dest_len) = this.project_to_simd(dest)?;
148
149                assert_eq!(dest_len, data_len);
150                assert_eq!(dest_len, control_len);
151
152                for i in 0..dest_len {
153                    let control = this.project_index(&control, i)?;
154
155                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
156                    // four 32-bit elements, only two bits from `control` are used. To read the
157                    // value from the current chunk, add the destination index truncated to a multiple
158                    // of 4.
159                    let chunk_base = i & !0b11;
160                    let src_i = u64::from(this.read_scalar(&control)?.to_u32()? & 0b11)
161                        .strict_add(chunk_base);
162
163                    this.copy_op(
164                        &this.project_index(&data, src_i)?,
165                        &this.project_index(&dest, i)?,
166                    )?;
167                }
168            }
169            // Used to implement the _mm_permutevar_pd and _mm256_permutevar_pd functions.
170            // Shuffles 64-bit floats from `left` using `right` as control. Each 128-bit
171            // chunk is shuffled independently: this means that we view the vector as
172            // a sequence of 2-element arrays, and we shuffle each of these arrays,
173            // where `right` determines which element of the current `left` array is
174            // written.
175            "vpermilvar.pd" | "vpermilvar.pd.256" => {
176                let [data, control] =
177                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
178
179                let (data, data_len) = this.project_to_simd(data)?;
180                let (control, control_len) = this.project_to_simd(control)?;
181                let (dest, dest_len) = this.project_to_simd(dest)?;
182
183                assert_eq!(dest_len, data_len);
184                assert_eq!(dest_len, control_len);
185
186                for i in 0..dest_len {
187                    let control = this.project_index(&control, i)?;
188
189                    // Each 128-bit chunk is shuffled independently. Since each chunk contains
190                    // two 64-bit elements, only the second bit from `control` is used (yes, the
191                    // second instead of the first, ask Intel). To read the value from the current
192                    // chunk, add the destination index truncated to a multiple of 2.
193                    let chunk_base = i & !1;
194                    let src_i =
195                        ((this.read_scalar(&control)?.to_u64()? >> 1) & 1).strict_add(chunk_base);
196
197                    this.copy_op(
198                        &this.project_index(&data, src_i)?,
199                        &this.project_index(&dest, i)?,
200                    )?;
201                }
202            }
203            // Used to implement the _mm_maskload_ps, _mm_maskload_pd, _mm256_maskload_ps
204            // and _mm256_maskload_pd functions.
205            // For the element `i`, if the high bit of the `i`-th element of `mask`
206            // is one, it is loaded from `ptr.wrapping_add(i)`, otherwise zero is
207            // loaded.
208            "maskload.ps" | "maskload.pd" | "maskload.ps.256" | "maskload.pd.256" => {
209                let [ptr, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
210
211                mask_load(this, ptr, mask, dest)?;
212            }
213            // Used to implement the _mm_maskstore_ps, _mm_maskstore_pd, _mm256_maskstore_ps
214            // and _mm256_maskstore_pd functions.
215            // For the element `i`, if the high bit of the element `i`-th of `mask`
216            // is one, it is stored into `ptr.wapping_add(i)`.
217            // Unlike SSE2's _mm_maskmoveu_si128, these are not non-temporal stores.
218            "maskstore.ps" | "maskstore.pd" | "maskstore.ps.256" | "maskstore.pd.256" => {
219                let [ptr, mask, value] =
220                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
221
222                mask_store(this, ptr, mask, value)?;
223            }
224            // Used to implement the _mm256_lddqu_si256 function.
225            // Reads a 256-bit vector from an unaligned pointer. This intrinsic
226            // is expected to perform better than a regular unaligned read when
227            // the data crosses a cache line, but for Miri this is just a regular
228            // unaligned read.
229            "ldu.dq.256" => {
230                let [src_ptr] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
231                let src_ptr = this.read_pointer(src_ptr)?;
232                let dest = dest.force_mplace(this)?;
233
234                // Unaligned copy, which is what we want.
235                this.mem_copy(src_ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
236            }
237            // Used to implement the _mm256_testnzc_si256 function.
238            // Tests `op & mask != 0 && op & mask != mask`
239            "ptestnzc.256" => {
240                let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
241
242                let (all_zero, masked_set) = test_bits_masked(this, op, mask)?;
243                let res = !all_zero && !masked_set;
244
245                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
246            }
247            // Used to implement the _mm256_testz_pd, _mm256_testc_pd, _mm256_testnzc_pd
248            // _mm_testnzc_pd, _mm256_testz_ps, _mm256_testc_ps, _mm256_testnzc_ps and
249            // _mm_testnzc_ps functions.
250            // Calculates two booleans:
251            // `direct`, which is true when the highest bit of each element of `op & mask` is zero.
252            // `negated`, which is true when the highest bit of each element of `!op & mask` is zero.
253            // Return `direct` (testz), `negated` (testc) or `!direct & !negated` (testnzc)
254            "vtestz.pd.256" | "vtestc.pd.256" | "vtestnzc.pd.256" | "vtestnzc.pd"
255            | "vtestz.ps.256" | "vtestc.ps.256" | "vtestnzc.ps.256" | "vtestnzc.ps" => {
256                let [op, mask] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
257
258                let (direct, negated) = test_high_bits_masked(this, op, mask)?;
259                let res = match unprefixed_name {
260                    "vtestz.pd.256" | "vtestz.ps.256" => direct,
261                    "vtestc.pd.256" | "vtestc.ps.256" => negated,
262                    "vtestnzc.pd.256" | "vtestnzc.pd" | "vtestnzc.ps.256" | "vtestnzc.ps" =>
263                        !direct && !negated,
264                    _ => unreachable!(),
265                };
266
267                this.write_scalar(Scalar::from_i32(res.into()), dest)?;
268            }
269            // Used to implement the `_mm256_zeroupper` and `_mm256_zeroall` functions.
270            // These function clear out the upper 128 bits of all avx registers or
271            // zero out all avx registers respectively.
272            "vzeroupper" | "vzeroall" => {
273                // These functions are purely a performance hint for the CPU.
274                // Any registers currently in use will be saved beforehand by the
275                // compiler, making these functions no-ops.
276
277                // The only thing that needs to be ensured is the correct calling convention.
278                let [] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
279            }
280            _ => return interp_ok(EmulateItemResult::NotSupported),
281        }
282        interp_ok(EmulateItemResult::NeedsReturn)
283    }
284}