Skip to content

Commit 6c61f6d

Browse files
Added new type logic for A & B matrix elements
1 parent 3adbdbf commit 6c61f6d

3 files changed

Lines changed: 91 additions & 51 deletions

File tree

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1390,7 +1390,7 @@ class SYCLGen : public SYCLGenBase {
13901390
if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col)
13911391
return SYCLGenError();
13921392

1393-
// Only f16 type is supported for A and B matrix data
1393+
// Data types of D, A, B & C matrices respectively in the PTX instruction
13941394
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
13951395
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(1));
13961396
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(2));
@@ -1418,15 +1418,18 @@ class SYCLGen : public SYCLGenBase {
14181418
// Sizes of A & B matrices
14191419
std::string M, N, K;
14201420

1421-
// Data type used to multiply A & B matrices
1422-
std::string MulType;
1421+
// Data types of A, B & C matrices respectively in the PTX arguments
1422+
std::string InMatrixType[3];
1423+
14231424
if (Inst->hasAttr(InstAttr::m16n8k16)) {
14241425
M = "16";
14251426
N = "8";
14261427
K = "16";
1428+
14271429
// Only f16/s8 types are supported for A and B matrices of m16n8k16
14281430
if (AType->getKind() == InlineAsmBuiltinType::f16) {
1429-
MulType = "sycl::half";
1431+
InMatrixType[0] = "int32_t"; // A type is .f16x2
1432+
InMatrixType[1] = "int32_t"; // B type is .f16x2
14301433

14311434
// If A matrix type is f16, then C&D matrix types can only be f32
14321435
if (CType->getKind() == InlineAsmBuiltinType::f32) {
@@ -1437,7 +1440,8 @@ class SYCLGen : public SYCLGenBase {
14371440
} else
14381441
return SYCLGenError();
14391442
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
1440-
MulType = "int8_t";
1443+
InMatrixType[0] = "int32_t"; // A type is .s8x4
1444+
InMatrixType[1] = "int32_t"; // B type is .s8x4
14411445

14421446
// If A matrix type is s8, then C&D matrix types can only be s32
14431447
if (CType->getKind() == InlineAsmBuiltinType::s32) {
@@ -1452,6 +1456,8 @@ class SYCLGen : public SYCLGenBase {
14521456
} else
14531457
return SYCLGenError();
14541458

1459+
InMatrixType[2] = CDType;
1460+
14551461
// Check the register sizes for vector elements of A, B, C & D matrices
14561462
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
14571463
InputOp++) {
@@ -1465,13 +1471,9 @@ class SYCLGen : public SYCLGenBase {
14651471
if (DMatVE->getNumElements() != NumVecElements[3])
14661472
return SYCLGenError();
14671473

1468-
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1469-
OS() << "<";
1470-
OS() << M << ", " << N << ", " << K << ", ";
1471-
OS() << MulType;
1472-
OS() << ">(";
1473-
1474-
// Add D matrix address values to store the MAD result
1474+
// Declare and init an array for storing the addresses of D matrix elements
1475+
OS() << "{\n";
1476+
OS() << CDType << " *DMatrix_ct1[" << DMatVE->getNumElements() << "] = { ";
14751477
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
14761478
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement(Inst)))
14771479
continue;
@@ -1481,25 +1483,44 @@ class SYCLGen : public SYCLGenBase {
14811483
if ((Inst + 1) != DMatVE->getNumElements())
14821484
OS() << ", ";
14831485
}
1486+
OS() << " }";
1487+
endstmt();
14841488

1485-
// Add A, B & C matrix values to compute MAD
1489+
// Declare and init vectors for storing the values of A, B & C matrix elements
1490+
std::string InMatrixName[3] = {"A", "B", "C"};
14861491
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
14871492
InputOp++) {
14881493
if (auto VE =
14891494
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
1495+
OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " << VE->getNumElements() << "> " << InMatrixName[InputOp] << "Matrix_ct1(";
14901496
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
14911497
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
14921498
continue;
1493-
OS() << ", ";
14941499
if (emitStmt(VE->getElement(Inst)))
14951500
return SYCLGenError();
1501+
if ((Inst + 1) != VE->getNumElements())
1502+
OS() << ", ";
14961503
}
1504+
OS() << ")";
1505+
endstmt();
14971506
} else {
14981507
return SYCLGenError();
14991508
}
15001509
}
15011510

1502-
OS() << ");";
1511+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
1512+
OS() << "<";
1513+
OS() << M << ", " << N << ", " << K << ", ";
1514+
OS() << ABType;
1515+
OS() << ">(";
1516+
1517+
OS() << "DMatrix_ct1";
1518+
for (int i = 0; i < 3; i++)
1519+
OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&" << InMatrixName[i] << "Matrix_ct1)";
1520+
OS() << ")";
1521+
endstmt();
1522+
OS() << "}";
1523+
endstmt();
15031524

15041525
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
15051526
if (KernelDecl) {

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2218,6 +2218,22 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) {
22182218
ldmatrix(addr, m4, trans, 3);
22192219
}
22202220

2221+
/// Multiplies 2 matrices (A & B) and adds the result to C matrix and
2222+
/// accumulates the result to a D matrix (MAD). Requires the sub-group size of
2223+
/// kernel calling this function to be 32.
2224+
/// \tparam [in] M The rows of A, C & D matrix
2225+
/// \tparam [in] N The columns of B, C, D matrix
2226+
/// \tparam [in] K The columns & rows of A & B matrices respectively
2227+
/// \tparam [in] MulType The type used to multiply A and B matrix elements as
2228+
/// \tparam [in] ABType The type of the input matrix (A & B) elements
2229+
/// \tparam [in] CDType The type of the output matrix (C & D) elements
2230+
/// \param [in] d The elements of the output D matrix to store the result to
2231+
/// \param [in] a The elements of the input A matrix to be multiplied with B
2232+
/// matrix elements
2233+
/// \param [in] b The elements of the input B matrix to be multiplied with A
2234+
/// matrix elements
2235+
/// \param [in] c The elements of the input C matrix to be added with the result
2236+
/// of A * B
22212237
template <int M, int N, int K, typename MulType, typename ABType,
22222238
typename CDType, typename Op = sycl::bit_and<>>
22232239
void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) {
@@ -2228,12 +2244,8 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) {
22282244
short COL_LOAD_OFFSET = 8 * (lane % 4);
22292245

22302246
if (M == 16 && N == 8 && K == 16) {
2231-
if constexpr (std::is_same_v<CDType, sycl::half>) {
2247+
if constexpr (std::is_floating_point_v<CDType>) {
22322248
// f32.f16.f16.f32
2233-
auto c_h = reinterpret_cast<MulType *>(c);
2234-
2235-
float c_f[4] = {c_h[0], c_h[1], c_h[2], c_h[3]};
2236-
22372249
for (int i = 0; i < 4; i++) {
22382250
ABType recv_a[4], recv_b[4];
22392251

@@ -2245,50 +2257,45 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) {
22452257
recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i);
22462258
recv_b[1] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i);
22472259
recv_b[2] =
2248-
dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4);
2260+
dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + 4 + i);
22492261
recv_b[3] =
2250-
dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i + 4);
2262+
dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + 4 + i);
22512263

