miri/shims/x86/
mod.rs

1use rustc_abi::{CanonAbi, FieldIdx, Size};
2use rustc_apfloat::Float;
3use rustc_apfloat::ieee::Single;
4use rustc_middle::ty::Ty;
5use rustc_middle::{mir, ty};
6use rustc_span::Symbol;
7use rustc_target::callconv::FnAbi;
8use rustc_target::spec::Arch;
9
10use self::helpers::bool_to_simd_element;
11use crate::*;
12
13mod aesni;
14mod avx;
15mod avx2;
16mod avx512;
17mod bmi;
18mod gfni;
19mod sha;
20mod sse;
21mod sse2;
22mod sse3;
23mod sse41;
24mod sse42;
25mod ssse3;
26
27impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
28pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
29    fn emulate_x86_intrinsic(
30        &mut self,
31        link_name: Symbol,
32        abi: &FnAbi<'tcx, Ty<'tcx>>,
33        args: &[OpTy<'tcx>],
34        dest: &MPlaceTy<'tcx>,
35    ) -> InterpResult<'tcx, EmulateItemResult> {
36        let this = self.eval_context_mut();
37        // Prefix should have already been checked.
38        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
39        match unprefixed_name {
40            // Used to implement the `_addcarry_u{32, 64}` and the `_subborrow_u{32, 64}` functions.
41            // Computes a + b or a - b with input and output carry/borrow. The input carry/borrow is an 8-bit
42            // value, which is interpreted as 1 if it is non-zero. The output carry/borrow is an 8-bit value that will be 0 or 1.
43            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/addcarry-u32-addcarry-u64.html
44            // https://www.intel.com/content/www/us/en/docs/cpp-compiler/developer-guide-reference/2021-8/subborrow-u32-subborrow-u64.html
45            "addcarry.32" | "addcarry.64" | "subborrow.32" | "subborrow.64" => {
46                if unprefixed_name.ends_with("64") && this.tcx.sess.target.arch != Arch::X86_64 {
47                    return interp_ok(EmulateItemResult::NotSupported);
48                }
49
50                let [cb_in, a, b] =
51                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
52                let op = if unprefixed_name.starts_with("add") {
53                    mir::BinOp::AddWithOverflow
54                } else {
55                    mir::BinOp::SubWithOverflow
56                };
57
58                let (sum, cb_out) = carrying_add(this, cb_in, a, b, op)?;
59                this.write_scalar(cb_out, &this.project_field(dest, FieldIdx::ZERO)?)?;
60                this.write_immediate(*sum, &this.project_field(dest, FieldIdx::ONE)?)?;
61            }
62
63            // Used to implement the `_mm_pause` function.
64            // The intrinsic is used to hint the processor that the code is in a spin-loop.
65            // It is compiled down to a `pause` instruction. When SSE2 is not available,
66            // the instruction behaves like a no-op, so it is always safe to call the
67            // intrinsic.
68            "sse2.pause" => {
69                let [] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
70                // Only exhibit the spin-loop hint behavior when SSE2 is enabled.
71                if this.tcx.sess.unstable_target_features.contains(&Symbol::intern("sse2")) {
72                    this.yield_active_thread();
73                }
74            }
75
76            "pclmulqdq" | "pclmulqdq.256" | "pclmulqdq.512" => {
77                let mut len = 2; // in units of 64bits
78                this.expect_target_feature_for_intrinsic(link_name, "pclmulqdq")?;
79                if unprefixed_name.ends_with(".256") {
80                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
81                    len = 4;
82                } else if unprefixed_name.ends_with(".512") {
83                    this.expect_target_feature_for_intrinsic(link_name, "vpclmulqdq")?;
84                    this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
85                    len = 8;
86                }
87
88                let [left, right, imm] =
89                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
90
91                pclmulqdq(this, left, right, imm, dest, len)?;
92            }
93
94            name if name.starts_with("bmi.") => {
95                return bmi::EvalContextExt::emulate_x86_bmi_intrinsic(
96                    this, link_name, abi, args, dest,
97                );
98            }
99            // The GFNI extension does not get its own namespace.
100            // Check for instruction names instead.
101            name if name.starts_with("vgf2p8affine") || name.starts_with("vgf2p8mulb") => {
102                return gfni::EvalContextExt::emulate_x86_gfni_intrinsic(
103                    this, link_name, abi, args, dest,
104                );
105            }
106            name if name.starts_with("sha") => {
107                return sha::EvalContextExt::emulate_x86_sha_intrinsic(
108                    this, link_name, abi, args, dest,
109                );
110            }
111            name if name.starts_with("sse.") => {
112                return sse::EvalContextExt::emulate_x86_sse_intrinsic(
113                    this, link_name, abi, args, dest,
114                );
115            }
116            name if name.starts_with("sse2.") => {
117                return sse2::EvalContextExt::emulate_x86_sse2_intrinsic(
118                    this, link_name, abi, args, dest,
119                );
120            }
121            name if name.starts_with("sse3.") => {
122                return sse3::EvalContextExt::emulate_x86_sse3_intrinsic(
123                    this, link_name, abi, args, dest,
124                );
125            }
126            name if name.starts_with("ssse3.") => {
127                return ssse3::EvalContextExt::emulate_x86_ssse3_intrinsic(
128                    this, link_name, abi, args, dest,
129                );
130            }
131            name if name.starts_with("sse41.") => {
132                return sse41::EvalContextExt::emulate_x86_sse41_intrinsic(
133                    this, link_name, abi, args, dest,
134                );
135            }
136            name if name.starts_with("sse42.") => {
137                return sse42::EvalContextExt::emulate_x86_sse42_intrinsic(
138                    this, link_name, abi, args, dest,
139                );
140            }
141            name if name.starts_with("aesni.") => {
142                return aesni::EvalContextExt::emulate_x86_aesni_intrinsic(
143                    this, link_name, abi, args, dest,
144                );
145            }
146            name if name.starts_with("avx.") => {
147                return avx::EvalContextExt::emulate_x86_avx_intrinsic(
148                    this, link_name, abi, args, dest,
149                );
150            }
151            name if name.starts_with("avx2.") => {
152                return avx2::EvalContextExt::emulate_x86_avx2_intrinsic(
153                    this, link_name, abi, args, dest,
154                );
155            }
156            name if name.starts_with("avx512.") => {
157                return avx512::EvalContextExt::emulate_x86_avx512_intrinsic(
158                    this, link_name, abi, args, dest,
159                );
160            }
161
162            _ => return interp_ok(EmulateItemResult::NotSupported),
163        }
164        interp_ok(EmulateItemResult::NeedsReturn)
165    }
166}
167
168#[derive(Copy, Clone)]
169enum FloatBinOp {
170    /// Comparison
171    ///
172    /// The semantics of this operator is a case distinction: we compare the two operands,
173    /// and then we return one of the four booleans `gt`, `lt`, `eq`, `unord` depending on
174    /// which class they fall into.
175    ///
176    /// AVX supports all 16 combinations, SSE only a subset
177    ///
178    /// <https://www.felixcloutier.com/x86/cmpss>
179    /// <https://www.felixcloutier.com/x86/cmpps>
180    /// <https://www.felixcloutier.com/x86/cmpsd>
181    /// <https://www.felixcloutier.com/x86/cmppd>
182    Cmp {
183        /// Result when lhs < rhs
184        gt: bool,
185        /// Result when lhs > rhs
186        lt: bool,
187        /// Result when lhs == rhs
188        eq: bool,
189        /// Result when lhs is NaN or rhs is NaN
190        unord: bool,
191    },
192    /// Minimum value (with SSE semantics)
193    ///
194    /// <https://www.felixcloutier.com/x86/minss>
195    /// <https://www.felixcloutier.com/x86/minps>
196    /// <https://www.felixcloutier.com/x86/minsd>
197    /// <https://www.felixcloutier.com/x86/minpd>
198    Min,
199    /// Maximum value (with SSE semantics)
200    ///
201    /// <https://www.felixcloutier.com/x86/maxss>
202    /// <https://www.felixcloutier.com/x86/maxps>
203    /// <https://www.felixcloutier.com/x86/maxsd>
204    /// <https://www.felixcloutier.com/x86/maxpd>
205    Max,
206}
207
208impl FloatBinOp {
209    /// Convert from the `imm` argument used to specify the comparison
210    /// operation in intrinsics such as `llvm.x86.sse.cmp.ss`.
211    fn cmp_from_imm<'tcx>(
212        ecx: &crate::MiriInterpCx<'tcx>,
213        imm: i8,
214        intrinsic: Symbol,
215    ) -> InterpResult<'tcx, Self> {
216        // Only bits 0..=4 are used, remaining should be zero.
217        if imm & !0b1_1111 != 0 {
218            panic!("invalid `imm` parameter of {intrinsic}: 0x{imm:x}");
219        }
220        // Bit 4 specifies whether the operation is quiet or signaling, which
221        // we do not care in Miri.
222        // Bits 0..=2 specifies the operation.
223        // `gt` indicates the result to be returned when the LHS is strictly
224        // greater than the RHS, and so on.
225        let (gt, lt, eq, mut unord) = match imm & 0b111 {
226            // Equal
227            0x0 => (false, false, true, false),
228            // Less-than
229            0x1 => (false, true, false, false),
230            // Less-or-equal
231            0x2 => (false, true, true, false),
232            // Unordered (either is NaN)
233            0x3 => (false, false, false, true),
234            // Not equal
235            0x4 => (true, true, false, true),
236            // Not less-than
237            0x5 => (true, false, true, true),
238            // Not less-or-equal
239            0x6 => (true, false, false, true),
240            // Ordered (neither is NaN)
241            0x7 => (true, true, true, false),
242            _ => unreachable!(),
243        };
244        // When bit 3 is 1 (only possible in AVX), unord is toggled.
245        if imm & 0b1000 != 0 {
246            ecx.expect_target_feature_for_intrinsic(intrinsic, "avx")?;
247            unord = !unord;
248        }
249        interp_ok(Self::Cmp { gt, lt, eq, unord })
250    }
251}
252
253/// Performs `which` scalar operation on `left` and `right` and returns
254/// the result.
255fn bin_op_float<'tcx, F: rustc_apfloat::Float>(
256    which: FloatBinOp,
257    left: &ImmTy<'tcx>,
258    right: &ImmTy<'tcx>,
259) -> InterpResult<'tcx, Scalar> {
260    match which {
261        FloatBinOp::Cmp { gt, lt, eq, unord } => {
262            let left = left.to_scalar().to_float::<F>()?;
263            let right = right.to_scalar().to_float::<F>()?;
264
265            let res = match left.partial_cmp(&right) {
266                None => unord,
267                Some(std::cmp::Ordering::Less) => lt,
268                Some(std::cmp::Ordering::Equal) => eq,
269                Some(std::cmp::Ordering::Greater) => gt,
270            };
271            interp_ok(bool_to_simd_element(res, Size::from_bits(F::BITS)))
272        }
273        FloatBinOp::Min => {
274            let left_scalar = left.to_scalar();
275            let left = left_scalar.to_float::<F>()?;
276            let right_scalar = right.to_scalar();
277            let right = right_scalar.to_float::<F>()?;
278            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
279            // is true when `x` is either +0 or -0.
280            if (left == F::ZERO && right == F::ZERO)
281                || left.is_nan()
282                || right.is_nan()
283                || left >= right
284            {
285                interp_ok(right_scalar)
286            } else {
287                interp_ok(left_scalar)
288            }
289        }
290        FloatBinOp::Max => {
291            let left_scalar = left.to_scalar();
292            let left = left_scalar.to_float::<F>()?;
293            let right_scalar = right.to_scalar();
294            let right = right_scalar.to_float::<F>()?;
295            // SSE semantics to handle zero and NaN. Note that `x == F::ZERO`
296            // is true when `x` is either +0 or -0.
297            if (left == F::ZERO && right == F::ZERO)
298                || left.is_nan()
299                || right.is_nan()
300                || left <= right
301            {
302                interp_ok(right_scalar)
303            } else {
304                interp_ok(left_scalar)
305            }
306        }
307    }
308}
309
310/// Performs `which` operation on the first component of `left` and `right`
311/// and copies the other components from `left`. The result is stored in `dest`.
312fn bin_op_simd_float_first<'tcx, F: rustc_apfloat::Float>(
313    ecx: &mut crate::MiriInterpCx<'tcx>,
314    which: FloatBinOp,
315    left: &OpTy<'tcx>,
316    right: &OpTy<'tcx>,
317    dest: &MPlaceTy<'tcx>,
318) -> InterpResult<'tcx, ()> {
319    let (left, left_len) = ecx.project_to_simd(left)?;
320    let (right, right_len) = ecx.project_to_simd(right)?;
321    let (dest, dest_len) = ecx.project_to_simd(dest)?;
322
323    assert_eq!(dest_len, left_len);
324    assert_eq!(dest_len, right_len);
325
326    let res0 = bin_op_float::<F>(
327        which,
328        &ecx.read_immediate(&ecx.project_index(&left, 0)?)?,
329        &ecx.read_immediate(&ecx.project_index(&right, 0)?)?,
330    )?;
331    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
332
333    for i in 1..dest_len {
334        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
335    }
336
337    interp_ok(())
338}
339
340/// Performs `which` operation on each component of `left` and
341/// `right`, storing the result is stored in `dest`.
342fn bin_op_simd_float_all<'tcx, F: rustc_apfloat::Float>(
343    ecx: &mut crate::MiriInterpCx<'tcx>,
344    which: FloatBinOp,
345    left: &OpTy<'tcx>,
346    right: &OpTy<'tcx>,
347    dest: &MPlaceTy<'tcx>,
348) -> InterpResult<'tcx, ()> {
349    let (left, left_len) = ecx.project_to_simd(left)?;
350    let (right, right_len) = ecx.project_to_simd(right)?;
351    let (dest, dest_len) = ecx.project_to_simd(dest)?;
352
353    assert_eq!(dest_len, left_len);
354    assert_eq!(dest_len, right_len);
355
356    for i in 0..dest_len {
357        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
358        let right = ecx.read_immediate(&ecx.project_index(&right, i)?)?;
359        let dest = ecx.project_index(&dest, i)?;
360
361        let res = bin_op_float::<F>(which, &left, &right)?;
362        ecx.write_scalar(res, &dest)?;
363    }
364
365    interp_ok(())
366}
367
368#[derive(Copy, Clone)]
369enum FloatUnaryOp {
370    /// Approximation of 1/x
371    ///
372    /// <https://www.felixcloutier.com/x86/rcpss>
373    /// <https://www.felixcloutier.com/x86/rcpps>
374    Rcp,
375    /// Approximation of 1/sqrt(x)
376    ///
377    /// <https://www.felixcloutier.com/x86/rsqrtss>
378    /// <https://www.felixcloutier.com/x86/rsqrtps>
379    Rsqrt,
380}
381
382/// Performs `which` scalar operation on `op` and returns the result.
383fn unary_op_f32<'tcx>(
384    ecx: &mut crate::MiriInterpCx<'tcx>,
385    which: FloatUnaryOp,
386    op: &ImmTy<'tcx>,
387) -> InterpResult<'tcx, Scalar> {
388    match which {
389        FloatUnaryOp::Rcp => {
390            let op = op.to_scalar().to_f32()?;
391            let div = (Single::from_u128(1).value / op).value;
392            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
393            // inaccuracy of RCP.
394            let res = math::apply_random_float_error(ecx, div, -12);
395            interp_ok(Scalar::from_f32(res))
396        }
397        FloatUnaryOp::Rsqrt => {
398            let op = op.to_scalar().to_f32()?;
399            let rsqrt = (Single::from_u128(1).value / math::sqrt(op)).value;
400            // Apply a relative error with a magnitude on the order of 2^-12 to simulate the
401            // inaccuracy of RSQRT.
402            let res = math::apply_random_float_error(ecx, rsqrt, -12);
403            interp_ok(Scalar::from_f32(res))
404        }
405    }
406}
407
408/// Performs `which` operation on the first component of `op` and copies
409/// the other components. The result is stored in `dest`.
410fn unary_op_ss<'tcx>(
411    ecx: &mut crate::MiriInterpCx<'tcx>,
412    which: FloatUnaryOp,
413    op: &OpTy<'tcx>,
414    dest: &MPlaceTy<'tcx>,
415) -> InterpResult<'tcx, ()> {
416    let (op, op_len) = ecx.project_to_simd(op)?;
417    let (dest, dest_len) = ecx.project_to_simd(dest)?;
418
419    assert_eq!(dest_len, op_len);
420
421    let res0 = unary_op_f32(ecx, which, &ecx.read_immediate(&ecx.project_index(&op, 0)?)?)?;
422    ecx.write_scalar(res0, &ecx.project_index(&dest, 0)?)?;
423
424    for i in 1..dest_len {
425        ecx.copy_op(&ecx.project_index(&op, i)?, &ecx.project_index(&dest, i)?)?;
426    }
427
428    interp_ok(())
429}
430
431/// Performs `which` operation on each component of `op`, storing the
432/// result is stored in `dest`.
433fn unary_op_ps<'tcx>(
434    ecx: &mut crate::MiriInterpCx<'tcx>,
435    which: FloatUnaryOp,
436    op: &OpTy<'tcx>,
437    dest: &MPlaceTy<'tcx>,
438) -> InterpResult<'tcx, ()> {
439    let (op, op_len) = ecx.project_to_simd(op)?;
440    let (dest, dest_len) = ecx.project_to_simd(dest)?;
441
442    assert_eq!(dest_len, op_len);
443
444    for i in 0..dest_len {
445        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
446        let dest = ecx.project_index(&dest, i)?;
447
448        let res = unary_op_f32(ecx, which, &op)?;
449        ecx.write_scalar(res, &dest)?;
450    }
451
452    interp_ok(())
453}
454
455enum ShiftOp {
456    /// Shift left, logically (shift in zeros) -- same as shift left, arithmetically
457    Left,
458    /// Shift right, logically (shift in zeros)
459    RightLogic,
460    /// Shift right, arithmetically (shift in sign)
461    RightArith,
462}
463
464/// Shifts each element of `left` by a scalar amount. The shift amount
465/// is determined by the lowest 64 bits of `right` (which is a 128-bit vector).
466///
467/// For logic shifts, when right is larger than BITS - 1, zero is produced.
468/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
469/// bit is copied to all bits.
470fn shift_simd_by_scalar<'tcx>(
471    ecx: &mut crate::MiriInterpCx<'tcx>,
472    left: &OpTy<'tcx>,
473    right: &OpTy<'tcx>,
474    which: ShiftOp,
475    dest: &MPlaceTy<'tcx>,
476) -> InterpResult<'tcx, ()> {
477    let (left, left_len) = ecx.project_to_simd(left)?;
478    let (dest, dest_len) = ecx.project_to_simd(dest)?;
479
480    assert_eq!(dest_len, left_len);
481    // `right` may have a different length, and we only care about its
482    // lowest 64bit anyway.
483
484    // Get the 64-bit shift operand and convert it to the type expected
485    // by checked_{shl,shr} (u32).
486    // It is ok to saturate the value to u32::MAX because any value
487    // above BITS - 1 will produce the same result.
488    let shift = u32::try_from(extract_first_u64(ecx, right)?).unwrap_or(u32::MAX);
489
490    for i in 0..dest_len {
491        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
492        let dest = ecx.project_index(&dest, i)?;
493
494        let res = match which {
495            ShiftOp::Left => {
496                let left = left.to_uint(dest.layout.size)?;
497                let res = left.checked_shl(shift).unwrap_or(0);
498                // `truncate` is needed as left-shift can make the absolute value larger.
499                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
500            }
501            ShiftOp::RightLogic => {
502                let left = left.to_uint(dest.layout.size)?;
503                let res = left.checked_shr(shift).unwrap_or(0);
504                // No `truncate` needed as right-shift can only make the absolute value smaller.
505                Scalar::from_uint(res, dest.layout.size)
506            }
507            ShiftOp::RightArith => {
508                let left = left.to_int(dest.layout.size)?;
509                // On overflow, copy the sign bit to the remaining bits
510                let res = left.checked_shr(shift).unwrap_or(left >> 127);
511                // No `truncate` needed as right-shift can only make the absolute value smaller.
512                Scalar::from_int(res, dest.layout.size)
513            }
514        };
515        ecx.write_scalar(res, &dest)?;
516    }
517
518    interp_ok(())
519}
520
521/// Shifts each element of `left` by the corresponding element of `right`.
522///
523/// For logic shifts, when right is larger than BITS - 1, zero is produced.
524/// For arithmetic right-shifts, when right is larger than BITS - 1, the sign
525/// bit is copied to all bits.
526fn shift_simd_by_simd<'tcx>(
527    ecx: &mut crate::MiriInterpCx<'tcx>,
528    left: &OpTy<'tcx>,
529    right: &OpTy<'tcx>,
530    which: ShiftOp,
531    dest: &MPlaceTy<'tcx>,
532) -> InterpResult<'tcx, ()> {
533    let (left, left_len) = ecx.project_to_simd(left)?;
534    let (right, right_len) = ecx.project_to_simd(right)?;
535    let (dest, dest_len) = ecx.project_to_simd(dest)?;
536
537    assert_eq!(dest_len, left_len);
538    assert_eq!(dest_len, right_len);
539
540    for i in 0..dest_len {
541        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?;
542        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?;
543        let dest = ecx.project_index(&dest, i)?;
544
545        // It is ok to saturate the value to u32::MAX because any value
546        // above BITS - 1 will produce the same result.
547        let shift = u32::try_from(right.to_uint(dest.layout.size)?).unwrap_or(u32::MAX);
548
549        let res = match which {
550            ShiftOp::Left => {
551                let left = left.to_uint(dest.layout.size)?;
552                let res = left.checked_shl(shift).unwrap_or(0);
553                // `truncate` is needed as left-shift can make the absolute value larger.
554                Scalar::from_uint(dest.layout.size.truncate(res), dest.layout.size)
555            }
556            ShiftOp::RightLogic => {
557                let left = left.to_uint(dest.layout.size)?;
558                let res = left.checked_shr(shift).unwrap_or(0);
559                // No `truncate` needed as right-shift can only make the absolute value smaller.
560                Scalar::from_uint(res, dest.layout.size)
561            }
562            ShiftOp::RightArith => {
563                let left = left.to_int(dest.layout.size)?;
564                // On overflow, copy the sign bit to the remaining bits
565                let res = left.checked_shr(shift).unwrap_or(left >> 127);
566                // No `truncate` needed as right-shift can only make the absolute value smaller.
567                Scalar::from_int(res, dest.layout.size)
568            }
569        };
570        ecx.write_scalar(res, &dest)?;
571    }
572
573    interp_ok(())
574}
575
576/// Takes a 128-bit vector, transmutes it to `[u64; 2]` and extracts
577/// the first value.
578fn extract_first_u64<'tcx>(
579    ecx: &crate::MiriInterpCx<'tcx>,
580    op: &OpTy<'tcx>,
581) -> InterpResult<'tcx, u64> {
582    // Transmute vector to `[u64; 2]`
583    let array_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, 2))?;
584    let op = op.transmute(array_layout, ecx)?;
585
586    // Get the first u64 from the array
587    ecx.read_scalar(&ecx.project_index(&op, 0)?)?.to_u64()
588}
589
590// Rounds the first element of `right` according to `rounding`
591// and copies the remaining elements from `left`.
592fn round_first<'tcx, F: rustc_apfloat::Float>(
593    ecx: &mut crate::MiriInterpCx<'tcx>,
594    left: &OpTy<'tcx>,
595    right: &OpTy<'tcx>,
596    rounding: &OpTy<'tcx>,
597    dest: &MPlaceTy<'tcx>,
598) -> InterpResult<'tcx, ()> {
599    let (left, left_len) = ecx.project_to_simd(left)?;
600    let (right, right_len) = ecx.project_to_simd(right)?;
601    let (dest, dest_len) = ecx.project_to_simd(dest)?;
602
603    assert_eq!(dest_len, left_len);
604    assert_eq!(dest_len, right_len);
605
606    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
607
608    let op0: F = ecx.read_scalar(&ecx.project_index(&right, 0)?)?.to_float()?;
609    let res = op0.round_to_integral(rounding).value;
610    ecx.write_scalar(
611        Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
612        &ecx.project_index(&dest, 0)?,
613    )?;
614
615    for i in 1..dest_len {
616        ecx.copy_op(&ecx.project_index(&left, i)?, &ecx.project_index(&dest, i)?)?;
617    }
618
619    interp_ok(())
620}
621
622// Rounds all elements of `op` according to `rounding`.
623fn round_all<'tcx, F: rustc_apfloat::Float>(
624    ecx: &mut crate::MiriInterpCx<'tcx>,
625    op: &OpTy<'tcx>,
626    rounding: &OpTy<'tcx>,
627    dest: &MPlaceTy<'tcx>,
628) -> InterpResult<'tcx, ()> {
629    let (op, op_len) = ecx.project_to_simd(op)?;
630    let (dest, dest_len) = ecx.project_to_simd(dest)?;
631
632    assert_eq!(dest_len, op_len);
633
634    let rounding = rounding_from_imm(ecx.read_scalar(rounding)?.to_i32()?)?;
635
636    for i in 0..dest_len {
637        let op: F = ecx.read_scalar(&ecx.project_index(&op, i)?)?.to_float()?;
638        let res = op.round_to_integral(rounding).value;
639        ecx.write_scalar(
640            Scalar::from_uint(res.to_bits(), Size::from_bits(F::BITS)),
641            &ecx.project_index(&dest, i)?,
642        )?;
643    }
644
645    interp_ok(())
646}
647
648/// Gets equivalent `rustc_apfloat::Round` from rounding mode immediate of
649/// `round.{ss,sd,ps,pd}` intrinsics.
650fn rounding_from_imm<'tcx>(rounding: i32) -> InterpResult<'tcx, rustc_apfloat::Round> {
651    // The fourth bit of `rounding` only affects the SSE status
652    // register, which cannot be accessed from Miri (or from Rust,
653    // for that matter), so we can ignore it.
654    match rounding & !0b1000 {
655        // When the third bit is 0, the rounding mode is determined by the
656        // first two bits.
657        0b000 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
658        0b001 => interp_ok(rustc_apfloat::Round::TowardNegative),
659        0b010 => interp_ok(rustc_apfloat::Round::TowardPositive),
660        0b011 => interp_ok(rustc_apfloat::Round::TowardZero),
661        // When the third bit is 1, the rounding mode is determined by the
662        // SSE status register. Since we do not support modifying it from
663        // Miri (or Rust), we assume it to be at its default mode (round-to-nearest).
664        0b100..=0b111 => interp_ok(rustc_apfloat::Round::NearestTiesToEven),
665        rounding => panic!("invalid rounding mode 0x{rounding:02x}"),
666    }
667}
668
669/// Converts each element of `op` from floating point to signed integer.
670///
671/// When the input value is NaN or out of range, fall back to minimum value.
672///
673/// If `op` has more elements than `dest`, extra elements are ignored. If `op`
674/// has less elements than `dest`, the rest is filled with zeros.
675fn convert_float_to_int<'tcx>(
676    ecx: &mut crate::MiriInterpCx<'tcx>,
677    op: &OpTy<'tcx>,
678    rnd: rustc_apfloat::Round,
679    dest: &MPlaceTy<'tcx>,
680) -> InterpResult<'tcx, ()> {
681    let (op, op_len) = ecx.project_to_simd(op)?;
682    let (dest, dest_len) = ecx.project_to_simd(dest)?;
683
684    // Output must be *signed* integers.
685    assert!(matches!(dest.layout.field(ecx, 0).ty.kind(), ty::Int(_)));
686
687    for i in 0..op_len.min(dest_len) {
688        let op = ecx.read_immediate(&ecx.project_index(&op, i)?)?;
689        let dest = ecx.project_index(&dest, i)?;
690
691        let res = ecx.float_to_int_checked(&op, dest.layout, rnd)?.unwrap_or_else(|| {
692            // Fallback to minimum according to SSE/AVX semantics.
693            ImmTy::from_int(dest.layout.size.signed_int_min(), dest.layout)
694        });
695        ecx.write_immediate(*res, &dest)?;
696    }
697    // Fill remainder with zeros
698    for i in op_len..dest_len {
699        let dest = ecx.project_index(&dest, i)?;
700        ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
701    }
702
703    interp_ok(())
704}
705
706/// Splits `op` (which must be a SIMD vector) into 128-bit chunks.
707///
708/// Returns a tuple where:
709/// * The first element is the number of 128-bit chunks (let's call it `N`).
710/// * The second element is the number of elements per chunk (let's call it `M`).
711/// * The third element is the `op` vector split into chunks, i.e, it's
712///   type is `[[T; M]; N]` where `T` is the element type of `op`.
713fn split_simd_to_128bit_chunks<'tcx, P: Projectable<'tcx, Provenance>>(
714    ecx: &mut crate::MiriInterpCx<'tcx>,
715    op: &P,
716) -> InterpResult<'tcx, (u64, u64, P)> {
717    let simd_layout = op.layout();
718    let (simd_len, element_ty) = simd_layout.ty.simd_size_and_type(ecx.tcx.tcx);
719
720    assert_eq!(simd_layout.size.bits() % 128, 0);
721    let num_chunks = simd_layout.size.bits() / 128;
722    let items_per_chunk = simd_len.strict_div(num_chunks);
723
724    // Transmute to `[[T; items_per_chunk]; num_chunks]`
725    let chunked_layout = ecx
726        .layout_of(Ty::new_array(
727            ecx.tcx.tcx,
728            Ty::new_array(ecx.tcx.tcx, element_ty, items_per_chunk),
729            num_chunks,
730        ))
731        .unwrap();
732    let chunked_op = op.transmute(chunked_layout, ecx)?;
733
734    interp_ok((num_chunks, items_per_chunk, chunked_op))
735}
736
737/// Horizontally performs `which` operation on adjacent values of
738/// `left` and `right` SIMD vectors and stores the result in `dest`.
739/// "Horizontal" means that the i-th output element is calculated
740/// from the elements 2*i and 2*i+1 of the concatenation of `left` and
741/// `right`.
742///
743/// Each 128-bit chunk is treated independently (i.e., the value for
744/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
745/// 128-bit chunks of `left` and `right`).
746fn horizontal_bin_op<'tcx>(
747    ecx: &mut crate::MiriInterpCx<'tcx>,
748    which: mir::BinOp,
749    saturating: bool,
750    left: &OpTy<'tcx>,
751    right: &OpTy<'tcx>,
752    dest: &MPlaceTy<'tcx>,
753) -> InterpResult<'tcx, ()> {
754    assert_eq!(left.layout, dest.layout);
755    assert_eq!(right.layout, dest.layout);
756
757    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
758    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
759    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
760
761    let middle = items_per_chunk / 2;
762    for i in 0..num_chunks {
763        let left = ecx.project_index(&left, i)?;
764        let right = ecx.project_index(&right, i)?;
765        let dest = ecx.project_index(&dest, i)?;
766
767        for j in 0..items_per_chunk {
768            // `j` is the index in `dest`
769            // `k` is the index of the 2-item chunk in `src`
770            let (k, src) = if j < middle { (j, &left) } else { (j.strict_sub(middle), &right) };
771            // `base_i` is the index of the first item of the 2-item chunk in `src`
772            let base_i = k.strict_mul(2);
773            let lhs = ecx.read_immediate(&ecx.project_index(src, base_i)?)?;
774            let rhs = ecx.read_immediate(&ecx.project_index(src, base_i.strict_add(1))?)?;
775
776            let res = if saturating {
777                Immediate::from(ecx.saturating_arith(which, &lhs, &rhs)?)
778            } else {
779                *ecx.binary_op(which, &lhs, &rhs)?
780            };
781
782            ecx.write_immediate(res, &ecx.project_index(&dest, j)?)?;
783        }
784    }
785
786    interp_ok(())
787}
788
789/// Conditionally multiplies the packed floating-point elements in
790/// `left` and `right` using the high 4 bits in `imm`, sums the calculated
791/// products (up to 4), and conditionally stores the sum in `dest` using
792/// the low 4 bits of `imm`.
793///
794/// Each 128-bit chunk is treated independently (i.e., the value for
795/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
796/// 128-bit blocks of `left` and `right`).
797fn conditional_dot_product<'tcx>(
798    ecx: &mut crate::MiriInterpCx<'tcx>,
799    left: &OpTy<'tcx>,
800    right: &OpTy<'tcx>,
801    imm: &OpTy<'tcx>,
802    dest: &MPlaceTy<'tcx>,
803) -> InterpResult<'tcx, ()> {
804    assert_eq!(left.layout, dest.layout);
805    assert_eq!(right.layout, dest.layout);
806
807    let (num_chunks, items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
808    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
809    let (_, _, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
810
811    let element_layout = left.layout.field(ecx, 0).field(ecx, 0);
812    assert!(items_per_chunk <= 4);
813
814    // `imm` is a `u8` for SSE4.1 or an `i32` for AVX :/
815    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
816
817    for i in 0..num_chunks {
818        let left = ecx.project_index(&left, i)?;
819        let right = ecx.project_index(&right, i)?;
820        let dest = ecx.project_index(&dest, i)?;
821
822        // Calculate dot product
823        // Elements are floating point numbers, but we can use `from_int`
824        // for the initial value because the representation of 0.0 is all zero bits.
825        let mut sum = ImmTy::from_int(0u8, element_layout);
826        for j in 0..items_per_chunk {
827            if imm & (1 << j.strict_add(4)) != 0 {
828                let left = ecx.read_immediate(&ecx.project_index(&left, j)?)?;
829                let right = ecx.read_immediate(&ecx.project_index(&right, j)?)?;
830
831                let mul = ecx.binary_op(mir::BinOp::Mul, &left, &right)?;
832                sum = ecx.binary_op(mir::BinOp::Add, &sum, &mul)?;
833            }
834        }
835
836        // Write to destination (conditioned to imm)
837        for j in 0..items_per_chunk {
838            let dest = ecx.project_index(&dest, j)?;
839
840            if imm & (1 << j) != 0 {
841                ecx.write_immediate(*sum, &dest)?;
842            } else {
843                ecx.write_scalar(Scalar::from_int(0u8, element_layout.size), &dest)?;
844            }
845        }
846    }
847
848    interp_ok(())
849}
850
851/// Calculates two booleans.
852///
853/// The first is true when all the bits of `op & mask` are zero.
854/// The second is true when `(op & mask) == mask`
855fn test_bits_masked<'tcx>(
856    ecx: &crate::MiriInterpCx<'tcx>,
857    op: &OpTy<'tcx>,
858    mask: &OpTy<'tcx>,
859) -> InterpResult<'tcx, (bool, bool)> {
860    assert_eq!(op.layout, mask.layout);
861
862    let (op, op_len) = ecx.project_to_simd(op)?;
863    let (mask, mask_len) = ecx.project_to_simd(mask)?;
864
865    assert_eq!(op_len, mask_len);
866
867    let mut all_zero = true;
868    let mut masked_set = true;
869    for i in 0..op_len {
870        let op = ecx.project_index(&op, i)?;
871        let mask = ecx.project_index(&mask, i)?;
872
873        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
874        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
875        all_zero &= (op & mask) == 0;
876        masked_set &= (op & mask) == mask;
877    }
878
879    interp_ok((all_zero, masked_set))
880}
881
882/// Calculates two booleans.
883///
884/// The first is true when the highest bit of each element of `op & mask` is zero.
885/// The second is true when the highest bit of each element of `!op & mask` is zero.
886fn test_high_bits_masked<'tcx>(
887    ecx: &crate::MiriInterpCx<'tcx>,
888    op: &OpTy<'tcx>,
889    mask: &OpTy<'tcx>,
890) -> InterpResult<'tcx, (bool, bool)> {
891    assert_eq!(op.layout, mask.layout);
892
893    let (op, op_len) = ecx.project_to_simd(op)?;
894    let (mask, mask_len) = ecx.project_to_simd(mask)?;
895
896    assert_eq!(op_len, mask_len);
897
898    let high_bit_offset = op.layout.field(ecx, 0).size.bits().strict_sub(1);
899
900    let mut direct = true;
901    let mut negated = true;
902    for i in 0..op_len {
903        let op = ecx.project_index(&op, i)?;
904        let mask = ecx.project_index(&mask, i)?;
905
906        let op = ecx.read_scalar(&op)?.to_uint(op.layout.size)?;
907        let mask = ecx.read_scalar(&mask)?.to_uint(mask.layout.size)?;
908        direct &= (op & mask) >> high_bit_offset == 0;
909        negated &= (!op & mask) >> high_bit_offset == 0;
910    }
911
912    interp_ok((direct, negated))
913}
914
915/// Conditionally loads from `ptr` according the high bit of each
916/// element of `mask`. `ptr` does not need to be aligned.
917fn mask_load<'tcx>(
918    ecx: &mut crate::MiriInterpCx<'tcx>,
919    ptr: &OpTy<'tcx>,
920    mask: &OpTy<'tcx>,
921    dest: &MPlaceTy<'tcx>,
922) -> InterpResult<'tcx, ()> {
923    let (mask, mask_len) = ecx.project_to_simd(mask)?;
924    let (dest, dest_len) = ecx.project_to_simd(dest)?;
925
926    assert_eq!(dest_len, mask_len);
927
928    let mask_item_size = mask.layout.field(ecx, 0).size;
929    let high_bit_offset = mask_item_size.bits().strict_sub(1);
930
931    let ptr = ecx.read_pointer(ptr)?;
932    for i in 0..dest_len {
933        let mask = ecx.project_index(&mask, i)?;
934        let dest = ecx.project_index(&dest, i)?;
935
936        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
937            let ptr = ptr.wrapping_offset(dest.layout.size * i, &ecx.tcx);
938            // Unaligned copy, which is what we want.
939            ecx.mem_copy(ptr, dest.ptr(), dest.layout.size, /*nonoverlapping*/ true)?;
940        } else {
941            ecx.write_scalar(Scalar::from_int(0, dest.layout.size), &dest)?;
942        }
943    }
944
945    interp_ok(())
946}
947
948/// Conditionally stores into `ptr` according the high bit of each
949/// element of `mask`. `ptr` does not need to be aligned.
950fn mask_store<'tcx>(
951    ecx: &mut crate::MiriInterpCx<'tcx>,
952    ptr: &OpTy<'tcx>,
953    mask: &OpTy<'tcx>,
954    value: &OpTy<'tcx>,
955) -> InterpResult<'tcx, ()> {
956    let (mask, mask_len) = ecx.project_to_simd(mask)?;
957    let (value, value_len) = ecx.project_to_simd(value)?;
958
959    assert_eq!(value_len, mask_len);
960
961    let mask_item_size = mask.layout.field(ecx, 0).size;
962    let high_bit_offset = mask_item_size.bits().strict_sub(1);
963
964    let ptr = ecx.read_pointer(ptr)?;
965    for i in 0..value_len {
966        let mask = ecx.project_index(&mask, i)?;
967        let value = ecx.project_index(&value, i)?;
968
969        if ecx.read_scalar(&mask)?.to_uint(mask_item_size)? >> high_bit_offset != 0 {
970            // *Non-inbounds* pointer arithmetic to compute the destination.
971            // (That's why we can't use a place projection.)
972            let ptr = ptr.wrapping_offset(value.layout.size * i, &ecx.tcx);
973            // Deref the pointer *unaligned*, and do the copy.
974            let dest = ecx.ptr_to_mplace_unaligned(ptr, value.layout);
975            ecx.copy_op(&value, &dest)?;
976        }
977    }
978
979    interp_ok(())
980}
981
982/// Compute the sum of absolute differences of quadruplets of unsigned
983/// 8-bit integers in `left` and `right`, and store the 16-bit results
984/// in `right`. Quadruplets are selected from `left` and `right` with
985/// offsets specified in `imm`.
986///
987/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_maddubs_epi16>
988/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mpsadbw_epu8>
989///
990/// Each 128-bit chunk is treated independently (i.e., the value for
991/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
992/// 128-bit chunks of `left` and `right`).
993fn mpsadbw<'tcx>(
994    ecx: &mut crate::MiriInterpCx<'tcx>,
995    left: &OpTy<'tcx>,
996    right: &OpTy<'tcx>,
997    imm: &OpTy<'tcx>,
998    dest: &MPlaceTy<'tcx>,
999) -> InterpResult<'tcx, ()> {
1000    assert_eq!(left.layout, right.layout);
1001    assert_eq!(left.layout.size, dest.layout.size);
1002
1003    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1004    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1005    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1006
1007    assert_eq!(op_items_per_chunk, dest_items_per_chunk.strict_mul(2));
1008
1009    let imm = ecx.read_scalar(imm)?.to_uint(imm.layout.size)?;
1010    // Bit 2 of `imm` specifies the offset for indices of `left`.
1011    // The offset is 0 when the bit is 0 or 4 when the bit is 1.
1012    let left_offset = u64::try_from((imm >> 2) & 1).unwrap().strict_mul(4);
1013    // Bits 0..=1 of `imm` specify the offset for indices of
1014    // `right` in blocks of 4 elements.
1015    let right_offset = u64::try_from(imm & 0b11).unwrap().strict_mul(4);
1016
1017    for i in 0..num_chunks {
1018        let left = ecx.project_index(&left, i)?;
1019        let right = ecx.project_index(&right, i)?;
1020        let dest = ecx.project_index(&dest, i)?;
1021
1022        for j in 0..dest_items_per_chunk {
1023            let left_offset = left_offset.strict_add(j);
1024            let mut res: u16 = 0;
1025            for k in 0..4 {
1026                let left = ecx
1027                    .read_scalar(&ecx.project_index(&left, left_offset.strict_add(k))?)?
1028                    .to_u8()?;
1029                let right = ecx
1030                    .read_scalar(&ecx.project_index(&right, right_offset.strict_add(k))?)?
1031                    .to_u8()?;
1032                res = res.strict_add(left.abs_diff(right).into());
1033            }
1034            ecx.write_scalar(Scalar::from_u16(res), &ecx.project_index(&dest, j)?)?;
1035        }
1036    }
1037
1038    interp_ok(())
1039}
1040
1041/// Compute the absolute differences of packed unsigned 8-bit integers
1042/// in `left` and `right`, then horizontally sum each consecutive 8
1043/// differences to produce unsigned 16-bit integers, and pack
1044/// these unsigned 16-bit integers in the low 16 bits of 64-bit elements
1045/// in `dest`.
1046///
1047/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_sad_epu8>
1048/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_sad_epu8>
1049/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_sad_epu8>
1050fn psadbw<'tcx>(
1051    ecx: &mut crate::MiriInterpCx<'tcx>,
1052    left: &OpTy<'tcx>,
1053    right: &OpTy<'tcx>,
1054    dest: &MPlaceTy<'tcx>,
1055) -> InterpResult<'tcx, ()> {
1056    let (left, left_len) = ecx.project_to_simd(left)?;
1057    let (right, right_len) = ecx.project_to_simd(right)?;
1058    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1059
1060    // fn psadbw(a: u8x16, b: u8x16) -> u64x2;
1061    // fn psadbw(a: u8x32, b: u8x32) -> u64x4;
1062    // fn vpsadbw(a: u8x64, b: u8x64) -> u64x8;
1063    assert_eq!(left_len, right_len);
1064    assert_eq!(left_len, left.layout.layout.size().bytes());
1065    assert_eq!(dest_len, left_len.strict_div(8));
1066
1067    for i in 0..dest_len {
1068        let dest = ecx.project_index(&dest, i)?;
1069
1070        let mut acc: u16 = 0;
1071        for j in 0..8 {
1072            let src_index = i.strict_mul(8).strict_add(j);
1073
1074            let left = ecx.project_index(&left, src_index)?;
1075            let left = ecx.read_scalar(&left)?.to_u8()?;
1076
1077            let right = ecx.project_index(&right, src_index)?;
1078            let right = ecx.read_scalar(&right)?.to_u8()?;
1079
1080            acc = acc.strict_add(left.abs_diff(right).into());
1081        }
1082
1083        ecx.write_scalar(Scalar::from_u64(acc.into()), &dest)?;
1084    }
1085
1086    interp_ok(())
1087}
1088
1089/// Multiplies packed 16-bit signed integer values, truncates the 32-bit
1090/// product to the 18 most significant bits by right-shifting, and then
1091/// divides the 18-bit value by 2 (rounding to nearest) by first adding
1092/// 1 and then taking the bits `1..=16`.
1093///
1094/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_mulhrs_epi16>
1095/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_mulhrs_epi16>
1096fn pmulhrsw<'tcx>(
1097    ecx: &mut crate::MiriInterpCx<'tcx>,
1098    left: &OpTy<'tcx>,
1099    right: &OpTy<'tcx>,
1100    dest: &MPlaceTy<'tcx>,
1101) -> InterpResult<'tcx, ()> {
1102    let (left, left_len) = ecx.project_to_simd(left)?;
1103    let (right, right_len) = ecx.project_to_simd(right)?;
1104    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1105
1106    assert_eq!(dest_len, left_len);
1107    assert_eq!(dest_len, right_len);
1108
1109    for i in 0..dest_len {
1110        let left = ecx.read_scalar(&ecx.project_index(&left, i)?)?.to_i16()?;
1111        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_i16()?;
1112        let dest = ecx.project_index(&dest, i)?;
1113
1114        let res = (i32::from(left).strict_mul(right.into()) >> 14).strict_add(1) >> 1;
1115
1116        // The result of this operation can overflow a signed 16-bit integer.
1117        // When `left` and `right` are -0x8000, the result is 0x8000.
1118        #[expect(clippy::as_conversions)]
1119        let res = res as i16;
1120
1121        ecx.write_scalar(Scalar::from_i16(res), &dest)?;
1122    }
1123
1124    interp_ok(())
1125}
1126
1127/// Perform a carry-less multiplication of two 64-bit integers, selected from `left` and `right` according to `imm8`,
1128/// and store the results in `dst`.
1129///
1130/// `left` and `right` are both vectors of type `len` x i64. Only bits 0 and 4 of `imm8` matter;
1131/// they select the element of `left` and `right`, respectively.
1132///
1133/// `len` is the SIMD vector length (in counts of `i64` values). It is expected to be one of
1134/// `2`, `4`, or `8`.
1135///
1136/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_clmulepi64_si128>
1137fn pclmulqdq<'tcx>(
1138    ecx: &mut MiriInterpCx<'tcx>,
1139    left: &OpTy<'tcx>,
1140    right: &OpTy<'tcx>,
1141    imm8: &OpTy<'tcx>,
1142    dest: &MPlaceTy<'tcx>,
1143    len: u64,
1144) -> InterpResult<'tcx, ()> {
1145    assert_eq!(left.layout, right.layout);
1146    assert_eq!(left.layout.size, dest.layout.size);
1147    assert!([2u64, 4, 8].contains(&len));
1148
1149    // Transmute the input into arrays of `[u64; len]`.
1150    // Transmute the output into an array of `[u128, len / 2]`.
1151
1152    let src_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u64, len))?;
1153    let dest_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u128, len / 2))?;
1154
1155    let left = left.transmute(src_layout, ecx)?;
1156    let right = right.transmute(src_layout, ecx)?;
1157    let dest = dest.transmute(dest_layout, ecx)?;
1158
1159    let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
1160
1161    for i in 0..(len / 2) {
1162        let lo = i.strict_mul(2);
1163        let hi = i.strict_mul(2).strict_add(1);
1164
1165        // select the 64-bit integer from left that the user specified (low or high)
1166        let index = if (imm8 & 0x01) == 0 { lo } else { hi };
1167        let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u64()?;
1168
1169        // select the 64-bit integer from right that the user specified (low or high)
1170        let index = if (imm8 & 0x10) == 0 { lo } else { hi };
1171        let right = ecx.read_scalar(&ecx.project_index(&right, index)?)?.to_u64()?;
1172
1173        // Perform carry-less multiplication.
1174        //
1175        // This operation is like long multiplication, but ignores all carries.
1176        // That idea corresponds to the xor operator, which is used in the implementation.
1177        //
1178        // Wikipedia has an example https://en.wikipedia.org/wiki/Carry-less_product#Example
1179        let mut result: u128 = 0;
1180
1181        for i in 0..64 {
1182            // if the i-th bit in right is set
1183            if (right & (1 << i)) != 0 {
1184                // xor result with `left` shifted to the left by i positions
1185                result ^= u128::from(left) << i;
1186            }
1187        }
1188
1189        let dest = ecx.project_index(&dest, i)?;
1190        ecx.write_scalar(Scalar::from_u128(result), &dest)?;
1191    }
1192
1193    interp_ok(())
1194}
1195
1196/// Packs two N-bit integer vectors to a single N/2-bit integers.
1197///
1198/// The conversion from N-bit to N/2-bit should be provided by `f`.
1199///
1200/// Each 128-bit chunk is treated independently (i.e., the value for
1201/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1202/// 128-bit chunks of `left` and `right`).
1203fn pack_generic<'tcx>(
1204    ecx: &mut crate::MiriInterpCx<'tcx>,
1205    left: &OpTy<'tcx>,
1206    right: &OpTy<'tcx>,
1207    dest: &MPlaceTy<'tcx>,
1208    f: impl Fn(Scalar) -> InterpResult<'tcx, Scalar>,
1209) -> InterpResult<'tcx, ()> {
1210    assert_eq!(left.layout, right.layout);
1211    assert_eq!(left.layout.size, dest.layout.size);
1212
1213    let (num_chunks, op_items_per_chunk, left) = split_simd_to_128bit_chunks(ecx, left)?;
1214    let (_, _, right) = split_simd_to_128bit_chunks(ecx, right)?;
1215    let (_, dest_items_per_chunk, dest) = split_simd_to_128bit_chunks(ecx, dest)?;
1216
1217    assert_eq!(dest_items_per_chunk, op_items_per_chunk.strict_mul(2));
1218
1219    for i in 0..num_chunks {
1220        let left = ecx.project_index(&left, i)?;
1221        let right = ecx.project_index(&right, i)?;
1222        let dest = ecx.project_index(&dest, i)?;
1223
1224        for j in 0..op_items_per_chunk {
1225            let left = ecx.read_scalar(&ecx.project_index(&left, j)?)?;
1226            let right = ecx.read_scalar(&ecx.project_index(&right, j)?)?;
1227            let left_dest = ecx.project_index(&dest, j)?;
1228            let right_dest = ecx.project_index(&dest, j.strict_add(op_items_per_chunk))?;
1229
1230            let left_res = f(left)?;
1231            let right_res = f(right)?;
1232
1233            ecx.write_scalar(left_res, &left_dest)?;
1234            ecx.write_scalar(right_res, &right_dest)?;
1235        }
1236    }
1237
1238    interp_ok(())
1239}
1240
1241/// Converts two 16-bit integer vectors to a single 8-bit integer
1242/// vector with signed saturation.
1243///
1244/// Each 128-bit chunk is treated independently (i.e., the value for
1245/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1246/// 128-bit chunks of `left` and `right`).
1247fn packsswb<'tcx>(
1248    ecx: &mut crate::MiriInterpCx<'tcx>,
1249    left: &OpTy<'tcx>,
1250    right: &OpTy<'tcx>,
1251    dest: &MPlaceTy<'tcx>,
1252) -> InterpResult<'tcx, ()> {
1253    pack_generic(ecx, left, right, dest, |op| {
1254        let op = op.to_i16()?;
1255        let res = i8::try_from(op).unwrap_or(if op < 0 { i8::MIN } else { i8::MAX });
1256        interp_ok(Scalar::from_i8(res))
1257    })
1258}
1259
1260/// Converts two 16-bit signed integer vectors to a single 8-bit
1261/// unsigned integer vector with saturation.
1262///
1263/// Each 128-bit chunk is treated independently (i.e., the value for
1264/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1265/// 128-bit chunks of `left` and `right`).
1266fn packuswb<'tcx>(
1267    ecx: &mut crate::MiriInterpCx<'tcx>,
1268    left: &OpTy<'tcx>,
1269    right: &OpTy<'tcx>,
1270    dest: &MPlaceTy<'tcx>,
1271) -> InterpResult<'tcx, ()> {
1272    pack_generic(ecx, left, right, dest, |op| {
1273        let op = op.to_i16()?;
1274        let res = u8::try_from(op).unwrap_or(if op < 0 { 0 } else { u8::MAX });
1275        interp_ok(Scalar::from_u8(res))
1276    })
1277}
1278
1279/// Converts two 32-bit integer vectors to a single 16-bit integer
1280/// vector with signed saturation.
1281///
1282/// Each 128-bit chunk is treated independently (i.e., the value for
1283/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1284/// 128-bit chunks of `left` and `right`).
1285fn packssdw<'tcx>(
1286    ecx: &mut crate::MiriInterpCx<'tcx>,
1287    left: &OpTy<'tcx>,
1288    right: &OpTy<'tcx>,
1289    dest: &MPlaceTy<'tcx>,
1290) -> InterpResult<'tcx, ()> {
1291    pack_generic(ecx, left, right, dest, |op| {
1292        let op = op.to_i32()?;
1293        let res = i16::try_from(op).unwrap_or(if op < 0 { i16::MIN } else { i16::MAX });
1294        interp_ok(Scalar::from_i16(res))
1295    })
1296}
1297
1298/// Converts two 32-bit integer vectors to a single 16-bit integer
1299/// vector with unsigned saturation.
1300///
1301/// Each 128-bit chunk is treated independently (i.e., the value for
1302/// the is i-th 128-bit chunk of `dest` is calculated with the i-th
1303/// 128-bit chunks of `left` and `right`).
1304fn packusdw<'tcx>(
1305    ecx: &mut crate::MiriInterpCx<'tcx>,
1306    left: &OpTy<'tcx>,
1307    right: &OpTy<'tcx>,
1308    dest: &MPlaceTy<'tcx>,
1309) -> InterpResult<'tcx, ()> {
1310    pack_generic(ecx, left, right, dest, |op| {
1311        let op = op.to_i32()?;
1312        let res = u16::try_from(op).unwrap_or(if op < 0 { 0 } else { u16::MAX });
1313        interp_ok(Scalar::from_u16(res))
1314    })
1315}
1316
1317/// Negates elements from `left` when the corresponding element in
1318/// `right` is negative. If an element from `right` is zero, zero
1319/// is written to the corresponding output element.
1320/// In other words, multiplies `left` with `right.signum()`.
1321fn psign<'tcx>(
1322    ecx: &mut crate::MiriInterpCx<'tcx>,
1323    left: &OpTy<'tcx>,
1324    right: &OpTy<'tcx>,
1325    dest: &MPlaceTy<'tcx>,
1326) -> InterpResult<'tcx, ()> {
1327    let (left, left_len) = ecx.project_to_simd(left)?;
1328    let (right, right_len) = ecx.project_to_simd(right)?;
1329    let (dest, dest_len) = ecx.project_to_simd(dest)?;
1330
1331    assert_eq!(dest_len, left_len);
1332    assert_eq!(dest_len, right_len);
1333
1334    for i in 0..dest_len {
1335        let dest = ecx.project_index(&dest, i)?;
1336        let left = ecx.read_immediate(&ecx.project_index(&left, i)?)?;
1337        let right = ecx.read_scalar(&ecx.project_index(&right, i)?)?.to_int(dest.layout.size)?;
1338
1339        let res =
1340            ecx.binary_op(mir::BinOp::Mul, &left, &ImmTy::from_int(right.signum(), dest.layout))?;
1341
1342        ecx.write_immediate(*res, &dest)?;
1343    }
1344
1345    interp_ok(())
1346}
1347
1348/// Calcultates either `a + b + cb_in` or `a - b - cb_in` depending on the value
1349/// of `op` and returns both the sum and the overflow bit. `op` is expected to be
1350/// either one of `mir::BinOp::AddWithOverflow` and `mir::BinOp::SubWithOverflow`.
1351fn carrying_add<'tcx>(
1352    ecx: &mut crate::MiriInterpCx<'tcx>,
1353    cb_in: &OpTy<'tcx>,
1354    a: &OpTy<'tcx>,
1355    b: &OpTy<'tcx>,
1356    op: mir::BinOp,
1357) -> InterpResult<'tcx, (ImmTy<'tcx>, Scalar)> {
1358    assert!(op == mir::BinOp::AddWithOverflow || op == mir::BinOp::SubWithOverflow);
1359
1360    let cb_in = ecx.read_scalar(cb_in)?.to_u8()? != 0;
1361    let a = ecx.read_immediate(a)?;
1362    let b = ecx.read_immediate(b)?;
1363
1364    let (sum, overflow1) = ecx.binary_op(op, &a, &b)?.to_pair(ecx);
1365    let (sum, overflow2) =
1366        ecx.binary_op(op, &sum, &ImmTy::from_uint(cb_in, a.layout))?.to_pair(ecx);
1367    let cb_out = overflow1.to_scalar().to_bool()? | overflow2.to_scalar().to_bool()?;
1368
1369    interp_ok((sum, Scalar::from_u8(cb_out.into())))
1370}