1use rand::Rng;
2use rustc_apfloat::Float;
3use rustc_middle::ty;
4use rustc_middle::ty::FloatTy;
5
6use super::check_intrinsic_arg_count;
7use crate::helpers::{ToHost, ToSoft};
8use crate::*;
9
10impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
11pub trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
12 fn emulate_simd_intrinsic(
15 &mut self,
16 intrinsic_name: &str,
17 args: &[OpTy<'tcx>],
18 dest: &MPlaceTy<'tcx>,
19 ) -> InterpResult<'tcx, EmulateItemResult> {
20 let this = self.eval_context_mut();
21 match intrinsic_name {
22 #[rustfmt::skip]
23 | "fsqrt"
24 | "fsin"
25 | "fcos"
26 | "fexp"
27 | "fexp2"
28 | "flog"
29 | "flog2"
30 | "flog10"
31 => {
32 let [op] = check_intrinsic_arg_count(args)?;
33 let (op, op_len) = this.project_to_simd(op)?;
34 let (dest, dest_len) = this.project_to_simd(dest)?;
35
36 assert_eq!(dest_len, op_len);
37
38 for i in 0..dest_len {
39 let op = this.read_immediate(&this.project_index(&op, i)?)?;
40 let dest = this.project_index(&dest, i)?;
41 let ty::Float(float_ty) = op.layout.ty.kind() else {
42 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
43 };
44 let val = match float_ty {
47 FloatTy::F16 => unimplemented!("f16_f128"),
48 FloatTy::F32 => {
49 let f = op.to_scalar().to_f32()?;
50 let res = match intrinsic_name {
51 "fsqrt" => math::sqrt(f),
52 "fsin" => f.to_host().sin().to_soft(),
53 "fcos" => f.to_host().cos().to_soft(),
54 "fexp" => f.to_host().exp().to_soft(),
55 "fexp2" => f.to_host().exp2().to_soft(),
56 "flog" => f.to_host().ln().to_soft(),
57 "flog2" => f.to_host().log2().to_soft(),
58 "flog10" => f.to_host().log10().to_soft(),
59 _ => bug!(),
60 };
61 let res = this.adjust_nan(res, &[f]);
62 Scalar::from(res)
63 }
64 FloatTy::F64 => {
65 let f = op.to_scalar().to_f64()?;
66 let res = match intrinsic_name {
67 "fsqrt" => math::sqrt(f),
68 "fsin" => f.to_host().sin().to_soft(),
69 "fcos" => f.to_host().cos().to_soft(),
70 "fexp" => f.to_host().exp().to_soft(),
71 "fexp2" => f.to_host().exp2().to_soft(),
72 "flog" => f.to_host().ln().to_soft(),
73 "flog2" => f.to_host().log2().to_soft(),
74 "flog10" => f.to_host().log10().to_soft(),
75 _ => bug!(),
76 };
77 let res = this.adjust_nan(res, &[f]);
78 Scalar::from(res)
79 }
80 FloatTy::F128 => unimplemented!("f16_f128"),
81 };
82
83 this.write_scalar(val, &dest)?;
84 }
85 }
86 "fma" | "relaxed_fma" => {
87 let [a, b, c] = check_intrinsic_arg_count(args)?;
88 let (a, a_len) = this.project_to_simd(a)?;
89 let (b, b_len) = this.project_to_simd(b)?;
90 let (c, c_len) = this.project_to_simd(c)?;
91 let (dest, dest_len) = this.project_to_simd(dest)?;
92
93 assert_eq!(dest_len, a_len);
94 assert_eq!(dest_len, b_len);
95 assert_eq!(dest_len, c_len);
96
97 for i in 0..dest_len {
98 let a = this.read_scalar(&this.project_index(&a, i)?)?;
99 let b = this.read_scalar(&this.project_index(&b, i)?)?;
100 let c = this.read_scalar(&this.project_index(&c, i)?)?;
101 let dest = this.project_index(&dest, i)?;
102
103 let fuse: bool = intrinsic_name == "fma"
104 || (this.machine.float_nondet && this.machine.rng.get_mut().random());
105
106 let ty::Float(float_ty) = dest.layout.ty.kind() else {
109 span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
110 };
111 let val = match float_ty {
112 FloatTy::F16 => unimplemented!("f16_f128"),
113 FloatTy::F32 => {
114 let a = a.to_f32()?;
115 let b = b.to_f32()?;
116 let c = c.to_f32()?;
117 let res = if fuse {
118 a.mul_add(b, c).value
119 } else {
120 ((a * b).value + c).value
121 };
122 let res = this.adjust_nan(res, &[a, b, c]);
123 Scalar::from(res)
124 }
125 FloatTy::F64 => {
126 let a = a.to_f64()?;
127 let b = b.to_f64()?;
128 let c = c.to_f64()?;
129 let res = if fuse {
130 a.mul_add(b, c).value
131 } else {
132 ((a * b).value + c).value
133 };
134 let res = this.adjust_nan(res, &[a, b, c]);
135 Scalar::from(res)
136 }
137 FloatTy::F128 => unimplemented!("f16_f128"),
138 };
139 this.write_scalar(val, &dest)?;
140 }
141 }
142 "expose_provenance" => {
143 let [op] = check_intrinsic_arg_count(args)?;
144 let (op, op_len) = this.project_to_simd(op)?;
145 let (dest, dest_len) = this.project_to_simd(dest)?;
146
147 assert_eq!(dest_len, op_len);
148
149 for i in 0..dest_len {
150 let op = this.read_immediate(&this.project_index(&op, i)?)?;
151 let dest = this.project_index(&dest, i)?;
152
153 let val = match (op.layout.ty.kind(), dest.layout.ty.kind()) {
154 (ty::RawPtr(..), ty::Int(_) | ty::Uint(_)) =>
156 this.pointer_expose_provenance_cast(&op, dest.layout)?,
157 _ =>
159 throw_unsup_format!(
160 "Unsupported `simd_expose_provenance` from element type {from_ty} to {to_ty}",
161 from_ty = op.layout.ty,
162 to_ty = dest.layout.ty,
163 ),
164 };
165 this.write_immediate(*val, &dest)?;
166 }
167 }
168
169 _ => return interp_ok(EmulateItemResult::NotSupported),
170 }
171 interp_ok(EmulateItemResult::NeedsReturn)
172 }
173}