Skip to content

Commit 0414f7c

Browse files
authored
Merge pull request #2089 from sayantn/amx-more
Add AMX-AVX512 BF16 intrinsics
2 parents d9ee32c + 7eca5b6 commit 0414f7c

1 file changed

Lines changed: 193 additions & 15 deletions

File tree

  • crates/core_arch/src/x86_64

crates/core_arch/src/x86_64/amx.rs

Lines changed: 193 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -252,7 +252,7 @@ pub unsafe fn _tile_cmmrlfp16ps<const DST: i32, const A: i32, const B: i32>() {
252252
#[rustc_legacy_const_generics(0, 1, 2)]
253253
#[target_feature(enable = "amx-fp8")]
254254
#[cfg_attr(
255-
all(test, any(target_os = "linux", target_env = "msvc")),
255+
all(test, not(target_vendor = "apple")),
256256
assert_instr(tdpbf8ps, DST = 0, A = 1, B = 2)
257257
)]
258258
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -271,7 +271,7 @@ pub unsafe fn _tile_dpbf8ps<const DST: i32, const A: i32, const B: i32>() {
271271
#[rustc_legacy_const_generics(0, 1, 2)]
272272
#[target_feature(enable = "amx-fp8")]
273273
#[cfg_attr(
274-
all(test, any(target_os = "linux", target_env = "msvc")),
274+
all(test, not(target_vendor = "apple")),
275275
assert_instr(tdpbhf8ps, DST = 0, A = 1, B = 2)
276276
)]
277277
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -290,7 +290,7 @@ pub unsafe fn _tile_dpbhf8ps<const DST: i32, const A: i32, const B: i32>() {
290290
#[rustc_legacy_const_generics(0, 1, 2)]
291291
#[target_feature(enable = "amx-fp8")]
292292
#[cfg_attr(
293-
all(test, any(target_os = "linux", target_env = "msvc")),
293+
all(test, not(target_vendor = "apple")),
294294
assert_instr(tdphbf8ps, DST = 0, A = 1, B = 2)
295295
)]
296296
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -309,7 +309,7 @@ pub unsafe fn _tile_dphbf8ps<const DST: i32, const A: i32, const B: i32>() {
309309
#[rustc_legacy_const_generics(0, 1, 2)]
310310
#[target_feature(enable = "amx-fp8")]
311311
#[cfg_attr(
312-
all(test, any(target_os = "linux", target_env = "msvc")),
312+
all(test, not(target_vendor = "apple")),
313313
assert_instr(tdphf8ps, DST = 0, A = 1, B = 2)
314314
)]
315315
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -329,7 +329,7 @@ pub unsafe fn _tile_dphf8ps<const DST: i32, const A: i32, const B: i32>() {
329329
#[rustc_legacy_const_generics(0)]
330330
#[target_feature(enable = "amx-movrs")]
331331
#[cfg_attr(
332-
all(test, any(target_os = "linux", target_env = "msvc")),
332+
all(test, not(target_vendor = "apple")),
333333
assert_instr(tileloaddrs, DST = 0)
334334
)]
335335
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -349,7 +349,7 @@ pub unsafe fn _tile_loaddrs<const DST: i32>(base: *const u8, stride: usize) {
349349
#[rustc_legacy_const_generics(0)]
350350
#[target_feature(enable = "amx-movrs")]
351351
#[cfg_attr(
352-
all(test, any(target_os = "linux", target_env = "msvc")),
352+
all(test, not(target_vendor = "apple")),
353353
assert_instr(tileloaddrst1, DST = 0)
354354
)]
355355
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -372,7 +372,7 @@ pub unsafe fn _tile_stream_loaddrs<const DST: i32>(base: *const u8, stride: usiz
372372
#[rustc_legacy_const_generics(0, 1, 2)]
373373
#[target_feature(enable = "amx-tf32")]
374374
#[cfg_attr(
375-
all(test, any(target_os = "linux", target_env = "msvc")),
375+
all(test, not(target_vendor = "apple")),
376376
assert_instr(tmmultf32ps, DST = 0, A = 1, B = 2)
377377
)]
378378
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -389,7 +389,7 @@ pub unsafe fn _tile_mmultf32ps<const DST: i32, const A: i32, const B: i32>() {
389389
#[rustc_legacy_const_generics(0)]
390390
#[target_feature(enable = "amx-avx512,avx10.2")]
391391
#[cfg_attr(
392-
all(test, any(target_os = "linux", target_env = "msvc")),
392+
all(test, not(target_vendor = "apple")),
393393
assert_instr(tcvtrowd2ps, TILE = 0)
394394
)]
395395
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -404,7 +404,7 @@ pub unsafe fn _tile_cvtrowd2ps<const TILE: i32>(row: u32) -> __m512 {
404404
#[rustc_legacy_const_generics(0, 1)]
405405
#[target_feature(enable = "amx-avx512,avx10.2")]
406406
#[cfg_attr(
407-
all(test, any(target_os = "linux", target_env = "msvc")),
407+
all(test, not(target_vendor = "apple")),
408408
assert_instr(tcvtrowd2ps, TILE = 0, ROW = 0)
409409
)]
410410
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -421,7 +421,7 @@ pub unsafe fn _tile_cvtrowd2psi<const TILE: i32, const ROW: i32>() -> __m512 {
421421
#[rustc_legacy_const_generics(0)]
422422
#[target_feature(enable = "amx-avx512,avx10.2")]
423423
#[cfg_attr(
424-
all(test, any(target_os = "linux", target_env = "msvc")),
424+
all(test, not(target_vendor = "apple")),
425425
assert_instr(tcvtrowps2phh, TILE = 0)
426426
)]
427427
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -437,7 +437,7 @@ pub unsafe fn _tile_cvtrowps2phh<const TILE: i32>(row: u32) -> __m512h {
437437
#[rustc_legacy_const_generics(0, 1)]
438438
#[target_feature(enable = "amx-avx512,avx10.2")]
439439
#[cfg_attr(
440-
all(test, any(target_os = "linux", target_env = "msvc")),
440+
all(test, not(target_vendor = "apple")),
441441
assert_instr(tcvtrowps2phh, TILE = 0, ROW = 0)
442442
)]
443443
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -454,7 +454,7 @@ pub unsafe fn _tile_cvtrowps2phhi<const TILE: i32, const ROW: i32>() -> __m512h
454454
#[rustc_legacy_const_generics(0)]
455455
#[target_feature(enable = "amx-avx512,avx10.2")]
456456
#[cfg_attr(
457-
all(test, any(target_os = "linux", target_env = "msvc")),
457+
all(test, not(target_vendor = "apple")),
458458
assert_instr(tcvtrowps2phl, TILE = 0)
459459
)]
460460
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -470,7 +470,7 @@ pub unsafe fn _tile_cvtrowps2phl<const TILE: i32>(row: u32) -> __m512h {
470470
#[rustc_legacy_const_generics(0, 1)]
471471
#[target_feature(enable = "amx-avx512,avx10.2")]
472472
#[cfg_attr(
473-
all(test, any(target_os = "linux", target_env = "msvc")),
473+
all(test, not(target_vendor = "apple")),
474474
assert_instr(tcvtrowps2phl, TILE = 0, ROW = 0)
475475
)]
476476
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -480,12 +480,78 @@ pub unsafe fn _tile_cvtrowps2phli<const TILE: i32, const ROW: i32>() -> __m512h
480480
tcvtrowps2phli(TILE as i8, ROW as u32).as_m512h()
481481
}
482482

