Skip to main content

miri/shims/
aarch64.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::mir::BinOp;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use crate::shims::math::compute_crc32;
8use crate::*;
9
10impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
11pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
12    fn emulate_aarch64_intrinsic(
13        &mut self,
14        link_name: Symbol,
15        abi: &FnAbi<'tcx, Ty<'tcx>>,
16        args: &[OpTy<'tcx>],
17        dest: &MPlaceTy<'tcx>,
18    ) -> InterpResult<'tcx, EmulateItemResult> {
19        let this = self.eval_context_mut();
20        // Prefix should have already been checked.
21        let unprefixed_name = link_name.as_str().strip_prefix("llvm.aarch64.").unwrap();
22        match unprefixed_name {
23            // Used to implement the vpmaxq_u8 function.
24            // Computes the maximum of adjacent pairs; the first half of the output is produced from the
25            // `left` input, the second half of the output from the `right` input.
26            // https://developer.arm.com/architectures/instruction-sets/intrinsics/vpmaxq_u8
27            "neon.umaxp.v16i8" => {
28                let [left, right] =
29                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
30
31                let (left, left_len) = this.project_to_simd(left)?;
32                let (right, right_len) = this.project_to_simd(right)?;
33                let (dest, lane_count) = this.project_to_simd(dest)?;
34                assert_eq!(left_len, right_len);
35                assert_eq!(lane_count, left_len);
36
37                for lane_idx in 0..lane_count {
38                    let src = if lane_idx < (lane_count / 2) { &left } else { &right };
39                    let src_idx = lane_idx.strict_rem(lane_count / 2);
40
41                    let lhs_lane =
42                        this.read_immediate(&this.project_index(src, src_idx.strict_mul(2))?)?;
43                    let rhs_lane = this.read_immediate(
44                        &this.project_index(src, src_idx.strict_mul(2).strict_add(1))?,
45                    )?;
46
47                    // Compute `if lhs > rhs { lhs } else { rhs }`, i.e., `max`.
48                    let res_lane = if this
49                        .binary_op(BinOp::Gt, &lhs_lane, &rhs_lane)?
50                        .to_scalar()
51                        .to_bool()?
52                    {
53                        lhs_lane
54                    } else {
55                        rhs_lane
56                    };
57
58                    let dest = this.project_index(&dest, lane_idx)?;
59                    this.write_immediate(*res_lane, &dest)?;
60                }
61            }
62
63            // Wrapping pairwise addition.
64            //
65            // Concatenates the two input vectors and adds adjacent elements. For input vectors `v`
66            // and `w` this computes `[v0 + v1, v2 + v3, ..., w0 + w1, w2 + w3, ...]`, using
67            // wrapping addition for `+`.
68            //
69            // Used by `vpadd_{s8, u8, s16, u16, s32, u32}`.
70            name if name.starts_with("neon.addp.") => {
71                let [left, right] =
72                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
73
74                let (left, left_len) = this.project_to_simd(left)?;
75                let (right, right_len) = this.project_to_simd(right)?;
76                let (dest, dest_len) = this.project_to_simd(dest)?;
77
78                assert_eq!(left_len, right_len);
79                assert_eq!(left_len, dest_len);
80
81                assert_eq!(left.layout, right.layout);
82                assert_eq!(left.layout, dest.layout);
83
84                assert!(dest_len.is_multiple_of(2));
85                let half_len = dest_len.strict_div(2);
86
87                for lane_idx in 0..dest_len {
88                    // The left and right vectors are concatenated.
89                    let (src, src_pair_idx) = if lane_idx < half_len {
90                        (&left, lane_idx)
91                    } else {
92                        (&right, lane_idx.strict_sub(half_len))
93                    };
94                    // Convert "pair index" into "index of first element of the pair".
95                    let i = src_pair_idx.strict_mul(2);
96
97                    let lhs = this.read_immediate(&this.project_index(src, i)?)?;
98                    let rhs = this.read_immediate(&this.project_index(src, i.strict_add(1))?)?;
99
100                    // Wrapping addition on the element type.
101                    let sum = this.binary_op(BinOp::Add, &lhs, &rhs)?;
102
103                    let dst_lane = this.project_index(&dest, lane_idx)?;
104                    this.write_immediate(*sum, &dst_lane)?;
105                }
106            }
107
108            // Widening pairwise addition.
109            //
110            // Takes a single input vector, and an output vector with half as many lanes and double
111            // the element width. Takes adjacent pairs of elements, widens both, and then adds them
112            // together.
113            //
114            // Used by `vpaddl_{u8, u16, u32}` and `vpaddlq_{u8, u16, u32}`.
115            name if name.starts_with("neon.uaddlp.") => {
116                let [src] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
117
118                let (src, src_len) = this.project_to_simd(src)?;
119                let (dest, dest_len) = this.project_to_simd(dest)?;
120
121                // Operates pairwise, so src has twice as many lanes.
122                assert_eq!(src_len, dest_len.strict_mul(2));
123
124                let src_elem_size = src.layout.field(this, 0).size;
125                let dest_elem_size = dest.layout.field(this, 0).size;
126
127                // Widens, so dest elements must be exactly twice as wide.
128                assert_eq!(dest_elem_size.bytes(), src_elem_size.bytes().strict_mul(2));
129
130                for dest_idx in 0..dest_len {
131                    let src_idx = dest_idx.strict_mul(2);
132
133                    let a_scalar = this.read_scalar(&this.project_index(&src, src_idx)?)?;
134                    let b_scalar =
135                        this.read_scalar(&this.project_index(&src, src_idx.strict_add(1))?)?;
136
137                    let a_val = a_scalar.to_uint(src_elem_size)?;
138                    let b_val = b_scalar.to_uint(src_elem_size)?;
139
140                    // Use addition on u128 to simulate widening addition for the destination type.
141                    // This cannot wrap since the element type is at most u64.
142                    let sum = a_val.strict_add(b_val);
143
144                    let dst_lane = this.project_index(&dest, dest_idx)?;
145                    this.write_scalar(Scalar::from_uint(sum, dest_elem_size), &dst_lane)?;
146                }
147            }
148
149            // Signed saturating doubling multiply returning the high half.
150            //
151            // Used by the `vqdmulh*` functions.
152            //
153            // This LLVM intrinsic multiplies the values of corresponding elements of the two source
154            // vector registers (which are signed integers), doubles the results, places the most significant half of the
155            // final results (using a saturating cast to fit the element type) into a vector, and writes the vector to the destination register.
156            //
157            // https://developer.arm.com/architectures/instruction-sets/intrinsics#f:@navigationhierarchiessimdisa=[Neon]&q=vqdmulh
158            name if name.starts_with("neon.sqdmulh.") => {
159                let [left, right] =
160                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
161
162                let (left, left_len) = this.project_to_simd(left)?;
163                let (right, right_len) = this.project_to_simd(right)?;
164                let (dest, dest_len) = this.project_to_simd(dest)?;
165                assert_eq!(left_len, right_len);
166                assert_eq!(left_len, dest_len);
167
168                let elem_size = dest.layout.field(this, 0).size;
169                let bits = elem_size.bits();
170                let min = elem_size.signed_int_min();
171                let max = elem_size.signed_int_max();
172
173                for i in 0..dest_len {
174                    let a = this.read_scalar(&this.project_index(&left, i)?)?.to_int(elem_size)?;
175                    let b = this.read_scalar(&this.project_index(&right, i)?)?.to_int(elem_size)?;
176
177                    // Uses i128 arithmetic, which cannot overflow because the intrinsic takes at most i32.
178                    let doubled = a.strict_mul(b).strict_mul(2);
179                    let res = (doubled >> bits).clamp(min, max);
180
181                    this.write_scalar(
182                        Scalar::from_int(res, elem_size),
183                        &this.project_index(&dest, i)?,
184                    )?;
185                }
186            }
187
188            // Vector table lookup: each index selects a byte from the 16-byte table, out-of-range -> 0.
189            // Used to implement vtbl1_u8 function.
190            // LLVM does not have a portable shuffle that takes non-const indices
191            // so we need to implement this ourselves.
192            // https://developer.arm.com/architectures/instruction-sets/intrinsics/vtbl1_u8
193            "neon.tbl1.v16i8" => {
194                let [table, indices] =
195                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
196
197                let (table, table_len) = this.project_to_simd(table)?;
198                let (indices, idx_len) = this.project_to_simd(indices)?;
199                let (dest, dest_len) = this.project_to_simd(dest)?;
200                assert_eq!(table_len, 16);
201                assert_eq!(idx_len, dest_len);
202
203                for i in 0..dest_len {
204                    let idx = this.read_immediate(&this.project_index(&indices, i)?)?;
205                    let idx_u = idx.to_scalar().to_u8()?;
206                    let val = if u64::from(idx_u) < table_len {
207                        let t = this.read_immediate(&this.project_index(&table, idx_u.into())?)?;
208                        t.to_scalar()
209                    } else {
210                        Scalar::from_u8(0)
211                    };
212                    this.write_scalar(val, &this.project_index(&dest, i)?)?;
213                }
214            }
215            // Used to implement the __crc32{b,h,w,x} and __crc32c{b,h,w,x} functions.
216            // Polynomial 0x04C11DB7 (standard CRC-32):
217            // https://developer.arm.com/documentation/ddi0602/latest/Base-Instructions/CRC32B--CRC32H--CRC32W--CRC32X--CRC32-checksum-
218            // Polynomial 0x1EDC6F41 (CRC-32C / Castagnoli):
219            // https://developer.arm.com/documentation/ddi0602/latest/Base-Instructions/CRC32CB--CRC32CH--CRC32CW--CRC32CX--CRC32C-checksum-
220            "crc32b" | "crc32h" | "crc32w" | "crc32x" | "crc32cb" | "crc32ch" | "crc32cw"
221            | "crc32cx" => {
222                this.expect_target_feature_for_intrinsic(link_name, "crc")?;
223                // The polynomial constants below include the leading 1 bit
224                // (e.g. 0x104C11DB7 instead of 0x04C11DB7) which the ARM docs
225                // omit but the polynomial division algorithm requires.
226                let (bit_size, polynomial): (u32, u128) = match unprefixed_name {
227                    "crc32b" => (8, 0x104C11DB7),
228                    "crc32h" => (16, 0x104C11DB7),
229                    "crc32w" => (32, 0x104C11DB7),
230                    "crc32x" => (64, 0x104C11DB7),
231                    "crc32cb" => (8, 0x11EDC6F41),
232                    "crc32ch" => (16, 0x11EDC6F41),
233                    "crc32cw" => (32, 0x11EDC6F41),
234                    "crc32cx" => (64, 0x11EDC6F41),
235                    _ => unreachable!(),
236                };
237
238                let [crc, data] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
239                let crc = this.read_scalar(crc)?;
240                let data = this.read_scalar(data)?;
241
242                // The CRC accumulator is always u32. The data argument is u32 for
243                // b/h/w variants and u64 for the x variant, per the LLVM intrinsic
244                // definitions (all b/h/w take i32, only x takes i64).
245                // https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/IntrinsicsAArch64.td
246                // If the higher bits are non-zero, `compute_crc32` will panic. We should probably
247                // raise a proper error instead, but outside stdarch nobody can trigger this anyway.
248                let crc = crc.to_u32()?;
249                let data = if bit_size == 64 { data.to_u64()? } else { u64::from(data.to_u32()?) };
250
251                let result = compute_crc32(crc, data, bit_size, polynomial);
252                this.write_scalar(Scalar::from_u32(result), dest)?;
253            }
254            // Polynomial multiply long (64-bit x 64-bit -> 128-bit).
255            //
256            // This is the same as "carryless" multiplication, see
257            // <https://en.wikipedia.org/wiki/Carry-less_product#Multiplication_of_polynomials>.
258            //
259            // Used to implement the vmull_p64 and vmull_high_p64 functions.
260            // https://developer.arm.com/architectures/instruction-sets/intrinsics/vmull_p64
261            "neon.pmull64" => {
262                // LLVM and GCC group pmull with the AES intrinsics.
263                // Also see <https://gcc.gnu.org/pipermail/gcc-patches/2023-February/612088.html>.
264                this.expect_target_feature_for_intrinsic(link_name, "aes")?;
265
266                let [left, right] =
267                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
268                let left = this.read_scalar(left)?.to_u64()?;
269                let right = this.read_scalar(right)?.to_u64()?;
270
271                let result = left.widening_carryless_mul(right);
272
273                // dest is int8x16_t, transmute to u128 for the write.
274                let dest = dest.transmute(this.machine.layouts.u128, this)?;
275                this.write_scalar(Scalar::from_u128(result), &dest)?;
276            }
277
278            _ => return interp_ok(EmulateItemResult::NotSupported),
279        }
280        interp_ok(EmulateItemResult::NeedsReturn)
281    }
282}