miri/shims/x86/
mod.rs

1use rustc_abi::Size;
2use rustc_apfloat::Float;
3use rustc_apfloat::ieee::Single;
4use rustc_middle::ty::Ty;
5use rustc_middle::ty::layout::LayoutOf as _;
6use rustc_middle::{mir, ty};
7use rustc_span::Symbol;
8use rustc_target::callconv::{Conv, FnAbi};
9
10use self::helpers::bool_to_simd_element;
11use crate::*;
12
13mod aesni;
14mod avx;
15mod avx2;
16mod bmi;
17mod gfni;
18mod sha;
19mod sse;
20mod sse2;
21mod sse3;
22mod sse41;
23mod sse42;
24mod ssse3;
25
26impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
27pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
28    fn emulate_x86_intrinsic(
29        &mut self,
30        link_name: Symbol,
31        abi: &FnAbi<'tcx, Ty<'tcx>>,
32        args: &[OpTy<'tcx>],
33        dest: &MPlaceTy<'tcx>,
34    ) -> InterpResult<'tcx, EmulateItemResult> {
35        let this = self.eval_context_mut();
36        // Prefix should have already been checked.
37        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
38        match unprefixed_name {
39            // Used to implement the `_addcarry_u{32, 64}` and the `_subborrow_u{32, 64}` functions.
40            // Computes a + b or a - b with input and output carry/borrow. The input carry/borrow is an 8-bit
41            // value, which is interpreted as 1 if it is non-zero. The output carry/borrow is an 8-bit value that will be 0 or 1.
42            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarry-u32-addcarry-u64.html
43            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/subborrow-u32-subborrow-u64.html
44            "addcarry.32" | "addcarry.64" | "subborrow.32" | "subborrow.64" => {
45                if unprefixed_name.ends_with("64") && this.tcx.sess.target.arch != "x86_64" {
46                    return interp_ok(EmulateItemResult::NotSupported);
47                }
48
49                let [cb_in, a, b] = this.check_shim(abi, Conv::C, link_name, args)?;
50                let op = if unprefixed_name.starts_with("add") {
51                    mir::BinOp::AddWithOverflow
52                } else {
53                    mir::BinOp::SubWithOverflow
54                };
55
56                let (sum, cb_out) = carrying_add(this, cb_in, a, b, op)?;
57                this.write_scalar(cb_out, &this.project_field(dest, 0)?)?;
58                this.write_immediate(*sum, &this.project_field(dest, 1)?)?;
59            }
60
61            // Used to implement the `_addcarryx_u{32, 64}` functions. They are semantically identical with the `_addcarry_u{32, 64}` functions,
62            // except for a slightly different type signature and the requirement for the "adx" target feature.
63            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarryx-u32-addcarryx-u64.html
64            "addcarryx.u32" | "addcarryx.u64" => {
65                this.expect_target_feature_for_intrinsic(link_name, "adx")?;
66
67                let is_u64 = unprefixed_name.ends_with("64");
68                if is_u64 && this.tcx.sess.target.arch != "x86_64" {
69                    return interp_ok(EmulateItemResult::NotSupported);
70                }
71                let [c_in, a, b, out] = this.check_shim(abi, Conv::C, link_name, args)?;
72                let out = this.deref_pointer_as(
73                    out,
74                    if is_u64 { this.machine.layouts.u64 } else { this.machine.layouts.u32 },
75                )?;
76
77                let (sum, c_out) = carrying_add(this, c_in, a, b, mir::BinOp::AddWithOverflow)?;
78                this.write_scalar(c_out, dest)?;
79                this.write_immediate(*sum, &out)?;
80            }
81
82            // Used to implement the `_mm_pause` function.
83            // The intrinsic is used to hint the processor that the code is in a spin-loop.
84            // It is compiled down to a `pause` instruction. When SSE2 is not available,
85            // the instruction behaves like a no-op, so it is always safe to call the
86            // intrinsic.
87            "sse2.pause" => {
88                let [] = this.check_shim(abi, Conv::C, link_name, args)?;
89                // Only exhibit the spin-loop hint behavior when SSE2 is enabled.
90                if this.tcx.sess.unstable_target_features.contains(&Symbol::intern("sse2")) {
91                    this.yield_active_thread();
92                }
93            }
94
95            "pclmulqdq" | "pclmulqdq.256" | "pclmulqdq.512" => {
96                let mut len = 2; // in units of 64bits
97                this.expect_target_feature_for_intrinsic(link_name, "pclmulqdq")?;
98                if unprefixed_name.ends_with(".256") {
99                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
100                    len = 4;
101                } else if unprefixed_name.ends_with(".512") {
102                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
103                    this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
104                    len = 8;
105                }
106
107                let [left, right, imm] = this.check_shim(abi, Conv::C, link_name, args)?;
108
109                pclmulqdq(this, left, right, imm, dest, len)?;
110            }
111
112            name if name.starts_with("bmi.") => {
113                return bmi::EvalContextExt::emulate_x86_bmi_intrinsic(
114                    this, link_name, abi, args, dest,
115                );
116            }
117            // The GFNI extension does not get its own namespace.
118            // Check for instruction names instead.
119            name if name.starts_with("vgf2p8affine") || name.starts_with("vgf2p8mulb") => {
120                return gfni::EvalContextExt::emulate_x86_gfni_intrinsic(
121                    this, link_name, abi, args, dest,
122                );
123            }
124            name if name.starts_with("sha") => {
125                return sha::EvalContextExt::emulate_x86_sha_intrinsic(
126                    this, link_name, abi, args, dest,
127                );
128            }
129            name if name.starts_with("sse.") => {
130                return sse::EvalContextExt::emulate_x86_sse_intrinsic(
131                    this, link_name, abi, args, dest,
132                );
133            }
134            name if name.starts_with("sse2.") => {
135                return sse2::EvalContextExt::emulate_x86_sse2_intrinsic(
136                    this, link_name, abi, args, dest,
137                );
138            }
139            name if name.starts_with("sse3.") => {
140                return sse3::EvalContextExt::emulate_x86_sse3_intrinsic(
141                    this, link_name, abi, args, dest,
142                );
143            }
144            name if name.starts_with("ssse3.") => {
145                return ssse3::EvalContextExt::emulate_x86_ssse3_intrinsic(
146                    this, link_name, abi, args, dest,
147                );
148            }
149            name if name.starts_with("sse41.") => {
150                return sse41::EvalContextExt::emulate_x86_sse41_intrinsic(
151                    this, link_name, abi, args, dest,
152                );
153            }
154            name if name.starts_with("sse42.") => {
155                return sse42::EvalContextExt::emulate_x86_sse42_intrinsic(
156                    this, link_name, abi, args, dest,
157                );
158            }
159            name if name.starts_with("aesni.") => {
160                return aesni::EvalContextExt::emulate_x86_aesni_intrinsic(
161                    this, link_name, abi, args, dest,
162                );
163            }
164            name if name.starts_with("avx.") => {
165                return avx::EvalContextExt::emulate_x86_avx_intrinsic(
166                    this, link_name, abi, args, dest,
167                );
168            }
169            name if name.starts_with("avx2.") => {
170                return avx2::EvalContextExt::emulate_x86_avx2_intrinsic(
171                    this, link_name, abi, args, dest,
172                );
173            }
174
175            _ => return interp_ok(EmulateItemResult::NotSupported),
176        }
177        interp_ok(EmulateItemResult::NeedsReturn)
178    }
179}
180
181#[derive(Copy, Clone)]
182enum FloatBinOp {
183    /// Comparison
184    ///
185    /// The semantics of this operator is a case distinction: we compare the two operands,
186    /// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
187    /// which class they fall into.
188    ///
189    /// AVX supports all 16 combinations, SSE only a subset
190    ///
191    /// <https://www.felixcloutier.com/x86/cmpss>
192    /// <https://www.felixcloutier.com/x86/cmpps>
193    /// <https://www.felixcloutier.com/x86/cmpsd>
194    /// <https://www.felixcloutier.com/x86/cmppd>
195    Cmp {
196        /// Result when lhs < rhs
197        gt: bool,
198        /// Result when lhs > rhs
199        lt: bool,
200        /// Result when lhs == rhs
201        eq: bool,
202        /// Result when lhs is NaN or rhs is NaN
203        unord: bool,
204    },
205    /// Minimum value (with SSE semantics)
206    ///
207    /// <https://www.felixcloutier.com/x86/minss>
208    /// <https://www.felixcloutier.com/x86/minps>
209    /// <https://www.felixcloutier.com/x86/minsd>
210    /// <https://www.felixcloutier.com/x86/minpd>
211    Min,
212    /// Maximum value (with SSE semantics)
213    ///
214    /// <https://www.felixcloutier.com/x86/maxss>
215    /// <https://www.felixcloutier.com/x86/maxps>
216    /// <https://www.felixcloutier.com/x86/maxsd>
217    /// <https://www.felixcloutier.com/x86/maxpd>
218    Max,
219}
220
221impl FloatBinOp {
222    /// Convert from the `imm` argument used to specify the comparison
223    /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
224    fn cmp_from_imm<'tcx>(
225        ecx: &crate::MiriInterpCx<'tcx>,
226        imm: i8,
227        intrinsic: Symbol,
228    ) -> InterpResult<'tcx, Self> {
229        // Only bits 0..=4 are used, remaining should be zero.
230        if imm & !0b1_1111 != 0 {
231            panic!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
232        }
233        // Bit 4 specifies whether the operation is quiet or signaling, which
234        // we do not care in Miri.
235        // Bits 0..=2 specifies the operation.
236        // `gt` indicates the result to be returned when the LHS is strictly
237        // greater than the RHS, and so on.
238        let (gt, lt, eq, mut unord) = match imm & 0b111 {
239            // Equal
240            0x0 => (false, false, true, false),
241            // Less-than
242            0x1 => (false, true, false, false),
243            // Less-or-equal
244            0x2 => (false, true, true, false),
245            // Unordered (either is NaN)
246            0x3 => (false, false, false, true),
247            // Not equal
248            0x4 => (true, true, false, true),
249            // Not less-than
250            0x5 => (true, false, true, true),
251            // Not less-or-equal
252            0x6 => (true, false, false, true),
253            // Ordered (neither is NaN)
254            0x7 => (true, true, true, false),
255            _ => unreachable!(),
256        };
257        // When bit 3 is 1 (only possible in AVX), unord is toggled.
258        if imm & 0b1000 != 0 {
259            ecx.expect_target_feature_for_intrinsic(intrinsic, "avx")?;
260            unord = !unord;
261        }
262        interp_ok(Self::Cmp { gt, lt, eq, unord })
263    }
264}
265
266/// Performs `which` scalar operation on `left` and `right` and returns
267/// the result.
268fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
269    which: FloatBinOp,
270    left: &ImmTy<'tcx>,
271    right: &ImmTy<'tcx>,
272) -> InterpResult<'tcx, Scalar> {
273    match which {
274        FloatBinOp::Cmp { gt, lt, eq, unord } => {
275            let left = left.to_scalar().to_float::<F>()?;
276            let right = right.to_scalar().to_float::<F>()?;
277
278            let res = match left.partial_cmp(&right) {
279                None => unord,
280                Some(std::cmp::Ordering::Less) => lt,
281                Some(std::cmp::Ordering::Equal) => eq,
282                Some(std::cmp::Ordering::Greater) => gt,
283            };
284            interp_ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
285        }
286        FloatBinOp::Min => {
287            let left_scalar = left.to_scalar();
288            let left = left_scalar.to_float::<F>()?;
289            let right_scalar = right.to_scalar();
290            let right = right_scalar.to_float::<F>()?;
291            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
292            // is true when `x` is either +0 or -0.
293            if (left == F::ZERO && right == F::ZERO)
294                || left.is_nan()
295                || right.is_nan()
296                || left >= right
297            {
298                interp_ok(right_scalar)
299            } else {
300                interp_ok(left_scalar)
301            }
302        }
303        FloatBinOp::Max => {
304            let left_scalar = left.to_scalar();
305            let left = left_scalar.to_float::<F>()?;
306            let right_scalar = right.to_scalar();
307            let right = right_scalar.to_float::<F>()?;
308            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
309            // is true when `x` is either +0 or -0.
310            if (left == F::ZERO && right == F::ZERO)
311                || left.is_nan()
312                || right.is_nan()
313                || left <= right
314            {
315                interp_ok(right_scalar)
316            } else {
317                interp_ok(left_scalar)
318            }
319        }
320    }
321}
322
323/// Performs `which` operation on the first component of `left` and `right`
324/// and copies the other components from `left`. The result is stored in `dest`.
325fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
326    ecx: &mut crate::MiriInterpCx<'tcx>,
327    which: FloatBinOp,
328    left: &OpTy<'tcx>,
329    right: &OpTy<'tcx>,
330    dest: &MPlaceTy<'tcx>,
331) -> InterpResult<'tcx, ()> {
332    let (left, left_len) = ecx.project_to_simd(left)?;
333    let (right, right_len) = ecx.project_to_simd(right)?;
334    let (dest, dest_len) = ecx.project_to_simd(dest)?;
335
336    assert_eq!(dest_len, left_len);
337    assert_eq!(dest_len, right_len);
338
339    let res0 = bin_op_float::<F>(
340        which,
341        &ecx.read_immediate(&ecx.project_index(&left, 0)?)?,
342        &ecx.read_immediate(&ecx.project_index(&right, 0)?)?,
343    )?;
344    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
345
346    for i in 1..dest_len {
347        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
348    }
349
350    interp_ok(())
351}
352
353/// Performs `which` operation on each component of `left` and
354/// `right`, storing the result is stored in `dest`.
355fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
356    ecx: &mut crate::MiriInterpCx<'tcx>,
357    which: FloatBinOp,
358    left: &OpTy<'tcx>,
359    right: &OpTy<'tcx>,
360    dest: &MPlaceTy<'tcx>,
361) -> InterpResult<'tcx, ()> {
362    let (left, left_len) = ecx.project_to_simd(left)?;
363    let (right, right_len) = ecx.project_to_simd(right)?;
364    let (dest, dest_len) = ecx.project_to_simd(dest)?;
365
366    assert_eq!(dest_len, left_len);
367    assert_eq!(dest_len, right_len);
368
369    for i in 0..dest_len {
370        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
371        let right = ecx.read_immediate(&ecx.project_index(&right, i)?)?;
372        let dest = ecx.project_index(&dest, i)?;
373
374        let res = bin_op_float::<F>(which, &left, &right)?;
375        ecx.write_scalar(res, &dest)?;
376    }
377
378    interp_ok(())
379}
380
381#[derive(Copy, Clone)]
382enum FloatUnaryOp {
383    /// Approximation of 1/x
384    ///
385    /// <https://www.felixcloutier.com/x86/rcpss>
386    /// <https://www.felixcloutier.com/x86/rcpps>
387    Rcp,
388    /// Approximation of 1/sqrt(x)
389    ///
390    /// <https://www.felixcloutier.com/x86/rsqrtss>
391    /// <https://www.felixcloutier.com/x86/rsqrtps>
392    Rsqrt,
393}
394
395/// Performs `which` scalar operation on `op` and returns the result.
396fn unary_op_f32<'tcx>(
397    ecx: &mut crate::MiriInterpCx<'tcx>,
398    which: FloatUnaryOp,
399    op: &ImmTy<'tcx>,
400) -> InterpResult<'tcx, Scalar> {
401    match which {
402        FloatUnaryOp::Rcp => {
403            let op = op.to_scalar().to_f32()?;
404            let div = (Single::from_u128(1).value / op).value;
405            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
406            // inaccuracy of RCP.
407            let res = math::apply_random_float_error(ecx, div, -12);
408            interp_ok(Scalar::from_f32(res))
409        }
410        FloatUnaryOp::Rsqrt => {
411            let op = op.to_scalar().to_f32()?;
412            let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
413            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
414            // inaccuracy of RSQRT.
415            let res = math::apply_random_float_error(ecx, rsqrt, -12);
416            interp_ok(Scalar::from_f32(res))
417        }
418    }
419}
420
421/// Performs `which` operation on the first component of `op` and copies
422/// the other components. The result is stored in `dest`.
423fn unary_op_ss<'tcx>(
424    ecx: &mut crate::MiriInterpCx<'tcx>,
425    which: FloatUnaryOp,
426    op: &OpTy<'tcx>,
427    dest: &MPlaceTy<'tcx>,
428) -> InterpResult<'tcx, ()> {
429    let (op, op_len) = ecx.project_to_simd(op)?;
430    let (dest, dest_len) = ecx.project_to_simd(dest)?;
431
432    assert_eq!(dest_len, op_len);
433
434    let res0 = unary_op_f32(ecx, which, &ecx.read_immediate(&ecx.project_index(&op, 0)?)?)?;
435    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
436
437    for i in 1..dest_len {
438        ecx.copy_op(&ecx.project_index(&op, i)?, &ecx.project_index(&dest, i)?)?;
439    }
440
441    interp_ok(())
442}
443
444/// Performs `which` operation on each component of `op`, storing the
445/// result is stored in `dest`.
446fn unary_op_ps<'tcx>(
447    ecx: &mut crate::MiriInterpCx<'tcx>,
448    which: FloatUnaryOp,
449    op: &OpTy<'tcx>,
450    dest: &MPlaceTy<'tcx>,
451) -> InterpResult<'tcx, ()> {
452    let (op, op_len) = ecx.project_to_simd(op)?;
453    let (dest, dest_len) = ecx.project_to_simd(dest)?;
454
455    assert_eq!(dest_len, op_len);
456
457    for i in 0..dest_len {
458        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
459        let dest = ecx.project_index(&dest, i)?;
460
461        let res = unary_op_f32(ecx, which, &op)?;
462        ecx.write_scalar(res, &dest)?;
463    }
464
465    interp_ok(())
466}
467
468enum ShiftOp {
469    /// Shift left, logically (shift in zeros) -- same as shift left, arithmetically
470    Left,
471    /// Shift right, logically (shift in zeros)
472    RightLogic,
473    /// Shift right, arithmetically (shift in sign)
474    RightArith,
475}
476
477/// Shifts each element of `left` by a scalar amount. The shift amount
478/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector).
479///
480/// For logic shifts, when right is larger than BITS - 1, zero is produced.
481/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
482/// bit is copied to all bits.
483fn shift_simd_by_scalar<'tcx>(
484    ecx: &mut crate::MiriInterpCx<'tcx>,
485    left: &OpTy<'tcx>,
486    right: &OpTy<'tcx>,
487    which: ShiftOp,
488    dest: &MPlaceTy<'tcx>,
489) -> InterpResult<'tcx, ()> {
490    let (left, left_len) = ecx.project_to_simd(left)?;
491    let (dest, dest_len) = ecx.project_to_simd(dest)?;
492
493    assert_eq!(dest_len, left_len);
494    // `right` may have a different length, and we only care about its
495    // lowest 64bit anyway.
496
497    // Get the 64-bit shift operand and convert it to the type expected
498    // by checked_{shl,shr} (u32).
499    // It is ok to saturate the value to u32::MAX because any value
500    // above BITS - 1 will produce the same result.
501    let shift = u32::try_from(extract_first_u64(ecx, right)?).unwrap_or(u32::MAX);
502
503    for i in 0..dest_len {
504        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
505        let dest = ecx.project_index(&dest, i)?;
506
507        let res = match which {
508            ShiftOp::Left => {
509                let left = left.to_uint(dest.layout.size)?;
510                let res = left.checked_shl(shift).unwrap_or(0);
511                // `truncate` is needed as left-shift can make the absolute value larger.
512                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
513            }
514            ShiftOp::RightLogic => {
515                let left = left.to_uint(dest.layout.size)?;
516                let res = left.checked_shr(shift).unwrap_or(0);
517                // No `truncate` needed as right-shift can only make the absolute value smaller.
518                Scalar::from_uint(res, dest.layout.size)
519            }
520            ShiftOp::RightArith => {
521                let left = left.to_int(dest.layout.size)?;
522                // On overflow, copy the sign bit to the remaining bits
523                let res = left.checked_shr(shift).unwrap_or(left >> 127);
524                // No `truncate` needed as right-shift can only make the absolute value smaller.
525                Scalar::from_int(res, dest.layout.size)
526            }
527        };
528        ecx.write_scalar(res, &dest)?;
529    }
530
531    interp_ok(())
532}
533
534/// Shifts each element of `left` by the corresponding element of `right`.
535///
536/// For logic shifts, when right is larger than BITS - 1, zero is produced.
537/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
538/// bit is copied to all bits.
539fn shift_simd_by_simd<'tcx>(
540    ecx: &mut crate::MiriInterpCx<'tcx>,
541    left: &OpTy<'tcx>,
542    right: &OpTy<'tcx>,
543    which: ShiftOp,
544    dest: &MPlaceTy<'tcx>,
545) -> InterpResult<'tcx, ()> {
546    let (left, left_len) = ecx.project_to_simd(left)?;
547    let (right, right_len) = ecx.project_to_simd(right)?;
548    let (dest, dest_len) = ecx.project_to_simd(dest)?;
549
550    assert_eq!(dest_len, left_len);
551    assert_eq!(dest_len, right_len);
552
553    for i in 0..dest_len {
554        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
555        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?;
556        let dest = ecx.project_index(&dest, i)?;
557
558        // It is ok to saturate the value to u32::MAX because any value
559        // above BITS - 1 will produce the same result.
560        let shift = u32::try_from(right.to_uint(dest.layout.size)?).unwrap_or(u32::MAX);
561
562        let res = match which {
563            ShiftOp::Left => {
564                let left = left.to_uint(dest.layout.size)?;
565                let res = left.checked_shl(shift).unwrap_or(0);
566                // `truncate` is needed as left-shift can make the absolute value larger.
567                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
568            }
569            ShiftOp::RightLogic => {
570                let left = left.to_uint(dest.layout.size)?;
571                let res = left.checked_shr(shift).unwrap_or(0);
572                // No `truncate` needed as right-shift can only make the absolute value smaller.
573                Scalar::from_uint(res, dest.layout.size)
574            }
575            ShiftOp::RightArith => {
576                let left = left.to_int(dest.layout.size)?;
577                // On overflow, copy the sign bit to the remaining bits
578                let res = left.checked_shr(shift).unwrap_or(left >> 127);
579                // No `truncate` needed as right-shift can only make the absolute value smaller.
580                Scalar::from_int(res, dest.layout.size)
581            }
582        };
583        ecx.write_scalar(res, &dest)?;
584    }
585
586    interp_ok(())
587}
588
589/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
590/// the first value.
591fn extract_first_u64<'tcx>(
592    ecx: &crate::MiriInterpCx<'tcx>,
593    op: &OpTy<'tcx>,
594) -> InterpResult<'tcx, u64> {
595    // Transmute vector to `[u64; 2]`
596    let array_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, 2))?;
597    let op = op.transmute(array_layout, ecx)?;
598
599    // Get the first u64 from the array
600    ecx.read_scalar(&ecx.project_index(&op, 0)?)?.to_u64()
601}
602
603// Rounds the first element of `right` according to `rounding`
604// and copies the remaining elements from `left`.
605fn round_first<'tcx, F: rustc_apfloat::Float>(
606    ecx: &mut crate::MiriInterpCx<'tcx>,
607    left: &OpTy<'tcx>,
608    right: &OpTy<'tcx>,
609    rounding: &OpTy<'tcx>,
610    dest: &MPlaceTy<'tcx>,
611) -> InterpResult<'tcx, ()> {
612    let (left, left_len) = ecx.project_to_simd(left)?;
613    let (right, right_len) = ecx.project_to_simd(right)?;
614    let (dest, dest_len) = ecx.project_to_simd(dest)?;
615
616    assert_eq!(dest_len, left_len);
617    assert_eq!(dest_len, right_len);
618
619    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
620
621    let op0: F = ecx.read_scalar(&ecx.project_index(&right, 0)?)?.to_float()?;
622    let res = op0.round_to_integral(rounding).value;
623    ecx.write_scalar(
624        Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
625        &ecx.project_index(&dest, 0)?,
626    )?;
627
628    for i in 1..dest_len {
629        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
630    }
631
632    interp_ok(())
633}
634
635// Rounds all elements of `op` according to `rounding`.
636fn round_all<'tcx, F: rustc_apfloat::Float>(
637    ecx: &mut crate::MiriInterpCx<'tcx>,
638    op: &OpTy<'tcx>,
639    rounding: &OpTy<'tcx>,
640    dest: &MPlaceTy<'tcx>,
641) -> InterpResult<'tcx, ()> {
642    let (op, op_len) = ecx.project_to_simd(op)?;
643    let (dest, dest_len) = ecx.project_to_simd(dest)?;
644
645    assert_eq!(dest_len, op_len);
646
647    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
648
649    for i in 0..dest_len {
650        let op: F = ecx.read_scalar(&ecx.project_index(&op, i)?)?.to_float()?;
651        let res = op.round_to_integral(rounding).value;
652        ecx.write_scalar(
653            Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
654            &ecx.project_index(&dest, i)?,
655        )?;
656    }
657
658    interp_ok(())
659}
660
661/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
662/// `round.{ss,sd,ps,pd}` intrinsics.
663fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
664    // The fourth bit of `rounding` only affects the SSE status
665    // register, which cannot be accessed from Miri (or from Rust,
666    // for that matter), so we can ignore it.
667    match rounding & !0b1000 {
668        // When the third bit is 0, the rounding mode is determined by the
669        // first two bits.
670        0b000 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
671        0b001 => interp_ok(rustc_apfloat::Round::TowardNegative),
672        0b010 => interp_ok(rustc_apfloat::Round::TowardPositive),
673        0b011 => interp_ok(rustc_apfloat::Round::TowardZero),
674        // When the third bit is 1, the rounding mode is determined by the
675        // SSE status register. Since we do not support modifying it from
676        // Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
677        0b100..=0b111 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
678        rounding => panic!("invalid rounding mode 0x{rounding:02x}"),
679    }
680}
681
682/// Converts each element of `op` from floating point to signed integer.
683///
684/// When the input value is NaN or out of range, fall back to minimum value.
685///
686/// If `op` has more elements than `dest`, extra elements are ignored. If `op`
687/// has less elements than `dest`, the rest is filled with zeros.
688fn convert_float_to_int<'tcx>(
689    ecx: &mut crate::MiriInterpCx<'tcx>,
690    op: &OpTy<'tcx>,
691    rnd: rustc_apfloat::Round,
692    dest: &MPlaceTy<'tcx>,
693) -> InterpResult<'tcx, ()> {
694    let (op, op_len) = ecx.project_to_simd(op)?;
695    let (dest, dest_len) = ecx.project_to_simd(dest)?;
696
697    // Output must be *signed* integers.
698    assert!(matches!(dest.layout.field(ecx, 0).ty.kind(), ty::Int(_)));
699
700    for i in 0..op_len.min(dest_len) {
701        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
702        let dest = ecx.project_index(&dest, i)?;
703
704        let res = ecx.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
705            // Fallback to minimum according to SSE/AVX semantics.
706            ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
707        });
708        ecx.write_immediate(*res, &dest)?;
709    }
710    // Fill remainder with zeros
711    for i in op_len..dest_len {
712        let dest = ecx.project_index(&dest, i)?;
713        ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
714    }
715
716    interp_ok(())
717}
718
719/// Calculates absolute value of integers in `op` and stores the result in `dest`.
720///
721/// In case of overflow (when the operand is the minimum value), the operation
722/// will wrap around.
723fn int_abs<'tcx>(
724    ecx: &mut crate::MiriInterpCx<'tcx>,
725    op: &OpTy<'tcx>,
726    dest: &MPlaceTy<'tcx>,
727) -> InterpResult<'tcx, ()> {
728    let (op, op_len) = ecx.project_to_simd(op)?;
729    let (dest, dest_len) = ecx.project_to_simd(dest)?;
730
731    assert_eq!(op_len, dest_len);
732
733    let zero = ImmTy::from_int(0, op.layout.field(ecx, 0));
734
735    for i in 0..dest_len {
736        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
737        let dest = ecx.project_index(&dest, i)?;
738
739        let lt_zero = ecx.binary_op(mir::BinOp::Lt, &op, &zero)?;
740        let res =
741            if lt_zero.to_scalar().to_bool()? { ecx.unary_op(mir::UnOp::Neg, &op)? } else { op };
742
743        ecx.write_immediate(*res, &dest)?;
744    }
745
746    interp_ok(())
747}
748
749/// Splits `op` (which must be a SIMD vector) into 128-bit chunks.
750///
751/// Returns a tuple where:
752/// * The first element is the number of 128-bit chunks (let's call it `N`).
753/// * The second element is the number of elements per chunk (let's call it `M`).
754/// * The third element is the `op` vector split into chunks, i.e, it's
755///   type is `[[T; M]; N]` where `T` is the element type of `op`.
756fn split_simd_to_128bit_chunks<'tcx, P: Projectable<'tcx, Provenance>>(
757    ecx: &mut crate::MiriInterpCx<'tcx>,
758    op: &P,
759) -> InterpResult<'tcx, (u64, u64, P)> {
760    let simd_layout = op.layout();
761    let (simd_len, element_ty) = simd_layout.ty.simd_size_and_type(ecx.tcx.tcx);
762
763    assert_eq!(simd_layout.size.bits() % 128, 0);
764    let num_chunks = simd_layout.size.bits() / 128;
765    let items_per_chunk = simd_len.strict_div(num_chunks);
766
767    // Transmute to `[[T; items_per_chunk]; num_chunks]`
768    let chunked_layout = ecx
769        .layout_of(Ty::new_array(
770            ecx.tcx.tcx,
771            Ty::new_array(ecx.tcx.tcx, element_ty, items_per_chunk),
772            num_chunks,
773        ))
774        .unwrap();
775    let chunked_op = op.transmute(chunked_layout, ecx)?;
776
777    interp_ok((num_chunks, items_per_chunk, chunked_op))
778}
779
780/// Horizontally performs `which` operation on adjacent values of
781/// `left` and `right` SIMD vectors and stores the result in `dest`.
782/// "Horizontal" means that the i-th output element is calculated
783/// from the elements 2*i and 2*i+1 of the concatenation of `left` and
784/// `right`.
785///
786/// Each 128-bit chunk is treated independently (i.e., the value for
787/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
788/// 128-bit chunks of `left` and `right`).
789fn horizontal_bin_op<'tcx>(
790    ecx: &mut crate::MiriInterpCx<'tcx>,
791    which: mir::BinOp,
792    saturating: bool,
793    left: &OpTy<'tcx>,
794    right: &OpTy<'tcx>,
795    dest: &MPlaceTy<'tcx>,
796) -> InterpResult<'tcx, ()> {
797    assert_eq!(left.layout, dest.layout);
798    assert_eq!(right.layout, dest.layout);
799
800    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
801    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
802    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
803
804    let middle = items_per_chunk / 2;
805    for i in 0..num_chunks {
806        let left = ecx.project_index(&left, i)?;
807        let right = ecx.project_index(&right, i)?;
808        let dest = ecx.project_index(&dest, i)?;
809
810        for j in 0..items_per_chunk {
811            // `j` is the index in `dest`
812            // `k` is the index of the 2-item chunk in `src`
813            let (k, src) = if j < middle { (j, &left) } else { (j.strict_sub(middle), &right) };
814            // `base_i` is the index of the first item of the 2-item chunk in `src`
815            let base_i = k.strict_mul(2);
816            let lhs = ecx.read_immediate(&ecx.project_index(src, base_i)?)?;
817            let rhs = ecx.read_immediate(&ecx.project_index(src, base_i.strict_add(1))?)?;
818
819            let res = if saturating {
820                Immediate::from(ecx.saturating_arith(which, &lhs, &rhs)?)
821            } else {
822                *ecx.binary_op(which, &lhs, &rhs)?
823            };
824
825            ecx.write_immediate(res, &ecx.project_index(&dest, j)?)?;
826        }
827    }
828
829    interp_ok(())
830}
831
832/// Conditionally multiplies the packed floating-point elements in
833/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
834/// products (up to 4), and conditionally stores the sum in `dest` using
835/// the low 4 bits of `imm`.
836///
837/// Each 128-bit chunk is treated independently (i.e., the value for
838/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
839/// 128-bit blocks of `left` and `right`).
840fn conditional_dot_product<'tcx>(
841    ecx: &mut crate::MiriInterpCx<'tcx>,
842    left: &OpTy<'tcx>,
843    right: &OpTy<'tcx>,
844    imm: &OpTy<'tcx>,
845    dest: &MPlaceTy<'tcx>,
846) -> InterpResult<'tcx, ()> {
847    assert_eq!(left.layout, dest.layout);
848    assert_eq!(right.layout, dest.layout);
849
850    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
851    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
852    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
853
854    let element_layout = left.layout.field(ecx, 0).field(ecx, 0);
855    assert!(items_per_chunk <= 4);
856
857    // `imm` is a `u8` for SSE4.1 or an `i32` for AVX :/
858    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
859
860    for i in 0..num_chunks {
861        let left = ecx.project_index(&left, i)?;
862        let right = ecx.project_index(&right, i)?;
863        let dest = ecx.project_index(&dest, i)?;
864
865        // Calculate dot product
866        // Elements are floating point numbers, but we can use `from_int`
867        // for the initial value because the representation of 0.0 is all zero bits.
868        let mut sum = ImmTy::from_int(0u8, element_layout);
869        for j in 0..items_per_chunk {
870            if imm & (1 << j.strict_add(4)) != 0 {
871                let left = ecx.read_immediate(&ecx.project_index(&left, j)?)?;
872                let right = ecx.read_immediate(&ecx.project_index(&right, j)?)?;
873
874                let mul = ecx.binary_op(mir::BinOp::Mul, &left, &right)?;
875                sum = ecx.binary_op(mir::BinOp::Add, &sum, &mul)?;
876            }
877        }
878
879        // Write to destination (conditioned to imm)
880        for j in 0..items_per_chunk {
881            let dest = ecx.project_index(&dest, j)?;
882
883            if imm & (1 << j) != 0 {
884                ecx.write_immediate(*sum, &dest)?;
885            } else {
886                ecx.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
887            }
888        }
889    }
890
891    interp_ok(())
892}
893
894/// Calculates two booleans.
895///
896/// The first is true when all the bits of `op & mask` are zero.
897/// The second is true when `(op & mask) == mask`
898fn test_bits_masked<'tcx>(
899    ecx: &crate::MiriInterpCx<'tcx>,
900    op: &OpTy<'tcx>,
901    mask: &OpTy<'tcx>,
902) -> InterpResult<'tcx, (bool, bool)> {
903    assert_eq!(op.layout, mask.layout);
904
905    let (op, op_len) = ecx.project_to_simd(op)?;
906    let (mask, mask_len) = ecx.project_to_simd(mask)?;
907
908    assert_eq!(op_len, mask_len);
909
910    let mut all_zero = true;
911    let mut masked_set = true;
912    for i in 0..op_len {
913        let op = ecx.project_index(&op, i)?;
914        let mask = ecx.project_index(&mask, i)?;
915
916        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
917        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
918        all_zero &= (op & mask) == 0;
919        masked_set &= (op & mask) == mask;
920    }
921
922    interp_ok((all_zero, masked_set))
923}
924
925/// Calculates two booleans.
926///
927/// The first is true when the highest bit of each element of `op & mask` is zero.
928/// The second is true when the highest bit of each element of `!op & mask` is zero.
929fn test_high_bits_masked<'tcx>(
930    ecx: &crate::MiriInterpCx<'tcx>,
931    op: &OpTy<'tcx>,
932    mask: &OpTy<'tcx>,
933) -> InterpResult<'tcx, (bool, bool)> {
934    assert_eq!(op.layout, mask.layout);
935
936    let (op, op_len) = ecx.project_to_simd(op)?;
937    let (mask, mask_len) = ecx.project_to_simd(mask)?;
938
939    assert_eq!(op_len, mask_len);
940
941    let high_bit_offset = op.layout.field(ecx, 0).size.bits().strict_sub(1);
942
943    let mut direct = true;
944    let mut negated = true;
945    for i in 0..op_len {
946        let op = ecx.project_index(&op, i)?;
947        let mask = ecx.project_index(&mask, i)?;
948
949        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
950        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
951        direct &= (op & mask) >> high_bit_offset == 0;
952        negated &= (!op & mask) >> high_bit_offset == 0;
953    }
954
955    interp_ok((direct, negated))
956}
957
958/// Conditionally loads from `ptr` according the high bit of each
959/// element of `mask`. `ptr` does not need to be aligned.
960fn mask_load<'tcx>(
961    ecx: &mut crate::MiriInterpCx<'tcx>,
962    ptr: &OpTy<'tcx>,
963    mask: &OpTy<'tcx>,
964    dest: &MPlaceTy<'tcx>,
965) -> InterpResult<'tcx, ()> {
966    let (mask, mask_len) = ecx.project_to_simd(mask)?;
967    let (dest, dest_len) = ecx.project_to_simd(dest)?;
968
969    assert_eq!(dest_len, mask_len);
970
971    let mask_item_size = mask.layout.field(ecx, 0).size;
972    let high_bit_offset = mask_item_size.bits().strict_sub(1);
973
974    let ptr = ecx.read_pointer(ptr)?;
975    for i in 0..dest_len {
976        let mask = ecx.project_index(&mask, i)?;
977        let dest = ecx.project_index(&dest, i)?;
978
979        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
980            let ptr = ptr.wrapping_offset(dest.layout.size * i, &ecx.tcx);
981            // Unaligned copy, which is what we want.
982            ecx.mem_copy(ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
983        } else {
984            ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
985        }
986    }
987
988    interp_ok(())
989}
990
991/// Conditionally stores into `ptr` according the high bit of each
992/// element of `mask`. `ptr` does not need to be aligned.
993fn mask_store<'tcx>(
994    ecx: &mut crate::MiriInterpCx<'tcx>,
995    ptr: &OpTy<'tcx>,
996    mask: &OpTy<'tcx>,
997    value: &OpTy<'tcx>,
998) -> InterpResult<'tcx, ()> {
999    let (mask, mask_len) = ecx.project_to_simd(mask)?;
1000    let (value, value_len) = ecx.project_to_simd(value)?;
1001
1002    assert_eq!(value_len, mask_len);
1003
1004    let mask_item_size = mask.layout.field(ecx, 0).size;
1005    let high_bit_offset = mask_item_size.bits().strict_sub(1);
1006
1007    let ptr = ecx.read_pointer(ptr)?;
1008    for i in 0..value_len {
1009        let mask = ecx.project_index(&mask, i)?;
1010        let value = ecx.project_index(&value, i)?;
1011
1012        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
1013            // *Non-inbounds* pointer arithmetic to compute the destination.
1014            // (That's why we can't use a place projection.)
1015            let ptr = ptr.wrapping_offset(value.layout.size * i, &ecx.tcx);
1016            // Deref the pointer *unaligned*, and do the copy.
1017            let dest = ecx.ptr_to_mplace_unaligned(ptr, value.layout);
1018            ecx.copy_op(&value, &dest)?;
1019        }
1020    }
1021
1022    interp_ok(())
1023}
1024
1025/// Compute the sum of absolute differences of quadruplets of unsigned
1026/// 8-bit integers in `left` and `right`, and store the 16-bit results
1027/// in `right`. Quadruplets are selected from `left` and `right` with
1028/// offsets specified in `imm`.
1029///
1030/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16>
1031/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8>
1032///
1033/// Each 128-bit chunk is treated independently (i.e., the value for
1034/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1035/// 128-bit chunks of `left` and `right`).
1036fn mpsadbw<'tcx>(
1037    ecx: &mut crate::MiriInterpCx<'tcx>,
1038    left: &OpTy<'tcx>,
1039    right: &OpTy<'tcx>,
1040    imm: &OpTy<'tcx>,
1041    dest: &MPlaceTy<'tcx>,
1042) -> InterpResult<'tcx, ()> {
1043    assert_eq!(left.layout, right.layout);
1044    assert_eq!(left.layout.size, dest.layout.size);
1045
1046    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1047    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1048    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1049
1050    assert_eq!(op_items_per_chunk, dest_items_per_chunk.strict_mul(2));
1051
1052    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
1053    // Bit 2 of `imm` specifies the offset for indices of `left`.
1054    // The offset is 0 when the bit is 0 or 4 when the bit is 1.
1055    let left_offset = u64::try_from((imm >> 2) & 1).unwrap().strict_mul(4);
1056    // Bits 0..=1 of `imm` specify the offset for indices of
1057    // `right` in blocks of 4 elements.
1058    let right_offset = u64::try_from(imm & 0b11).unwrap().strict_mul(4);
1059
1060    for i in 0..num_chunks {
1061        let left = ecx.project_index(&left, i)?;
1062        let right = ecx.project_index(&right, i)?;
1063        let dest = ecx.project_index(&dest, i)?;
1064
1065        for j in 0..dest_items_per_chunk {
1066            let left_offset = left_offset.strict_add(j);
1067            let mut res: u16 = 0;
1068            for k in 0..4 {
1069                let left = ecx
1070                    .read_scalar(&ecx.project_index(&left, left_offset.strict_add(k))?)?
1071                    .to_u8()?;
1072                let right = ecx
1073                    .read_scalar(&ecx.project_index(&right, right_offset.strict_add(k))?)?
1074                    .to_u8()?;
1075                res = res.strict_add(left.abs_diff(right).into());
1076            }
1077            ecx.write_scalar(Scalar::from_u16(res), &ecx.project_index(&dest, j)?)?;
1078        }
1079    }
1080
1081    interp_ok(())
1082}
1083
1084/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
1085/// product to the 18 most significant bits by right-shifting, and then
1086/// divides the 18-bit value by 2 (rounding to nearest) by first adding
1087/// 1 and then taking the bits `1..=16`.
1088///
1089/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mulhrs_epi16>
1090/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16>
1091fn pmulhrsw<'tcx>(
1092    ecx: &mut crate::MiriInterpCx<'tcx>,
1093    left: &OpTy<'tcx>,
1094    right: &OpTy<'tcx>,
1095    dest: &MPlaceTy<'tcx>,
1096) -> InterpResult<'tcx, ()> {
1097    let (left, left_len) = ecx.project_to_simd(left)?;
1098    let (right, right_len) = ecx.project_to_simd(right)?;
1099    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1100
1101    assert_eq!(dest_len, left_len);
1102    assert_eq!(dest_len, right_len);
1103
1104    for i in 0..dest_len {
1105        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?.to_i16()?;
1106        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_i16()?;
1107        let dest = ecx.project_index(&dest, i)?;
1108
1109        let res = (i32::from(left).strict_mul(right.into()) >> 14).strict_add(1) >> 1;
1110
1111        // The result of this operation can overflow a signed 16-bit integer.
1112        // When `left` and `right` are -0x8000, the result is 0x8000.
1113        #[expect(clippy::cast_possible_truncation)]
1114        let res = res as i16;
1115
1116        ecx.write_scalar(Scalar::from_i16(res), &dest)?;
1117    }
1118
1119    interp_ok(())
1120}
1121
1122/// Perform a carry-less multiplication of two 64-bit integers, selected from `left` and `right` according to `imm8`,
1123/// and store the results in `dst`.
1124///
1125/// `left` and `right` are both vectors of type `len` x i64. Only bits 0 and 4 of `imm8` matter;
1126/// they select the element of `left` and `right`, respectively.
1127///
1128/// `len` is the SIMD vector length (in counts of `i64` values). It is expected to be one of
1129/// `2`, `4`, or `8`.
1130///
1131/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_clmulepi64_si128>
1132fn pclmulqdq<'tcx>(
1133    ecx: &mut MiriInterpCx<'tcx>,
1134    left: &OpTy<'tcx>,
1135    right: &OpTy<'tcx>,
1136    imm8: &OpTy<'tcx>,
1137    dest: &MPlaceTy<'tcx>,
1138    len: u64,
1139) -> InterpResult<'tcx, ()> {
1140    assert_eq!(left.layout, right.layout);
1141    assert_eq!(left.layout.size, dest.layout.size);
1142    assert!([2u64, 4, 8].contains(&len));
1143
1144    // Transmute the input into arrays of `[u64; len]`.
1145    // Transmute the output into an array of `[u128, len / 2]`.
1146
1147    let src_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, len))?;
1148    let dest_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u128, len / 2))?;
1149
1150    let left = left.transmute(src_layout, ecx)?;
1151    let right = right.transmute(src_layout, ecx)?;
1152    let dest = dest.transmute(dest_layout, ecx)?;
1153
1154    let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
1155
1156    for i in 0..(len / 2) {
1157        let lo = i.strict_mul(2);
1158        let hi = i.strict_mul(2).strict_add(1);
1159
1160        // select the 64-bit integer from left that the user specified (low or high)
1161        let index = if (imm8 & 0x01) == 0 { lo } else { hi };
1162        let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u64()?;
1163
1164        // select the 64-bit integer from right that the user specified (low or high)
1165        let index = if (imm8 & 0x10) == 0 { lo } else { hi };
1166        let right = ecx.read_scalar(&ecx.project_index(&right, index)?)?.to_u64()?;
1167
1168        // Perform carry-less multiplication.
1169        //
1170        // This operation is like long multiplication, but ignores all carries.
1171        // That idea corresponds to the xor operator, which is used in the implementation.
1172        //
1173        // Wikipedia has an example https://en.wikipedia.org/wiki/Carry-less_product#Example
1174        let mut result: u128 = 0;
1175
1176        for i in 0..64 {
1177            // if the i-th bit in right is set
1178            if (right & (1 << i)) != 0 {
1179                // xor result with `left` shifted to the left by i positions
1180                result ^= u128::from(left) << i;
1181            }
1182        }
1183
1184        let dest = ecx.project_index(&dest, i)?;
1185        ecx.write_scalar(Scalar::from_u128(result), &dest)?;
1186    }
1187
1188    interp_ok(())
1189}
1190
1191/// Packs two N-bit integer vectors to a single N/2-bit integers.
1192///
1193/// The conversion from N-bit to N/2-bit should be provided by `f`.
1194///
1195/// Each 128-bit chunk is treated independently (i.e., the value for
1196/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1197/// 128-bit chunks of `left` and `right`).
1198fn pack_generic<'tcx>(
1199    ecx: &mut crate::MiriInterpCx<'tcx>,
1200    left: &OpTy<'tcx>,
1201    right: &OpTy<'tcx>,
1202    dest: &MPlaceTy<'tcx>,
1203    f: impl Fn(Scalar) -> InterpResult<'tcx, Scalar>,
1204) -> InterpResult<'tcx, ()> {
1205    assert_eq!(left.layout, right.layout);
1206    assert_eq!(left.layout.size, dest.layout.size);
1207
1208    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1209    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1210    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1211
1212    assert_eq!(dest_items_per_chunk, op_items_per_chunk.strict_mul(2));
1213
1214    for i in 0..num_chunks {
1215        let left = ecx.project_index(&left, i)?;
1216        let right = ecx.project_index(&right, i)?;
1217        let dest = ecx.project_index(&dest, i)?;
1218
1219        for j in 0..op_items_per_chunk {
1220            let left = ecx.read_scalar(&ecx.project_index(&left, j)?)?;
1221            let right = ecx.read_scalar(&ecx.project_index(&right, j)?)?;
1222            let left_dest = ecx.project_index(&dest, j)?;
1223            let right_dest = ecx.project_index(&dest, j.strict_add(op_items_per_chunk))?;
1224
1225            let left_res = f(left)?;
1226            let right_res = f(right)?;
1227
1228            ecx.write_scalar(left_res, &left_dest)?;
1229            ecx.write_scalar(right_res, &right_dest)?;
1230        }
1231    }
1232
1233    interp_ok(())
1234}
1235
1236/// Converts two 16-bit integer vectors to a single 8-bit integer
1237/// vector with signed saturation.
1238///
1239/// Each 128-bit chunk is treated independently (i.e., the value for
1240/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1241/// 128-bit chunks of `left` and `right`).
1242fn packsswb<'tcx>(
1243    ecx: &mut crate::MiriInterpCx<'tcx>,
1244    left: &OpTy<'tcx>,
1245    right: &OpTy<'tcx>,
1246    dest: &MPlaceTy<'tcx>,
1247) -> InterpResult<'tcx, ()> {
1248    pack_generic(ecx, left, right, dest, |op| {
1249        let op = op.to_i16()?;
1250        let res = i8::try_from(op).unwrap_or(if op < 0 { i8::MIN } else { i8::MAX });
1251        interp_ok(Scalar::from_i8(res))
1252    })
1253}
1254
1255/// Converts two 16-bit signed integer vectors to a single 8-bit
1256/// unsigned integer vector with saturation.
1257///
1258/// Each 128-bit chunk is treated independently (i.e., the value for
1259/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1260/// 128-bit chunks of `left` and `right`).
1261fn packuswb<'tcx>(
1262    ecx: &mut crate::MiriInterpCx<'tcx>,
1263    left: &OpTy<'tcx>,
1264    right: &OpTy<'tcx>,
1265    dest: &MPlaceTy<'tcx>,
1266) -> InterpResult<'tcx, ()> {
1267    pack_generic(ecx, left, right, dest, |op| {
1268        let op = op.to_i16()?;
1269        let res = u8::try_from(op).unwrap_or(if op < 0 { 0 } else { u8::MAX });
1270        interp_ok(Scalar::from_u8(res))
1271    })
1272}
1273
1274/// Converts two 32-bit integer vectors to a single 16-bit integer
1275/// vector with signed saturation.
1276///
1277/// Each 128-bit chunk is treated independently (i.e., the value for
1278/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1279/// 128-bit chunks of `left` and `right`).
1280fn packssdw<'tcx>(
1281    ecx: &mut crate::MiriInterpCx<'tcx>,
1282    left: &OpTy<'tcx>,
1283    right: &OpTy<'tcx>,
1284    dest: &MPlaceTy<'tcx>,
1285) -> InterpResult<'tcx, ()> {
1286    pack_generic(ecx, left, right, dest, |op| {
1287        let op = op.to_i32()?;
1288        let res = i16::try_from(op).unwrap_or(if op < 0 { i16::MIN } else { i16::MAX });
1289        interp_ok(Scalar::from_i16(res))
1290    })
1291}
1292
1293/// Converts two 32-bit integer vectors to a single 16-bit integer
1294/// vector with unsigned saturation.
1295///
1296/// Each 128-bit chunk is treated independently (i.e., the value for
1297/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1298/// 128-bit chunks of `left` and `right`).
1299fn packusdw<'tcx>(
1300    ecx: &mut crate::MiriInterpCx<'tcx>,
1301    left: &OpTy<'tcx>,
1302    right: &OpTy<'tcx>,
1303    dest: &MPlaceTy<'tcx>,
1304) -> InterpResult<'tcx, ()> {
1305    pack_generic(ecx, left, right, dest, |op| {
1306        let op = op.to_i32()?;
1307        let res = u16::try_from(op).unwrap_or(if op < 0 { 0 } else { u16::MAX });
1308        interp_ok(Scalar::from_u16(res))
1309    })
1310}
1311
1312/// Negates elements from `left` when the corresponding element in
1313/// `right` is negative. If an element from `right` is zero, zero
1314/// is written to the corresponding output element.
1315/// In other words, multiplies `left` with `right.signum()`.
1316fn psign<'tcx>(
1317    ecx: &mut crate::MiriInterpCx<'tcx>,
1318    left: &OpTy<'tcx>,
1319    right: &OpTy<'tcx>,
1320    dest: &MPlaceTy<'tcx>,
1321) -> InterpResult<'tcx, ()> {
1322    let (left, left_len) = ecx.project_to_simd(left)?;
1323    let (right, right_len) = ecx.project_to_simd(right)?;
1324    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1325
1326    assert_eq!(dest_len, left_len);
1327    assert_eq!(dest_len, right_len);
1328
1329    for i in 0..dest_len {
1330        let dest = ecx.project_index(&dest, i)?;
1331        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
1332        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_int(dest.layout.size)?;
1333
1334        let res =
1335            ecx.binary_op(mir::BinOp::Mul, &left, &ImmTy::from_int(right.signum(), dest.layout))?;
1336
1337        ecx.write_immediate(*res, &dest)?;
1338    }
1339
1340    interp_ok(())
1341}
1342
1343/// Calcultates either `a + b + cb_in` or `a - b - cb_in` depending on the value
1344/// of `op` and returns both the sum and the overflow bit. `op` is expected to be
1345/// either one of `mir::BinOp::AddWithOverflow` and `mir::BinOp::SubWithOverflow`.
1346fn carrying_add<'tcx>(
1347    ecx: &mut crate::MiriInterpCx<'tcx>,
1348    cb_in: &OpTy<'tcx>,
1349    a: &OpTy<'tcx>,
1350    b: &OpTy<'tcx>,
1351    op: mir::BinOp,
1352) -> InterpResult<'tcx, (ImmTy<'tcx>, Scalar)> {
1353    assert!(op == mir::BinOp::AddWithOverflow || op == mir::BinOp::SubWithOverflow);
1354
1355    let cb_in = ecx.read_scalar(cb_in)?.to_u8()? != 0;
1356    let a = ecx.read_immediate(a)?;
1357    let b = ecx.read_immediate(b)?;
1358
1359    let (sum, overflow1) = ecx.binary_op(op, &a, &b)?.to_pair(ecx);
1360    let (sum, overflow2) =
1361        ecx.binary_op(op, &sum, &ImmTy::from_uint(cb_in, a.layout))?.to_pair(ecx);
1362    let cb_out = overflow1.to_scalar().to_bool()? | overflow2.to_scalar().to_bool()?;
1363
1364    interp_ok((sum, Scalar::from_u8(cb_out.into())))
1365}