miri/shims/x86/
avx512.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::ty::Ty;
3use rustc_span::Symbol;
4use rustc_target::callconv::FnAbi;
5
6use super::{permute, pmaddbw, psadbw, pshufb};
7use crate::*;
8
9impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
10pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
11    fn emulate_x86_avx512_intrinsic(
12        &mut self,
13        link_name: Symbol,
14        abi: &FnAbi<'tcx, Ty<'tcx>>,
15        args: &[OpTy<'tcx>],
16        dest: &MPlaceTy<'tcx>,
17    ) -> InterpResult<'tcx, EmulateItemResult> {
18        let this = self.eval_context_mut();
19        // Prefix should have already been checked.
20        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.avx512.").unwrap();
21
22        match unprefixed_name {
23            // Used by the ternarylogic functions.
24            "pternlog.d.128" | "pternlog.d.256" | "pternlog.d.512" => {
25                this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
26                if matches!(unprefixed_name, "pternlog.d.128" | "pternlog.d.256") {
27                    this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
28                }
29
30                let [a, b, c, imm8] =
31                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
32
33                assert_eq!(dest.layout, a.layout);
34                assert_eq!(dest.layout, b.layout);
35                assert_eq!(dest.layout, c.layout);
36
37                // The signatures of these operations are:
38                //
39                // ```
40                // fn vpternlogd(a: i32x16, b: i32x16, c: i32x16, imm8: i32) -> i32x16;
41                // fn vpternlogd256(a: i32x8, b: i32x8, c: i32x8, imm8: i32) -> i32x8;
42                // fn vpternlogd128(a: i32x4, b: i32x4, c: i32x4, imm8: i32) -> i32x4;
43                // ```
44                //
45                // The element type is always a 32-bit integer, the width varies.
46
47                let (a, _a_len) = this.project_to_simd(a)?;
48                let (b, _b_len) = this.project_to_simd(b)?;
49                let (c, _c_len) = this.project_to_simd(c)?;
50                let (dest, dest_len) = this.project_to_simd(dest)?;
51
52                // Compute one lane with ternary table.
53                let tern = |xa: u32, xb: u32, xc: u32, imm: u32| -> u32 {
54                    let mut out = 0u32;
55                    // At each bit position, select bit from imm8 at index = (a << 2) | (b << 1) | c
56                    for bit in 0..32 {
57                        let ia = (xa >> bit) & 1;
58                        let ib = (xb >> bit) & 1;
59                        let ic = (xc >> bit) & 1;
60                        let idx = (ia << 2) | (ib << 1) | ic;
61                        let v = (imm >> idx) & 1;
62                        out |= v << bit;
63                    }
64                    out
65                };
66
67                let imm8 = this.read_scalar(imm8)?.to_u32()? & 0xFF;
68                for i in 0..dest_len {
69                    let a_lane = this.project_index(&a, i)?;
70                    let b_lane = this.project_index(&b, i)?;
71                    let c_lane = this.project_index(&c, i)?;
72                    let d_lane = this.project_index(&dest, i)?;
73
74                    let va = this.read_scalar(&a_lane)?.to_u32()?;
75                    let vb = this.read_scalar(&b_lane)?.to_u32()?;
76                    let vc = this.read_scalar(&c_lane)?.to_u32()?;
77
78                    let r = tern(va, vb, vc, imm8);
79                    this.write_scalar(Scalar::from_u32(r), &d_lane)?;
80                }
81            }
82            // Used to implement the _mm512_sad_epu8 function.
83            "psad.bw.512" => {
84                this.expect_target_feature_for_intrinsic(link_name, "avx512bw")?;
85
86                let [left, right] =
87                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
88
89                psadbw(this, left, right, dest)?
90            }
91            // Used to implement the _mm512_maddubs_epi16 function.
92            "pmaddubs.w.512" => {
93                let [left, right] =
94                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
95
96                pmaddbw(this, left, right, dest)?;
97            }
98            // Used to implement the _mm512_permutexvar_epi32 function.
99            "permvar.si.512" => {
100                let [left, right] =
101                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
102
103                permute(this, left, right, dest)?;
104            }
105            // Used to implement the _mm512_shuffle_epi8 intrinsic.
106            "pshuf.b.512" => {
107                let [left, right] =
108                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
109
110                pshufb(this, left, right, dest)?;
111            }
112
113            // Used to implement the _mm512_dpbusd_epi32 function.
114            "vpdpbusd.512" | "vpdpbusd.256" | "vpdpbusd.128" => {
115                this.expect_target_feature_for_intrinsic(link_name, "avx512vnni")?;
116                if matches!(unprefixed_name, "vpdpbusd.128" | "vpdpbusd.256") {
117                    this.expect_target_feature_for_intrinsic(link_name, "avx512vl")?;
118                }
119
120                let [src, a, b] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
121
122                vpdpbusd(this, src, a, b, dest)?;
123            }
124            _ => return interp_ok(EmulateItemResult::NotSupported),
125        }
126        interp_ok(EmulateItemResult::NeedsReturn)
127    }
128}
129
130/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in `a` with corresponding signed
131/// 8-bit integers in `b`, producing 4 intermediate signed 16-bit results. Sum these 4 results with
132/// the corresponding 32-bit integer in `src` (using wrapping arighmetic), and store the packed
133/// 32-bit results in `dst`.
134///
135/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_dpbusd_epi32>
136/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm256_dpbusd_epi32>
137/// <https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm512_dpbusd_epi32>
138fn vpdpbusd<'tcx>(
139    ecx: &mut crate::MiriInterpCx<'tcx>,
140    src: &OpTy<'tcx>,
141    a: &OpTy<'tcx>,
142    b: &OpTy<'tcx>,
143    dest: &MPlaceTy<'tcx>,
144) -> InterpResult<'tcx, ()> {
145    let (src, src_len) = ecx.project_to_simd(src)?;
146    let (a, a_len) = ecx.project_to_simd(a)?;
147    let (b, b_len) = ecx.project_to_simd(b)?;
148    let (dest, dest_len) = ecx.project_to_simd(dest)?;
149
150    // fn vpdpbusd(src: i32x16, a: i32x16, b: i32x16) -> i32x16;
151    // fn vpdpbusd256(src: i32x8, a: i32x8, b: i32x8) -> i32x8;
152    // fn vpdpbusd128(src: i32x4, a: i32x4, b: i32x4) -> i32x4;
153    assert_eq!(dest_len, src_len);
154    assert_eq!(dest_len, a_len);
155    assert_eq!(dest_len, b_len);
156
157    for i in 0..dest_len {
158        let src = ecx.read_scalar(&ecx.project_index(&src, i)?)?.to_i32()?;
159        let a = ecx.read_scalar(&ecx.project_index(&a, i)?)?.to_u32()?;
160        let b = ecx.read_scalar(&ecx.project_index(&b, i)?)?.to_u32()?;
161        let dest = ecx.project_index(&dest, i)?;
162
163        let zipped = a.to_le_bytes().into_iter().zip(b.to_le_bytes());
164        let intermediate_sum: i32 = zipped
165            .map(|(a, b)| i32::from(a).strict_mul(i32::from(b.cast_signed())))
166            .fold(0, |x, y| x.strict_add(y));
167
168        // Use `wrapping_add` because `src` is an arbitrary i32 and the addition can overflow.
169        let res = Scalar::from_i32(intermediate_sum.wrapping_add(src));
170        ecx.write_scalar(res, &dest)?;
171    }
172
173    interp_ok(())
174}