Skip to main content

core/stdarch/crates/core_arch/src/x86_64/
amx.rs

1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6/// Load tile configuration from a 64-byte memory location specified by mem_addr.
7/// The tile configuration format is specified below, and includes the tile type pallette,
8/// the number of bytes per row, and the number of rows. If the specified pallette_id is zero,
9/// that signifies the init state for both the tile config and the tile data, and the tiles are zeroed.
10/// Any invalid configurations will result in #GP fault.
11///
12/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadconfig&ig_expand=6875)
13#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18    ldtilecfg(mem_addr);
19}
20
21/// Stores the current tile configuration to a 64-byte memory location specified by mem_addr.
22/// The tile configuration format is specified below, and includes the tile type pallette,
23/// the number of bytes per row, and the number of rows. If tiles are not configured, all zeroes will be stored to memory.
24///
25/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_storeconfig&ig_expand=6879)
26#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31    sttilecfg(mem_addr);
32}
33
34/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration previously configured via _tile_loadconfig.
35///
36/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_loadd&ig_expand=6877)
37#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43    static_assert_uimm_bits!(DST, 3);
44    tileloadd64(DST as i8, base, stride);
45}
46
47/// Release the tile configuration to return to the init state, which releases all storage it currently holds.
48///
49/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_release&ig_expand=6878)
50#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55    tilerelease();
56}
57
58/// Store the tile specified by src to memory specifieid by base address and stride using the tile configuration previously configured via _tile_loadconfig.
59///
60/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stored&ig_expand=6881)
61#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67    static_assert_uimm_bits!(DST, 3);
68    tilestored64(DST as i8, base, stride);
69}
70
71/// Load tile rows from memory specifieid by base address and stride into destination tile dst using the tile configuration
72/// previously configured via _tile_loadconfig. This intrinsic provides a hint to the implementation that the data will
73/// likely not be reused in the near future and the data caching can be optimized accordingly.
74///
75/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_stream_loadd&ig_expand=6883)
76#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82    static_assert_uimm_bits!(DST, 3);
83    tileloaddt164(DST as i8, base, stride);
84}
85
86/// Zero the tile specified by tdest.
87///
88/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_zero&ig_expand=6885)
89#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95    static_assert_uimm_bits!(DST, 3);
96    tilezero(DST as i8);
97}
98
99/// Compute dot-product of BF16 (16-bit) floating-point pairs in tiles a and b,
100/// accumulating the intermediate single-precision (32-bit) floating-point elements
101/// with elements in dst, and store the 32-bit result back to tile dst.
102///
103/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbf16ps&ig_expand=6864)
104#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110    static_assert_uimm_bits!(DST, 3);
111    static_assert_uimm_bits!(A, 3);
112    static_assert_uimm_bits!(B, 3);
113    tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116/// Compute dot-product of bytes in tiles with a source/destination accumulator.
117/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
118/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
119/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
120///
121/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbssd&ig_expand=6866)
122#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128    static_assert_uimm_bits!(DST, 3);
129    static_assert_uimm_bits!(A, 3);
130    static_assert_uimm_bits!(B, 3);
131    tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134/// Compute dot-product of bytes in tiles with a source/destination accumulator.
135/// Multiply groups of 4 adjacent pairs of signed 8-bit integers in a with corresponding
136/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
137/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
138///
139/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbsud&ig_expand=6868)
140#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146    static_assert_uimm_bits!(DST, 3);
147    static_assert_uimm_bits!(A, 3);
148    static_assert_uimm_bits!(B, 3);
149    tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152/// Compute dot-product of bytes in tiles with a source/destination accumulator.
153/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
154/// signed 8-bit integers in b, producing 4 intermediate 32-bit results.
155/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
156///
157/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbusd&ig_expand=6870)
158#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164    static_assert_uimm_bits!(DST, 3);
165    static_assert_uimm_bits!(A, 3);
166    static_assert_uimm_bits!(B, 3);
167    tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170/// Compute dot-product of bytes in tiles with a source/destination accumulator.
171/// Multiply groups of 4 adjacent pairs of unsigned 8-bit integers in a with corresponding
172/// unsigned 8-bit integers in b, producing 4 intermediate 32-bit results.
173/// Sum these 4 results with the corresponding 32-bit integer in dst, and store the 32-bit result back to tile dst.
174///
175/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpbuud&ig_expand=6872)
176#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182    static_assert_uimm_bits!(DST, 3);
183    static_assert_uimm_bits!(A, 3);
184    static_assert_uimm_bits!(B, 3);
185    tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188/// Compute dot-product of FP16 (16-bit) floating-point pairs in tiles a and b,
189/// accumulating the intermediate single-precision (32-bit) floating-point elements
190///  with elements in dst, and store the 32-bit result back to tile dst.
191///
192/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_dpfp16ps&ig_expand=6874)
193#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199    static_assert_uimm_bits!(DST, 3);
200    static_assert_uimm_bits!(A, 3);
201    static_assert_uimm_bits!(B, 3);
202    tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
206/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
207/// Calculates the imaginary part of the result. For each possible combination of (row of a, column of b),
208/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
209/// The imaginary part of the a element is multiplied with the real part of the corresponding b element, and the real part of
210/// the a element is multiplied with the imaginary part of the corresponding b elements. The two accumulated results are added,
211/// and then accumulated into the corresponding row and column of dst.
212///
213/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmimfp16ps&ig_expand=6860)
214#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220    static_assert_uimm_bits!(DST, 3);
221    static_assert_uimm_bits!(A, 3);
222    static_assert_uimm_bits!(B, 3);
223    tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226/// Perform matrix multiplication of two tiles containing complex elements and accumulate the results into a packed single precision tile.
227/// Each dword element in input tiles a and b is interpreted as a complex number with FP16 real part and FP16 imaginary part.
228/// Calculates the real part of the result. For each possible combination of (row of a, column of b),
229/// it performs a set of multiplication and accumulations on all corresponding complex numbers (one from a and one from b).
230/// The real part of the a element is multiplied with the real part of the corresponding b element, and the negated imaginary part of
231/// the a element is multiplied with the imaginary part of the corresponding b elements.
232/// The two accumulated results are added, and then accumulated into the corresponding row and column of dst.
233///
234/// [Intel's documentation](https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_tile_cmmrlfp16ps&ig_expand=6862)
235#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241    static_assert_uimm_bits!(DST, 3);
242    static_assert_uimm_bits!(A, 3);
243    static_assert_uimm_bits!(B, 3);
244    tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and BF8 (8-bit E5M2)
248/// floating-point elements in tile b, accumulating the intermediate single-precision
249/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
250/// back to tile dst.
251#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255    all(test, any(target_os = "linux", target_env = "msvc")),
256    assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260    static_assert_uimm_bits!(DST, 3);
261    static_assert_uimm_bits!(A, 3);
262    static_assert_uimm_bits!(B, 3);
263    tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266/// Compute dot-product of BF8 (8-bit E5M2) floating-point elements in tile a and HF8
267/// (8-bit E4M3) floating-point elements in tile b, accumulating the intermediate single-precision
268/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
269/// back to tile dst.
270#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274    all(test, any(target_os = "linux", target_env = "msvc")),
275    assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279    static_assert_uimm_bits!(DST, 3);
280    static_assert_uimm_bits!(A, 3);
281    static_assert_uimm_bits!(B, 3);
282    tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and BF8
286/// (8-bit E5M2) floating-point elements in tile b, accumulating the intermediate single-precision
287/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
288/// back to tile dst.
289#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293    all(test, any(target_os = "linux", target_env = "msvc")),
294    assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298    static_assert_uimm_bits!(DST, 3);
299    static_assert_uimm_bits!(A, 3);
300    static_assert_uimm_bits!(B, 3);
301    tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304/// Compute dot-product of HF8 (8-bit E4M3) floating-point elements in tile a and HF8 (8-bit E4M3)
305/// floating-point elements in tile b, accumulating the intermediate single-precision
306/// (32-bit) floating-point elements with elements in dst, and store the 32-bit result
307/// back to tile dst.
308#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312    all(test, any(target_os = "linux", target_env = "msvc")),
313    assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317    static_assert_uimm_bits!(DST, 3);
318    static_assert_uimm_bits!(A, 3);
319    static_assert_uimm_bits!(B, 3);
320    tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323/// Load tile rows from memory specified by base address and stride into destination tile dst
324/// using the tile configuration previously configured via _tile_loadconfig.
325/// Additionally, this intrinsic indicates the source memory location is likely to become
326/// read-shared by multiple processors, i.e., read in the future by at least one other processor
327/// before it is written, assuming it is ever written in the future.
328#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332    all(test, any(target_os = "linux", target_env = "msvc")),
333    assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337    static_assert_uimm_bits!(DST, 3);
338    tileloaddrs64(DST as i8, base, stride);
339}
340
341/// Load tile rows from memory specified by base address and stride into destination tile dst
342/// using the tile configuration previously configured via _tile_loadconfig.
343/// Provides a hint to the implementation that the data would be reused but does not need
344/// to be resident in the nearest cache levels.
345/// Additionally, this intrinsic indicates the source memory location is likely to become
346/// read-shared by multiple processors, i.e., read in the future by at least one other processor
347/// before it is written, assuming it is ever written in the future.
348#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352    all(test, any(target_os = "linux", target_env = "msvc")),
353    assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357    static_assert_uimm_bits!(DST, 3);
358    tileloaddrst164(DST as i8, base, stride);
359}
360
361/// Perform matrix multiplication of two tiles a and b, containing packed single precision (32-bit)
362/// floating-point elements, which are converted to TF32 (tensor-float32) format, and accumulate the
363///  results into a packed single precision tile.
364/// For each possible combination of (row of a, column of b), it performs
365///  - convert to TF32
366///  - multiply the corresponding elements of a and b
367///  - accumulate the results into the corresponding row and column of dst using round-to-nearest-even
368/// rounding mode.
369/// Output FP32 denormals are always flushed to zero, input single precision denormals are always
370/// handled and *not* treated as zero.
371#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375    all(test, any(target_os = "linux", target_env = "msvc")),
376    assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380    static_assert_uimm_bits!(DST, 3);
381    static_assert_uimm_bits!(A, 3);
382    static_assert_uimm_bits!(B, 3);
383    tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
387/// elements to packed single-precision (32-bit) floating-point elements.
388#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392    all(test, any(target_os = "linux", target_env = "msvc")),
393    assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397    static_assert_uimm_bits!(TILE, 3);
398    tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401/// Moves a row from a tile register to a zmm register, converting the packed 32-bit signed integer
402/// elements to packed single-precision (32-bit) floating-point elements.
403#[inline]
404#[rustc_legacy_const_generics(0, 1)]
405#[target_feature(enable = "amx-avx512,avx10.2")]
406#[cfg_attr(
407    all(test, any(target_os = "linux", target_env = "msvc")),
408    assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
409)]
410#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
411pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
412    static_assert_uimm_bits!(TILE, 3);
413    static_assert_uimm_bits!(ROW, 6);
414    tcvtrowd2psi(TILE as i8, ROW as u32).as_m512()
415}
416
417/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
418/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
419/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
420#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424    all(test, any(target_os = "linux", target_env = "msvc")),
425    assert_instr(tcvtrowps2phh, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
429    static_assert_uimm_bits!(TILE, 3);
430    tcvtrowps2phh(TILE as i8, row).as_m512h()
431}
432
433/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
434/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
435/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
436#[inline]
437#[rustc_legacy_const_generics(0, 1)]
438#[target_feature(enable = "amx-avx512,avx10.2")]
439#[cfg_attr(
440    all(test, any(target_os = "linux", target_env = "msvc")),
441    assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
442)]
443#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
444pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h {
445    static_assert_uimm_bits!(TILE, 3);
446    static_assert_uimm_bits!(ROW, 6);
447    tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h()
448}
449
450/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
451/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
452/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
453#[inline]
454#[rustc_legacy_const_generics(0)]
455#[target_feature(enable = "amx-avx512,avx10.2")]
456#[cfg_attr(
457    all(test, any(target_os = "linux", target_env = "msvc")),
458    assert_instr(tcvtrowps2phl, TILE = 0)
459)]
460#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
461pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
462    static_assert_uimm_bits!(TILE, 3);
463    tcvtrowps2phl(TILE as i8, row).as_m512h()
464}
465
466/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
467/// floating-point elements to packed half-precision (16-bit) floating-point elements. The resulting
468/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
469#[inline]
470#[rustc_legacy_const_generics(0, 1)]
471#[target_feature(enable = "amx-avx512,avx10.2")]
472#[cfg_attr(
473    all(test, any(target_os = "linux", target_env = "msvc")),
474    assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
475)]
476#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
477pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h {
478    static_assert_uimm_bits!(TILE, 3);
479    static_assert_uimm_bits!(ROW, 6);
480    tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
481}
482
483/// Moves one row of tile data into a zmm vector register
484#[inline]
485#[rustc_legacy_const_generics(0)]
486#[target_feature(enable = "amx-avx512,avx10.2")]
487#[cfg_attr(
488    all(test, any(target_os = "linux", target_env = "msvc")),
489    assert_instr(tilemovrow, TILE = 0)
490)]
491#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
492pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
493    static_assert_uimm_bits!(TILE, 3);
494    tilemovrow(TILE as i8, row).as_m512i()
495}
496
497/// Moves one row of tile data into a zmm vector register
498#[inline]
499#[rustc_legacy_const_generics(0, 1)]
500#[target_feature(enable = "amx-avx512,avx10.2")]
501#[cfg_attr(
502    all(test, any(target_os = "linux", target_env = "msvc")),
503    assert_instr(tilemovrow, TILE = 0, ROW = 0)
504)]
505#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
506pub unsafe fn _tile_movrowi<const TILE: i32, const ROW: i32>() -> __m512i {
507    static_assert_uimm_bits!(TILE, 3);
508    static_assert_uimm_bits!(ROW, 6);
509    tilemovrowi(TILE as i8, ROW as u32).as_m512i()
510}
511
512#[allow(improper_ctypes)]
513unsafe extern "C" {
514    #[link_name = "llvm.x86.ldtilecfg"]
515    fn ldtilecfg(mem_addr: *const u8);
516    #[link_name = "llvm.x86.sttilecfg"]
517    fn sttilecfg(mem_addr: *mut u8);
518    #[link_name = "llvm.x86.tileloadd64"]
519    fn tileloadd64(dst: i8, base: *const u8, stride: usize);
520    #[link_name = "llvm.x86.tileloaddt164"]
521    fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
522    #[link_name = "llvm.x86.tilerelease"]
523    fn tilerelease();
524    #[link_name = "llvm.x86.tilestored64"]
525    fn tilestored64(dst: i8, base: *mut u8, stride: usize);
526    #[link_name = "llvm.x86.tilezero"]
527    fn tilezero(dst: i8);
528    #[link_name = "llvm.x86.tdpbf16ps"]
529    fn tdpbf16ps(dst: i8, a: i8, b: i8);
530    #[link_name = "llvm.x86.tdpbuud"]
531    fn tdpbuud(dst: i8, a: i8, b: i8);
532    #[link_name = "llvm.x86.tdpbusd"]
533    fn tdpbusd(dst: i8, a: i8, b: i8);
534    #[link_name = "llvm.x86.tdpbsud"]
535    fn tdpbsud(dst: i8, a: i8, b: i8);
536    #[link_name = "llvm.x86.tdpbssd"]
537    fn tdpbssd(dst: i8, a: i8, b: i8);
538    #[link_name = "llvm.x86.tdpfp16ps"]
539    fn tdpfp16ps(dst: i8, a: i8, b: i8);
540    #[link_name = "llvm.x86.tcmmimfp16ps"]
541    fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
542    #[link_name = "llvm.x86.tcmmrlfp16ps"]
543    fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
544    #[link_name = "llvm.x86.tdpbf8ps"]
545    fn tdpbf8ps(dst: i8, a: i8, b: i8);
546    #[link_name = "llvm.x86.tdpbhf8ps"]
547    fn tdpbhf8ps(dst: i8, a: i8, b: i8);
548    #[link_name = "llvm.x86.tdphbf8ps"]
549    fn tdphbf8ps(dst: i8, a: i8, b: i8);
550    #[link_name = "llvm.x86.tdphf8ps"]
551    fn tdphf8ps(dst: i8, a: i8, b: i8);
552    #[link_name = "llvm.x86.tileloaddrs64"]
553    fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
554    #[link_name = "llvm.x86.tileloaddrst164"]
555    fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
556    #[link_name = "llvm.x86.tmmultf32ps"]
557    fn tmmultf32ps(dst: i8, a: i8, b: i8);
558    #[link_name = "llvm.x86.tcvtrowd2ps"]
559    fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
560    #[link_name = "llvm.x86.tcvtrowd2psi"]
561    fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16;
562    #[link_name = "llvm.x86.tcvtrowps2phh"]
563    fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
564    #[link_name = "llvm.x86.tcvtrowps2phhi"]
565    fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32;
566    #[link_name = "llvm.x86.tcvtrowps2phl"]
567    fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
568    #[link_name = "llvm.x86.tcvtrowps2phli"]
569    fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
570    #[link_name = "llvm.x86.tilemovrow"]
571    fn tilemovrow(tile: i8, row: u32) -> i32x16;
572    #[link_name = "llvm.x86.tilemovrowi"]
573    fn tilemovrowi(tile: i8, row: u32) -> i32x16;
574}
575
576#[cfg(test)]
577mod tests {
578    use crate::core_arch::x86::_mm_cvtness_sbh;
579    use crate::core_arch::x86_64::*;
580    use core::{array, mem::transmute};
581    use stdarch_test::simd_test;
582    #[cfg(target_os = "linux")]
583    use syscalls::{Sysno, syscall};
584
585    #[allow(non_camel_case_types)]
586    #[repr(C, packed)]
587    #[derive(Copy, Clone, Default, Debug, PartialEq)]
588    struct __tilecfg {
589        /// 0 `or` 1
590        palette: u8,
591        start_row: u8,
592        /// reserved, must be zero
593        reserved_a0: [u8; 14],
594        /// number of bytes of one row in each tile
595        colsb: [u16; 8],
596        /// reserved, must be zero
597        reserved_b0: [u16; 8],
598        /// number of rows in each tile
599        rows: [u8; 8],
600        /// reserved, must be zero
601        reserved_c0: [u8; 8],
602    }
603
604    impl __tilecfg {
605        fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
606            Self {
607                palette,
608                start_row,
609                reserved_a0: [0u8; 14],
610                colsb,
611                reserved_b0: [0u16; 8],
612                rows,
613                reserved_c0: [0u8; 8],
614            }
615        }
616
617        const fn as_ptr(&self) -> *const u8 {
618            self as *const Self as *const u8
619        }
620
621        fn as_mut_ptr(&mut self) -> *mut u8 {
622            self as *mut Self as *mut u8
623        }
624    }
625
626    #[cfg(not(target_os = "linux"))]
627    #[target_feature(enable = "amx-tile")]
628    fn _init_amx() {}
629
630    #[cfg(target_os = "linux")]
631    #[target_feature(enable = "amx-tile")]
632    #[inline]
633    unsafe fn _init_amx() {
634        let mut ret: usize;
635        let mut xfeatures: usize = 0;
636        ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
637            .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
638        if ret != 0 {
639            panic!("Failed to get XFEATURES");
640        } else {
641            match 0b11 & (xfeatures >> 17) {
642                0 => panic!("AMX is not available"),
643                1 => {
644                    ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
645                        .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
646                    if ret != 0 {
647                        panic!("Failed to enable AMX");
648                    }
649                }
650                3 => {}
651                _ => unreachable!(),
652            }
653        }
654    }
655
656    #[simd_test(enable = "amx-tile")]
657    fn test_tile_loadconfig() {
658        unsafe {
659            let config = __tilecfg::default();
660            _tile_loadconfig(config.as_ptr());
661            _tile_release();
662        }
663    }
664
665    #[simd_test(enable = "amx-tile")]
666    fn test_tile_storeconfig() {
667        unsafe {
668            let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
669            _tile_loadconfig(config.as_ptr());
670            let mut _config = __tilecfg::default();
671            _tile_storeconfig(_config.as_mut_ptr());
672            _tile_release();
673            assert_eq!(config, _config);
674        }
675    }
676
677    #[simd_test(enable = "amx-tile")]
678    fn test_tile_zero() {
679        unsafe {
680            _init_amx();
681            let mut config = __tilecfg::default();
682            config.palette = 1;
683            config.colsb[0] = 64;
684            config.rows[0] = 16;
685            _tile_loadconfig(config.as_ptr());
686            _tile_zero::<0>();
687            let mut out = [[1_i8; 64]; 16];
688            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
689            _tile_release();
690            assert_eq!(out, [[0; 64]; 16]);
691        }
692    }
693
694    #[simd_test(enable = "amx-tile")]
695    fn test_tile_stored() {
696        unsafe {
697            _init_amx();
698            let mut config = __tilecfg::default();
699            config.palette = 1;
700            config.colsb[0] = 64;
701            config.rows[0] = 16;
702            _tile_loadconfig(config.as_ptr());
703            _tile_zero::<0>();
704            let mut out = [[1_i8; 64]; 16];
705            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
706            _tile_release();
707            assert_eq!(out, [[0; 64]; 16]);
708        }
709    }
710
711    #[simd_test(enable = "amx-tile")]
712    fn test_tile_loadd() {
713        unsafe {
714            _init_amx();
715            let mut config = __tilecfg::default();
716            config.palette = 1;
717            config.colsb[0] = 64;
718            config.rows[0] = 16;
719            _tile_loadconfig(config.as_ptr());
720            _tile_zero::<0>();
721            let mat = [1_i8; 1024];
722            _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
723            let mut out = [[0_i8; 64]; 16];
724            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
725            _tile_release();
726            assert_eq!(out, [[1; 64]; 16]);
727        }
728    }
729
730    #[simd_test(enable = "amx-tile")]
731    fn test_tile_stream_loadd() {
732        unsafe {
733            _init_amx();
734            let mut config = __tilecfg::default();
735            config.palette = 1;
736            config.colsb[0] = 64;
737            config.rows[0] = 16;
738            _tile_loadconfig(config.as_ptr());
739            _tile_zero::<0>();
740            let mat = [1_i8; 1024];
741            _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
742            let mut out = [[0_i8; 64]; 16];
743            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
744            _tile_release();
745            assert_eq!(out, [[1; 64]; 16]);
746        }
747    }
748
749    #[simd_test(enable = "amx-tile")]
750    fn test_tile_release() {
751        unsafe {
752            _tile_release();
753        }
754    }
755
756    #[simd_test(enable = "amx-bf16,avx512f")]
757    fn test_tile_dpbf16ps() {
758        unsafe {
759            _init_amx();
760            let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
761            let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
762            let ones: [u8; 1024] = transmute([bf16_1; 512]);
763            let twos: [u8; 1024] = transmute([bf16_2; 512]);
764            let mut res = [[0f32; 16]; 16];
765            let mut config = __tilecfg::default();
766            config.palette = 1;
767            (0..=2).for_each(|i| {
768                config.colsb[i] = 64;
769                config.rows[i] = 16;
770            });
771            _tile_loadconfig(config.as_ptr());
772            _tile_zero::<0>();
773            _tile_loadd::<1>(&ones as *const u8, 64);
774            _tile_loadd::<2>(&twos as *const u8, 64);
775            _tile_dpbf16ps::<0, 1, 2>();
776            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
777            _tile_release();
778            assert_eq!(res, [[64f32; 16]; 16]);
779        }
780    }
781
782    #[simd_test(enable = "amx-int8")]
783    fn test_tile_dpbssd() {
784        unsafe {
785            _init_amx();
786            let ones = [-1_i8; 1024];
787            let twos = [-2_i8; 1024];
788            let mut res = [[0_i32; 16]; 16];
789            let mut config = __tilecfg::default();
790            config.palette = 1;
791            (0..=2).for_each(|i| {
792                config.colsb[i] = 64;
793                config.rows[i] = 16;
794            });
795            _tile_loadconfig(config.as_ptr());
796            _tile_zero::<0>();
797            _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
798            _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
799            _tile_dpbssd::<0, 1, 2>();
800            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
801            _tile_release();
802            assert_eq!(res, [[128_i32; 16]; 16]);
803        }
804    }
805
806    #[simd_test(enable = "amx-int8")]
807    fn test_tile_dpbsud() {
808        unsafe {
809            _init_amx();
810            let ones = [-1_i8; 1024];
811            let twos = [2_u8; 1024];
812            let mut res = [[0_i32; 16]; 16];
813            let mut config = __tilecfg::default();
814            config.palette = 1;
815            (0..=2).for_each(|i| {
816                config.colsb[i] = 64;
817                config.rows[i] = 16;
818            });
819            _tile_loadconfig(config.as_ptr());
820            _tile_zero::<0>();
821            _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
822            _tile_loadd::<2>(&twos as *const u8, 64);
823            _tile_dpbsud::<0, 1, 2>();
824            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
825            _tile_release();
826            assert_eq!(res, [[-128_i32; 16]; 16]);
827        }
828    }
829
830    #[simd_test(enable = "amx-int8")]
831    fn test_tile_dpbusd() {
832        unsafe {
833            _init_amx();
834            let ones = [1_u8; 1024];
835            let twos = [-2_i8; 1024];
836            let mut res = [[0_i32; 16]; 16];
837            let mut config = __tilecfg::default();
838            config.palette = 1;
839            (0..=2).for_each(|i| {
840                config.colsb[i] = 64;
841                config.rows[i] = 16;
842            });
843            _tile_loadconfig(config.as_ptr());
844            _tile_zero::<0>();
845            _tile_loadd::<1>(&ones as *const u8, 64);
846            _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
847            _tile_dpbusd::<0, 1, 2>();
848            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
849            _tile_release();
850            assert_eq!(res, [[-128_i32; 16]; 16]);
851        }
852    }
853
854    #[simd_test(enable = "amx-int8")]
855    fn test_tile_dpbuud() {
856        unsafe {
857            _init_amx();
858            let ones = [1_u8; 1024];
859            let twos = [2_u8; 1024];
860            let mut res = [[0_i32; 16]; 16];
861            let mut config = __tilecfg::default();
862            config.palette = 1;
863            (0..=2).for_each(|i| {
864                config.colsb[i] = 64;
865                config.rows[i] = 16;
866            });
867            _tile_loadconfig(config.as_ptr());
868            _tile_zero::<0>();
869            _tile_loadd::<1>(&ones as *const u8, 64);
870            _tile_loadd::<2>(&twos as *const u8, 64);
871            _tile_dpbuud::<0, 1, 2>();
872            _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
873            _tile_release();
874            assert_eq!(res, [[128_i32; 16]; 16]);
875        }
876    }
877
878    #[simd_test(enable = "amx-fp16")]
879    fn test_tile_dpfp16ps() {
880        unsafe {
881            _init_amx();
882            let ones = [1f16; 512];
883            let twos = [2f16; 512];
884            let mut res = [[0f32; 16]; 16];
885            let mut config = __tilecfg::default();
886            config.palette = 1;
887            (0..=2).for_each(|i| {
888                config.colsb[i] = 64;
889                config.rows[i] = 16;
890            });
891            _tile_loadconfig(config.as_ptr());
892            _tile_zero::<0>();
893            _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
894            _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
895            _tile_dpfp16ps::<0, 1, 2>();
896            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
897            _tile_release();
898            assert_eq!(res, [[64f32; 16]; 16]);
899        }
900    }
901
902    #[simd_test(enable = "amx-complex")]
903    fn test_tile_cmmimfp16ps() {
904        unsafe {
905            _init_amx();
906            let ones = [1f16; 512];
907            let twos = [2f16; 512];
908            let mut res = [[0f32; 16]; 16];
909            let mut config = __tilecfg::default();
910            config.palette = 1;
911            (0..=2).for_each(|i| {
912                config.colsb[i] = 64;
913                config.rows[i] = 16;
914            });
915            _tile_loadconfig(config.as_ptr());
916            _tile_zero::<0>();
917            _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
918            _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
919            _tile_cmmimfp16ps::<0, 1, 2>();
920            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
921            _tile_release();
922            assert_eq!(res, [[64f32; 16]; 16]);
923        }
924    }
925
926    #[simd_test(enable = "amx-complex")]
927    fn test_tile_cmmrlfp16ps() {
928        unsafe {
929            _init_amx();
930            let ones = [1f16; 512];
931            let twos = [2f16; 512];
932            let mut res = [[0f32; 16]; 16];
933            let mut config = __tilecfg::default();
934            config.palette = 1;
935            (0..=2).for_each(|i| {
936                config.colsb[i] = 64;
937                config.rows[i] = 16;
938            });
939            _tile_loadconfig(config.as_ptr());
940            _tile_zero::<0>();
941            _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
942            _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
943            _tile_cmmrlfp16ps::<0, 1, 2>();
944            _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
945            _tile_release();
946            assert_eq!(res, [[0f32; 16]; 16]);
947        }
948    }
949
950    const BF8_ONE: u8 = 0x3c;
951    const BF8_TWO: u8 = 0x40;
952    const HF8_ONE: u8 = 0x38;
953    const HF8_TWO: u8 = 0x40;
954
955    #[simd_test(enable = "amx-fp8")]
956    fn test_tile_dpbf8ps() {
957        unsafe {
958            _init_amx();
959            let ones = [BF8_ONE; 1024];
960            let twos = [BF8_TWO; 1024];
961            let mut res = [[0.0_f32; 16]; 16];
962            let mut config = __tilecfg::default();
963            config.palette = 1;
964            (0..=2).for_each(|i| {
965                config.colsb[i] = 64;
966                config.rows[i] = 16;
967            });
968            _tile_loadconfig(config.as_ptr());
969            _tile_zero::<0>();
970            _tile_loadd::<1>(&ones as *const u8, 64);
971            _tile_loadd::<2>(&twos as *const u8, 64);
972            _tile_dpbf8ps::<0, 1, 2>();
973            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
974            _tile_release();
975            assert_eq!(res, [[128.0_f32; 16]; 16]);
976        }
977    }
978
979    #[simd_test(enable = "amx-fp8")]
980    fn test_tile_dpbhf8ps() {
981        unsafe {
982            _init_amx();
983            let ones = [BF8_ONE; 1024];
984            let twos = [HF8_TWO; 1024];
985            let mut res = [[0.0_f32; 16]; 16];
986            let mut config = __tilecfg::default();
987            config.palette = 1;
988            (0..=2).for_each(|i| {
989                config.colsb[i] = 64;
990                config.rows[i] = 16;
991            });
992            _tile_loadconfig(config.as_ptr());
993            _tile_zero::<0>();
994            _tile_loadd::<1>(&ones as *const u8, 64);
995            _tile_loadd::<2>(&twos as *const u8, 64);
996            _tile_dpbhf8ps::<0, 1, 2>();
997            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
998            _tile_release();
999            assert_eq!(res, [[128.0_f32; 16]; 16]);
1000        }
1001    }
1002
1003    #[simd_test(enable = "amx-fp8")]
1004    fn test_tile_dphbf8ps() {
1005        unsafe {
1006            _init_amx();
1007            let ones = [HF8_ONE; 1024];
1008            let twos = [BF8_TWO; 1024];
1009            let mut res = [[0.0_f32; 16]; 16];
1010            let mut config = __tilecfg::default();
1011            config.palette = 1;
1012            (0..=2).for_each(|i| {
1013                config.colsb[i] = 64;
1014                config.rows[i] = 16;
1015            });
1016            _tile_loadconfig(config.as_ptr());
1017            _tile_zero::<0>();
1018            _tile_loadd::<1>(&ones as *const u8, 64);
1019            _tile_loadd::<2>(&twos as *const u8, 64);
1020            _tile_dphbf8ps::<0, 1, 2>();
1021            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1022            _tile_release();
1023            assert_eq!(res, [[128.0_f32; 16]; 16]);
1024        }
1025    }
1026
1027    #[simd_test(enable = "amx-fp8")]
1028    fn test_tile_dphf8ps() {
1029        unsafe {
1030            _init_amx();
1031            let ones = [HF8_ONE; 1024];
1032            let twos = [HF8_TWO; 1024];
1033            let mut res = [[0.0_f32; 16]; 16];
1034            let mut config = __tilecfg::default();
1035            config.palette = 1;
1036            (0..=2).for_each(|i| {
1037                config.colsb[i] = 64;
1038                config.rows[i] = 16;
1039            });
1040            _tile_loadconfig(config.as_ptr());
1041            _tile_zero::<0>();
1042            _tile_loadd::<1>(&ones as *const u8, 64);
1043            _tile_loadd::<2>(&twos as *const u8, 64);
1044            _tile_dphf8ps::<0, 1, 2>();
1045            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1046            _tile_release();
1047            assert_eq!(res, [[128.0_f32; 16]; 16]);
1048        }
1049    }
1050
1051    #[simd_test(enable = "amx-movrs")]
1052    fn test_tile_loaddrs() {
1053        unsafe {
1054            _init_amx();
1055            let mut config = __tilecfg::default();
1056            config.palette = 1;
1057            config.colsb[0] = 64;
1058            config.rows[0] = 16;
1059            _tile_loadconfig(config.as_ptr());
1060            _tile_zero::<0>();
1061            let mat = [1_i8; 1024];
1062            _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1063            let mut out = [[0_i8; 64]; 16];
1064            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1065            _tile_release();
1066            assert_eq!(out, [[1; 64]; 16]);
1067        }
1068    }
1069
1070    #[simd_test(enable = "amx-movrs")]
1071    fn test_tile_stream_loaddrs() {
1072        unsafe {
1073            _init_amx();
1074            let mut config = __tilecfg::default();
1075            config.palette = 1;
1076            config.colsb[0] = 64;
1077            config.rows[0] = 16;
1078            _tile_loadconfig(config.as_ptr());
1079            _tile_zero::<0>();
1080            let mat = [1_i8; 1024];
1081            _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1082            let mut out = [[0_i8; 64]; 16];
1083            _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1084            _tile_release();
1085            assert_eq!(out, [[1; 64]; 16]);
1086        }
1087    }
1088
1089    #[simd_test(enable = "amx-avx512,avx10.2")]
1090    fn test_tile_movrow() {
1091        unsafe {
1092            _init_amx();
1093            let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1094
1095            let mut config = __tilecfg::default();
1096            config.palette = 1;
1097            config.colsb[0] = 64;
1098            config.rows[0] = 16;
1099            _tile_loadconfig(config.as_ptr());
1100            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1101            for i in 0..16 {
1102                let row = _tile_movrow::<0>(i);
1103                assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1104            }
1105        }
1106    }
1107
1108    macro_rules! wrap_imm4 {
1109        ($name:ident :: <$TILE:literal>, $row:expr) => {
1110            match $row {
1111                0 => $name::<$TILE, 0>(),
1112                1 => $name::<$TILE, 1>(),
1113                2 => $name::<$TILE, 2>(),
1114                3 => $name::<$TILE, 3>(),
1115                4 => $name::<$TILE, 4>(),
1116                5 => $name::<$TILE, 5>(),
1117                6 => $name::<$TILE, 6>(),
1118                7 => $name::<$TILE, 7>(),
1119                8 => $name::<$TILE, 8>(),
1120                9 => $name::<$TILE, 9>(),
1121                10 => $name::<$TILE, 10>(),
1122                11 => $name::<$TILE, 11>(),
1123                12 => $name::<$TILE, 12>(),
1124                13 => $name::<$TILE, 13>(),
1125                14 => $name::<$TILE, 14>(),
1126                15 => $name::<$TILE, 15>(),
1127                _ => panic!("row index out of range"),
1128            }
1129        };
1130    }
1131
1132    #[simd_test(enable = "amx-avx512,avx10.2")]
1133    fn test_tile_movrowi() {
1134        unsafe {
1135            _init_amx();
1136            let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1137
1138            let mut config = __tilecfg::default();
1139            config.palette = 1;
1140            config.colsb[0] = 64;
1141            config.rows[0] = 16;
1142            _tile_loadconfig(config.as_ptr());
1143            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1144
1145            for i in 0..16 {
1146                let row = wrap_imm4!(_tile_movrowi::<0>, i);
1147                assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1148            }
1149        }
1150    }
1151
1152    #[simd_test(enable = "amx-avx512,avx10.2")]
1153    fn test_tile_cvtrowd2ps() {
1154        unsafe {
1155            _init_amx();
1156            let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1157
1158            let mut config = __tilecfg::default();
1159            config.palette = 1;
1160            config.colsb[0] = 64;
1161            config.rows[0] = 16;
1162            _tile_loadconfig(config.as_ptr());
1163            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1164            for i in 0..16 {
1165                let row = _tile_cvtrowd2ps::<0>(i);
1166                assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1167            }
1168        }
1169    }
1170
1171    #[simd_test(enable = "amx-avx512,avx10.2")]
1172    fn test_tile_cvtrowd2psi() {
1173        unsafe {
1174            _init_amx();
1175            let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1176
1177            let mut config = __tilecfg::default();
1178            config.palette = 1;
1179            config.colsb[0] = 64;
1180            config.rows[0] = 16;
1181            _tile_loadconfig(config.as_ptr());
1182            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1183
1184            for i in 0..16 {
1185                let row = wrap_imm4!(_tile_cvtrowd2psi::<0>, i);
1186                assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1187            }
1188        }
1189    }
1190
1191    #[simd_test(enable = "amx-avx512,avx10.2")]
1192    fn test_tile_cvtrowps2phh() {
1193        unsafe {
1194            _init_amx();
1195            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1196
1197            let mut config = __tilecfg::default();
1198            config.palette = 1;
1199            config.colsb[0] = 64;
1200            config.rows[0] = 16;
1201            _tile_loadconfig(config.as_ptr());
1202            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1203            for i in 0..16 {
1204                let row = _tile_cvtrowps2phh::<0>(i);
1205                assert_eq!(
1206                    *row.as_f16x32().as_array(),
1207                    array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1208                );
1209            }
1210        }
1211    }
1212
1213    #[simd_test(enable = "amx-avx512,avx10.2")]
1214    fn test_tile_cvtrowps2phhi() {
1215        unsafe {
1216            _init_amx();
1217            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1218
1219            let mut config = __tilecfg::default();
1220            config.palette = 1;
1221            config.colsb[0] = 64;
1222            config.rows[0] = 16;
1223            _tile_loadconfig(config.as_ptr());
1224            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1225            for i in 0..16 {
1226                let row = wrap_imm4!(_tile_cvtrowps2phhi::<0>, i);
1227                assert_eq!(
1228                    *row.as_f16x32().as_array(),
1229                    array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1230                );
1231            }
1232        }
1233    }
1234
1235    #[simd_test(enable = "amx-avx512,avx10.2")]
1236    fn test_tile_cvtrowps2phl() {
1237        unsafe {
1238            _init_amx();
1239            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1240
1241            let mut config = __tilecfg::default();
1242            config.palette = 1;
1243            config.colsb[0] = 64;
1244            config.rows[0] = 16;
1245            _tile_loadconfig(config.as_ptr());
1246            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1247            for i in 0..16 {
1248                let row = _tile_cvtrowps2phl::<0>(i);
1249                assert_eq!(
1250                    *row.as_f16x32().as_array(),
1251                    array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1252                );
1253            }
1254        }
1255    }
1256
1257    #[simd_test(enable = "amx-avx512,avx10.2")]
1258    fn test_tile_cvtrowps2phli() {
1259        unsafe {
1260            _init_amx();
1261            let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1262
1263            let mut config = __tilecfg::default();
1264            config.palette = 1;
1265            config.colsb[0] = 64;
1266            config.rows[0] = 16;
1267            _tile_loadconfig(config.as_ptr());
1268            _tile_loadd::<0>(array.as_ptr().cast(), 64);
1269            for i in 0..16 {
1270                let row = wrap_imm4!(_tile_cvtrowps2phli::<0>, i);
1271                assert_eq!(
1272                    *row.as_f16x32().as_array(),
1273                    array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1274                );
1275            }
1276        }
1277    }
1278
1279    #[simd_test(enable = "amx-tf32")]
1280    fn test_tile_mmultf32ps() {
1281        unsafe {
1282            _init_amx();
1283            let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1284            let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1285            let mut res = [[0.0; 16]; 16];
1286
1287            let mut config = __tilecfg::default();
1288            config.palette = 1;
1289            (0..=2).for_each(|i| {
1290                config.colsb[i] = 64;
1291                config.rows[i] = 16;
1292            });
1293            _tile_loadconfig(config.as_ptr());
1294            _tile_zero::<0>();
1295            _tile_loadd::<1>(a.as_ptr().cast(), 64);
1296            _tile_loadd::<2>(b.as_ptr().cast(), 64);
1297            _tile_mmultf32ps::<0, 1, 2>();
1298            _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1299            _tile_release();
1300
1301            let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1302            assert_eq!(res, expected);
1303        }
1304    }
1305}