miri/intrinsics/
simd.rs

1use rand::Rng;
2use rustc_apfloat::Float;
3use rustc_middle::ty;
4use rustc_middle::ty::FloatTy;
5
6use super::check_intrinsic_arg_count;
7use crate::helpers::{ToHost, ToSoft};
8use crate::*;
9
10impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
11pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
12    /// Calls the simd intrinsic `intrinsic`; the `simd_` prefix has already been removed.
13    /// Returns `Ok(true)` if the intrinsic was handled.
14    fn emulate_simd_intrinsic(
15        &mut self,
16        intrinsic_name: &str,
17        args: &[OpTy<'tcx>],
18        dest: &MPlaceTy<'tcx>,
19    ) -> InterpResult<'tcx, EmulateItemResult> {
20        let this = self.eval_context_mut();
21        match intrinsic_name {
22            #[rustfmt::skip]
23            | "fsqrt"
24            | "fsin"
25            | "fcos"
26            | "fexp"
27            | "fexp2"
28            | "flog"
29            | "flog2"
30            | "flog10"
31            => {
32                let [op] = check_intrinsic_arg_count(args)?;
33                let (op, op_len) = this.project_to_simd(op)?;
34                let (dest, dest_len) = this.project_to_simd(dest)?;
35
36                assert_eq!(dest_len, op_len);
37
38                for i in 0..dest_len {
39                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
40                    let dest = this.project_index(&dest, i)?;
41                    let ty::Float(float_ty) = op.layout.ty.kind() else {
42                        span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
43                    };
44                    // Using host floats except for sqrt (but it's fine, these operations do not
45                    // have guaranteed precision).
46                    let val = match float_ty {
47                        FloatTy::F16 => unimplemented!("f16_f128"),
48                        FloatTy::F32 => {
49                            let f = op.to_scalar().to_f32()?;
50                            let res = match intrinsic_name {
51                                "fsqrt" => math::sqrt(f),
52                                "fsin" => f.to_host().sin().to_soft(),
53                                "fcos" => f.to_host().cos().to_soft(),
54                                "fexp" => f.to_host().exp().to_soft(),
55                                "fexp2" => f.to_host().exp2().to_soft(),
56                                "flog" => f.to_host().ln().to_soft(),
57                                "flog2" => f.to_host().log2().to_soft(),
58                                "flog10" => f.to_host().log10().to_soft(),
59                                _ => bug!(),
60                            };
61                            let res = this.adjust_nan(res, &[f]);
62                            Scalar::from(res)
63                        }
64                        FloatTy::F64 => {
65                            let f = op.to_scalar().to_f64()?;
66                            let res = match intrinsic_name {
67                                "fsqrt" => math::sqrt(f),
68                                "fsin" => f.to_host().sin().to_soft(),
69                                "fcos" => f.to_host().cos().to_soft(),
70                                "fexp" => f.to_host().exp().to_soft(),
71                                "fexp2" => f.to_host().exp2().to_soft(),
72                                "flog" => f.to_host().ln().to_soft(),
73                                "flog2" => f.to_host().log2().to_soft(),
74                                "flog10" => f.to_host().log10().to_soft(),
75                                _ => bug!(),
76                            };
77                            let res = this.adjust_nan(res, &[f]);
78                            Scalar::from(res)
79                        }
80                        FloatTy::F128 => unimplemented!("f16_f128"),
81                    };
82
83                    this.write_scalar(val, &dest)?;
84                }
85            }
86            "fma" | "relaxed_fma" => {
87                let [a, b, c] = check_intrinsic_arg_count(args)?;
88                let (a, a_len) = this.project_to_simd(a)?;
89                let (b, b_len) = this.project_to_simd(b)?;
90                let (c, c_len) = this.project_to_simd(c)?;
91                let (dest, dest_len) = this.project_to_simd(dest)?;
92
93                assert_eq!(dest_len, a_len);
94                assert_eq!(dest_len, b_len);
95                assert_eq!(dest_len, c_len);
96
97                for i in 0..dest_len {
98                    let a = this.read_scalar(&this.project_index(&a, i)?)?;
99                    let b = this.read_scalar(&this.project_index(&b, i)?)?;
100                    let c = this.read_scalar(&this.project_index(&c, i)?)?;
101                    let dest = this.project_index(&dest, i)?;
102
103                    let fuse: bool = intrinsic_name == "fma"
104                        || (this.machine.float_nondet && this.machine.rng.get_mut().random());
105
106                    // Works for f32 and f64.
107                    // FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
108                    let ty::Float(float_ty) = dest.layout.ty.kind() else {
109                        span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
110                    };
111                    let val = match float_ty {
112                        FloatTy::F16 => unimplemented!("f16_f128"),
113                        FloatTy::F32 => {
114                            let a = a.to_f32()?;
115                            let b = b.to_f32()?;
116                            let c = c.to_f32()?;
117                            let res = if fuse {
118                                a.mul_add(b, c).value
119                            } else {
120                                ((a * b).value + c).value
121                            };
122                            let res = this.adjust_nan(res, &[a, b, c]);
123                            Scalar::from(res)
124                        }
125                        FloatTy::F64 => {
126                            let a = a.to_f64()?;
127                            let b = b.to_f64()?;
128                            let c = c.to_f64()?;
129                            let res = if fuse {
130                                a.mul_add(b, c).value
131                            } else {
132                                ((a * b).value + c).value
133                            };
134                            let res = this.adjust_nan(res, &[a, b, c]);
135                            Scalar::from(res)
136                        }
137                        FloatTy::F128 => unimplemented!("f16_f128"),
138                    };
139                    this.write_scalar(val, &dest)?;
140                }
141            }
142            "expose_provenance" => {
143                let [op] = check_intrinsic_arg_count(args)?;
144                let (op, op_len) = this.project_to_simd(op)?;
145                let (dest, dest_len) = this.project_to_simd(dest)?;
146
147                assert_eq!(dest_len, op_len);
148
149                for i in 0..dest_len {
150                    let op = this.read_immediate(&this.project_index(&op, i)?)?;
151                    let dest = this.project_index(&dest, i)?;
152
153                    let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
154                        // Ptr/Int casts
155                        (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) =>
156                            this.pointer_expose_provenance_cast(&op, dest.layout)?,
157                        // Error otherwise
158                        _ =>
159                            throw_unsup_format!(
160                                "Unsupported `simd_expose_provenance` from element type {from_ty} to {to_ty}",
161                                from_ty = op.layout.ty,
162                                to_ty = dest.layout.ty,
163                            ),
164                    };
165                    this.write_immediate(*val, &dest)?;
166                }
167            }
168
169            _ => return interp_ok(EmulateItemResult::NotSupported),
170        }
171        interp_ok(EmulateItemResult::NeedsReturn)
172    }
173}