miri/intrinsics/
simd.rs

1use either::Either;
2use rand::Rng;
3use rustc_abi::{Endian, HasDataLayout};
4use rustc_apfloat::{Float, Round};
5use rustc_middle::ty::FloatTy;
6use rustc_middle::ty::layout::LayoutOf;
7use rustc_middle::{mir, ty};
8use rustc_span::{Symbol, sym};
9
10use crate::helpers::{ToHost, ToSoft, bool_to_simd_element, check_arg_count, simd_element_to_bool};
11use crate::*;
12
13#[derive(Copy, Clone)]
14pub(crate) enum MinMax {
15    Min,
16    Max,
17}
18
19impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
20pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
21    /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
22    /// Returns `Ok(true)` if the intrinsic was handled.
23    fn emulate_simd_intrinsic(
24        &mut self,
25        intrinsic_name: &str,
26        generic_args: ty::GenericArgsRef<'tcx>,
27        args: &[OpTy<'tcx>],
28        dest: &MPlaceTy<'tcx>,
29    ) -> InterpResult<'tcx, EmulateItemResult> {
30        let this = self.eval_context_mut();
31        match intrinsic_name {
32            #[rustfmt::skip]
33            | "neg"
34            | "fabs"
35            | "ceil"
36            | "floor"
37            | "round"
38            | "trunc"
39            | "fsqrt"
40            | "fsin"
41            | "fcos"
42            | "fexp"
43            | "fexp2"
44            | "flog"
45            | "flog2"
46            | "flog10"
47            | "ctlz"
48            | "ctpop"
49            | "cttz"
50            | "bswap"
51            | "bitreverse"
52            => {
53                let [op] = check_arg_count(args)?;
54                let (op, op_len) = this.project_to_simd(op)?;
55                let (dest, dest_len) = this.project_to_simd(dest)?;
56
57                assert_eq!(dest_len, op_len);
58
59                #[derive(Copy, Clone)]
60                enum Op<'a> {
61                    MirOp(mir::UnOp),
62                    Abs,
63                    Round(rustc_apfloat::Round),
64                    Numeric(Symbol),
65                    HostOp(&'a str),
66                }
67                let which = match intrinsic_name {
68                    "neg" => Op::MirOp(mir::UnOp::Neg),
69                    "fabs" => Op::Abs,
70                    "ceil" => Op::Round(rustc_apfloat::Round::TowardPositive),
71                    "floor" => Op::Round(rustc_apfloat::Round::TowardNegative),
72                    "round" => Op::Round(rustc_apfloat::Round::NearestTiesToAway),
73                    "trunc" => Op::Round(rustc_apfloat::Round::TowardZero),
74                    "ctlz" => Op::Numeric(sym::ctlz),
75                    "ctpop" => Op::Numeric(sym::ctpop),
76                    "cttz" => Op::Numeric(sym::cttz),
77                    "bswap" => Op::Numeric(sym::bswap),
78                    "bitreverse" => Op::Numeric(sym::bitreverse),
79                    _ => Op::HostOp(intrinsic_name),
80                };
81
82                for i in 0..dest_len {
83                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
84                    let dest = this.project_index(&dest, i)?;
85                    let val = match which {
86                        Op::MirOp(mir_op) => {
87                            // This already does NaN adjustments
88                            this.unary_op(mir_op, &op)?.to_scalar()
89                        }
90                        Op::Abs => {
91                            // Works for f32 and f64.
92                            let ty::Float(float_ty) = op.layout.ty.kind() else {
93                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
94                            };
95                            let op = op.to_scalar();
96                            // "Bitwise" operation, no NaN adjustments
97                            match float_ty {
98                                FloatTy::F16 => unimplemented!("f16_f128"),
99                                FloatTy::F32 => Scalar::from_f32(op.to_f32()?.abs()),
100                                FloatTy::F64 => Scalar::from_f64(op.to_f64()?.abs()),
101                                FloatTy::F128 => unimplemented!("f16_f128"),
102                            }
103                        }
104                        Op::HostOp(host_op) => {
105                            let ty::Float(float_ty) = op.layout.ty.kind() else {
106                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
107                            };
108                            // Using host floats except for sqrt (but it's fine, these operations do not
109                            // have guaranteed precision).
110                            match float_ty {
111                                FloatTy::F16 => unimplemented!("f16_f128"),
112                                FloatTy::F32 => {
113                                    let f = op.to_scalar().to_f32()?;
114                                    let res = match host_op {
115                                        "fsqrt" => math::sqrt(f),
116                                        "fsin" => f.to_host().sin().to_soft(),
117                                        "fcos" => f.to_host().cos().to_soft(),
118                                        "fexp" => f.to_host().exp().to_soft(),
119                                        "fexp2" => f.to_host().exp2().to_soft(),
120                                        "flog" => f.to_host().ln().to_soft(),
121                                        "flog2" => f.to_host().log2().to_soft(),
122                                        "flog10" => f.to_host().log10().to_soft(),
123                                        _ => bug!(),
124                                    };
125                                    let res = this.adjust_nan(res, &[f]);
126                                    Scalar::from(res)
127                                }
128                                FloatTy::F64 => {
129                                    let f = op.to_scalar().to_f64()?;
130                                    let res = match host_op {
131                                        "fsqrt" => math::sqrt(f),
132                                        "fsin" => f.to_host().sin().to_soft(),
133                                        "fcos" => f.to_host().cos().to_soft(),
134                                        "fexp" => f.to_host().exp().to_soft(),
135                                        "fexp2" => f.to_host().exp2().to_soft(),
136                                        "flog" => f.to_host().ln().to_soft(),
137                                        "flog2" => f.to_host().log2().to_soft(),
138                                        "flog10" => f.to_host().log10().to_soft(),
139                                        _ => bug!(),
140                                    };
141                                    let res = this.adjust_nan(res, &[f]);
142                                    Scalar::from(res)
143                                }
144                                FloatTy::F128 => unimplemented!("f16_f128"),
145                            }
146                        }
147                        Op::Round(rounding) => {
148                            let ty::Float(float_ty) = op.layout.ty.kind() else {
149                                span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
150                            };
151                            match float_ty {
152                                FloatTy::F16 => unimplemented!("f16_f128"),
153                                FloatTy::F32 => {
154                                    let f = op.to_scalar().to_f32()?;
155                                    let res = f.round_to_integral(rounding).value;
156                                    let res = this.adjust_nan(res, &[f]);
157                                    Scalar::from_f32(res)
158                                }
159                                FloatTy::F64 => {
160                                    let f = op.to_scalar().to_f64()?;
161                                    let res = f.round_to_integral(rounding).value;
162                                    let res = this.adjust_nan(res, &[f]);
163                                    Scalar::from_f64(res)
164                                }
165                                FloatTy::F128 => unimplemented!("f16_f128"),
166                            }
167                        }
168                        Op::Numeric(name) => {
169                            this.numeric_intrinsic(name, op.to_scalar(), op.layout, op.layout)?
170                        }
171                    };
172                    this.write_scalar(val, &dest)?;
173                }
174            }
175            #[rustfmt::skip]
176            | "add"
177            | "sub"
178            | "mul"
179            | "div"
180            | "rem"
181            | "shl"
182            | "shr"
183            | "and"
184            | "or"
185            | "xor"
186            | "eq"
187            | "ne"
188            | "lt"
189            | "le"
190            | "gt"
191            | "ge"
192            | "fmax"
193            | "fmin"
194            | "saturating_add"
195            | "saturating_sub"
196            | "arith_offset"
197            => {
198                use mir::BinOp;
199
200                let [left, right] = check_arg_count(args)?;
201                let (left, left_len) = this.project_to_simd(left)?;
202                let (right, right_len) = this.project_to_simd(right)?;
203                let (dest, dest_len) = this.project_to_simd(dest)?;
204
205                assert_eq!(dest_len, left_len);
206                assert_eq!(dest_len, right_len);
207
208                enum Op {
209                    MirOp(BinOp),
210                    SaturatingOp(BinOp),
211                    FMinMax(MinMax),
212                    WrappingOffset,
213                }
214                let which = match intrinsic_name {
215                    "add" => Op::MirOp(BinOp::Add),
216                    "sub" => Op::MirOp(BinOp::Sub),
217                    "mul" => Op::MirOp(BinOp::Mul),
218                    "div" => Op::MirOp(BinOp::Div),
219                    "rem" => Op::MirOp(BinOp::Rem),
220                    "shl" => Op::MirOp(BinOp::ShlUnchecked),
221                    "shr" => Op::MirOp(BinOp::ShrUnchecked),
222                    "and" => Op::MirOp(BinOp::BitAnd),
223                    "or" => Op::MirOp(BinOp::BitOr),
224                    "xor" => Op::MirOp(BinOp::BitXor),
225                    "eq" => Op::MirOp(BinOp::Eq),
226                    "ne" => Op::MirOp(BinOp::Ne),
227                    "lt" => Op::MirOp(BinOp::Lt),
228                    "le" => Op::MirOp(BinOp::Le),
229                    "gt" => Op::MirOp(BinOp::Gt),
230                    "ge" => Op::MirOp(BinOp::Ge),
231                    "fmax" => Op::FMinMax(MinMax::Max),
232                    "fmin" => Op::FMinMax(MinMax::Min),
233                    "saturating_add" => Op::SaturatingOp(BinOp::Add),
234                    "saturating_sub" => Op::SaturatingOp(BinOp::Sub),
235                    "arith_offset" => Op::WrappingOffset,
236                    _ => unreachable!(),
237                };
238
239                for i in 0..dest_len {
240                    let left = this.read_immediate(&this.project_index(&left, i)?)?;
241                    let right = this.read_immediate(&this.project_index(&right, i)?)?;
242                    let dest = this.project_index(&dest, i)?;
243                    let val = match which {
244                        Op::MirOp(mir_op) => {
245                            // This does NaN adjustments.
246                            let val = this.binary_op(mir_op, &left, &right).map_err_kind(|kind| {
247                                match kind {
248                                    InterpErrorKind::UndefinedBehavior(UndefinedBehaviorInfo::ShiftOverflow { shift_amount, .. }) => {
249                                        // This resets the interpreter backtrace, but it's not worth avoiding that.
250                                        let shift_amount = match shift_amount {
251                                            Either::Left(v) => v.to_string(),
252                                            Either::Right(v) => v.to_string(),
253                                        };
254                                        err_ub_format!("overflowing shift by {shift_amount} in `simd_{intrinsic_name}` in lane {i}")
255                                    }
256                                    kind => kind
257                                }
258                            })?;
259                            if matches!(mir_op, BinOp::Eq | BinOp::Ne | BinOp::Lt | BinOp::Le | BinOp::Gt | BinOp::Ge) {
260                                // Special handling for boolean-returning operations
261                                assert_eq!(val.layout.ty, this.tcx.types.bool);
262                                let val = val.to_scalar().to_bool().unwrap();
263                                bool_to_simd_element(val, dest.layout.size)
264                            } else {
265                                assert_ne!(val.layout.ty, this.tcx.types.bool);
266                                assert_eq!(val.layout.ty, dest.layout.ty);
267                                val.to_scalar()
268                            }
269                        }
270                        Op::SaturatingOp(mir_op) => {
271                            this.saturating_arith(mir_op, &left, &right)?
272                        }
273                        Op::WrappingOffset => {
274                            let ptr = left.to_scalar().to_pointer(this)?;
275                            let offset_count = right.to_scalar().to_target_isize(this)?;
276                            let pointee_ty = left.layout.ty.builtin_deref(true).unwrap();
277
278                            let pointee_size = i64::try_from(this.layout_of(pointee_ty)?.size.bytes()).unwrap();
279                            let offset_bytes = offset_count.wrapping_mul(pointee_size);
280                            let offset_ptr = ptr.wrapping_signed_offset(offset_bytes, this);
281                            Scalar::from_maybe_pointer(offset_ptr, this)
282                        }
283                        Op::FMinMax(op) => {
284                            this.fminmax_op(op, &left, &right)?
285                        }
286                    };
287                    this.write_scalar(val, &dest)?;
288                }
289            }
290            "fma" | "relaxed_fma" => {
291                let [a, b, c] = check_arg_count(args)?;
292                let (a, a_len) = this.project_to_simd(a)?;
293                let (b, b_len) = this.project_to_simd(b)?;
294                let (c, c_len) = this.project_to_simd(c)?;
295                let (dest, dest_len) = this.project_to_simd(dest)?;
296
297                assert_eq!(dest_len, a_len);
298                assert_eq!(dest_len, b_len);
299                assert_eq!(dest_len, c_len);
300
301                for i in 0..dest_len {
302                    let a = this.read_scalar(&this.project_index(&a, i)?)?;
303                    let b = this.read_scalar(&this.project_index(&b, i)?)?;
304                    let c = this.read_scalar(&this.project_index(&c, i)?)?;
305                    let dest = this.project_index(&dest, i)?;
306
307                    let fuse: bool = intrinsic_name == "fma" || this.machine.rng.get_mut().random();
308
309                    // Works for f32 and f64.
310                    // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
311                    let ty::Float(float_ty) = dest.layout.ty.kind() else {
312                        span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
313                    };
314                    let val = match float_ty {
315                        FloatTy::F16 => unimplemented!("f16_f128"),
316                        FloatTy::F32 => {
317                            let a = a.to_f32()?;
318                            let b = b.to_f32()?;
319                            let c = c.to_f32()?;
320                            let res = if fuse {
321                                a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
322                            } else {
323                                ((a * b).value + c).value
324                            };
325                            let res = this.adjust_nan(res, &[a, b, c]);
326                            Scalar::from(res)
327                        }
328                        FloatTy::F64 => {
329                            let a = a.to_f64()?;
330                            let b = b.to_f64()?;
331                            let c = c.to_f64()?;
332                            let res = if fuse {
333                                a.to_host().mul_add(b.to_host(), c.to_host()).to_soft()
334                            } else {
335                                ((a * b).value + c).value
336                            };
337                            let res = this.adjust_nan(res, &[a, b, c]);
338                            Scalar::from(res)
339                        }
340                        FloatTy::F128 => unimplemented!("f16_f128"),
341                    };
342                    this.write_scalar(val, &dest)?;
343                }
344            }
345            #[rustfmt::skip]
346            | "reduce_and"
347            | "reduce_or"
348            | "reduce_xor"
349            | "reduce_any"
350            | "reduce_all"
351            | "reduce_max"
352            | "reduce_min" => {
353                use mir::BinOp;
354
355                let [op] = check_arg_count(args)?;
356                let (op, op_len) = this.project_to_simd(op)?;
357
358                let imm_from_bool =
359                    |b| ImmTy::from_scalar(Scalar::from_bool(b), this.machine.layouts.bool);
360
361                enum Op {
362                    MirOp(BinOp),
363                    MirOpBool(BinOp),
364                    MinMax(MinMax),
365                }
366                let which = match intrinsic_name {
367                    "reduce_and" => Op::MirOp(BinOp::BitAnd),
368                    "reduce_or" => Op::MirOp(BinOp::BitOr),
369                    "reduce_xor" => Op::MirOp(BinOp::BitXor),
370                    "reduce_any" => Op::MirOpBool(BinOp::BitOr),
371                    "reduce_all" => Op::MirOpBool(BinOp::BitAnd),
372                    "reduce_max" => Op::MinMax(MinMax::Max),
373                    "reduce_min" => Op::MinMax(MinMax::Min),
374                    _ => unreachable!(),
375                };
376
377                // Initialize with first lane, then proceed with the rest.
378                let mut res = this.read_immediate(&this.project_index(&op, 0)?)?;
379                if matches!(which, Op::MirOpBool(_)) {
380                    // Convert to `bool` scalar.
381                    res = imm_from_bool(simd_element_to_bool(res)?);
382                }
383                for i in 1..op_len {
384                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
385                    res = match which {
386                        Op::MirOp(mir_op) => {
387                            this.binary_op(mir_op, &res, &op)?
388                        }
389                        Op::MirOpBool(mir_op) => {
390                            let op = imm_from_bool(simd_element_to_bool(op)?);
391                            this.binary_op(mir_op, &res, &op)?
392                        }
393                        Op::MinMax(mmop) => {
394                            if matches!(res.layout.ty.kind(), ty::Float(_)) {
395                                ImmTy::from_scalar(this.fminmax_op(mmop, &res, &op)?, res.layout)
396                            } else {
397                                // Just boring integers, so NaNs to worry about
398                                let mirop = match mmop {
399                                    MinMax::Min => BinOp::Le,
400                                    MinMax::Max => BinOp::Ge,
401                                };
402                                if this.binary_op(mirop, &res, &op)?.to_scalar().to_bool()? {
403                                    res
404                                } else {
405                                    op
406                                }
407                            }
408                        }
409                    };
410                }
411                this.write_immediate(*res, dest)?;
412            }
413            #[rustfmt::skip]
414            | "reduce_add_ordered"
415            | "reduce_mul_ordered" => {
416                use mir::BinOp;
417
418                let [op, init] = check_arg_count(args)?;
419                let (op, op_len) = this.project_to_simd(op)?;
420                let init = this.read_immediate(init)?;
421
422                let mir_op = match intrinsic_name {
423                    "reduce_add_ordered" => BinOp::Add,
424                    "reduce_mul_ordered" => BinOp::Mul,
425                    _ => unreachable!(),
426                };
427
428                let mut res = init;
429                for i in 0..op_len {
430                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
431                    res = this.binary_op(mir_op, &res, &op)?;
432                }
433                this.write_immediate(*res, dest)?;
434            }
435            "select" => {
436                let [mask, yes, no] = check_arg_count(args)?;
437                let (mask, mask_len) = this.project_to_simd(mask)?;
438                let (yes, yes_len) = this.project_to_simd(yes)?;
439                let (no, no_len) = this.project_to_simd(no)?;
440                let (dest, dest_len) = this.project_to_simd(dest)?;
441
442                assert_eq!(dest_len, mask_len);
443                assert_eq!(dest_len, yes_len);
444                assert_eq!(dest_len, no_len);
445
446                for i in 0..dest_len {
447                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
448                    let yes = this.read_immediate(&this.project_index(&yes, i)?)?;
449                    let no = this.read_immediate(&this.project_index(&no, i)?)?;
450                    let dest = this.project_index(&dest, i)?;
451
452                    let val = if simd_element_to_bool(mask)? { yes } else { no };
453                    this.write_immediate(*val, &dest)?;
454                }
455            }
456            // Variant of `select` that takes a bitmask rather than a "vector of bool".
457            "select_bitmask" => {
458                let [mask, yes, no] = check_arg_count(args)?;
459                let (yes, yes_len) = this.project_to_simd(yes)?;
460                let (no, no_len) = this.project_to_simd(no)?;
461                let (dest, dest_len) = this.project_to_simd(dest)?;
462                let bitmask_len = dest_len.next_multiple_of(8);
463                if bitmask_len > 64 {
464                    throw_unsup_format!(
465                        "simd_select_bitmask: vectors larger than 64 elements are currently not supported"
466                    );
467                }
468
469                assert_eq!(dest_len, yes_len);
470                assert_eq!(dest_len, no_len);
471
472                // Read the mask, either as an integer or as an array.
473                let mask: u64 = match mask.layout.ty.kind() {
474                    ty::Uint(_) => {
475                        // Any larger integer type is fine.
476                        assert!(mask.layout.size.bits() >= bitmask_len);
477                        this.read_scalar(mask)?.to_bits(mask.layout.size)?.try_into().unwrap()
478                    }
479                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
480                        // The array must have exactly the right size.
481                        assert_eq!(mask.layout.size.bits(), bitmask_len);
482                        // Read the raw bytes.
483                        let mask = mask.assert_mem_place(); // arrays cannot be immediate
484                        let mask_bytes =
485                            this.read_bytes_ptr_strip_provenance(mask.ptr(), mask.layout.size)?;
486                        // Turn them into a `u64` in the right way.
487                        let mask_size = mask.layout.size.bytes_usize();
488                        let mut mask_arr = [0u8; 8];
489                        match this.data_layout().endian {
490                            Endian::Little => {
491                                // Fill the first N bytes.
492                                mask_arr[..mask_size].copy_from_slice(mask_bytes);
493                                u64::from_le_bytes(mask_arr)
494                            }
495                            Endian::Big => {
496                                // Fill the last N bytes.
497                                let i = mask_arr.len().strict_sub(mask_size);
498                                mask_arr[i..].copy_from_slice(mask_bytes);
499                                u64::from_be_bytes(mask_arr)
500                            }
501                        }
502                    }
503                    _ => bug!("simd_select_bitmask: invalid mask type {}", mask.layout.ty),
504                };
505
506                let dest_len = u32::try_from(dest_len).unwrap();
507                let bitmask_len = u32::try_from(bitmask_len).unwrap();
508                for i in 0..dest_len {
509                    let bit_i = simd_bitmask_index(i, dest_len, this.data_layout().endian);
510                    let mask = mask & 1u64.strict_shl(bit_i);
511                    let yes = this.read_immediate(&this.project_index(&yes, i.into())?)?;
512                    let no = this.read_immediate(&this.project_index(&no, i.into())?)?;
513                    let dest = this.project_index(&dest, i.into())?;
514
515                    let val = if mask != 0 { yes } else { no };
516                    this.write_immediate(*val, &dest)?;
517                }
518                for i in dest_len..bitmask_len {
519                    // If the mask is "padded", ensure that padding is all-zero.
520                    // This deliberately does not use `simd_bitmask_index`; these bits are outside
521                    // the bitmask. It does not matter in which order we check them.
522                    let mask = mask & 1u64.strict_shl(i);
523                    if mask != 0 {
524                        throw_ub_format!(
525                            "a SIMD bitmask less than 8 bits long must be filled with 0s for the remaining bits"
526                        );
527                    }
528                }
529            }
530            // Converts a "vector of bool" into a bitmask.
531            "bitmask" => {
532                let [op] = check_arg_count(args)?;
533                let (op, op_len) = this.project_to_simd(op)?;
534                let bitmask_len = op_len.next_multiple_of(8);
535                if bitmask_len > 64 {
536                    throw_unsup_format!(
537                        "simd_bitmask: vectors larger than 64 elements are currently not supported"
538                    );
539                }
540
541                let op_len = u32::try_from(op_len).unwrap();
542                let mut res = 0u64;
543                for i in 0..op_len {
544                    let op = this.read_immediate(&this.project_index(&op, i.into())?)?;
545                    if simd_element_to_bool(op)? {
546                        let bit_i = simd_bitmask_index(i, op_len, this.data_layout().endian);
547                        res |= 1u64.strict_shl(bit_i);
548                    }
549                }
550                // Write the result, depending on the `dest` type.
551                // Returns either an unsigned integer or array of `u8`.
552                match dest.layout.ty.kind() {
553                    ty::Uint(_) => {
554                        // Any larger integer type is fine, it will be zero-extended.
555                        assert!(dest.layout.size.bits() >= bitmask_len);
556                        this.write_int(res, dest)?;
557                    }
558                    ty::Array(elem, _len) if elem == &this.tcx.types.u8 => {
559                        // The array must have exactly the right size.
560                        assert_eq!(dest.layout.size.bits(), bitmask_len);
561                        // We have to write the result byte-for-byte.
562                        let res_size = dest.layout.size.bytes_usize();
563                        let res_bytes;
564                        let res_bytes_slice = match this.data_layout().endian {
565                            Endian::Little => {
566                                res_bytes = res.to_le_bytes();
567                                &res_bytes[..res_size] // take the first N bytes
568                            }
569                            Endian::Big => {
570                                res_bytes = res.to_be_bytes();
571                                &res_bytes[res_bytes.len().strict_sub(res_size)..] // take the last N bytes
572                            }
573                        };
574                        this.write_bytes_ptr(dest.ptr(), res_bytes_slice.iter().cloned())?;
575                    }
576                    _ => bug!("simd_bitmask: invalid return type {}", dest.layout.ty),
577                }
578            }
579            "cast" | "as" | "cast_ptr" | "expose_provenance" | "with_exposed_provenance" => {
580                let [op] = check_arg_count(args)?;
581                let (op, op_len) = this.project_to_simd(op)?;
582                let (dest, dest_len) = this.project_to_simd(dest)?;
583
584                assert_eq!(dest_len, op_len);
585
586                let unsafe_cast = intrinsic_name == "cast";
587                let safe_cast = intrinsic_name == "as";
588                let ptr_cast = intrinsic_name == "cast_ptr";
589                let expose_cast = intrinsic_name == "expose_provenance";
590                let from_exposed_cast = intrinsic_name == "with_exposed_provenance";
591
592                for i in 0..dest_len {
593                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
594                    let dest = this.project_index(&dest, i)?;
595
596                    let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
597                        // Int-to-(int|float): always safe
598                        (ty::Int(_) | ty::Uint(_), ty::Int(_) | ty::Uint(_) | ty::Float(_))
599                            if safe_cast || unsafe_cast =>
600                            this.int_to_int_or_float(&op, dest.layout)?,
601                        // Float-to-float: always safe
602                        (ty::Float(_), ty::Float(_)) if safe_cast || unsafe_cast =>
603                            this.float_to_float_or_int(&op, dest.layout)?,
604                        // Float-to-int in safe mode
605                        (ty::Float(_), ty::Int(_) | ty::Uint(_)) if safe_cast =>
606                            this.float_to_float_or_int(&op, dest.layout)?,
607                        // Float-to-int in unchecked mode
608                        (ty::Float(_), ty::Int(_) | ty::Uint(_)) if unsafe_cast => {
609                            this.float_to_int_checked(&op, dest.layout, Round::TowardZero)?
610                                .ok_or_else(|| {
611                                    err_ub_format!(
612                                        "`simd_cast` intrinsic called on {op} which cannot be represented in target type `{:?}`",
613                                        dest.layout.ty
614                                    )
615                                })?
616                        }
617                        // Ptr-to-ptr cast
618                        (ty::RawPtr(..), ty::RawPtr(..)) if ptr_cast =>
619                            this.ptr_to_ptr(&op, dest.layout)?,
620                        // Ptr/Int casts
621                        (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) if expose_cast =>
622                            this.pointer_expose_provenance_cast(&op, dest.layout)?,
623                        (ty::Int(_) | ty::Uint(_), ty::RawPtr(..)) if from_exposed_cast =>
624                            this.pointer_with_exposed_provenance_cast(&op, dest.layout)?,
625                        // Error otherwise
626                        _ =>
627                            throw_unsup_format!(
628                                "Unsupported SIMD cast from element type {from_ty} to {to_ty}",
629                                from_ty = op.layout.ty,
630                                to_ty = dest.layout.ty,
631                            ),
632                    };
633                    this.write_immediate(*val, &dest)?;
634                }
635            }
636            "shuffle_generic" => {
637                let [left, right] = check_arg_count(args)?;
638                let (left, left_len) = this.project_to_simd(left)?;
639                let (right, right_len) = this.project_to_simd(right)?;
640                let (dest, dest_len) = this.project_to_simd(dest)?;
641
642                let index = generic_args[2].expect_const().to_value().valtree.unwrap_branch();
643                let index_len = index.len();
644
645                assert_eq!(left_len, right_len);
646                assert_eq!(index_len as u64, dest_len);
647
648                for i in 0..dest_len {
649                    let src_index: u64 =
650                        index[usize::try_from(i).unwrap()].unwrap_leaf().to_u32().into();
651                    let dest = this.project_index(&dest, i)?;
652
653                    let val = if src_index < left_len {
654                        this.read_immediate(&this.project_index(&left, src_index)?)?
655                    } else if src_index < left_len.strict_add(right_len) {
656                        let right_idx = src_index.strict_sub(left_len);
657                        this.read_immediate(&this.project_index(&right, right_idx)?)?
658                    } else {
659                        throw_ub_format!(
660                            "`simd_shuffle_generic` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}"
661                        );
662                    };
663                    this.write_immediate(*val, &dest)?;
664                }
665            }
666            "shuffle" => {
667                let [left, right, index] = check_arg_count(args)?;
668                let (left, left_len) = this.project_to_simd(left)?;
669                let (right, right_len) = this.project_to_simd(right)?;
670                let (index, index_len) = this.project_to_simd(index)?;
671                let (dest, dest_len) = this.project_to_simd(dest)?;
672
673                assert_eq!(left_len, right_len);
674                assert_eq!(index_len, dest_len);
675
676                for i in 0..dest_len {
677                    let src_index: u64 = this
678                        .read_immediate(&this.project_index(&index, i)?)?
679                        .to_scalar()
680                        .to_u32()?
681                        .into();
682                    let dest = this.project_index(&dest, i)?;
683
684                    let val = if src_index < left_len {
685                        this.read_immediate(&this.project_index(&left, src_index)?)?
686                    } else if src_index < left_len.strict_add(right_len) {
687                        let right_idx = src_index.strict_sub(left_len);
688                        this.read_immediate(&this.project_index(&right, right_idx)?)?
689                    } else {
690                        throw_ub_format!(
691                            "`simd_shuffle` index {src_index} is out-of-bounds for 2 vectors with length {dest_len}"
692                        );
693                    };
694                    this.write_immediate(*val, &dest)?;
695                }
696            }
697            "gather" => {
698                let [passthru, ptrs, mask] = check_arg_count(args)?;
699                let (passthru, passthru_len) = this.project_to_simd(passthru)?;
700                let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
701                let (mask, mask_len) = this.project_to_simd(mask)?;
702                let (dest, dest_len) = this.project_to_simd(dest)?;
703
704                assert_eq!(dest_len, passthru_len);
705                assert_eq!(dest_len, ptrs_len);
706                assert_eq!(dest_len, mask_len);
707
708                for i in 0..dest_len {
709                    let passthru = this.read_immediate(&this.project_index(&passthru, i)?)?;
710                    let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?;
711                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
712                    let dest = this.project_index(&dest, i)?;
713
714                    let val = if simd_element_to_bool(mask)? {
715                        let place = this.deref_pointer(&ptr)?;
716                        this.read_immediate(&place)?
717                    } else {
718                        passthru
719                    };
720                    this.write_immediate(*val, &dest)?;
721                }
722            }
723            "scatter" => {
724                let [value, ptrs, mask] = check_arg_count(args)?;
725                let (value, value_len) = this.project_to_simd(value)?;
726                let (ptrs, ptrs_len) = this.project_to_simd(ptrs)?;
727                let (mask, mask_len) = this.project_to_simd(mask)?;
728
729                assert_eq!(ptrs_len, value_len);
730                assert_eq!(ptrs_len, mask_len);
731
732                for i in 0..ptrs_len {
733                    let value = this.read_immediate(&this.project_index(&value, i)?)?;
734                    let ptr = this.read_immediate(&this.project_index(&ptrs, i)?)?;
735                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
736
737                    if simd_element_to_bool(mask)? {
738                        let place = this.deref_pointer(&ptr)?;
739                        this.write_immediate(*value, &place)?;
740                    }
741                }
742            }
743            "masked_load" => {
744                let [mask, ptr, default] = check_arg_count(args)?;
745                let (mask, mask_len) = this.project_to_simd(mask)?;
746                let ptr = this.read_pointer(ptr)?;
747                let (default, default_len) = this.project_to_simd(default)?;
748                let (dest, dest_len) = this.project_to_simd(dest)?;
749
750                assert_eq!(dest_len, mask_len);
751                assert_eq!(dest_len, default_len);
752
753                for i in 0..dest_len {
754                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
755                    let default = this.read_immediate(&this.project_index(&default, i)?)?;
756                    let dest = this.project_index(&dest, i)?;
757
758                    let val = if simd_element_to_bool(mask)? {
759                        // Size * u64 is implemented as always checked
760                        let ptr = ptr.wrapping_offset(dest.layout.size * i, this);
761                        let place = this.ptr_to_mplace(ptr, dest.layout);
762                        this.read_immediate(&place)?
763                    } else {
764                        default
765                    };
766                    this.write_immediate(*val, &dest)?;
767                }
768            }
769            "masked_store" => {
770                let [mask, ptr, vals] = check_arg_count(args)?;
771                let (mask, mask_len) = this.project_to_simd(mask)?;
772                let ptr = this.read_pointer(ptr)?;
773                let (vals, vals_len) = this.project_to_simd(vals)?;
774
775                assert_eq!(mask_len, vals_len);
776
777                for i in 0..vals_len {
778                    let mask = this.read_immediate(&this.project_index(&mask, i)?)?;
779                    let val = this.read_immediate(&this.project_index(&vals, i)?)?;
780
781                    if simd_element_to_bool(mask)? {
782                        // Size * u64 is implemented as always checked
783                        let ptr = ptr.wrapping_offset(val.layout.size * i, this);
784                        let place = this.ptr_to_mplace(ptr, val.layout);
785                        this.write_immediate(*val, &place)?
786                    };
787                }
788            }
789
790            _ => return interp_ok(EmulateItemResult::NotSupported),
791        }
792        interp_ok(EmulateItemResult::NeedsReturn)
793    }
794
795    fn fminmax_op(
796        &self,
797        op: MinMax,
798        left: &ImmTy<'tcx>,
799        right: &ImmTy<'tcx>,
800    ) -> InterpResult<'tcx, Scalar> {
801        let this = self.eval_context_ref();
802        assert_eq!(left.layout.ty, right.layout.ty);
803        let ty::Float(float_ty) = left.layout.ty.kind() else {
804            bug!("fmax operand is not a float")
805        };
806        let left = left.to_scalar();
807        let right = right.to_scalar();
808        interp_ok(match float_ty {
809            FloatTy::F16 => unimplemented!("f16_f128"),
810            FloatTy::F32 => {
811                let left = left.to_f32()?;
812                let right = right.to_f32()?;
813                let res = match op {
814                    MinMax::Min => left.min(right),
815                    MinMax::Max => left.max(right),
816                };
817                let res = this.adjust_nan(res, &[left, right]);
818                Scalar::from_f32(res)
819            }
820            FloatTy::F64 => {
821                let left = left.to_f64()?;
822                let right = right.to_f64()?;
823                let res = match op {
824                    MinMax::Min => left.min(right),
825                    MinMax::Max => left.max(right),
826                };
827                let res = this.adjust_nan(res, &[left, right]);
828                Scalar::from_f64(res)
829            }
830            FloatTy::F128 => unimplemented!("f16_f128"),
831        })
832    }
833}
834
835fn simd_bitmask_index(idx: u32, vec_len: u32, endianness: Endian) -> u32 {
836    assert!(idx < vec_len);
837    match endianness {
838        Endian::Little => idx,
839        #[expect(clippy::arithmetic_side_effects)] // idx < vec_len
840        Endian::Big => vec_len - 1 - idx, // reverse order of bits
841    }
842}