1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use super::{
7 packssdw, packsswb, packusdw, packuswb, permute, permute2, pmaddbw, pmaddwd, psadbw, pshufb,
8};
9use crate::*;
10
11impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
12pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
13 fn emulate_x86_avx512_intrinsic(
14 &mut self,
15 link_name: Symbol,
16 abi: &FnAbi<'tcx, Ty<'tcx>>,
17 args: &[OpTy<'tcx>],
18 dest: &MPlaceTy<'tcx>,
19 ) -> InterpResult<'tcx, EmulateItemResult> {
20 let this = self.eval_context_mut();
21 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx512.").unwrap();
23
24 match unprefixed_name {
25 "pternlog.d.128" | "pternlog.d.256" | "pternlog.d.512" => {
27 this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
28 if matches!(unprefixed_name, "pternlog.d.128" | "pternlog.d.256") {
29 this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
30 }
31
32 let [a, b, c, imm8] =
33 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
34
35 assert_eq!(dest.layout, a.layout);
36 assert_eq!(dest.layout, b.layout);
37 assert_eq!(dest.layout, c.layout);
38
39 let (a, _a_len) = this.project_to_simd(a)?;
50 let (b, _b_len) = this.project_to_simd(b)?;
51 let (c, _c_len) = this.project_to_simd(c)?;
52 let (dest, dest_len) = this.project_to_simd(dest)?;
53
54 let tern = |xa: u32, xb: u32, xc: u32, imm: u32| -> u32 {
56 let mut out = 0u32;
57 for bit in 0..32 {
59 let ia = (xa >> bit) & 1;
60 let ib = (xb >> bit) & 1;
61 let ic = (xc >> bit) & 1;
62 let idx = (ia << 2) | (ib << 1) | ic;
63 let v = (imm >> idx) & 1;
64 out |= v << bit;
65 }
66 out
67 };
68
69 let imm8 = this.read_scalar(imm8)?.to_u32()? & 0xFF;
70 for i in 0..dest_len {
71 let a_lane = this.project_index(&a, i)?;
72 let b_lane = this.project_index(&b, i)?;
73 let c_lane = this.project_index(&c, i)?;
74 let d_lane = this.project_index(&dest, i)?;
75
76 let va = this.read_scalar(&a_lane)?.to_u32()?;
77 let vb = this.read_scalar(&b_lane)?.to_u32()?;
78 let vc = this.read_scalar(&c_lane)?.to_u32()?;
79
80 let r = tern(va, vb, vc, imm8);
81 this.write_scalar(Scalar::from_u32(r), &d_lane)?;
82 }
83 }
84 "psad.bw.512" => {
86 this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
87
88 let [left, right] =
89 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
90
91 psadbw(this, left, right, dest)?
92 }
93 "pmaddw.d.512" => {
95 this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
96
97 let [left, right] =
98 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
99
100 pmaddwd(this, left, right, dest)?;
101 }
102 "pmaddubs.w.512" => {
104 let [left, right] =
105 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
106
107 pmaddbw(this, left, right, dest)?;
108 }
109 "permvar.si.512" | "permvar.di.512" => {
111 let [left, right] =
112 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
113
114 permute(this, left, right, dest)?;
115 }
116 "vpermi2var.q.512" => {
118 let [left, indices, right] =
119 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
120
121 permute2(this, left, indices, right, dest)?;
122 }
123 "vpermi2var.qi.512" => {
125 this.expect_target_feature_for_intrinsic(link_name, "avx512vbmi")?;
126
127 let [left, indices, right] =
128 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
129
130 permute2(this, left, indices, right, dest)?;
131 }
132 "pshuf.b.512" => {
134 let [left, right] =
135 this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
136
137 pshufb(this, left, right, dest)?;
138 }
139
140 "vpdpbusd.512" | "vpdpbusd.256" | "vpdpbusd.128" => {
142 this.expect_target_feature_for_intrinsic(link_name, "avx512vnni")?;
143 if matches!(unprefixed_name, "vpdpbusd.128" | "vpdpbusd.256") {
144 this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
145 }
146
147 let [src, a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
148
149 vpdpbusd(this, src, a, b, dest)?;
150 }
151 "packsswb.512" => {
153 this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
154
155 let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
156
157 packsswb(this, a, b, dest)?;
158 }
159 "packuswb.512" => {
161 this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
162
163 let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
164
165 packuswb(this, a, b, dest)?;
166 }
167 "packssdw.512" => {
169 this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
170
171 let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
172
173 packssdw(this, a, b, dest)?;
174 }
175 "packusdw.512" => {
177 this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
178
179 let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
180
181 packusdw(this, a, b, dest)?;
182 }
183 _ => return interp_ok(EmulateItemResult::NotSupported),
184 }
185 interp_ok(EmulateItemResult::NeedsReturn)
186 }
187}
188
189fn vpdpbusd<'tcx>(
198 ecx: &mut crate::MiriInterpCx<'tcx>,
199 src: &OpTy<'tcx>,
200 a: &OpTy<'tcx>,
201 b: &OpTy<'tcx>,
202 dest: &MPlaceTy<'tcx>,
203) -> InterpResult<'tcx, ()> {
204 let (src, src_len) = ecx.project_to_simd(src)?;
205 let (a, a_len) = ecx.project_to_simd(a)?;
206 let (b, b_len) = ecx.project_to_simd(b)?;
207 let (dest, dest_len) = ecx.project_to_simd(dest)?;
208
209 assert_eq!(src_len, dest_len);
213 assert_eq!(a_len, dest_len.strict_mul(4));
214 assert_eq!(b_len, a_len);
215
216 for i in 0..dest_len {
217 let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?;
218 let dest = ecx.project_index(&dest, i)?;
219
220 let mut intermediate_sum: i32 = 0;
221 for j in 0..4 {
222 let idx = i.strict_mul(4).strict_add(j);
223 let a = ecx.read_scalar(&ecx.project_index(&a, idx)?)?.to_u8()?;
224 let b = ecx.read_scalar(&ecx.project_index(&b, idx)?)?.to_i8()?;
225
226 let product = i32::from(a).strict_mul(i32::from(b));
227 intermediate_sum = intermediate_sum.strict_add(product);
228 }
229
230 let res = Scalar::from_i32(intermediate_sum.wrapping_add(src));
232 ecx.write_scalar(res, &dest)?;
233 }
234
235 interp_ok(())
236}