Skip to content

Commit 0d61ee3

Browse files
committed
Add AMX-AVX512 BF16 intrinsics
1 parent a989a69 commit 0d61ee3

1 file changed

Lines changed: 178 additions & 0 deletions

File tree

  • crates/core_arch/src/x86_64

crates/core_arch/src/x86_64/amx.rs

Lines changed: 178 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -480,6 +480,72 @@ pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h
480480
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
481481
}
482482

483+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
484+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
485+
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
486+
#[inline]
487+
#[rustc_legacy_const_generics(0)]
488+
#[target_feature(enable = "amx-avx512,avx10.2")]
489+
#[cfg_attr(
490+
all(test, any(target_os = "linux", target_env = "msvc")),
491+
assert_instr(tcvtrowps2bf16h, TILE = 0)
492+
)]
493+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
494+
pub unsafe fn _tile_cvtrowps2bf16h<const TILE: i32>(row: u32) -> __m512bh {
495+
static_assert_uimm_bits!(TILE, 3);
496+
tcvtrowps2bf16h(TILE as i8, row).as_m512bh()
497+
}
498+
499+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
500+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
501+
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
502+
#[inline]
503+
#[rustc_legacy_const_generics(0, 1)]
504+
#[target_feature(enable = "amx-avx512,avx10.2")]
505+
#[cfg_attr(
506+
all(test, any(target_os = "linux", target_env = "msvc")),
507+
assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0)
508+
)]
509+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
510+
pub unsafe fn _tile_cvtrowps2bf16hi<const TILE: i32, const ROW: i32>() -> __m512bh {
511+
static_assert_uimm_bits!(TILE, 3);
512+
static_assert_uimm_bits!(ROW, 6);
513+
tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh()
514+
}
515+
516+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
517+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
518+
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
519+
#[inline]
520+
#[rustc_legacy_const_generics(0)]
521+
#[target_feature(enable = "amx-avx512,avx10.2")]
522+
#[cfg_attr(
523+
all(test, any(target_os = "linux", target_env = "msvc")),
524+
assert_instr(tcvtrowps2bf16l, TILE = 0)
525+
)]
526+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
527+
pub unsafe fn _tile_cvtrowps2bf16l<const TILE: i32>(row: u32) -> __m512bh {
528+
static_assert_uimm_bits!(TILE, 3);
529+
tcvtrowps2bf16l(TILE as i8, row).as_m512bh()
530+
}
531+
532+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
533+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
534+
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
535+
#[inline]
536+
#[rustc_legacy_const_generics(0, 1)]
537+
#[target_feature(enable = "amx-avx512,avx10.2")]
538+
#[cfg_attr(
539+
all(test, any(target_os = "linux", target_env = "msvc")),
540+
assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0)
541+
)]
542+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
543+
pub unsafe fn _tile_cvtrowps2bf16li<const TILE: i32, const ROW: i32>() -> __m512bh {
544+
static_assert_uimm_bits!(TILE, 3);
545+
static_assert_uimm_bits!(ROW, 6);
546+
tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh()
547+
}
548+
483549
/// Moves one row of tile data into a zmm vector register
484550
#[inline]
485551
#[rustc_legacy_const_generics(0)]
@@ -567,6 +633,14 @@ unsafe extern "C" {
567633
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
568634
#[link_name = "llvm.x86.tcvtrowps2phli"]
569635
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
636+
#[link_name = "llvm.x86.tcvtrowps2bf16h"]
637+
fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32;
638+
#[link_name = "llvm.x86.tcvtrowps2bf16hi"]
639+
fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32;
640+
#[link_name = "llvm.x86.tcvtrowps2bf16l"]
641+
fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32;
642+
#[link_name = "llvm.x86.tcvtrowps2bf16li"]
643+
fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32;
570644
#[link_name = "llvm.x86.tilemovrow"]
571645
fn tilemovrow(tile: i8, row: u32) -> i32x16;
572646
#[link_name = "llvm.x86.tilemovrowi"]
@@ -1276,6 +1350,110 @@ mod tests {
12761350
}
12771351
}
12781352

