diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 898a96018939..a1e01b6ed238 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1554,10 +1554,11 @@ class SYCLGen : public SYCLGenBase { N = "8"; K = "16"; - // Only f16/s8 types are supported for A and B matrices of m16n8k16 - if (AType->getKind() == InlineAsmBuiltinType::f16) { - InMatrixType[0] = "uint32_t"; // A type is .f16x2 - InMatrixType[1] = "uint32_t"; // B type is .f16x2 + // Only f16/s8/bf16 types are supported for A and B matrices of m16n8k16 + if (AType->getKind() == InlineAsmBuiltinType::f16 || + AType->getKind() == InlineAsmBuiltinType::bf16) { + InMatrixType[0] = "uint32_t"; // A type is .f16/.bf16x2 + InMatrixType[1] = "uint32_t"; // B type is .f16/.bf16x2 // If A matrix type is f16, then C&D matrix types can only be f32 if (CType->getKind() == InlineAsmBuiltinType::f32) { diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index daed5c6f8ffa..fbb56f0f7cf5 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2394,7 +2394,7 @@ template struct MMAType { /// - m8n8k4 (f32.f16.f16.f32) /// - m8n8k16 (s32.s8.s8.s32) /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32) -/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) +/// - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32) /// - m16n8k32 (s32.s8.s8.s32) /// Here, m, n & k define the shapes of A, B & C matrices respectively /// (A = [m x k], B = [k x n], C = [m x n]). diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 3d20291aa509..2a38e63f9869 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -17,7 +17,8 @@ As per PTX ASM 8.1, below is the status of supported configurations m8n8k4 .f16 .f16 .f32 m8n8k16 .s8 .s8 .s32 m16n8k8 .f16/.bf16 .f16/.bf16 .f32 -m16n8k16 .f16 .f16 .f32 +m16n8k16 .f16 .f16 .f32 + .bf16 .bf16 .f32 .s8 .s8 .s32 m16n8k32 .s8 .s8 .s32 @@ -116,6 +117,22 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); + // CHECK: { + // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1], fc[2], fc[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::ext::oneapi::bfloat16, float>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); + // CHECK-NEXT: } + asm("mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %0, %1, %2, %3 };" + : "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1])); + // CHECK: { // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1]);