22522264
auto ra = reinterpret_cast<MulType *>(recv_a);
22532265
auto rb = reinterpret_cast<MulType *>(recv_b);
22542266

22552267
for (int j = 0; j < 4; j++) {
2256-
c_f[0] += static_cast<float>(ra[j]) * static_cast<float>(rb[j]);
2257-
c_f[1] += static_cast<float>(ra[j]) * static_cast<float>(rb[j + 4]);
2258-
c_f[2] += static_cast<float>(ra[j + 4]) * static_cast<float>(rb[j]);
2259-
c_f[3] +=
2260-
static_cast<float>(ra[j + 4]) * static_cast<float>(rb[j + 4]);
2268+
c[0] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j]);
2269+
c[1] += static_cast<CDType>(ra[j]) * static_cast<CDType>(rb[j + 4]);
2270+
c[2] += static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j]);
2271+
c[3] +=
2272+
static_cast<CDType>(ra[j + 4]) * static_cast<CDType>(rb[j + 4]);
22612273
}
22622274
}
22632275

2264-
c_h[0] = c_f[0];
2265-
c_h[1] = c_f[1];
2266-
c_h[2] = c_f[2];
2267-
c_h[3] = c_f[3];
2268-
22692276
*d[0] = c[0];
22702277
*d[1] = c[1];
2278+
*d[2] = c[2];
2279+
*d[3] = c[3];
22712280
} else if constexpr (std::is_integral_v<MulType>) {
22722281
// s32.s8.s8.s32
2273-
ABType recv_a[4 * 2], recv_b[4 * 2];
2274-
22752282
for (int i = 0; i < 4; i++) {
2276-
recv_a[i] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i);
2277-
recv_a[i + 4] =
2278-
dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i);
2283+
ABType recv_a[2], recv_b[2];
22792284