1353+
#[simd_test(enable = "amx-avx512,avx10.2")]
1354+
fn test_tile_cvtrowps2bf16h() {
1355+
unsafe {
1356+
_init_amx();
1357+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1358+
1359+
let mut config = __tilecfg::default();
1360+
config.palette = 1;
1361+
config.colsb[0] = 64;
1362+
config.rows[0] = 16;
1363+
_tile_loadconfig(config.as_ptr());
1364+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1365+
for i in 0..16 {
1366+
let row = _tile_cvtrowps2bf16h::<0>(i);
1367+
assert_eq!(
1368+
*row.as_u16x32().as_array(),
1369+
array::from_fn(|j| if j & 1 == 0 {
1370+
0
1371+
} else {
1372+
_mm_cvtness_sbh(i as _).to_bits()
1373+
})
1374+
);
1375+
}
1376+
}
1377+
}
1378+
1379+
#[simd_test(enable = "amx-avx512,avx10.2")]
1380+
fn test_tile_cvtrowps2bf16hi() {
1381+
unsafe {
1382+
_init_amx();
1383+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1384+
1385+
let mut config = __tilecfg::default();
1386+
config.palette = 1;
1387+
config.colsb[0] = 64;
1388+
config.rows[0] = 16;
1389+
_tile_loadconfig(config.as_ptr());
1390+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1391+
for i in 0..16 {
1392+
let row = wrap_imm4!(_tile_cvtrowps2bf16hi::<0>, i);
1393+
assert_eq!(
1394+
*row.as_u16x32().as_array(),
1395+
array::from_fn(|j| if j & 1 == 0 {
1396+
0
1397+
} else {
1398+
_mm_cvtness_sbh(i as _).to_bits()
1399+
})
1400+
);
1401+
}
1402+
}
1403+
}
1404+
1405+
#[simd_test(enable = "amx-avx512,avx10.2")]
1406+
fn test_tile_cvtrowps2bf16l() {
1407+
unsafe {
1408+
_init_amx();
1409+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1410+
1411+
let mut config = __tilecfg::default();
1412+
config.palette = 1;
1413+
config.colsb[0] = 64;
1414+
config.rows[0] = 16;
1415+
_tile_loadconfig(config.as_ptr());
1416+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1417+
for i in 0..16 {
1418+
let row = _tile_cvtrowps2bf16l::<0>(i);
1419+
assert_eq!(
1420+
*row.as_u16x32().as_array(),
1421+
array::from_fn(|j| if j & 1 == 0 {
1422+
_mm_cvtness_sbh(i as _).to_bits()
1423+
} else {
1424+
0
1425+
})
1426+
);
1427+
}
1428+
}
1429+
}
1430+
1431+
#[simd_test(enable = "amx-avx512,avx10.2")]
1432+
fn test_tile_cvtrowps2bf16li() {
1433+
unsafe {
1434+
_init_amx();
1435+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1436+
1437+
let mut config = __tilecfg::default();
1438+
config.palette = 1;
1439+
config.colsb[0] = 64;
1440+
config.rows[0] = 16;
1441+
_tile_loadconfig(config.as_ptr());
1442+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1443+
for i in 0..16 {
1444+
let row = wrap_imm4!(_tile_cvtrowps2bf16li::<0>, i);
1445+
assert_eq!(
1446+
*row.as_u16x32().as_array(),
1447+
array::from_fn(|j| if j & 1 == 0 {
1448+
_mm_cvtness_sbh(i as _).to_bits()
1449+
} else {
1450+
0
1451+
})
1452+
);
1453+
}
1454+
}
1455+
}
1456+
12791457
#[simd_test(enable = "amx-tf32")]
12801458
fn test_tile_mmultf32ps() {
12811459
unsafe {

0 commit comments

Comments
 (0)