From b1c973c953f1cdd733cd96211abc1488bde7a22d Mon Sep 17 00:00:00 2001 From: jiejanezhang Date: Wed, 30 Jul 2025 17:17:15 +0800 Subject: [PATCH 1/4] Support mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 Support mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 12 +++- clang/runtime/dpct-rt/include/dpct/math.hpp | 66 +++++++++++++++++++++ clang/test/dpct/asm/mma.cu | 18 ++++++ 3 files changed, 94 insertions(+), 2 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index a1e01b6ed238..1b43adc532f1 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1560,13 +1560,21 @@ class SYCLGen : public SYCLGenBase { InMatrixType[0] = "uint32_t"; // A type is .f16/.bf16x2 InMatrixType[1] = "uint32_t"; // B type is .f16/.bf16x2 - // If A matrix type is f16, then C&D matrix types can only be f32 + // If A matrix type is f16, then C&D matrix types can be f32 if (CType->getKind() == InlineAsmBuiltinType::f32) { NumVecElements[0] = 4; // A NumVecElements[1] = 2; // B NumVecElements[2] = 4; // C NumVecElements[3] = 4; // D - } else + } + // C &D matrix types can be f16. + else if (CType->getKind() == InlineAsmBuiltinType::f16) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 2; // C + NumVecElements[3] = 2; // D + } + else return SYCLGenError(); } else if (AType->getKind() == InlineAsmBuiltinType::s8) { InMatrixType[0] = "uint32_t"; // A type is .s8x4 diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index fbb56f0f7cf5..ca2fb1942be1 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2671,6 +2671,72 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, static_cast(ra[j + 4]) * static_cast(rb[j + 4]); } } + } else if constexpr (std::is_same_v) { + // Init D matrix fragment with C matrix fragment + *const_cast(d[0]) = c[0]; + *const_cast(d[1]) = c[1]; + *const_cast(d[2]) = c[2]; + *const_cast(d[3]) = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment + // of matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[4], recv_b[4]; + + // Load partial fragment from row0 of matrix A ({a0, a1}) + recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row0 of matrix A ({a2, a3}) + recv_a[1] = dpct::select_from_sub_group(sg, a[2], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a0, a1}) + recv_a[2] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a2, a3}) + recv_a[3] = dpct::select_from_sub_group(sg, a[3], row_load_offset + i); + + // Load partial fragment from col0 of matrix B ({b0, b1}) + recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col0 of matrix B ({b2, b3}) + recv_b[1] = dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1}) + recv_b[2] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i); + // Load partial fragment from col1 of matrix B ({b2, b3}) + recv_b[3] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix + // fragments and adds it to the corresponding D matrix fragment d0 + // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ + // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, + // a2, a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } * + // col1{ b0, b1, b2, b3 } + // for (int j = 0; j < 4; j++) { + // *d[0] += + // static_cast(ra[j]) * static_cast(rb[j]); + // *d[1] += static_cast(ra[j]) * + // static_cast(rb[j + 4]); + // *d[2] += static_cast(ra[j + 4]) * + // static_cast(rb[j]); + // *d[3] += static_cast(ra[j + 4]) * + // static_cast(rb[j + 4]); + + for (int j = 0; j < 4; j++) { + *const_cast(d[0]) += ra[j] * rb[j]; + *const_cast(d[1]) += ra[j] * rb[j + 4]; + *const_cast(d[2]) += ra[j + 4] * rb[j]; + *const_cast(d[3]) += ra[j + 4] * rb[j + 4]; + } + } } else if constexpr (std::is_integral_v) { // Init D matrix with fragments of C matrix *d[0] = c[0]; diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 2a38e63f9869..ce9ca25fd555 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -100,6 +100,24 @@ __global__ void mma_kernel_m16n8k8(int *a, int *b, float *fc, float *fd) { "f"(fc[0]), "f"(fc[1]), "f"(fc[2]), "f"(fc[3])); } +__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) { + // CHECK: { + // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1]}; + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, sycl::half>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); + // CHECK-NEXT: } + asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + " { %0, %1 }, " + " { %2, %3, %4, %5 }, " + " { %6, %7 }, " + " { %0, %1 };" + : "+r"(c[0]), "+r"(c[1]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(d[0]), "r"(d[1])); +} + __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK: { // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; From 5d749a53ddbe0a8854c1748a547b4004dd37e81b Mon Sep 17 00:00:00 2001 From: jiejanezhang Date: Thu, 31 Jul 2025 11:02:27 +0800 Subject: [PATCH 2/4] more enhancement more enhancement --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 5 +++-- clang/runtime/dpct-rt/include/dpct/math.hpp | 20 ++++++++++++-------- clang/test/dpct/asm/mma.cu | 3 ++- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 1b43adc532f1..eca11101559a 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1488,6 +1488,8 @@ class SYCLGen : public SYCLGenBase { // Data types of A, B & C matrices respectively in the PTX arguments std::string InMatrixType[3]; + InMatrixType[2] = CDType; + if (Inst->hasAttr(InstAttr::m8n8k4)) { M = "8"; N = "8"; @@ -1573,6 +1575,7 @@ class SYCLGen : public SYCLGenBase { NumVecElements[1] = 2; // B NumVecElements[2] = 2; // C NumVecElements[3] = 2; // D + InMatrixType[2] = "uint32_t"; // C type is f16*2 } else return SYCLGenError(); @@ -1613,8 +1616,6 @@ class SYCLGen : public SYCLGenBase { } else return SYCLGenError(); - InMatrixType[2] = CDType; - // Check the register sizes for vector elements of A, B, C & D matrices for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); InputOp++) { diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index ca2fb1942be1..2d21add9b2fb 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2673,10 +2673,14 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, } } else if constexpr (std::is_same_v) { // Init D matrix fragment with C matrix fragment - *const_cast(d[0]) = c[0]; - *const_cast(d[1]) = c[1]; - *const_cast(d[2]) = c[2]; - *const_cast(d[3]) = c[3]; + sycl::half *d0 = const_cast(d[0]); + sycl::half *d1 = d0 + 1; + sycl::half *d2 = const_cast(d[1]); + sycl::half *d3 = d2 + 1; + *d0 = c[0]; + *d1 = c[1]; + *d2 = c[2]; + *d3 = c[3]; // Each sub-group is responsible for computing a fragment size of 16*8 // elements of matrix D. @@ -2731,10 +2735,10 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, // static_cast(rb[j + 4]); for (int j = 0; j < 4; j++) { - *const_cast(d[0]) += ra[j] * rb[j]; - *const_cast(d[1]) += ra[j] * rb[j + 4]; - *const_cast(d[2]) += ra[j + 4] * rb[j]; - *const_cast(d[3]) += ra[j + 4] * rb[j + 4]; + *d0 += ra[j] * rb[j]; + *d1 += ra[j] * rb[j + 4]; + *d2 += ra[j + 4] * rb[j]; + *d3 += ra[j + 4] * rb[j + 4]; } } } else if constexpr (std::is_integral_v) { diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index ce9ca25fd555..8d01fe93f177 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -19,7 +19,8 @@ m8n8k16 .s8 .s8 .s32 m16n8k8 .f16/.bf16 .f16/.bf16 .f32 m16n8k16 .f16 .f16 .f32 .bf16 .bf16 .f32 - .s8 .s8 .s32 + .s8 .s8 .s32 + .f16 .f16 .f16 m16n8k32 .s8 .s8 .s32 Except for m8n8k4, all other shapes are supported for row/col layout of A/B matrices respectively. From 3714ae71f6aaaad142fda40619a2d830ea340128 Mon Sep 17 00:00:00 2001 From: jiejanezhang Date: Mon, 4 Aug 2025 10:15:49 +0800 Subject: [PATCH 3/4] Minor fixing Minor fixings. --- clang/runtime/dpct-rt/include/dpct/math.hpp | 2 +- clang/test/dpct/asm/mma.cu | 9 +++++---- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 2d21add9b2fb..2079b395eb1a 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2394,7 +2394,7 @@ template struct MMAType { /// - m8n8k4 (f32.f16.f16.f32) /// - m8n8k16 (s32.s8.s8.s32) /// - m16n8k8 (f32.f16.f16.f32 & f32.bf16.bf16.f32) -/// - m16n8k16 (f32.f16.f16.f32 & f32.bf16.bf16.f32 & s32.s8.s8.s32) +/// - m16n8k16 (f32.f16.f16.f32 & f16.f16.f16.f16 & f32.bf16.bf16.f32 & s32.s8.s8.s32) /// - m16n8k32 (s32.s8.s8.s32) /// Here, m, n & k define the shapes of A, B & C matrices respectively /// (A = [m x k], B = [k x n], C = [m x n]). diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 8d01fe93f177..54a7701bcebc 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -103,7 +103,7 @@ __global__ void mma_kernel_m16n8k8(int *a, int *b, float *fc, float *fd) { __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) { // CHECK: { - // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &fc[0], &fc[1]}; + // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &d[0], &d[1]}; // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1]); @@ -113,10 +113,11 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) { " { %0, %1 }, " " { %2, %3, %4, %5 }, " " { %6, %7 }, " - " { %0, %1 };" - : "+r"(c[0]), "+r"(c[1]) + " { %8, %9 };" + : "+r"(d[0]), "+r"(d[1]) : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), - "r"(d[0]), "r"(d[1])); + "r"(b[0]), "r"(b[1]), + "r"(c[0]), "r"(c[1])); } __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { From 008daa8e255dede54fc755adcd23d8a722e8fb56 Mon Sep 17 00:00:00 2001 From: jiejanezhang Date: Mon, 4 Aug 2025 11:45:58 +0800 Subject: [PATCH 4/4] Fix LIT test and refine the comments Fix LIT test and refine the comments --- clang/runtime/dpct-rt/include/dpct/math.hpp | 22 ++++++--------------- clang/test/dpct/asm/mma.cu | 4 ++-- 2 files changed, 8 insertions(+), 18 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 2079b395eb1a..a89f10ebc76b 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2718,22 +2718,12 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, auto ra = reinterpret_cast(recv_a); auto rb = reinterpret_cast(recv_b); - // Each work item calculates a partial product of A & B matrix - // fragments and adds it to the corresponding D matrix fragment d0 - // += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } d1 += row0{ - // a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } d2 += row1{ a0, a1, - // a2, a3 } * col0{ b0, b1, b2, b3 } d3 += row1{ a0, a1, a2, a3 } * - // col1{ b0, b1, b2, b3 } - // for (int j = 0; j < 4; j++) { - // *d[0] += - // static_cast(ra[j]) * static_cast(rb[j]); - // *d[1] += static_cast(ra[j]) * - // static_cast(rb[j + 4]); - // *d[2] += static_cast(ra[j + 4]) * - // static_cast(rb[j]); - // *d[3] += static_cast(ra[j + 4]) * - // static_cast(rb[j + 4]); - + // Each work item calculates a partial product of A & B matrix fragments + // and adds it to the corresponding D matrix fragment + // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + // d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } for (int j = 0; j < 4; j++) { *d0 += ra[j] * rb[j]; *d1 += ra[j] * rb[j + 4]; diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 54a7701bcebc..88a7e835e1bf 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -103,10 +103,10 @@ __global__ void mma_kernel_m16n8k8(int *a, int *b, float *fc, float *fd) { __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, int *d) { // CHECK: { - // CHECK-NEXT: volatile void *d_mat_frag_ct1[4] = { &d[0], &d[1]}; + // CHECK-NEXT: volatile void *d_mat_frag_ct1[2] = { &d[0], &d[1] }; // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); - // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(c[0], c[1]); // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, sycl::half>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 "