2280-
recv_b[i] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i);
2281-
recv_b[i + 4] =
2285+
recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i);
2286+
recv_a[1] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i);
2287+
recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i);
2288+
recv_b[1] =
22822289
dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4);
2283-
}
22842290

2285-
MulType *a = reinterpret_cast<MulType *>(recv_a);
2286-
MulType *b = reinterpret_cast<MulType *>(recv_b);
2287-
for (int i = 0; i < 16; i++) {
2288-
c[0] += a[i] * b[i];
2289-
c[1] += a[i] * b[i + 16];
2290-
c[2] += a[i + 16] * b[i];
2291-
c[3] += a[i + 16] * b[i + 16];
2291+
auto ra = reinterpret_cast<MulType *>(recv_a);
2292+
auto rb = reinterpret_cast<MulType *>(recv_b);
2293+
for (int i = 0; i < 4; i++) {
2294+
c[0] += ra[i] * rb[i];
2295+
c[1] += ra[i] * rb[i + 4];
2296+
c[2] += ra[i + 4] * rb[i];
2297+
c[3] += ra[i + 4] * rb[i + 4];
2298+
}
22922299
}
22932300

22942301
*d[0] = c[0];

clang/test/dpct/asm/mma.cu

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,21 @@ As per PTX ASM 8.1, below is the status of supported configurations
1414
--------- --------- ---------- ----------- -------------
1515
| Shape | | A | | B | | C / D | | Supported |
1616
--------- --------- ---------- ----------- -------------
17-
m16n8k16 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32)
18-
.s8/.u8 .s8/.u8 .s32 Yes
17+
m16n8k16 .f16 .f16 .f16/.f32 Yes
18+
.s8 .s8 .s32 Yes
1919
2020
A Layout: row
2121
B Layout: col
2222
*/
2323

2424
__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
25-
// CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&fc[0], &fc[1], &fc[2], &fc[3], a[0], a[1], a[2], a[3], b[0], b[1], fc[0], fc[1], fc[2], fc[3]);
25+
// CHECK: {
26+
// CHECK-NEXT: float *DMatrix_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] };
27+
// CHECK-NEXT: sycl::vec<int32_t, 4> AMatrix_ct1(a[0], a[1], a[2], a[3]);
28+
// CHECK-NEXT: sycl::vec<int32_t, 2> BMatrix_ct1(b[0], b[1]);
29+
// CHECK-NEXT: sycl::vec<float, 4> CMatrix_ct1(fc[0], fc[1], fc[2], fc[3]);
30+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(DMatrix_ct1, reinterpret_cast<int32_t *>(&AMatrix_ct1), reinterpret_cast<int32_t *>(&BMatrix_ct1), reinterpret_cast<float *>(&CMatrix_ct1));
31+
// CHECK-NEXT: }
2632
asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 "
2733
" { %0, %1, %2, %3 }, "
2834
" { %4, %5, %6, %7 }, "
@@ -32,7 +38,13 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) {
3238
: "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]),
3339
"r"(b[0]), "r"(b[1]));
3440

35-
// CHECK: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]);
41+
// CHECK: {
42+
// CHECK-NEXT: int32_t *DMatrix_ct1[4] = { &d[0], &d[1], &d[2], &d[3] };
43+
// CHECK-NEXT: sycl::vec<int32_t, 2> AMatrix_ct1(a[0], a[1]);
44+
// CHECK-NEXT: sycl::vec<int32_t, 1> BMatrix_ct1(b[0]);
45+
// CHECK-NEXT: sycl::vec<int32_t, 4> CMatrix_ct1(c[0], c[1], c[2], c[3]);
46+
// CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(DMatrix_ct1, reinterpret_cast<int32_t *>(&AMatrix_ct1), reinterpret_cast<int32_t *>(&BMatrix_ct1), reinterpret_cast<int32_t *>(&CMatrix_ct1));
47+
// CHECK-NEXT: }
3648
asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 "
3749
" { %0, %1, %2, %3 }, "
3850
" { %4, %5 }, "

0 commit comments

Comments
 (0)