miri/intrinsics/
simd.rs

1use either::Either;
2use rand::Rng;
3use rustc_abi::{Endian, HasDataLayout};
4use rustc_apfloat::{Float, Round};
5use rustc_middle::ty::FloatTy;
6use rustc_middle::ty::layout::LayoutOf;
7use rustc_middle::{mir, ty};
8use rustc_span::{Symbol, sym};
9
10use crate::helpers::{
11    ToHost, ToSoft, bool_to_simd_element, check_intrinsic_arg_count, simd_element_to_bool,
12};
13use crate::*;
14
15#[derive(Copy, Clone)]
16pub(crate) enum MinMax {
17    Min,
18    Max,
19}
20
21impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
22pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
23    /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
24    /// Returns `Ok(true)` if the intrinsic was handled.
25    fn emulate_simd_intrinsic(
26        &mut self,
27        intrinsic_name: &str,
28        generic_args: ty::GenericArgsRef<'tcx>,
29        args: &[OpTy<'tcx>],
30        dest: &MPlaceTy<'tcx>,
31    ) -> InterpResult<'tcx, EmulateItemResult> {
32        let this = self.eval_context_mut();
33        match intrinsic_name {
34            #[rustfmt::skip]
35            | "neg"
36            | "fabs"
37            | "ceil"
38            | "floor"
39            | "round"
40            | "trunc"
41            | "fsqrt"
42            | "fsin"
43            | "fcos"
44            | "fexp"
45            | "fexp2"
46            | "flog"
47            | "flog2"
48            | "flog10"
49            | "ctlz"
50            | "ctpop"
51            | "cttz"
52            | "bswap"
53            | "bitreverse"
54            => {
55                let [op] = check_intrinsic_arg_count(args)?;
56                let (op, op_len) = this.project_to_simd(op)?;
57                let (dest, dest_len) = this.project_to_simd(dest)?;
58
59                assert_eq!(dest_len, op_len);
60
61                #[derive(Copy, Clone)]
62                enum Op<'a> {
63                    MirOp(mir::UnOp),
64                    Abs,
65                    Round(rustc_apfloat::Round),
66                    Numeric(Symbol),
67                    HostOp(&'a str),
68                }
69                let which = match intrinsic_name {
70                    "neg" => Op::MirOp(mir::UnOp::Neg),
71                    "fabs" => Op::Abs,
72                    "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive),
73                    "floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
74                    "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
75                    "trunc" => Op::Round(rustc_apfloat::Round::TowardZero),
76                    "ctlz" => Op::Numeric(sym::ctlz),
77                    "ctpop" => Op::Numeric(sym::ctpop),
78                    "cttz" => Op::Numeric(sym::cttz),
79                    "bswap" => Op::Numeric(sym::bswap),
80                    "bitreverse" => Op::Numeric(sym::bitreverse),
81                    _ => Op::HostOp(intrinsic_name),
82                };
83
84                for i in 0..dest_len {
85                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
86                    let dest = this.project_index(&dest, i)?;
87                    let val = match which {
88                        Op::MirOp(mir_op) => {
89                            // This already does NaN adjustments
90                            this.unary_op(mir_op, &op)?.to_scalar()
91                        }
92                        Op::Abs => {
93                            // Works for f32 and f64.
94                            let ty::Float(float_ty) = op.layout.ty.kind() else {
95                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
96                            };
97                            let op = op.to_scalar();
98                            // "Bitwise" operation, no NaN adjustments
99                            match float_ty {
100                                FloatTy::F16 => unimplemented!("f16_f128"),
101                                FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
102                                FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
103                                FloatTy::F128 => unimplemented!("f16_f128"),
104                            }
105                        }
106                        Op::HostOp(host_op) => {
107                            let ty::Float(float_ty) = op.layout.ty.kind() else {
108                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
109                            };
110                            // Using host floats except for sqrt (but it's fine, these operations do not
111                            // have guaranteed precision).
112                            match float_ty {
113                                FloatTy::F16 => unimplemented!("f16_f128"),
114                                FloatTy::F32 => {
115                                    let f = op.to_scalar().to_f32()?;
116                                    let res = match host_op {
117                                        "fsqrt" => math::sqrt(f),
118                                        "fsin" => f.to_host().sin().to_soft(),
119                                        "fcos" => f.to_host().cos().to_soft(),
120                                        "fexp" => f.to_host().exp().to_soft(),
121                                        "fexp2" => f.to_host().exp2().to_soft(),
122                                        "flog" => f.to_host().ln().to_soft(),
123                                        "flog2" => f.to_host().log2().to_soft(),
124                                        "flog10" => f.to_host().log10().to_soft(),
125                                        _ => bug!(),
126                                    };
127                                    let res = this.adjust_nan(res, &[f]);
128                                    Scalar::from(res)
129                                }
130                                FloatTy::F64 => {
131                                    let f = op.to_scalar().to_f64()?;
132                                    let res = match host_op {
133                                        "fsqrt" => math::sqrt(f),
134                                        "fsin" => f.to_host().sin().to_soft(),
135                                        "fcos" => f.to_host().cos().to_soft(),
136                                        "fexp" => f.to_host().exp().to_soft(),
137                                        "fexp2" => f.to_host().exp2().to_soft(),
138                                        "flog" => f.to_host().ln().to_soft(),
139                                        "flog2" => f.to_host().log2().to_soft(),
140                                        "flog10" => f.to_host().log10().to_soft(),
141                                        _ => bug!(),
142                                    };
143                                    let res = this.adjust_nan(res, &[f]);
144                                    Scalar::from(res)
145                                }
146                                FloatTy::F128 => unimplemented!("f16_f128"),
147                            }
148                        }
149                        Op::Round(rounding) => {
150                            let ty::Float(float_ty) = op.layout.ty.kind() else {
151                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
152                            };
153                            match float_ty {
154                                FloatTy::F16 => unimplemented!("f16_f128"),
155                                FloatTy::F32 => {
156                                    let f = op.to_scalar().to_f32()?;
157                                    let res = f.round_to_integral(rounding).value;
158                                    let res = this.adjust_nan(res, &[f]);
159                                    Scalar::from_f32(res)
160                                }
161                                FloatTy::F64 => {
162                                    let f = op.to_scalar().to_f64()?;
163                                    let res = f.round_to_integral(rounding).value;
164                                    let res = this.adjust_nan(res, &[f]);
165                                    Scalar::from_f64(res)
166                                }
167                                FloatTy::F128 => unimplemented!("f16_f128"),
168                            }
169                        }
170                        Op::Numeric(name) => {
171                            this.numeric_intrinsic(name, op.to_scalar(), op.layout, op.layout)?
172                        }
173                    };
174                    this.write_scalar(val, &dest)?;
175                }
176            }
177            #[rustfmt::skip]
178            | "add"
179            | "sub"
180            | "mul"
181            | "div"
182            | "rem"
183            | "shl"
184            | "shr"
185            | "and"
186            | "or"
187            | "xor"
188            | "eq"
189            | "ne"
190            | "lt"
191            | "le"
192            | "gt"
193            | "ge"
194            | "fmax"
195            | "fmin"
196            | "saturating_add"
197            | "saturating_sub"
198            | "arith_offset"
199            => {
200                use mir::BinOp;
201
202                let [left, right] = check_intrinsic_arg_count(args)?;
203                let (left, left_len) = this.project_to_simd(left)?;
204                let (right, right_len) = this.project_to_simd(right)?;
205                let (dest, dest_len) = this.project_to_simd(dest)?;
206
207                assert_eq!(dest_len, left_len);
208                assert_eq!(dest_len, right_len);
209
210                enum Op {
211                    MirOp(BinOp),
212                    SaturatingOp(BinOp),
213                    FMinMax(MinMax),
214                    WrappingOffset,
215                }
216                let which = match intrinsic_name {
217                    "add" => Op::MirOp(BinOp::Add),
218                    "sub" => Op::MirOp(BinOp::Sub),
219                    "mul" => Op::MirOp(BinOp::Mul),
220                    "div" => Op::MirOp(BinOp::Div),
221                    "rem" => Op::MirOp(BinOp::Rem),
222                    "shl" => Op::MirOp(BinOp::ShlUnchecked),
223                    "shr" => Op::MirOp(BinOp::ShrUnchecked),
224                    "and" => Op::MirOp(BinOp::BitAnd),
225                    "or" => Op::MirOp(BinOp::BitOr),
226                    "xor" => Op::MirOp(BinOp::BitXor),
227                    "eq" => Op::MirOp(BinOp::Eq),
228                    "ne" => Op::MirOp(BinOp::Ne),
229                    "lt" => Op::MirOp(BinOp::Lt),
230                    "le" => Op::MirOp(BinOp::Le),
231                    "gt" => Op::MirOp(BinOp::Gt),
232                    "ge" => Op::MirOp(BinOp::Ge),
233                    "fmax" => Op::FMinMax(MinMax::Max),
234                    "fmin" => Op::FMinMax(MinMax::Min),
235                    "saturating_add" => Op::SaturatingOp(BinOp::Add),
236                    "saturating_sub" => Op::SaturatingOp(BinOp::Sub),
237                    "arith_offset" => Op::WrappingOffset,
238                    _ => unreachable!(),
239                };
240
241                for i in 0..dest_len {
242                    let left = this.read_immediate(&this.project_index(&left, i)?)?;
243                    let right = this.read_immediate(&this.project_index(&right, i)?)?;
244                    let dest = this.project_index(&dest, i)?;
245                    let val = match which {
246                        Op::MirOp(mir_op) => {
247                            // This does NaN adjustments.
248                            let val = this.binary_op(mir_op, &left, &right).map_err_kind(|kind| {
249                                match kind {
250                                    InterpErrorKind::UndefinedBehavior(UndefinedBehaviorInfo::ShiftOverflow { shift_amount, .. }) => {
251                                        // This resets the interpreter backtrace, but it's not worth avoiding that.
252                                        let shift_amount = match shift_amount {
253                                            Either::Left(v) => v.to_string(),
254                                            Either::Right(v) => v.to_string(),
255                                        };
256                                        err_ub_format!("overflowing shift by {shift_amount} in `simd_{intrinsic_name}` in lane {i}")
257                                    }
258                                    kind => kind
259                                }
260                            })?;
261                            if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
262                                // Special handling for boolean-returning operations
263                                assert_eq!(val.layout.ty, this.tcx.types.bool);
264                                let val = val.to_scalar().to_bool().unwrap();
265                                bool_to_simd_element(val, dest.layout.size)
266                            } else {
267                                assert_ne!(val.layout.ty, this.tcx.types.bool);
268                                assert_eq!(val.layout.ty, dest.layout.ty);
269                                val.to_scalar()
270                            }
271                        }
272                        Op::SaturatingOp(mir_op) => {
273                            this.saturating_arith(mir_op, &left, &right)?
274                        }
275                        Op::WrappingOffset => {
276                            let ptr = left.to_scalar().to_pointer(this)?;
277                            let offset_count = right.to_scalar().to_target_isize(this)?;
278                            let pointee_ty = left.layout.ty.builtin_deref(true).unwrap();
279
280                            let pointee_size = i64::try_from(this.layout_of(pointee_ty)?.size.bytes()).unwrap();
281                            let offset_bytes = offset_count.wrapping_mul(pointee_size);
282                            let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
283                            Scalar::from_maybe_pointer(offset_ptr, this)
284                        }
285                        Op::FMinMax(op) => {
286                            this.fminmax_op(op, &left, &right)?
287                        }
288                    };
289                    this.write_scalar(val, &dest)?;
290                }
291            }
292            "fma" | "relaxed_fma" => {
293                let [a, b, c] = check_intrinsic_arg_count(args)?;
294                let (a, a_len) = this.project_to_simd(a)?;
295                let (b, b_len) = this.project_to_simd(b)?;
296                let (c, c_len) = this.project_to_simd(c)?;
297                let (dest, dest_len) = this.project_to_simd(dest)?;
298
299                assert_eq!(dest_len, a_len);
300                assert_eq!(dest_len, b_len);
301                assert_eq!(dest_len, c_len);
302
303                for i in 0..dest_len {
304                    let a = this.read_scalar(&this.project_index(&a, i)?)?;
305                    let b = this.read_scalar(&this.project_index(&b, i)?)?;
306                    let c = this.read_scalar(&this.project_index(&c, i)?)?;
307                    let dest = this.project_index(&dest, i)?;
308
309                    let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().random();
310
311                    // Works for f32 and f64.
312                    // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
313                    let ty::Float(float_ty) = dest.layout.ty.kind() else {
314                        span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
315                    };
316                    let val = match float_ty {
317                        FloatTy::F16 => unimplemented!("f16_f128"),
318                        FloatTy::F32 => {
319                            let a = a.to_f32()?;
320                            let b = b.to_f32()?;
321                            let c = c.to_f32()?;
322                            let res = if fuse {
323                                a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
324                            } else {
325                                ((a * b).value + c).value
326                            };
327                            let res = this.adjust_nan(res, &[a, b, c]);
328                            Scalar::from(res)
329                        }
330                        FloatTy::F64 => {
331                            let a = a.to_f64()?;
332                            let b = b.to_f64()?;
333                            let c = c.to_f64()?;
334                            let res = if fuse {
335                                a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
336                            } else {
337                                ((a * b).value + c).value
338                            };
339                            let res = this.adjust_nan(res, &[a, b, c]);
340                            Scalar::from(res)
341                        }
342                        FloatTy::F128 => unimplemented!("f16_f128"),
343                    };
344                    this.write_scalar(val, &dest)?;
345                }
346            }
347            #[rustfmt::skip]
348            | "reduce_and"
349            | "reduce_or"
350            | "reduce_xor"
351            | "reduce_any"
352            | "reduce_all"
353            | "reduce_max"
354            | "reduce_min" => {
355                use mir::BinOp;
356
357                let [op] = check_intrinsic_arg_count(args)?;
358                let (op, op_len) = this.project_to_simd(op)?;
359
360                let imm_from_bool =
361                    |b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
362
363                enum Op {
364                    MirOp(BinOp),
365                    MirOpBool(BinOp),
366                    MinMax(MinMax),
367                }
368                let which = match intrinsic_name {
369                    "reduce_and" => Op::MirOp(BinOp::BitAnd),
370                    "reduce_or" => Op::MirOp(BinOp::BitOr),
371                    "reduce_xor" => Op::MirOp(BinOp::BitXor),
372                    "reduce_any" => Op::MirOpBool(BinOp::BitOr),
373                    "reduce_all" => Op::MirOpBool(BinOp::BitAnd),
374                    "reduce_max" => Op::MinMax(MinMax::Max),
375                    "reduce_min" => Op::MinMax(MinMax::Min),
376                    _ => unreachable!(),
377                };
378
379                // Initialize with first lane, then proceed with the rest.
380                let mut res = this.read_immediate(&this.project_index(&op, 0)?)?;
381                if matches!(which, Op::MirOpBool(_)) {
382                    // Convert to `bool` scalar.
383                    res = imm_from_bool(simd_element_to_bool(res)?);
384                }
385                for i in 1..op_len {
386                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
387                    res = match which {
388                        Op::MirOp(mir_op) => {
389                            this.binary_op(mir_op, &res, &op)?
390                        }
391                        Op::MirOpBool(mir_op) => {
392                            let op = imm_from_bool(simd_element_to_bool(op)?);
393                            this.binary_op(mir_op, &res, &op)?
394                        }
395                        Op::MinMax(mmop) => {
396                            if matches!(res.layout.ty.kind(), ty::Float(_)) {
397                                ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout)
398                            } else {
399                                // Just boring integers, so NaNs to worry about
400                                let mirop = match mmop {
401                                    MinMax::Min => BinOp::Le,
402                                    MinMax::Max => BinOp::Ge,
403                                };
404                                if this.binary_op(mirop, &res, &op)?.to_scalar().to_bool()? {
405                                    res
406                                } else {
407                                    op
408                                }
409                            }
410                        }
411                    };
412                }
413                this.write_immediate(*res, dest)?;
414            }
415            #[rustfmt::skip]
416            | "reduce_add_ordered"
417            | "reduce_mul_ordered" => {
418                use mir::BinOp;
419
420                let [op, init] = check_intrinsic_arg_count(args)?;
421                let (op, op_len) = this.project_to_simd(op)?;
422                let init = this.read_immediate(init)?;
423
424                let mir_op = match intrinsic_name {
425                    "reduce_add_ordered" => BinOp::Add,
426                    "reduce_mul_ordered" => BinOp::Mul,
427                    _ => unreachable!(),
428                };
429
430                let mut res = init;
431                for i in 0..op_len {
432                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
433                    res = this.binary_op(mir_op, &res, &op)?;
434                }
435                this.write_immediate(*res, dest)?;
436            }
437            "select" => {
438                let [mask, yes, no] = check_intrinsic_arg_count(args)?;
439                let (mask, mask_len) = this.project_to_simd(mask)?;
440                let (yes, yes_len) = this.project_to_simd(yes)?;
441                let (no, no_len) = this.project_to_simd(no)?;
442                let (dest, dest_len) = this.project_to_simd(dest)?;
443
444                assert_eq!(dest_len, mask_len);
445                assert_eq!(dest_len, yes_len);
446                assert_eq!(dest_len, no_len);
447
448                for i in 0..dest_len {
449                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
450                    let yes = this.read_immediate(&this.project_index(&yes, i)?)?;
451                    let no = this.read_immediate(&this.project_index(&no, i)?)?;
452                    let dest = this.project_index(&dest, i)?;
453
454                    let val = if simd_element_to_bool(mask)? { yes } else { no };
455                    this.write_immediate(*val, &dest)?;
456                }
457            }
458            // Variant of `select` that takes a bitmask rather than a "vector of bool".
459            "select_bitmask" => {
460                let [mask, yes, no] = check_intrinsic_arg_count(args)?;
461                let (yes, yes_len) = this.project_to_simd(yes)?;
462                let (no, no_len) = this.project_to_simd(no)?;
463                let (dest, dest_len) = this.project_to_simd(dest)?;
464                let bitmask_len = dest_len.next_multiple_of(8);
465                if bitmask_len > 64 {
466                    throw_unsup_format!(
467                        "simd_select_bitmask: vectors larger than 64 elements are currently not supported"
468                    );
469                }
470
471                assert_eq!(dest_len, yes_len);
472                assert_eq!(dest_len, no_len);
473
474                // Read the mask, either as an integer or as an array.
475                let mask: u64 = match mask.layout.ty.kind() {
476                    ty::Uint(_) => {
477                        // Any larger integer type is fine.
478                        assert!(mask.layout.size.bits() >= bitmask_len);
479                        this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
480                    }
481                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
482                        // The array must have exactly the right size.
483                        assert_eq!(mask.layout.size.bits(), bitmask_len);
484                        // Read the raw bytes.
485                        let mask = mask.assert_mem_place(); // arrays cannot be immediate
486                        let mask_bytes =
487                            this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
488                        // Turn them into a `u64` in the right way.
489                        let mask_size = mask.layout.size.bytes_usize();
490                        let mut mask_arr = [0u8; 8];
491                        match this.data_layout().endian {
492                            Endian::Little => {
493                                // Fill the first N bytes.
494                                mask_arr[..mask_size].copy_from_slice(mask_bytes);
495                                u64::from_le_bytes(mask_arr)
496                            }
497                            Endian::Big => {
498                                // Fill the last N bytes.
499                                let i = mask_arr.len().strict_sub(mask_size);
500                                mask_arr[i..].copy_from_slice(mask_bytes);
501                                u64::from_be_bytes(mask_arr)
502                            }
503                        }
504                    }
505                    _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
506                };
507
508                let dest_len = u32::try_from(dest_len).unwrap();
509                let bitmask_len = u32::try_from(bitmask_len).unwrap();
510                for i in 0..dest_len {
511                    let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
512                    let mask = mask & 1u64.strict_shl(bit_i);
513                    let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
514                    let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
515                    let dest = this.project_index(&dest, i.into())?;
516
517                    let val = if mask != 0 { yes } else { no };
518                    this.write_immediate(*val, &dest)?;
519                }
520                for i in dest_len..bitmask_len {
521                    // If the mask is "padded", ensure that padding is all-zero.
522                    // This deliberately does not use `simd_bitmask_index`; these bits are outside
523                    // the bitmask. It does not matter in which order we check them.
524                    let mask = mask & 1u64.strict_shl(i);
525                    if mask != 0 {
526                        throw_ub_format!(
527                            "a SIMD bitmask less than 8 bits long must be filled with 0s for the remaining bits"
528                        );
529                    }
530                }
531            }
532            // Converts a "vector of bool" into a bitmask.
533            "bitmask" => {
534                let [op] = check_intrinsic_arg_count(args)?;
535                let (op, op_len) = this.project_to_simd(op)?;
536                let bitmask_len = op_len.next_multiple_of(8);
537                if bitmask_len > 64 {
538                    throw_unsup_format!(
539                        "simd_bitmask: vectors larger than 64 elements are currently not supported"
540                    );
541                }
542
543                let op_len = u32::try_from(op_len).unwrap();
544                let mut res = 0u64;
545                for i in 0..op_len {
546                    let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
547                    if simd_element_to_bool(op)? {
548                        let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian);
549                        res |= 1u64.strict_shl(bit_i);
550                    }
551                }
552                // Write the result, depending on the `dest` type.
553                // Returns either an unsigned integer or array of `u8`.
554                match dest.layout.ty.kind() {
555                    ty::Uint(_) => {
556                        // Any larger integer type is fine, it will be zero-extended.
557                        assert!(dest.layout.size.bits() >= bitmask_len);
558                        this.write_int(res, dest)?;
559                    }
560                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
561                        // The array must have exactly the right size.
562                        assert_eq!(dest.layout.size.bits(), bitmask_len);
563                        // We have to write the result byte-for-byte.
564                        let res_size = dest.layout.size.bytes_usize();
565                        let res_bytes;
566                        let res_bytes_slice = match this.data_layout().endian {
567                            Endian::Little => {
568                                res_bytes = res.to_le_bytes();
569                                &res_bytes[..res_size] // take the first N bytes
570                            }
571                            Endian::Big => {
572                                res_bytes = res.to_be_bytes();
573                                &res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
574                            }
575                        };
576                        this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
577                    }
578                    _ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
579                }
580            }
581            "cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
582                let [op] = check_intrinsic_arg_count(args)?;
583                let (op, op_len) = this.project_to_simd(op)?;
584                let (dest, dest_len) = this.project_to_simd(dest)?;
585
586                assert_eq!(dest_len, op_len);
587
588                let unsafe_cast = intrinsic_name == "cast";
589                let safe_cast = intrinsic_name == "as";
590                let ptr_cast = intrinsic_name == "cast_ptr";
591                let expose_cast = intrinsic_name == "expose_provenance";
592                let from_exposed_cast = intrinsic_name == "with_exposed_provenance";
593
594                for i in 0..dest_len {
595                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
596                    let dest = this.project_index(&dest, i)?;
597
598                    let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
599                        // Int-to-(int|float): always safe
600                        (ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_))
601                            if safe_cast || unsafe_cast =>
602                            this.int_to_int_or_float(&op, dest.layout)?,
603                        // Float-to-float: always safe
604                        (ty::Float(_), ty::Float(_)) if safe_cast || unsafe_cast =>
605                            this.float_to_float_or_int(&op, dest.layout)?,
606                        // Float-to-int in safe mode
607                        (ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast =>
608                            this.float_to_float_or_int(&op, dest.layout)?,
609                        // Float-to-int in unchecked mode
610                        (ty::Float(_), ty::Int(_) | ty::Uint(_)) if unsafe_cast => {
611                            this.float_to_int_checked(&op, dest.layout, Round::TowardZero)?
612                                .ok_or_else(|| {
613                                    err_ub_format!(
614                                        "`simd_cast` intrinsic called on {op} which cannot be represented in target type `{:?}`",
615                                        dest.layout.ty
616                                    )
617                                })?
618                        }
619                        // Ptr-to-ptr cast
620                        (ty::RawPtr(..), ty::RawPtr(..)) if ptr_cast =>
621                            this.ptr_to_ptr(&op, dest.layout)?,
622                        // Ptr/Int casts
623                        (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) if expose_cast =>
624                            this.pointer_expose_provenance_cast(&op, dest.layout)?,
625                        (ty::Int(_) | ty::Uint(_), ty::RawPtr(..)) if from_exposed_cast =>
626                            this.pointer_with_exposed_provenance_cast(&op, dest.layout)?,
627                        // Error otherwise
628                        _ =>
629                            throw_unsup_format!(
630                                "Unsupported SIMD cast from element type {from_ty} to {to_ty}",
631                                from_ty = op.layout.ty,
632                                to_ty = dest.layout.ty,
633                            ),
634                    };
635                    this.write_immediate(*val, &dest)?;
636                }
637            }
638            "shuffle_const_generic" => {
639                let [left, right] = check_intrinsic_arg_count(args)?;
640                let (left, left_len) = this.project_to_simd(left)?;
641                let (right, right_len) = this.project_to_simd(right)?;
642                let (dest, dest_len) = this.project_to_simd(dest)?;
643
644                let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch();
645                let index_len = index.len();
646
647                assert_eq!(left_len, right_len);
648                assert_eq!(index_len as u64, dest_len);
649
650                for i in 0..dest_len {
651                    let src_index: u64 =
652                        index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into();
653                    let dest = this.project_index(&dest, i)?;
654
655                    let val = if src_index < left_len {
656                        this.read_immediate(&this.project_index(&left, src_index)?)?
657                    } else if src_index < left_len.strict_add(right_len) {
658                        let right_idx = src_index.strict_sub(left_len);
659                        this.read_immediate(&this.project_index(&right, right_idx)?)?
660                    } else {
661                        throw_ub_format!(
662                            "`simd_shuffle_const_generic` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}"
663                        );
664                    };
665                    this.write_immediate(*val, &dest)?;
666                }
667            }
668            "shuffle" => {
669                let [left, right, index] = check_intrinsic_arg_count(args)?;
670                let (left, left_len) = this.project_to_simd(left)?;
671                let (right, right_len) = this.project_to_simd(right)?;
672                let (index, index_len) = this.project_to_simd(index)?;
673                let (dest, dest_len) = this.project_to_simd(dest)?;
674
675                assert_eq!(left_len, right_len);
676                assert_eq!(index_len, dest_len);
677
678                for i in 0..dest_len {
679                    let src_index: u64 = this
680                        .read_immediate(&this.project_index(&index, i)?)?
681                        .to_scalar()
682                        .to_u32()?
683                        .into();
684                    let dest = this.project_index(&dest, i)?;
685
686                    let val = if src_index < left_len {
687                        this.read_immediate(&this.project_index(&left, src_index)?)?
688                    } else if src_index < left_len.strict_add(right_len) {
689                        let right_idx = src_index.strict_sub(left_len);
690                        this.read_immediate(&this.project_index(&right, right_idx)?)?
691                    } else {
692                        throw_ub_format!(
693                            "`simd_shuffle` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}"
694                        );
695                    };
696                    this.write_immediate(*val, &dest)?;
697                }
698            }
699            "gather" => {
700                let [passthru, ptrs, mask] = check_intrinsic_arg_count(args)?;
701                let (passthru, passthru_len) = this.project_to_simd(passthru)?;
702                let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
703                let (mask, mask_len) = this.project_to_simd(mask)?;
704                let (dest, dest_len) = this.project_to_simd(dest)?;
705
706                assert_eq!(dest_len, passthru_len);
707                assert_eq!(dest_len, ptrs_len);
708                assert_eq!(dest_len, mask_len);
709
710                for i in 0..dest_len {
711                    let passthru = this.read_immediate(&this.project_index(&passthru, i)?)?;
712                    let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?;
713                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
714                    let dest = this.project_index(&dest, i)?;
715
716                    let val = if simd_element_to_bool(mask)? {
717                        let place = this.deref_pointer(&ptr)?;
718                        this.read_immediate(&place)?
719                    } else {
720                        passthru
721                    };
722                    this.write_immediate(*val, &dest)?;
723                }
724            }
725            "scatter" => {
726                let [value, ptrs, mask] = check_intrinsic_arg_count(args)?;
727                let (value, value_len) = this.project_to_simd(value)?;
728                let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
729                let (mask, mask_len) = this.project_to_simd(mask)?;
730
731                assert_eq!(ptrs_len, value_len);
732                assert_eq!(ptrs_len, mask_len);
733
734                for i in 0..ptrs_len {
735                    let value = this.read_immediate(&this.project_index(&value, i)?)?;
736                    let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?;
737                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
738
739                    if simd_element_to_bool(mask)? {
740                        let place = this.deref_pointer(&ptr)?;
741                        this.write_immediate(*value, &place)?;
742                    }
743                }
744            }
745            "masked_load" => {
746                let [mask, ptr, default] = check_intrinsic_arg_count(args)?;
747                let (mask, mask_len) = this.project_to_simd(mask)?;
748                let ptr = this.read_pointer(ptr)?;
749                let (default, default_len) = this.project_to_simd(default)?;
750                let (dest, dest_len) = this.project_to_simd(dest)?;
751
752                assert_eq!(dest_len, mask_len);
753                assert_eq!(dest_len, default_len);
754
755                for i in 0..dest_len {
756                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
757                    let default = this.read_immediate(&this.project_index(&default, i)?)?;
758                    let dest = this.project_index(&dest, i)?;
759
760                    let val = if simd_element_to_bool(mask)? {
761                        // Size * u64 is implemented as always checked
762                        let ptr = ptr.wrapping_offset(dest.layout.size * i, this);
763                        let place = this.ptr_to_mplace(ptr, dest.layout);
764                        this.read_immediate(&place)?
765                    } else {
766                        default
767                    };
768                    this.write_immediate(*val, &dest)?;
769                }
770            }
771            "masked_store" => {
772                let [mask, ptr, vals] = check_intrinsic_arg_count(args)?;
773                let (mask, mask_len) = this.project_to_simd(mask)?;
774                let ptr = this.read_pointer(ptr)?;
775                let (vals, vals_len) = this.project_to_simd(vals)?;
776
777                assert_eq!(mask_len, vals_len);
778
779                for i in 0..vals_len {
780                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
781                    let val = this.read_immediate(&this.project_index(&vals, i)?)?;
782
783                    if simd_element_to_bool(mask)? {
784                        // Size * u64 is implemented as always checked
785                        let ptr = ptr.wrapping_offset(val.layout.size * i, this);
786                        let place = this.ptr_to_mplace(ptr, val.layout);
787                        this.write_immediate(*val, &place)?
788                    };
789                }
790            }
791
792            _ => return interp_ok(EmulateItemResult::NotSupported),
793        }
794        interp_ok(EmulateItemResult::NeedsReturn)
795    }
796
797    fn fminmax_op(
798        &self,
799        op: MinMax,
800        left: &ImmTy<'tcx>,
801        right: &ImmTy<'tcx>,
802    ) -> InterpResult<'tcx, Scalar> {
803        let this = self.eval_context_ref();
804        assert_eq!(left.layout.ty, right.layout.ty);
805        let ty::Float(float_ty) = left.layout.ty.kind() else {
806            bug!("fmax operand is not a float")
807        };
808        let left = left.to_scalar();
809        let right = right.to_scalar();
810        interp_ok(match float_ty {
811            FloatTy::F16 => unimplemented!("f16_f128"),
812            FloatTy::F32 => {
813                let left = left.to_f32()?;
814                let right = right.to_f32()?;
815                let res = match op {
816                    MinMax::Min => left.min(right),
817                    MinMax::Max => left.max(right),
818                };
819                let res = this.adjust_nan(res, &[left, right]);
820                Scalar::from_f32(res)
821            }
822            FloatTy::F64 => {
823                let left = left.to_f64()?;
824                let right = right.to_f64()?;
825                let res = match op {
826                    MinMax::Min => left.min(right),
827                    MinMax::Max => left.max(right),
828                };
829                let res = this.adjust_nan(res, &[left, right]);
830                Scalar::from_f64(res)
831            }
832            FloatTy::F128 => unimplemented!("f16_f128"),
833        })
834    }
835}
836
837fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
838    assert!(idx < vec_len);
839    match endianness {
840        Endian::Little => idx,
841        #[expect(clippy::arithmetic_side_effects)] // idx < vec_len
842        Endian::Big => vec_len - 1 - idx, // reverse order of bits
843    }
844}