1use rustc_middle::ty::Ty;
2use rustc_span::Symbol;
3use rustc_target::callconv::{Conv, FnAbi};
4
5use crate::*;
6
7impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
8pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
9 fn emulate_x86_gfni_intrinsic(
10 &mut self,
11 link_name: Symbol,
12 abi: &FnAbi<'tcx, Ty<'tcx>>,
13 args: &[OpTy<'tcx>],
14 dest: &MPlaceTy<'tcx>,
15 ) -> InterpResult<'tcx, EmulateItemResult> {
16 let this = self.eval_context_mut();
17
18 let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.").unwrap();
20
21 this.expect_target_feature_for_intrinsic(link_name, "gfni")?;
22 if unprefixed_name.ends_with(".256") {
23 this.expect_target_feature_for_intrinsic(link_name, "avx")?;
24 } else if unprefixed_name.ends_with(".512") {
25 this.expect_target_feature_for_intrinsic(link_name, "avx512f")?;
26 }
27
28 match unprefixed_name {
29 "vgf2p8affineqb.128" | "vgf2p8affineqb.256" | "vgf2p8affineqb.512" => {
33 let [left, right, imm8] = this.check_shim(abi, Conv::C, link_name, args)?;
34 affine_transform(this, left, right, imm8, dest, false)?;
35 }
36 "vgf2p8affineinvqb.128" | "vgf2p8affineinvqb.256" | "vgf2p8affineinvqb.512" => {
40 let [left, right, imm8] = this.check_shim(abi, Conv::C, link_name, args)?;
41 affine_transform(this, left, right, imm8, dest, true)?;
42 }
43 "vgf2p8mulb.128" | "vgf2p8mulb.256" | "vgf2p8mulb.512" => {
49 let [left, right] = this.check_shim(abi, Conv::C, link_name, args)?;
50 let (left, left_len) = this.project_to_simd(left)?;
51 let (right, right_len) = this.project_to_simd(right)?;
52 let (dest, dest_len) = this.project_to_simd(dest)?;
53
54 assert_eq!(left_len, right_len);
55 assert_eq!(dest_len, right_len);
56
57 for i in 0..dest_len {
58 let left = this.read_scalar(&this.project_index(&left, i)?)?.to_u8()?;
59 let right = this.read_scalar(&this.project_index(&right, i)?)?.to_u8()?;
60 let dest = this.project_index(&dest, i)?;
61 this.write_scalar(Scalar::from_u8(gf2p8_mul(left, right)), &dest)?;
62 }
63 }
64 _ => return interp_ok(EmulateItemResult::NotSupported),
65 }
66 interp_ok(EmulateItemResult::NeedsReturn)
67 }
68}
69
70fn affine_transform<'tcx>(
75 ecx: &mut MiriInterpCx<'tcx>,
76 left: &OpTy<'tcx>,
77 right: &OpTy<'tcx>,
78 imm8: &OpTy<'tcx>,
79 dest: &MPlaceTy<'tcx>,
80 inverse: bool,
81) -> InterpResult<'tcx, ()> {
82 let (left, left_len) = ecx.project_to_simd(left)?;
83 let (right, right_len) = ecx.project_to_simd(right)?;
84 let (dest, dest_len) = ecx.project_to_simd(dest)?;
85
86 assert_eq!(dest_len, right_len);
87 assert_eq!(dest_len, left_len);
88
89 let imm8 = ecx.read_scalar(imm8)?.to_u8()?;
90
91 for i in (0..dest_len).step_by(8) {
94 let mut matrix = [0u8; 8];
96 for j in 0..8 {
97 matrix[usize::try_from(j).unwrap()] =
98 ecx.read_scalar(&ecx.project_index(&right, i.wrapping_add(j))?)?.to_u8()?;
99 }
100
101 for j in 0..8 {
103 let index = i.wrapping_add(j);
104 let left = ecx.read_scalar(&ecx.project_index(&left, index)?)?.to_u8()?;
105 let left = if inverse { TABLE[usize::from(left)] } else { left };
106
107 let mut res = 0;
108
109 for bit in 0u8..8 {
111 let mut b = matrix[usize::from(bit)] & left;
112
113 b = (b & 0b1111) ^ (b >> 4);
115 b = (b & 0b11) ^ (b >> 2);
116 b = (b & 0b1) ^ (b >> 1);
117
118 res |= b << 7u8.wrapping_sub(bit);
119 }
120
121 res ^= imm8;
123
124 let dest = ecx.project_index(&dest, index)?;
125 ecx.write_scalar(Scalar::from_u8(res), &dest)?;
126 }
127 }
128
129 interp_ok(())
130}
131
132#[expect(clippy::cast_possible_truncation)]
137static TABLE: [u8; 256] = {
138 let mut array = [0; 256];
139
140 let mut i = 1;
141 while i < 256 {
142 let mut x = i as u8;
143 let mut y = gf2p8_mul(x, x);
144 x = y;
145 let mut j = 2;
146 while j < 8 {
147 x = gf2p8_mul(x, x);
148 y = gf2p8_mul(x, y);
149 j += 1;
150 }
151 array[i] = y;
152 i += 1;
153 }
154
155 array
156};
157
158#[expect(clippy::cast_possible_truncation)]
164const fn gf2p8_mul(left: u8, right: u8) -> u8 {
165 const POLYNOMIAL: u32 = 0x11b;
170
171 let left = left as u32;
172 let right = right as u32;
173
174 let mut result = 0u32;
175
176 let mut i = 0u32;
177 while i < 8 {
178 if left & (1 << i) != 0 {
179 result ^= right << i;
180 }
181 i = i.wrapping_add(1);
182 }
183
184 let mut i = 14u32;
185 while i >= 8 {
186 if result & (1 << i) != 0 {
187 result ^= POLYNOMIAL << i.wrapping_sub(8);
188 }
189 i = i.wrapping_sub(1);
190 }
191
192 result as u8
193}