483+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
484+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
485+
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
486+
#[inline]
487+
#[rustc_legacy_const_generics(0)]
488+
#[target_feature(enable = "amx-avx512,avx10.2")]
489+
#[cfg_attr(
490+
all(test, not(target_vendor = "apple")),
491+
assert_instr(tcvtrowps2bf16h, TILE = 0)
492+
)]
493+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
494+
pub unsafe fn _tile_cvtrowps2bf16h<const TILE: i32>(row: u32) -> __m512bh {
495+
static_assert_uimm_bits!(TILE, 3);
496+
tcvtrowps2bf16h(TILE as i8, row).as_m512bh()
497+
}
498+
499+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
500+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
501+
/// 16-bit elements are placed in the high 16-bits within each 32-bit element of the returned vector.
502+
#[inline]
503+
#[rustc_legacy_const_generics(0, 1)]
504+
#[target_feature(enable = "amx-avx512,avx10.2")]
505+
#[cfg_attr(
506+
all(test, not(target_vendor = "apple")),
507+
assert_instr(tcvtrowps2bf16h, TILE = 0, ROW = 0)
508+
)]
509+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
510+
pub unsafe fn _tile_cvtrowps2bf16hi<const TILE: i32, const ROW: i32>() -> __m512bh {
511+
static_assert_uimm_bits!(TILE, 3);
512+
static_assert_uimm_bits!(ROW, 6);
513+
tcvtrowps2bf16hi(TILE as i8, ROW as u32).as_m512bh()
514+
}
515+
516+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
517+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
518+
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
519+
#[inline]
520+
#[rustc_legacy_const_generics(0)]
521+
#[target_feature(enable = "amx-avx512,avx10.2")]
522+
#[cfg_attr(
523+
all(test, not(target_vendor = "apple")),
524+
assert_instr(tcvtrowps2bf16l, TILE = 0)
525+
)]
526+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
527+
pub unsafe fn _tile_cvtrowps2bf16l<const TILE: i32>(row: u32) -> __m512bh {
528+
static_assert_uimm_bits!(TILE, 3);
529+
tcvtrowps2bf16l(TILE as i8, row).as_m512bh()
530+
}
531+
532+
/// Moves a row from a tile register to a zmm register, converting the packed single-precision (32-bit)
533+
/// floating-point elements to packed BF16 (16-bit) floating-point elements. The resulting
534+
/// 16-bit elements are placed in the low 16-bits within each 32-bit element of the returned vector.
535+
#[inline]
536+
#[rustc_legacy_const_generics(0, 1)]
537+
#[target_feature(enable = "amx-avx512,avx10.2")]
538+
#[cfg_attr(
539+
all(test, not(target_vendor = "apple")),
540+
assert_instr(tcvtrowps2bf16l, TILE = 0, ROW = 0)
541+
)]
542+
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
543+
pub unsafe fn _tile_cvtrowps2bf16li<const TILE: i32, const ROW: i32>() -> __m512bh {
544+
static_assert_uimm_bits!(TILE, 3);
545+
static_assert_uimm_bits!(ROW, 6);
546+
tcvtrowps2bf16li(TILE as i8, ROW as u32).as_m512bh()
547+
}
548+
483549
/// Moves one row of tile data into a zmm vector register
484550
#[inline]
485551
#[rustc_legacy_const_generics(0)]
486552
#[target_feature(enable = "amx-avx512,avx10.2")]
487553
#[cfg_attr(
488-
all(test, any(target_os = "linux", target_env = "msvc")),
554+
all(test, not(target_vendor = "apple")),
489555
assert_instr(tilemovrow, TILE = 0)
490556
)]
491557
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -499,7 +565,7 @@ pub unsafe fn _tile_movrow<const TILE: i32>(row: u32) -> __m512i {
499565
#[rustc_legacy_const_generics(0, 1)]
500566
#[target_feature(enable = "amx-avx512,avx10.2")]
501567
#[cfg_attr(
502-
all(test, any(target_os = "linux", target_env = "msvc")),
568+
all(test, not(target_vendor = "apple")),
503569
assert_instr(tilemovrow, TILE = 0, ROW = 0)
504570
)]
505571
#[unstable(feature = "x86_amx_intrinsics", issue = "126622")]
@@ -567,6 +633,14 @@ unsafe extern "C" {
567633
fn tcvtrowps2phl(tile: i8, row: u32) -> f16x32;
568634
#[link_name = "llvm.x86.tcvtrowps2phli"]
569635
fn tcvtrowps2phli(tile: i8, row: u32) -> f16x32;
636+
#[link_name = "llvm.x86.tcvtrowps2bf16h"]
637+
fn tcvtrowps2bf16h(tile: i8, row: u32) -> u16x32;
638+
#[link_name = "llvm.x86.tcvtrowps2bf16hi"]
639+
fn tcvtrowps2bf16hi(tile: i8, row: u32) -> u16x32;
640+
#[link_name = "llvm.x86.tcvtrowps2bf16l"]
641+
fn tcvtrowps2bf16l(tile: i8, row: u32) -> u16x32;
642+
#[link_name = "llvm.x86.tcvtrowps2bf16li"]
643+
fn tcvtrowps2bf16li(tile: i8, row: u32) -> u16x32;
570644
#[link_name = "llvm.x86.tilemovrow"]
571645
fn tilemovrow(tile: i8, row: u32) -> i32x16;
572646
#[link_name = "llvm.x86.tilemovrowi"]
@@ -1276,6 +1350,110 @@ mod tests {
12761350
}
12771351
}
12781352

