1use crate::core_arch::{simd::*, x86::*};
2
3#[cfg(test)]
4use stdarch_test::assert_instr;
5
6#[inline]
14#[target_feature(enable = "amx-tile")]
15#[cfg_attr(test, assert_instr(ldtilecfg))]
16#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
17pub unsafe fn _tile_loadconfig(mem_addr: *const u8) {
18 ldtilecfg(mem_addr);
19}
20
21#[inline]
27#[target_feature(enable = "amx-tile")]
28#[cfg_attr(test, assert_instr(sttilecfg))]
29#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
30pub unsafe fn _tile_storeconfig(mem_addr: *mut u8) {
31 sttilecfg(mem_addr);
32}
33
34#[inline]
38#[rustc_legacy_const_generics(0)]
39#[target_feature(enable = "amx-tile")]
40#[cfg_attr(test, assert_instr(tileloadd, DST = 0))]
41#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
42pub unsafe fn _tile_loadd<const DST: i32>(base: *const u8, stride: usize) {
43 static_assert_uimm_bits!(DST, 3);
44 tileloadd64(DST as i8, base, stride);
45}
46
47#[inline]
51#[target_feature(enable = "amx-tile")]
52#[cfg_attr(test, assert_instr(tilerelease))]
53#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
54pub unsafe fn _tile_release() {
55 tilerelease();
56}
57
58#[inline]
62#[rustc_legacy_const_generics(0)]
63#[target_feature(enable = "amx-tile")]
64#[cfg_attr(test, assert_instr(tilestored, DST = 0))]
65#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
66pub unsafe fn _tile_stored<const DST: i32>(base: *mut u8, stride: usize) {
67 static_assert_uimm_bits!(DST, 3);
68 tilestored64(DST as i8, base, stride);
69}
70
71#[inline]
77#[rustc_legacy_const_generics(0)]
78#[target_feature(enable = "amx-tile")]
79#[cfg_attr(test, assert_instr(tileloaddt1, DST = 0))]
80#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
81pub unsafe fn _tile_stream_loadd<const DST: i32>(base: *const u8, stride: usize) {
82 static_assert_uimm_bits!(DST, 3);
83 tileloaddt164(DST as i8, base, stride);
84}
85
86#[inline]
90#[rustc_legacy_const_generics(0)]
91#[target_feature(enable = "amx-tile")]
92#[cfg_attr(test, assert_instr(tilezero, DST = 0))]
93#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
94pub unsafe fn _tile_zero<const DST: i32>() {
95 static_assert_uimm_bits!(DST, 3);
96 tilezero(DST as i8);
97}
98
99#[inline]
105#[rustc_legacy_const_generics(0, 1, 2)]
106#[target_feature(enable = "amx-bf16")]
107#[cfg_attr(test, assert_instr(tdpbf16ps, DST = 0, A = 1, B = 2))]
108#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
109pub unsafe fn _tile_dpbf16ps<const DST: i32, const A: i32, const B: i32>() {
110 static_assert_uimm_bits!(DST, 3);
111 static_assert_uimm_bits!(A, 3);
112 static_assert_uimm_bits!(B, 3);
113 tdpbf16ps(DST as i8, A as i8, B as i8);
114}
115
116#[inline]
123#[rustc_legacy_const_generics(0, 1, 2)]
124#[target_feature(enable = "amx-int8")]
125#[cfg_attr(test, assert_instr(tdpbssd, DST = 0, A = 1, B = 2))]
126#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
127pub unsafe fn _tile_dpbssd<const DST: i32, const A: i32, const B: i32>() {
128 static_assert_uimm_bits!(DST, 3);
129 static_assert_uimm_bits!(A, 3);
130 static_assert_uimm_bits!(B, 3);
131 tdpbssd(DST as i8, A as i8, B as i8);
132}
133
134#[inline]
141#[rustc_legacy_const_generics(0, 1, 2)]
142#[target_feature(enable = "amx-int8")]
143#[cfg_attr(test, assert_instr(tdpbsud, DST = 0, A = 1, B = 2))]
144#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
145pub unsafe fn _tile_dpbsud<const DST: i32, const A: i32, const B: i32>() {
146 static_assert_uimm_bits!(DST, 3);
147 static_assert_uimm_bits!(A, 3);
148 static_assert_uimm_bits!(B, 3);
149 tdpbsud(DST as i8, A as i8, B as i8);
150}
151
152#[inline]
159#[rustc_legacy_const_generics(0, 1, 2)]
160#[target_feature(enable = "amx-int8")]
161#[cfg_attr(test, assert_instr(tdpbusd, DST = 0, A = 1, B = 2))]
162#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
163pub unsafe fn _tile_dpbusd<const DST: i32, const A: i32, const B: i32>() {
164 static_assert_uimm_bits!(DST, 3);
165 static_assert_uimm_bits!(A, 3);
166 static_assert_uimm_bits!(B, 3);
167 tdpbusd(DST as i8, A as i8, B as i8);
168}
169
170#[inline]
177#[rustc_legacy_const_generics(0, 1, 2)]
178#[target_feature(enable = "amx-int8")]
179#[cfg_attr(test, assert_instr(tdpbuud, DST = 0, A = 1, B = 2))]
180#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
181pub unsafe fn _tile_dpbuud<const DST: i32, const A: i32, const B: i32>() {
182 static_assert_uimm_bits!(DST, 3);
183 static_assert_uimm_bits!(A, 3);
184 static_assert_uimm_bits!(B, 3);
185 tdpbuud(DST as i8, A as i8, B as i8);
186}
187
188#[inline]
194#[rustc_legacy_const_generics(0, 1, 2)]
195#[target_feature(enable = "amx-fp16")]
196#[cfg_attr(test, assert_instr(tdpfp16ps, DST = 0, A = 1, B = 2))]
197#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
198pub unsafe fn _tile_dpfp16ps<const DST: i32, const A: i32, const B: i32>() {
199 static_assert_uimm_bits!(DST, 3);
200 static_assert_uimm_bits!(A, 3);
201 static_assert_uimm_bits!(B, 3);
202 tdpfp16ps(DST as i8, A as i8, B as i8);
203}
204
205#[inline]
215#[rustc_legacy_const_generics(0, 1, 2)]
216#[target_feature(enable = "amx-complex")]
217#[cfg_attr(test, assert_instr(tcmmimfp16ps, DST = 0, A = 1, B = 2))]
218#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
219pub unsafe fn _tile_cmmimfp16ps<const DST: i32, const A: i32, const B: i32>() {
220 static_assert_uimm_bits!(DST, 3);
221 static_assert_uimm_bits!(A, 3);
222 static_assert_uimm_bits!(B, 3);
223 tcmmimfp16ps(DST as i8, A as i8, B as i8);
224}
225
226#[inline]
236#[rustc_legacy_const_generics(0, 1, 2)]
237#[target_feature(enable = "amx-complex")]
238#[cfg_attr(test, assert_instr(tcmmrlfp16ps, DST = 0, A = 1, B = 2))]
239#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
240pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
241 static_assert_uimm_bits!(DST, 3);
242 static_assert_uimm_bits!(A, 3);
243 static_assert_uimm_bits!(B, 3);
244 tcmmrlfp16ps(DST as i8, A as i8, B as i8);
245}
246
247#[inline]
252#[rustc_legacy_const_generics(0, 1, 2)]
253#[target_feature(enable = "amx-fp8")]
254#[cfg_attr(
255 all(test, any(target_os = "linux", target_env = "msvc")),
256 assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257)]
258#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
259pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
260 static_assert_uimm_bits!(DST, 3);
261 static_assert_uimm_bits!(A, 3);
262 static_assert_uimm_bits!(B, 3);
263 tdpbf8ps(DST as i8, A as i8, B as i8);
264}
265
266#[inline]
271#[rustc_legacy_const_generics(0, 1, 2)]
272#[target_feature(enable = "amx-fp8")]
273#[cfg_attr(
274 all(test, any(target_os = "linux", target_env = "msvc")),
275 assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276)]
277#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
278pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
279 static_assert_uimm_bits!(DST, 3);
280 static_assert_uimm_bits!(A, 3);
281 static_assert_uimm_bits!(B, 3);
282 tdpbhf8ps(DST as i8, A as i8, B as i8);
283}
284
285#[inline]
290#[rustc_legacy_const_generics(0, 1, 2)]
291#[target_feature(enable = "amx-fp8")]
292#[cfg_attr(
293 all(test, any(target_os = "linux", target_env = "msvc")),
294 assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295)]
296#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
297pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
298 static_assert_uimm_bits!(DST, 3);
299 static_assert_uimm_bits!(A, 3);
300 static_assert_uimm_bits!(B, 3);
301 tdphbf8ps(DST as i8, A as i8, B as i8);
302}
303
304#[inline]
309#[rustc_legacy_const_generics(0, 1, 2)]
310#[target_feature(enable = "amx-fp8")]
311#[cfg_attr(
312 all(test, any(target_os = "linux", target_env = "msvc")),
313 assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314)]
315#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
316pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
317 static_assert_uimm_bits!(DST, 3);
318 static_assert_uimm_bits!(A, 3);
319 static_assert_uimm_bits!(B, 3);
320 tdphf8ps(DST as i8, A as i8, B as i8);
321}
322
323#[inline]
329#[rustc_legacy_const_generics(0)]
330#[target_feature(enable = "amx-movrs")]
331#[cfg_attr(
332 all(test, any(target_os = "linux", target_env = "msvc")),
333 assert_instr(tileloaddrs, DST = 0)
334)]
335#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
336pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
337 static_assert_uimm_bits!(DST, 3);
338 tileloaddrs64(DST as i8, base, stride);
339}
340
341#[inline]
349#[rustc_legacy_const_generics(0)]
350#[target_feature(enable = "amx-movrs")]
351#[cfg_attr(
352 all(test, any(target_os = "linux", target_env = "msvc")),
353 assert_instr(tileloaddrst1, DST = 0)
354)]
355#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
356pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
357 static_assert_uimm_bits!(DST, 3);
358 tileloaddrst164(DST as i8, base, stride);
359}
360
361#[inline]
372#[rustc_legacy_const_generics(0, 1, 2)]
373#[target_feature(enable = "amx-tf32")]
374#[cfg_attr(
375 all(test, any(target_os = "linux", target_env = "msvc")),
376 assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377)]
378#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
379pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
380 static_assert_uimm_bits!(DST, 3);
381 static_assert_uimm_bits!(A, 3);
382 static_assert_uimm_bits!(B, 3);
383 tmmultf32ps(DST as i8, A as i8, B as i8);
384}
385
386#[inline]
389#[rustc_legacy_const_generics(0)]
390#[target_feature(enable = "amx-avx512,avx10.2")]
391#[cfg_attr(
392 all(test, any(target_os = "linux", target_env = "msvc")),
393 assert_instr(tcvtrowd2ps, TILE = 0)
394)]
395#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
396pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
397 static_assert_uimm_bits!(TILE, 3);
398 tcvtrowd2ps(TILE as i8, row).as_m512()
399}
400
401#[inline]
404#[rustc_legacy_const_generics(0, 1)]
405#[target_feature(enable = "amx-avx512,avx10.2")]
406#[cfg_attr(
407 all(test, any(target_os = "linux", target_env = "msvc")),
408 assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
409)]
410#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
411pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
412 static_assert_uimm_bits!(TILE, 3);
413 static_assert_uimm_bits!(ROW, 6);
414 tcvtrowd2psi(TILE as i8, ROW as u32).as_m512()
415}
416
417#[inline]
421#[rustc_legacy_const_generics(0)]
422#[target_feature(enable = "amx-avx512,avx10.2")]
423#[cfg_attr(
424 all(test, any(target_os = "linux", target_env = "msvc")),
425 assert_instr(tcvtrowps2phh, TILE = 0)
426)]
427#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
428pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
429 static_assert_uimm_bits!(TILE, 3);
430 tcvtrowps2phh(TILE as i8, row).as_m512h()
431}
432
433#[inline]
437#[rustc_legacy_const_generics(0, 1)]
438#[target_feature(enable = "amx-avx512,avx10.2")]
439#[cfg_attr(
440 all(test, any(target_os = "linux", target_env = "msvc")),
441 assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
442)]
443#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
444pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h {
445 static_assert_uimm_bits!(TILE, 3);
446 static_assert_uimm_bits!(ROW, 6);
447 tcvtrowps2phhi(TILE as i8, ROW as u32).as_m512h()
448}
449
450#[inline]
454#[rustc_legacy_const_generics(0)]
455#[target_feature(enable = "amx-avx512,avx10.2")]
456#[cfg_attr(
457 all(test, any(target_os = "linux", target_env = "msvc")),
458 assert_instr(tcvtrowps2phl, TILE = 0)
459)]
460#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
461pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
462 static_assert_uimm_bits!(TILE, 3);
463 tcvtrowps2phl(TILE as i8, row).as_m512h()
464}
465
466#[inline]
470#[rustc_legacy_const_generics(0, 1)]
471#[target_feature(enable = "amx-avx512,avx10.2")]
472#[cfg_attr(
473 all(test, any(target_os = "linux", target_env = "msvc")),
474 assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
475)]
476#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
477pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h {
478 static_assert_uimm_bits!(TILE, 3);
479 static_assert_uimm_bits!(ROW, 6);
480 tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
481}
482
483#[inline]
485#[rustc_legacy_const_generics(0)]
486#[target_feature(enable = "amx-avx512,avx10.2")]
487#[cfg_attr(
488 all(test, any(target_os = "linux", target_env = "msvc")),
489 assert_instr(tilemovrow, TILE = 0)
490)]
491#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
492pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
493 static_assert_uimm_bits!(TILE, 3);
494 tilemovrow(TILE as i8, row).as_m512i()
495}
496
497#[inline]
499#[rustc_legacy_const_generics(0, 1)]
500#[target_feature(enable = "amx-avx512,avx10.2")]
501#[cfg_attr(
502 all(test, any(target_os = "linux", target_env = "msvc")),
503 assert_instr(tilemovrow, TILE = 0, ROW = 0)
504)]
505#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
506pub unsafe fn _tile_movrowi<const TILE: i32, const ROW: i32>() -> __m512i {
507 static_assert_uimm_bits!(TILE, 3);
508 static_assert_uimm_bits!(ROW, 6);
509 tilemovrowi(TILE as i8, ROW as u32).as_m512i()
510}
511
512#[allow(improper_ctypes)]
513unsafe extern "C" {
514 #[link_name = "llvm.x86.ldtilecfg"]
515 fn ldtilecfg(mem_addr: *const u8);
516 #[link_name = "llvm.x86.sttilecfg"]
517 fn sttilecfg(mem_addr: *mut u8);
518 #[link_name = "llvm.x86.tileloadd64"]
519 fn tileloadd64(dst: i8, base: *const u8, stride: usize);
520 #[link_name = "llvm.x86.tileloaddt164"]
521 fn tileloaddt164(dst: i8, base: *const u8, stride: usize);
522 #[link_name = "llvm.x86.tilerelease"]
523 fn tilerelease();
524 #[link_name = "llvm.x86.tilestored64"]
525 fn tilestored64(dst: i8, base: *mut u8, stride: usize);
526 #[link_name = "llvm.x86.tilezero"]
527 fn tilezero(dst: i8);
528 #[link_name = "llvm.x86.tdpbf16ps"]
529 fn tdpbf16ps(dst: i8, a: i8, b: i8);
530 #[link_name = "llvm.x86.tdpbuud"]
531 fn tdpbuud(dst: i8, a: i8, b: i8);
532 #[link_name = "llvm.x86.tdpbusd"]
533 fn tdpbusd(dst: i8, a: i8, b: i8);
534 #[link_name = "llvm.x86.tdpbsud"]
535 fn tdpbsud(dst: i8, a: i8, b: i8);
536 #[link_name = "llvm.x86.tdpbssd"]
537 fn tdpbssd(dst: i8, a: i8, b: i8);
538 #[link_name = "llvm.x86.tdpfp16ps"]
539 fn tdpfp16ps(dst: i8, a: i8, b: i8);
540 #[link_name = "llvm.x86.tcmmimfp16ps"]
541 fn tcmmimfp16ps(dst: i8, a: i8, b: i8);
542 #[link_name = "llvm.x86.tcmmrlfp16ps"]
543 fn tcmmrlfp16ps(dst: i8, a: i8, b: i8);
544 #[link_name = "llvm.x86.tdpbf8ps"]
545 fn tdpbf8ps(dst: i8, a: i8, b: i8);
546 #[link_name = "llvm.x86.tdpbhf8ps"]
547 fn tdpbhf8ps(dst: i8, a: i8, b: i8);
548 #[link_name = "llvm.x86.tdphbf8ps"]
549 fn tdphbf8ps(dst: i8, a: i8, b: i8);
550 #[link_name = "llvm.x86.tdphf8ps"]
551 fn tdphf8ps(dst: i8, a: i8, b: i8);
552 #[link_name = "llvm.x86.tileloaddrs64"]
553 fn tileloaddrs64(dst: i8, base: *const u8, stride: usize);
554 #[link_name = "llvm.x86.tileloaddrst164"]
555 fn tileloaddrst164(dst: i8, base: *const u8, stride: usize);
556 #[link_name = "llvm.x86.tmmultf32ps"]
557 fn tmmultf32ps(dst: i8, a: i8, b: i8);
558 #[link_name = "llvm.x86.tcvtrowd2ps"]
559 fn tcvtrowd2ps(tile: i8, row: u32) -> f32x16;
560 #[link_name = "llvm.x86.tcvtrowd2psi"]
561 fn tcvtrowd2psi(tile: i8, row: u32) -> f32x16;
562 #[link_name = "llvm.x86.tcvtrowps2phh"]
563 fn tcvtrowps2phh(tile: i8, row: u32) -> f16x32;
564 #[link_name = "llvm.x86.tcvtrowps2phhi"]
565 fn tcvtrowps2phhi(tile: i8, row: u32) -> f16x32;
566 #[link_name = "llvm.x86.tcvtrowps2phl"]
567 fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
568 #[link_name = "llvm.x86.tcvtrowps2phli"]
569 fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
570 #[link_name = "llvm.x86.tilemovrow"]
571 fn tilemovrow(tile: i8, row: u32) -> i32x16;
572 #[link_name = "llvm.x86.tilemovrowi"]
573 fn tilemovrowi(tile: i8, row: u32) -> i32x16;
574}
575
576#[cfg(test)]
577mod tests {
578 use crate::core_arch::x86::_mm_cvtness_sbh;
579 use crate::core_arch::x86_64::*;
580 use core::{array, mem::transmute};
581 use stdarch_test::simd_test;
582 #[cfg(target_os = "linux")]
583 use syscalls::{Sysno, syscall};
584
585 #[allow(non_camel_case_types)]
586 #[repr(C, packed)]
587 #[derive(Copy, Clone, Default, Debug, PartialEq)]
588 struct __tilecfg {
589 palette: u8,
591 start_row: u8,
592 reserved_a0: [u8; 14],
594 colsb: [u16; 8],
596 reserved_b0: [u16; 8],
598 rows: [u8; 8],
600 reserved_c0: [u8; 8],
602 }
603
604 impl __tilecfg {
605 fn new(palette: u8, start_row: u8, colsb: [u16; 8], rows: [u8; 8]) -> Self {
606 Self {
607 palette,
608 start_row,
609 reserved_a0: [0u8; 14],
610 colsb,
611 reserved_b0: [0u16; 8],
612 rows,
613 reserved_c0: [0u8; 8],
614 }
615 }
616
617 const fn as_ptr(&self) -> *const u8 {
618 self as *const Self as *const u8
619 }
620
621 fn as_mut_ptr(&mut self) -> *mut u8 {
622 self as *mut Self as *mut u8
623 }
624 }
625
626 #[cfg(not(target_os = "linux"))]
627 #[target_feature(enable = "amx-tile")]
628 fn _init_amx() {}
629
630 #[cfg(target_os = "linux")]
631 #[target_feature(enable = "amx-tile")]
632 #[inline]
633 unsafe fn _init_amx() {
634 let mut ret: usize;
635 let mut xfeatures: usize = 0;
636 ret = syscall!(Sysno::arch_prctl, 0x1022, &mut xfeatures as *mut usize)
637 .expect("arch_prctl ARCH_GET_XCOMP_PERM syscall failed");
638 if ret != 0 {
639 panic!("Failed to get XFEATURES");
640 } else {
641 match 0b11 & (xfeatures >> 17) {
642 0 => panic!("AMX is not available"),
643 1 => {
644 ret = syscall!(Sysno::arch_prctl, 0x1023, 18)
645 .expect("arch_prctl ARCH_REQ_XCOMP_PERM syscall failed");
646 if ret != 0 {
647 panic!("Failed to enable AMX");
648 }
649 }
650 3 => {}
651 _ => unreachable!(),
652 }
653 }
654 }
655
656 #[simd_test(enable = "amx-tile")]
657 fn test_tile_loadconfig() {
658 unsafe {
659 let config = __tilecfg::default();
660 _tile_loadconfig(config.as_ptr());
661 _tile_release();
662 }
663 }
664
665 #[simd_test(enable = "amx-tile")]
666 fn test_tile_storeconfig() {
667 unsafe {
668 let config = __tilecfg::new(1, 0, [32; 8], [8; 8]);
669 _tile_loadconfig(config.as_ptr());
670 let mut _config = __tilecfg::default();
671 _tile_storeconfig(_config.as_mut_ptr());
672 _tile_release();
673 assert_eq!(config, _config);
674 }
675 }
676
677 #[simd_test(enable = "amx-tile")]
678 fn test_tile_zero() {
679 unsafe {
680 _init_amx();
681 let mut config = __tilecfg::default();
682 config.palette = 1;
683 config.colsb[0] = 64;
684 config.rows[0] = 16;
685 _tile_loadconfig(config.as_ptr());
686 _tile_zero::<0>();
687 let mut out = [[1_i8; 64]; 16];
688 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
689 _tile_release();
690 assert_eq!(out, [[0; 64]; 16]);
691 }
692 }
693
694 #[simd_test(enable = "amx-tile")]
695 fn test_tile_stored() {
696 unsafe {
697 _init_amx();
698 let mut config = __tilecfg::default();
699 config.palette = 1;
700 config.colsb[0] = 64;
701 config.rows[0] = 16;
702 _tile_loadconfig(config.as_ptr());
703 _tile_zero::<0>();
704 let mut out = [[1_i8; 64]; 16];
705 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
706 _tile_release();
707 assert_eq!(out, [[0; 64]; 16]);
708 }
709 }
710
711 #[simd_test(enable = "amx-tile")]
712 fn test_tile_loadd() {
713 unsafe {
714 _init_amx();
715 let mut config = __tilecfg::default();
716 config.palette = 1;
717 config.colsb[0] = 64;
718 config.rows[0] = 16;
719 _tile_loadconfig(config.as_ptr());
720 _tile_zero::<0>();
721 let mat = [1_i8; 1024];
722 _tile_loadd::<0>(&mat as *const i8 as *const u8, 64);
723 let mut out = [[0_i8; 64]; 16];
724 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
725 _tile_release();
726 assert_eq!(out, [[1; 64]; 16]);
727 }
728 }
729
730 #[simd_test(enable = "amx-tile")]
731 fn test_tile_stream_loadd() {
732 unsafe {
733 _init_amx();
734 let mut config = __tilecfg::default();
735 config.palette = 1;
736 config.colsb[0] = 64;
737 config.rows[0] = 16;
738 _tile_loadconfig(config.as_ptr());
739 _tile_zero::<0>();
740 let mat = [1_i8; 1024];
741 _tile_stream_loadd::<0>(&mat as *const i8 as *const u8, 64);
742 let mut out = [[0_i8; 64]; 16];
743 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
744 _tile_release();
745 assert_eq!(out, [[1; 64]; 16]);
746 }
747 }
748
749 #[simd_test(enable = "amx-tile")]
750 fn test_tile_release() {
751 unsafe {
752 _tile_release();
753 }
754 }
755
756 #[simd_test(enable = "amx-bf16,avx512f")]
757 fn test_tile_dpbf16ps() {
758 unsafe {
759 _init_amx();
760 let bf16_1: u16 = _mm_cvtness_sbh(1.0).to_bits();
761 let bf16_2: u16 = _mm_cvtness_sbh(2.0).to_bits();
762 let ones: [u8; 1024] = transmute([bf16_1; 512]);
763 let twos: [u8; 1024] = transmute([bf16_2; 512]);
764 let mut res = [[0f32; 16]; 16];
765 let mut config = __tilecfg::default();
766 config.palette = 1;
767 (0..=2).for_each(|i| {
768 config.colsb[i] = 64;
769 config.rows[i] = 16;
770 });
771 _tile_loadconfig(config.as_ptr());
772 _tile_zero::<0>();
773 _tile_loadd::<1>(&ones as *const u8, 64);
774 _tile_loadd::<2>(&twos as *const u8, 64);
775 _tile_dpbf16ps::<0, 1, 2>();
776 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
777 _tile_release();
778 assert_eq!(res, [[64f32; 16]; 16]);
779 }
780 }
781
782 #[simd_test(enable = "amx-int8")]
783 fn test_tile_dpbssd() {
784 unsafe {
785 _init_amx();
786 let ones = [-1_i8; 1024];
787 let twos = [-2_i8; 1024];
788 let mut res = [[0_i32; 16]; 16];
789 let mut config = __tilecfg::default();
790 config.palette = 1;
791 (0..=2).for_each(|i| {
792 config.colsb[i] = 64;
793 config.rows[i] = 16;
794 });
795 _tile_loadconfig(config.as_ptr());
796 _tile_zero::<0>();
797 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
798 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
799 _tile_dpbssd::<0, 1, 2>();
800 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
801 _tile_release();
802 assert_eq!(res, [[128_i32; 16]; 16]);
803 }
804 }
805
806 #[simd_test(enable = "amx-int8")]
807 fn test_tile_dpbsud() {
808 unsafe {
809 _init_amx();
810 let ones = [-1_i8; 1024];
811 let twos = [2_u8; 1024];
812 let mut res = [[0_i32; 16]; 16];
813 let mut config = __tilecfg::default();
814 config.palette = 1;
815 (0..=2).for_each(|i| {
816 config.colsb[i] = 64;
817 config.rows[i] = 16;
818 });
819 _tile_loadconfig(config.as_ptr());
820 _tile_zero::<0>();
821 _tile_loadd::<1>(&ones as *const i8 as *const u8, 64);
822 _tile_loadd::<2>(&twos as *const u8, 64);
823 _tile_dpbsud::<0, 1, 2>();
824 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
825 _tile_release();
826 assert_eq!(res, [[-128_i32; 16]; 16]);
827 }
828 }
829
830 #[simd_test(enable = "amx-int8")]
831 fn test_tile_dpbusd() {
832 unsafe {
833 _init_amx();
834 let ones = [1_u8; 1024];
835 let twos = [-2_i8; 1024];
836 let mut res = [[0_i32; 16]; 16];
837 let mut config = __tilecfg::default();
838 config.palette = 1;
839 (0..=2).for_each(|i| {
840 config.colsb[i] = 64;
841 config.rows[i] = 16;
842 });
843 _tile_loadconfig(config.as_ptr());
844 _tile_zero::<0>();
845 _tile_loadd::<1>(&ones as *const u8, 64);
846 _tile_loadd::<2>(&twos as *const i8 as *const u8, 64);
847 _tile_dpbusd::<0, 1, 2>();
848 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
849 _tile_release();
850 assert_eq!(res, [[-128_i32; 16]; 16]);
851 }
852 }
853
854 #[simd_test(enable = "amx-int8")]
855 fn test_tile_dpbuud() {
856 unsafe {
857 _init_amx();
858 let ones = [1_u8; 1024];
859 let twos = [2_u8; 1024];
860 let mut res = [[0_i32; 16]; 16];
861 let mut config = __tilecfg::default();
862 config.palette = 1;
863 (0..=2).for_each(|i| {
864 config.colsb[i] = 64;
865 config.rows[i] = 16;
866 });
867 _tile_loadconfig(config.as_ptr());
868 _tile_zero::<0>();
869 _tile_loadd::<1>(&ones as *const u8, 64);
870 _tile_loadd::<2>(&twos as *const u8, 64);
871 _tile_dpbuud::<0, 1, 2>();
872 _tile_stored::<0>(&mut res as *mut [i32; 16] as *mut u8, 64);
873 _tile_release();
874 assert_eq!(res, [[128_i32; 16]; 16]);
875 }
876 }
877
878 #[simd_test(enable = "amx-fp16")]
879 fn test_tile_dpfp16ps() {
880 unsafe {
881 _init_amx();
882 let ones = [1f16; 512];
883 let twos = [2f16; 512];
884 let mut res = [[0f32; 16]; 16];
885 let mut config = __tilecfg::default();
886 config.palette = 1;
887 (0..=2).for_each(|i| {
888 config.colsb[i] = 64;
889 config.rows[i] = 16;
890 });
891 _tile_loadconfig(config.as_ptr());
892 _tile_zero::<0>();
893 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
894 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
895 _tile_dpfp16ps::<0, 1, 2>();
896 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
897 _tile_release();
898 assert_eq!(res, [[64f32; 16]; 16]);
899 }
900 }
901
902 #[simd_test(enable = "amx-complex")]
903 fn test_tile_cmmimfp16ps() {
904 unsafe {
905 _init_amx();
906 let ones = [1f16; 512];
907 let twos = [2f16; 512];
908 let mut res = [[0f32; 16]; 16];
909 let mut config = __tilecfg::default();
910 config.palette = 1;
911 (0..=2).for_each(|i| {
912 config.colsb[i] = 64;
913 config.rows[i] = 16;
914 });
915 _tile_loadconfig(config.as_ptr());
916 _tile_zero::<0>();
917 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
918 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
919 _tile_cmmimfp16ps::<0, 1, 2>();
920 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
921 _tile_release();
922 assert_eq!(res, [[64f32; 16]; 16]);
923 }
924 }
925
926 #[simd_test(enable = "amx-complex")]
927 fn test_tile_cmmrlfp16ps() {
928 unsafe {
929 _init_amx();
930 let ones = [1f16; 512];
931 let twos = [2f16; 512];
932 let mut res = [[0f32; 16]; 16];
933 let mut config = __tilecfg::default();
934 config.palette = 1;
935 (0..=2).for_each(|i| {
936 config.colsb[i] = 64;
937 config.rows[i] = 16;
938 });
939 _tile_loadconfig(config.as_ptr());
940 _tile_zero::<0>();
941 _tile_loadd::<1>(&ones as *const f16 as *const u8, 64);
942 _tile_loadd::<2>(&twos as *const f16 as *const u8, 64);
943 _tile_cmmrlfp16ps::<0, 1, 2>();
944 _tile_stored::<0>(&mut res as *mut [f32; 16] as *mut u8, 64);
945 _tile_release();
946 assert_eq!(res, [[0f32; 16]; 16]);
947 }
948 }
949
950 const BF8_ONE: u8 = 0x3c;
951 const BF8_TWO: u8 = 0x40;
952 const HF8_ONE: u8 = 0x38;
953 const HF8_TWO: u8 = 0x40;
954
955 #[simd_test(enable = "amx-fp8")]
956 fn test_tile_dpbf8ps() {
957 unsafe {
958 _init_amx();
959 let ones = [BF8_ONE; 1024];
960 let twos = [BF8_TWO; 1024];
961 let mut res = [[0.0_f32; 16]; 16];
962 let mut config = __tilecfg::default();
963 config.palette = 1;
964 (0..=2).for_each(|i| {
965 config.colsb[i] = 64;
966 config.rows[i] = 16;
967 });
968 _tile_loadconfig(config.as_ptr());
969 _tile_zero::<0>();
970 _tile_loadd::<1>(&ones as *const u8, 64);
971 _tile_loadd::<2>(&twos as *const u8, 64);
972 _tile_dpbf8ps::<0, 1, 2>();
973 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
974 _tile_release();
975 assert_eq!(res, [[128.0_f32; 16]; 16]);
976 }
977 }
978
979 #[simd_test(enable = "amx-fp8")]
980 fn test_tile_dpbhf8ps() {
981 unsafe {
982 _init_amx();
983 let ones = [BF8_ONE; 1024];
984 let twos = [HF8_TWO; 1024];
985 let mut res = [[0.0_f32; 16]; 16];
986 let mut config = __tilecfg::default();
987 config.palette = 1;
988 (0..=2).for_each(|i| {
989 config.colsb[i] = 64;
990 config.rows[i] = 16;
991 });
992 _tile_loadconfig(config.as_ptr());
993 _tile_zero::<0>();
994 _tile_loadd::<1>(&ones as *const u8, 64);
995 _tile_loadd::<2>(&twos as *const u8, 64);
996 _tile_dpbhf8ps::<0, 1, 2>();
997 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
998 _tile_release();
999 assert_eq!(res, [[128.0_f32; 16]; 16]);
1000 }
1001 }
1002
1003 #[simd_test(enable = "amx-fp8")]
1004 fn test_tile_dphbf8ps() {
1005 unsafe {
1006 _init_amx();
1007 let ones = [HF8_ONE; 1024];
1008 let twos = [BF8_TWO; 1024];
1009 let mut res = [[0.0_f32; 16]; 16];
1010 let mut config = __tilecfg::default();
1011 config.palette = 1;
1012 (0..=2).for_each(|i| {
1013 config.colsb[i] = 64;
1014 config.rows[i] = 16;
1015 });
1016 _tile_loadconfig(config.as_ptr());
1017 _tile_zero::<0>();
1018 _tile_loadd::<1>(&ones as *const u8, 64);
1019 _tile_loadd::<2>(&twos as *const u8, 64);
1020 _tile_dphbf8ps::<0, 1, 2>();
1021 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1022 _tile_release();
1023 assert_eq!(res, [[128.0_f32; 16]; 16]);
1024 }
1025 }
1026
1027 #[simd_test(enable = "amx-fp8")]
1028 fn test_tile_dphf8ps() {
1029 unsafe {
1030 _init_amx();
1031 let ones = [HF8_ONE; 1024];
1032 let twos = [HF8_TWO; 1024];
1033 let mut res = [[0.0_f32; 16]; 16];
1034 let mut config = __tilecfg::default();
1035 config.palette = 1;
1036 (0..=2).for_each(|i| {
1037 config.colsb[i] = 64;
1038 config.rows[i] = 16;
1039 });
1040 _tile_loadconfig(config.as_ptr());
1041 _tile_zero::<0>();
1042 _tile_loadd::<1>(&ones as *const u8, 64);
1043 _tile_loadd::<2>(&twos as *const u8, 64);
1044 _tile_dphf8ps::<0, 1, 2>();
1045 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1046 _tile_release();
1047 assert_eq!(res, [[128.0_f32; 16]; 16]);
1048 }
1049 }
1050
1051 #[simd_test(enable = "amx-movrs")]
1052 fn test_tile_loaddrs() {
1053 unsafe {
1054 _init_amx();
1055 let mut config = __tilecfg::default();
1056 config.palette = 1;
1057 config.colsb[0] = 64;
1058 config.rows[0] = 16;
1059 _tile_loadconfig(config.as_ptr());
1060 _tile_zero::<0>();
1061 let mat = [1_i8; 1024];
1062 _tile_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1063 let mut out = [[0_i8; 64]; 16];
1064 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1065 _tile_release();
1066 assert_eq!(out, [[1; 64]; 16]);
1067 }
1068 }
1069
1070 #[simd_test(enable = "amx-movrs")]
1071 fn test_tile_stream_loaddrs() {
1072 unsafe {
1073 _init_amx();
1074 let mut config = __tilecfg::default();
1075 config.palette = 1;
1076 config.colsb[0] = 64;
1077 config.rows[0] = 16;
1078 _tile_loadconfig(config.as_ptr());
1079 _tile_zero::<0>();
1080 let mat = [1_i8; 1024];
1081 _tile_stream_loaddrs::<0>(&mat as *const i8 as *const u8, 64);
1082 let mut out = [[0_i8; 64]; 16];
1083 _tile_stored::<0>(&mut out as *mut [i8; 64] as *mut u8, 64);
1084 _tile_release();
1085 assert_eq!(out, [[1; 64]; 16]);
1086 }
1087 }
1088
1089 #[simd_test(enable = "amx-avx512,avx10.2")]
1090 fn test_tile_movrow() {
1091 unsafe {
1092 _init_amx();
1093 let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1094
1095 let mut config = __tilecfg::default();
1096 config.palette = 1;
1097 config.colsb[0] = 64;
1098 config.rows[0] = 16;
1099 _tile_loadconfig(config.as_ptr());
1100 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1101 for i in 0..16 {
1102 let row = _tile_movrow::<0>(i);
1103 assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1104 }
1105 }
1106 }
1107
1108 macro_rules! wrap_imm4 {
1109 ($name:ident :: <$TILE:literal>, $row:expr) => {
1110 match $row {
1111 0 => $name::<$TILE, 0>(),
1112 1 => $name::<$TILE, 1>(),
1113 2 => $name::<$TILE, 2>(),
1114 3 => $name::<$TILE, 3>(),
1115 4 => $name::<$TILE, 4>(),
1116 5 => $name::<$TILE, 5>(),
1117 6 => $name::<$TILE, 6>(),
1118 7 => $name::<$TILE, 7>(),
1119 8 => $name::<$TILE, 8>(),
1120 9 => $name::<$TILE, 9>(),
1121 10 => $name::<$TILE, 10>(),
1122 11 => $name::<$TILE, 11>(),
1123 12 => $name::<$TILE, 12>(),
1124 13 => $name::<$TILE, 13>(),
1125 14 => $name::<$TILE, 14>(),
1126 15 => $name::<$TILE, 15>(),
1127 _ => panic!("row index out of range"),
1128 }
1129 };
1130 }
1131
1132 #[simd_test(enable = "amx-avx512,avx10.2")]
1133 fn test_tile_movrowi() {
1134 unsafe {
1135 _init_amx();
1136 let array: [[u8; 64]; 16] = array::from_fn(|i| [i as _; _]);
1137
1138 let mut config = __tilecfg::default();
1139 config.palette = 1;
1140 config.colsb[0] = 64;
1141 config.rows[0] = 16;
1142 _tile_loadconfig(config.as_ptr());
1143 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1144
1145 for i in 0..16 {
1146 let row = wrap_imm4!(_tile_movrowi::<0>, i);
1147 assert_eq!(*row.as_u8x64().as_array(), [i as _; _]);
1148 }
1149 }
1150 }
1151
1152 #[simd_test(enable = "amx-avx512,avx10.2")]
1153 fn test_tile_cvtrowd2ps() {
1154 unsafe {
1155 _init_amx();
1156 let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1157
1158 let mut config = __tilecfg::default();
1159 config.palette = 1;
1160 config.colsb[0] = 64;
1161 config.rows[0] = 16;
1162 _tile_loadconfig(config.as_ptr());
1163 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1164 for i in 0..16 {
1165 let row = _tile_cvtrowd2ps::<0>(i);
1166 assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1167 }
1168 }
1169 }
1170
1171 #[simd_test(enable = "amx-avx512,avx10.2")]
1172 fn test_tile_cvtrowd2psi() {
1173 unsafe {
1174 _init_amx();
1175 let array: [[u32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1176
1177 let mut config = __tilecfg::default();
1178 config.palette = 1;
1179 config.colsb[0] = 64;
1180 config.rows[0] = 16;
1181 _tile_loadconfig(config.as_ptr());
1182 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1183
1184 for i in 0..16 {
1185 let row = wrap_imm4!(_tile_cvtrowd2psi::<0>, i);
1186 assert_eq!(*row.as_f32x16().as_array(), [i as _; _]);
1187 }
1188 }
1189 }
1190
1191 #[simd_test(enable = "amx-avx512,avx10.2")]
1192 fn test_tile_cvtrowps2phh() {
1193 unsafe {
1194 _init_amx();
1195 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1196
1197 let mut config = __tilecfg::default();
1198 config.palette = 1;
1199 config.colsb[0] = 64;
1200 config.rows[0] = 16;
1201 _tile_loadconfig(config.as_ptr());
1202 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1203 for i in 0..16 {
1204 let row = _tile_cvtrowps2phh::<0>(i);
1205 assert_eq!(
1206 *row.as_f16x32().as_array(),
1207 array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1208 );
1209 }
1210 }
1211 }
1212
1213 #[simd_test(enable = "amx-avx512,avx10.2")]
1214 fn test_tile_cvtrowps2phhi() {
1215 unsafe {
1216 _init_amx();
1217 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1218
1219 let mut config = __tilecfg::default();
1220 config.palette = 1;
1221 config.colsb[0] = 64;
1222 config.rows[0] = 16;
1223 _tile_loadconfig(config.as_ptr());
1224 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1225 for i in 0..16 {
1226 let row = wrap_imm4!(_tile_cvtrowps2phhi::<0>, i);
1227 assert_eq!(
1228 *row.as_f16x32().as_array(),
1229 array::from_fn(|j| if j & 1 == 0 { 0.0 } else { i as _ })
1230 );
1231 }
1232 }
1233 }
1234
1235 #[simd_test(enable = "amx-avx512,avx10.2")]
1236 fn test_tile_cvtrowps2phl() {
1237 unsafe {
1238 _init_amx();
1239 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1240
1241 let mut config = __tilecfg::default();
1242 config.palette = 1;
1243 config.colsb[0] = 64;
1244 config.rows[0] = 16;
1245 _tile_loadconfig(config.as_ptr());
1246 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1247 for i in 0..16 {
1248 let row = _tile_cvtrowps2phl::<0>(i);
1249 assert_eq!(
1250 *row.as_f16x32().as_array(),
1251 array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1252 );
1253 }
1254 }
1255 }
1256
1257 #[simd_test(enable = "amx-avx512,avx10.2")]
1258 fn test_tile_cvtrowps2phli() {
1259 unsafe {
1260 _init_amx();
1261 let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1262
1263 let mut config = __tilecfg::default();
1264 config.palette = 1;
1265 config.colsb[0] = 64;
1266 config.rows[0] = 16;
1267 _tile_loadconfig(config.as_ptr());
1268 _tile_loadd::<0>(array.as_ptr().cast(), 64);
1269 for i in 0..16 {
1270 let row = wrap_imm4!(_tile_cvtrowps2phli::<0>, i);
1271 assert_eq!(
1272 *row.as_f16x32().as_array(),
1273 array::from_fn(|j| if j & 1 == 0 { i as _ } else { 0.0 })
1274 );
1275 }
1276 }
1277 }
1278
1279 #[simd_test(enable = "amx-tf32")]
1280 fn test_tile_mmultf32ps() {
1281 unsafe {
1282 _init_amx();
1283 let a: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1284 let b: [[f32; 16]; 16] = [array::from_fn(|j| j as _); _];
1285 let mut res = [[0.0; 16]; 16];
1286
1287 let mut config = __tilecfg::default();
1288 config.palette = 1;
1289 (0..=2).for_each(|i| {
1290 config.colsb[i] = 64;
1291 config.rows[i] = 16;
1292 });
1293 _tile_loadconfig(config.as_ptr());
1294 _tile_zero::<0>();
1295 _tile_loadd::<1>(a.as_ptr().cast(), 64);
1296 _tile_loadd::<2>(b.as_ptr().cast(), 64);
1297 _tile_mmultf32ps::<0, 1, 2>();
1298 _tile_stored::<0>(res.as_mut_ptr().cast(), 64);
1299 _tile_release();
1300
1301 let expected = array::from_fn(|i| array::from_fn(|j| 16.0 * i as f32 * j as f32));
1302 assert_eq!(res, expected);
1303 }
1304 }
1305}