miri/shims/x86/
avx512.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use super::{packssdw, packsswb, packusdw, packuswb, permute, pmaddbw, pmaddwd, psadbw, pshufb};
7use crate::*;
8
9impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
10pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
11    fn emulate_x86_avx512_intrinsic(
12        &mut self,
13        link_name: Symbol,
14        abi: &FnAbi<'tcx, Ty<'tcx>>,
15        args: &[OpTy<'tcx>],
16        dest: &MPlaceTy<'tcx>,
17    ) -> InterpResult<'tcx, EmulateItemResult> {
18        let this = self.eval_context_mut();
19        // Prefix should have already been checked.
20        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx512.").unwrap();
21
22        match unprefixed_name {
23            // Used by the ternarylogic functions.
24            "pternlog.d.128" | "pternlog.d.256" | "pternlog.d.512" => {
25                this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
26                if matches!(unprefixed_name, "pternlog.d.128" | "pternlog.d.256") {
27                    this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
28                }
29
30                let [a, b, c, imm8] =
31                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
32
33                assert_eq!(dest.layout, a.layout);
34                assert_eq!(dest.layout, b.layout);
35                assert_eq!(dest.layout, c.layout);
36
37                // The signatures of these operations are:
38                //
39                // ```
40                // fn vpternlogd(a: i32x16, b: i32x16, c: i32x16, imm8: i32) -> i32x16;
41                // fn vpternlogd256(a: i32x8, b: i32x8, c: i32x8, imm8: i32) -> i32x8;
42                // fn vpternlogd128(a: i32x4, b: i32x4, c: i32x4, imm8: i32) -> i32x4;
43                // ```
44                //
45                // The element type is always a 32-bit integer, the width varies.
46
47                let (a, _a_len) = this.project_to_simd(a)?;
48                let (b, _b_len) = this.project_to_simd(b)?;
49                let (c, _c_len) = this.project_to_simd(c)?;
50                let (dest, dest_len) = this.project_to_simd(dest)?;
51
52                // Compute one lane with ternary table.
53                let tern = |xa: u32, xb: u32, xc: u32, imm: u32| -> u32 {
54                    let mut out = 0u32;
55                    // At each bit position, select bit from imm8 at index = (a << 2) | (b << 1) | c
56                    for bit in 0..32 {
57                        let ia = (xa >> bit) & 1;
58                        let ib = (xb >> bit) & 1;
59                        let ic = (xc >> bit) & 1;
60                        let idx = (ia << 2) | (ib << 1) | ic;
61                        let v = (imm >> idx) & 1;
62                        out |= v << bit;
63                    }
64                    out
65                };
66
67                let imm8 = this.read_scalar(imm8)?.to_u32()? & 0xFF;
68                for i in 0..dest_len {
69                    let a_lane = this.project_index(&a, i)?;
70                    let b_lane = this.project_index(&b, i)?;
71                    let c_lane = this.project_index(&c, i)?;
72                    let d_lane = this.project_index(&dest, i)?;
73
74                    let va = this.read_scalar(&a_lane)?.to_u32()?;
75                    let vb = this.read_scalar(&b_lane)?.to_u32()?;
76                    let vc = this.read_scalar(&c_lane)?.to_u32()?;
77
78                    let r = tern(va, vb, vc, imm8);
79                    this.write_scalar(Scalar::from_u32(r), &d_lane)?;
80                }
81            }
82            // Used to implement the _mm512_sad_epu8 function.
83            "psad.bw.512" => {
84                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
85
86                let [left, right] =
87                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
88
89                psadbw(this, left, right, dest)?
90            }
91            // Used to implement the _mm512_madd_epi16 function.
92            "pmaddw.d.512" => {
93                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
94
95                let [left, right] =
96                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
97
98                pmaddwd(this, left, right, dest)?;
99            }
100            // Used to implement the _mm512_maddubs_epi16 function.
101            "pmaddubs.w.512" => {
102                let [left, right] =
103                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
104
105                pmaddbw(this, left, right, dest)?;
106            }
107            // Used to implement the _mm512_permutexvar_epi32 function.
108            "permvar.si.512" => {
109                let [left, right] =
110                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
111
112                permute(this, left, right, dest)?;
113            }
114            // Used to implement the _mm512_shuffle_epi8 intrinsic.
115            "pshuf.b.512" => {
116                let [left, right] =
117                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
118
119                pshufb(this, left, right, dest)?;
120            }
121
122            // Used to implement the _mm512_dpbusd_epi32 function.
123            "vpdpbusd.512" | "vpdpbusd.256" | "vpdpbusd.128" => {
124                this.expect_target_feature_for_intrinsic(link_name, "avx512vnni")?;
125                if matches!(unprefixed_name, "vpdpbusd.128" | "vpdpbusd.256") {
126                    this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
127                }
128
129                let [src, a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
130
131                vpdpbusd(this, src, a, b, dest)?;
132            }
133            // Used to implement the _mm512_packs_epi16 function
134            "packsswb.512" => {
135                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
136
137                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
138
139                packsswb(this, a, b, dest)?;
140            }
141            // Used to implement the _mm512_packus_epi16 function
142            "packuswb.512" => {
143                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
144
145                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
146
147                packuswb(this, a, b, dest)?;
148            }
149            // Used to implement the _mm512_packs_epi32 function
150            "packssdw.512" => {
151                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
152
153                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
154
155                packssdw(this, a, b, dest)?;
156            }
157            // Used to implement the _mm512_packus_epi32 function
158            "packusdw.512" => {
159                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
160
161                let [a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
162
163                packusdw(this, a, b, dest)?;
164            }
165            _ => return interp_ok(EmulateItemResult::NotSupported),
166        }
167        interp_ok(EmulateItemResult::NeedsReturn)
168    }
169}
170
171/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in `a` with corresponding signed
172/// 8-bit integers in `b`, producing 4 intermediate signed 16-bit results. Sum these 4 results with
173/// the corresponding 32-bit integer in `src` (using wrapping arighmetic), and store the packed
174/// 32-bit results in `dst`.
175///
176/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_dpbusd_epi32>
177/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_dpbusd_epi32>
178/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_dpbusd_epi32>
179fn vpdpbusd<'tcx>(
180    ecx: &mut crate::MiriInterpCx<'tcx>,
181    src: &OpTy<'tcx>,
182    a: &OpTy<'tcx>,
183    b: &OpTy<'tcx>,
184    dest: &MPlaceTy<'tcx>,
185) -> InterpResult<'tcx, ()> {
186    let (src, src_len) = ecx.project_to_simd(src)?;
187    let (a, a_len) = ecx.project_to_simd(a)?;
188    let (b, b_len) = ecx.project_to_simd(b)?;
189    let (dest, dest_len) = ecx.project_to_simd(dest)?;
190
191    // fn vpdpbusd(src: i32x16, a: i32x16, b: i32x16) -> i32x16;
192    // fn vpdpbusd256(src: i32x8, a: i32x8, b: i32x8) -> i32x8;
193    // fn vpdpbusd128(src: i32x4, a: i32x4, b: i32x4) -> i32x4;
194    assert_eq!(dest_len, src_len);
195    assert_eq!(dest_len, a_len);
196    assert_eq!(dest_len, b_len);
197
198    for i in 0..dest_len {
199        let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?;
200        let a = ecx.read_scalar(&ecx.project_index(&a, i)?)?.to_u32()?;
201        let b = ecx.read_scalar(&ecx.project_index(&b, i)?)?.to_u32()?;
202        let dest = ecx.project_index(&dest, i)?;
203
204        let zipped = a.to_le_bytes().into_iter().zip(b.to_le_bytes());
205        let intermediate_sum: i32 = zipped
206            .map(|(a, b)| i32::from(a).strict_mul(i32::from(b.cast_signed())))
207            .fold(0, |x, y| x.strict_add(y));
208
209        // Use `wrapping_add` because `src` is an arbitrary i32 and the addition can overflow.
210        let res = Scalar::from_i32(intermediate_sum.wrapping_add(src));
211        ecx.write_scalar(res, &dest)?;
212    }
213
214    interp_ok(())
215}