1353+
#[simd_test(enable = "amx-avx512,avx10.2")]
1354+
fn test_tile_cvtrowps2bf16h() {
1355+
unsafe {
1356+
_init_amx();
1357+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1358+
1359+
let mut config = __tilecfg::default();
1360+
config.palette = 1;
1361+
config.colsb[0] = 64;
1362+
config.rows[0] = 16;
1363+
_tile_loadconfig(config.as_ptr());
1364+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1365+
for i in 0..16 {
1366+
let row = _tile_cvtrowps2bf16h::<0>(i);
1367+
assert_eq!(
1368+
*row.as_u16x32().as_array(),
1369+
array::from_fn(|j| if j & 1 == 0 {
1370+
0
1371+
} else {
1372+
_mm_cvtness_sbh(i as _).to_bits()
1373+
})
1374+
);
1375+
}
1376+
}
1377+
}
1378+
1379+
#[simd_test(enable = "amx-avx512,avx10.2")]
1380+
fn test_tile_cvtrowps2bf16hi() {
1381+
unsafe {
1382+
_init_amx();
1383+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1384+
1385+
let mut config = __tilecfg::default();
1386+
config.palette = 1;
1387+
config.colsb[0] = 64;
1388+
config.rows[0] = 16;
1389+
_tile_loadconfig(config.as_ptr());
1390+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1391+
for i in 0..16 {
1392+
let row = wrap_imm4!(_tile_cvtrowps2bf16hi::<0>, i);
1393+
assert_eq!(
1394+
*row.as_u16x32().as_array(),
1395+
array::from_fn(|j| if j & 1 == 0 {
1396+
0
1397+
} else {
1398+
_mm_cvtness_sbh(i as _).to_bits()
1399+
})
1400+
);
1401+
}
1402+
}
1403+
}
1404+
1405+
#[simd_test(enable = "amx-avx512,avx10.2")]
1406+
fn test_tile_cvtrowps2bf16l() {
1407+
unsafe {
1408+
_init_amx();
1409+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1410+
1411+
let mut config = __tilecfg::default();
1412+
config.palette = 1;
1413+
config.colsb[0] = 64;
1414+
config.rows[0] = 16;
1415+
_tile_loadconfig(config.as_ptr());
1416+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1417+
for i in 0..16 {
1418+
let row = _tile_cvtrowps2bf16l::<0>(i);
1419+
assert_eq!(
1420+
*row.as_u16x32().as_array(),
1421+
array::from_fn(|j| if j & 1 == 0 {
1422+
_mm_cvtness_sbh(i as _).to_bits()
1423+
} else {
1424+
0
1425+
})
1426+
);
1427+
}
1428+
}
1429+
}
1430+
1431+
#[simd_test(enable = "amx-avx512,avx10.2")]
1432+
fn test_tile_cvtrowps2bf16li() {
1433+
unsafe {
1434+
_init_amx();
1435+
let array: [[f32; 16]; 16] = array::from_fn(|i| [i as _; _]);
1436+
1437+
let mut config = __tilecfg::default();
1438+
config.palette = 1;
1439+
config.colsb[0] = 64;
1440+
config.rows[0] = 16;
1441+
_tile_loadconfig(config.as_ptr());
1442+
_tile_loadd::<0>(array.as_ptr().cast(), 64);
1443+
for i in 0..16 {
1444+
let row = wrap_imm4!(_tile_cvtrowps2bf16li::<0>, i);
1445+
assert_eq!(
1446+
*row.as_u16x32().as_array(),
1447+
array::from_fn(|j| if j & 1 == 0 {
1448+
_mm_cvtness_sbh(i as _).to_bits()
1449+
} else {
1450+
0
1451+
})
1452+
);
1453+
}
1454+
}
1455+
}
1456+
12791457
#[simd_test(enable = "amx-tf32")]
12801458
fn test_tile_mmultf32ps() {
12811459
unsafe {

0 commit comments

Comments
 (0)