Skip to main content

miri/shims/
aarch64.rs

1use rustc_abi::CanonAbi;
2use rustc_middle::mir::BinOp;
3use rustc_middle::ty::Ty;
4use rustc_span::Symbol;
5use rustc_target::callconv::FnAbi;
6
7use crate::shims::math::compute_crc32;
8use crate::*;
9
10impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
11pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
12    fn emulate_aarch64_intrinsic(
13        &mut self,
14        link_name: Symbol,
15        abi: &FnAbi<'tcx, Ty<'tcx>>,
16        args: &[OpTy<'tcx>],
17        dest: &MPlaceTy<'tcx>,
18    ) -> InterpResult<'tcx, EmulateItemResult> {
19        let this = self.eval_context_mut();
20        // Prefix should have already been checked.
21        let unprefixed_name = link_name.as_str().strip_prefix("llvm.aarch64.").unwrap();
22        match unprefixed_name {
23            // Used to implement the vpmaxq_u8 function.
24            // Computes the maximum of adjacent pairs; the first half of the output is produced from the
25            // `left` input, the second half of the output from the `right` input.
26            // https://developer.arm.com/architectures/instruction-sets/intrinsics/vpmaxq_u8
27            "neon.umaxp.v16i8" => {
28                let [left, right] =
29                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
30
31                let (left, left_len) = this.project_to_simd(left)?;
32                let (right, right_len) = this.project_to_simd(right)?;
33                let (dest, lane_count) = this.project_to_simd(dest)?;
34                assert_eq!(left_len, right_len);
35                assert_eq!(lane_count, left_len);
36
37                for lane_idx in 0..lane_count {
38                    let src = if lane_idx < (lane_count / 2) { &left } else { &right };
39                    let src_idx = lane_idx.strict_rem(lane_count / 2);
40
41                    let lhs_lane =
42                        this.read_immediate(&this.project_index(src, src_idx.strict_mul(2))?)?;
43                    let rhs_lane = this.read_immediate(
44                        &this.project_index(src, src_idx.strict_mul(2).strict_add(1))?,
45                    )?;
46
47                    // Compute `if lhs > rhs { lhs } else { rhs }`, i.e., `max`.
48                    let res_lane = if this
49                        .binary_op(BinOp::Gt, &lhs_lane, &rhs_lane)?
50                        .to_scalar()
51                        .to_bool()?
52                    {
53                        lhs_lane
54                    } else {
55                        rhs_lane
56                    };
57
58                    let dest = this.project_index(&dest, lane_idx)?;
59                    this.write_immediate(*res_lane, &dest)?;
60                }
61            }
62
63            // Wrapping pairwise addition.
64            //
65            // Concatenates the two input vectors and adds adjacent elements. For input vectors `v`
66            // and `w` this computes `[v0 + v1, v2 + v3, ..., w0 + w1, w2 + w3, ...]`, using
67            // wrapping addition for `+`.
68            //
69            // Used by `vpadd_{s8, u8, s16, u16, s32, u32}`.
70            name if name.starts_with("neon.addp.") => {
71                let [left, right] =
72                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
73
74                let (left, left_len) = this.project_to_simd(left)?;
75                let (right, right_len) = this.project_to_simd(right)?;
76                let (dest, dest_len) = this.project_to_simd(dest)?;
77
78                assert_eq!(left_len, right_len);
79                assert_eq!(left_len, dest_len);
80
81                assert_eq!(left.layout, right.layout);
82                assert_eq!(left.layout, dest.layout);
83
84                assert!(dest_len.is_multiple_of(2));
85                let half_len = dest_len.strict_div(2);
86
87                for lane_idx in 0..dest_len {
88                    // The left and right vectors are concatenated.
89                    let (src, src_pair_idx) = if lane_idx < half_len {
90                        (&left, lane_idx)
91                    } else {
92                        (&right, lane_idx.strict_sub(half_len))
93                    };
94                    // Convert "pair index" into "index of first element of the pair".
95                    let i = src_pair_idx.strict_mul(2);
96
97                    let lhs = this.read_immediate(&this.project_index(src, i)?)?;
98                    let rhs = this.read_immediate(&this.project_index(src, i.strict_add(1))?)?;
99
100                    // Wrapping addition on the element type.
101                    let sum = this.binary_op(BinOp::Add, &lhs, &rhs)?;
102
103                    let dst_lane = this.project_index(&dest, lane_idx)?;
104                    this.write_immediate(*sum, &dst_lane)?;
105                }
106            }
107
108            // Widening pairwise addition.
109            //
110            // Takes a single input vector, and an output vector with half as many lanes and double
111            // the element width. Takes adjacent pairs of elements, widens both, and then adds them
112            // together.
113            //
114            // Used by `vpaddl_{u8, u16, u32}` and `vpaddlq_{u8, u16, u32}`.
115            name if name.starts_with("neon.uaddlp.") => {
116                let [src] = this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
117
118                let (src, src_len) = this.project_to_simd(src)?;
119                let (dest, dest_len) = this.project_to_simd(dest)?;
120
121                // Operates pairwise, so src has twice as many lanes.
122                assert_eq!(src_len, dest_len.strict_mul(2));
123
124                let src_elem_size = src.layout.field(this, 0).size;
125                let dest_elem_size = dest.layout.field(this, 0).size;
126
127                // Widens, so dest elements must be exactly twice as wide.
128                assert_eq!(dest_elem_size.bytes(), src_elem_size.bytes().strict_mul(2));
129
130                for dest_idx in 0..dest_len {
131                    let src_idx = dest_idx.strict_mul(2);
132
133                    let a_scalar = this.read_scalar(&this.project_index(&src, src_idx)?)?;
134                    let b_scalar =
135                        this.read_scalar(&this.project_index(&src, src_idx.strict_add(1))?)?;
136
137                    let a_val = a_scalar.to_uint(src_elem_size)?;
138                    let b_val = b_scalar.to_uint(src_elem_size)?;
139
140                    // Use addition on u128 to simulate widening addition for the destination type.
141                    // This cannot wrap since the element type is at most u64.
142                    let sum = a_val.strict_add(b_val);
143
144                    let dst_lane = this.project_index(&dest, dest_idx)?;
145                    this.write_scalar(Scalar::from_uint(sum, dest_elem_size), &dst_lane)?;
146                }
147            }
148
149            // Vector table lookup: each index selects a byte from the 16-byte table, out-of-range -> 0.
150            // Used to implement vtbl1_u8 function.
151            // LLVM does not have a portable shuffle that takes non-const indices
152            // so we need to implement this ourselves.
153            // https://developer.arm.com/architectures/instruction-sets/intrinsics/vtbl1_u8
154            "neon.tbl1.v16i8" => {
155                let [table, indices] =
156                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
157
158                let (table, table_len) = this.project_to_simd(table)?;
159                let (indices, idx_len) = this.project_to_simd(indices)?;
160                let (dest, dest_len) = this.project_to_simd(dest)?;
161                assert_eq!(table_len, 16);
162                assert_eq!(idx_len, dest_len);
163
164                for i in 0..dest_len {
165                    let idx = this.read_immediate(&this.project_index(&indices, i)?)?;
166                    let idx_u = idx.to_scalar().to_u8()?;
167                    let val = if u64::from(idx_u) < table_len {
168                        let t = this.read_immediate(&this.project_index(&table, idx_u.into())?)?;
169                        t.to_scalar()
170                    } else {
171                        Scalar::from_u8(0)
172                    };
173                    this.write_scalar(val, &this.project_index(&dest, i)?)?;
174                }
175            }
176            // Used to implement the __crc32{b,h,w,x} and __crc32c{b,h,w,x} functions.
177            // Polynomial 0x04C11DB7 (standard CRC-32):
178            // https://developer.arm.com/documentation/ddi0602/latest/Base-Instructions/CRC32B--CRC32H--CRC32W--CRC32X--CRC32-checksum-
179            // Polynomial 0x1EDC6F41 (CRC-32C / Castagnoli):
180            // https://developer.arm.com/documentation/ddi0602/latest/Base-Instructions/CRC32CB--CRC32CH--CRC32CW--CRC32CX--CRC32C-checksum-
181            "crc32b" | "crc32h" | "crc32w" | "crc32x" | "crc32cb" | "crc32ch" | "crc32cw"
182            | "crc32cx" => {
183                this.expect_target_feature_for_intrinsic(link_name, "crc")?;
184                // The polynomial constants below include the leading 1 bit
185                // (e.g. 0x104C11DB7 instead of 0x04C11DB7) which the ARM docs
186                // omit but the polynomial division algorithm requires.
187                let (bit_size, polynomial): (u32, u128) = match unprefixed_name {
188                    "crc32b" => (8, 0x104C11DB7),
189                    "crc32h" => (16, 0x104C11DB7),
190                    "crc32w" => (32, 0x104C11DB7),
191                    "crc32x" => (64, 0x104C11DB7),
192                    "crc32cb" => (8, 0x11EDC6F41),
193                    "crc32ch" => (16, 0x11EDC6F41),
194                    "crc32cw" => (32, 0x11EDC6F41),
195                    "crc32cx" => (64, 0x11EDC6F41),
196                    _ => unreachable!(),
197                };
198
199                let [left, right] =
200                    this.check_shim_sig_lenient(abi, CanonAbi::C, link_name, args)?;
201                let left = this.read_scalar(left)?;
202                let right = this.read_scalar(right)?;
203
204                // The CRC accumulator is always u32. The data argument is u32 for
205                // b/h/w variants and u64 for the x variant, per the LLVM intrinsic
206                // definitions (all b/h/w take i32, only x takes i64).
207                // https://github.com/llvm/llvm-project/blob/main/llvm/include/llvm/IR/IntrinsicsAArch64.td
208                // If the higher bits are non-zero, `compute_crc32` will panic. We should probably
209                // raise a proper error instead, but outside stdarch nobody can trigger this anyway.
210                let crc = left.to_u32()?;
211                let data =
212                    if bit_size == 64 { right.to_u64()? } else { u64::from(right.to_u32()?) };
213
214                let result = compute_crc32(crc, data, bit_size, polynomial);
215                this.write_scalar(Scalar::from_u32(result), dest)?;
216            }
217            _ => return interp_ok(EmulateItemResult::NotSupported),
218        }
219        interp_ok(EmulateItemResult::NeedsReturn)
220    }
221}