From 11eb1d74652c3207c417ebb4369990a865b3d9a3 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 25 Mar 2025 19:53:02 +0800 Subject: [PATCH 1/9] Added support for mma m16n8k16 migration * f32.f16.f16.f32 * s32.s8.s8.s32 --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 151 ++- clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h | 20 +- clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp | 5 +- clang/lib/DPCT/RulesAsm/Parser/AsmParser.h | 2 +- .../DPCT/RulesAsm/Parser/AsmTokenKinds.def | 20 + clang/runtime/dpct-rt/include/dpct/math.hpp | 934 +++++++++++++++++- clang/test/dpct/asm/mma.cu | 57 ++ .../ASM_API_migration_status.csv | 2 +- 8 files changed, 1173 insertions(+), 18 deletions(-) create mode 100644 clang/test/dpct/asm/mma.cu diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 2f940f84d7f9..429192f9030c 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) { bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) { switch (T->getKind()) { // clang-format off + case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::s4: OS() << "int8_t"; break; case InlineAsmBuiltinType::s8: OS() << "int8_t"; break; case InlineAsmBuiltinType::s16: OS() << "int16_t"; break; case InlineAsmBuiltinType::s32: OS() << "int32_t"; break; @@ -559,6 +562,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) { case InlineAsmVectorType::x1: OS() << 1; break; + case InlineAsmVectorType::v1: + OS() << 1; + break; case InlineAsmVectorType::v2: case InlineAsmVectorType::x2: OS() << 2; @@ -1370,6 +1376,142 @@ class SYCLGen : public SYCLGenBase { return SYCLGenSuccess(); } + bool handle_mma(const InlineAsmInstruction *Inst) override { + if (Inst->getNumInputOperands() != 3) + return SYCLGenError(); + + const InlineAsmVectorExpr *DMatVE = + dyn_cast(Inst->getOutputOperand()); + if (!DMatVE) + return SYCLGenError(); + + // Only row Layout is supported for of A matrix and + // only col Layout is supported for of B matrix + if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col) + return SYCLGenError(); + + // Only f16 type is supported for A and B matrix data + const auto *DType = dyn_cast(Inst->getType(0)); + const auto *AType = dyn_cast(Inst->getType(1)); + const auto *BType = dyn_cast(Inst->getType(2)); + const auto *CType = dyn_cast(Inst->getType(3)); + + if (!(AType && BType && CType && DType)) + return SYCLGenError(); + + // Data types of matrix elements for A&B and C&D matrices should be same + if ((AType->getKind() != BType->getKind()) || + (CType->getKind() != DType->getKind())) + return SYCLGenError(); + + // Check the validity of AB & CD types + std::string ABType, CDType; + if (tryEmitType(ABType, AType)) + return SYCLGenError(); + + if (tryEmitType(CDType, CType)) + return SYCLGenError(); + + // Register sizes for vector elements of A, B, C & D matrices + unsigned NumVecElements[4] = {0}; + + // Sizes of A & B matrices + std::string M, N, K; + + // Data type used to multiply A & B matrices + std::string MulType; + if (Inst->hasAttr(InstAttr::m16n8k16)) { + M = "16"; + N = "8"; + K = "16"; + // Only f16/s8 types are supported for A and B matrices of m16n8k16 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + MulType = "sycl::half"; + + // If A matrix type is f16, then C&D matrix types can only be f32 + if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::s8) { + MulType = "int8_t"; + + // If A matrix type is s8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + + // Check the register sizes for vector elements of A, B, C & D matrices + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + if (VE->getNumElements() != NumVecElements[InputOp]) + return SYCLGenError(); + } else + return SYCLGenError(); + } + if (DMatVE->getNumElements() != NumVecElements[3]) + return SYCLGenError(); + + OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; + OS() << "<"; + OS() << M << ", " << N << ", " << K << ", "; + OS() << MulType; + OS() << ">("; + + // Add D matrix address values to store the MAD result + for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { + if (isa(DMatVE->getElement(Inst))) + continue; + OS() << "&"; + if (emitStmt(DMatVE->getElement(Inst))) + return SYCLGenError(); + if ((Inst + 1) != DMatVE->getNumElements()) + OS() << ", "; + } + + // Add A, B & C matrix values to compute MAD + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { + if (isa(VE->getElement(Inst))) + continue; + OS() << ", "; + if (emitStmt(VE->getElement(Inst))) + return SYCLGenError(); + } + } else { + return SYCLGenError(); + } + } + + OS() << ");"; + + const auto *KernelDecl = getImmediateOuterFuncDecl(GAS); + if (KernelDecl) { + auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl); + if (FuncInfo) + FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(), + DpctGlobalInfo::getSubGroup(GAS)); + } + + return SYCLGenSuccess(); + } + bool handle_prefetch(const InlineAsmInstruction *Inst) override { if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1) return SYCLGenError(); @@ -2595,11 +2737,10 @@ class SYCLGen : public SYCLGenBase { Op = std::move(NewOp); } - bool HasHalfOrBfloat16 = - SrcType->getKind() == InlineAsmBuiltinType::f16 || - DesType->getKind() == InlineAsmBuiltinType::f16 || - SrcType->getKind() == InlineAsmBuiltinType::bf16 || - DesType->getKind() == InlineAsmBuiltinType::bf16; + bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 || + DesType->getKind() == InlineAsmBuiltinType::f16 || + SrcType->getKind() == InlineAsmBuiltinType::bf16 || + DesType->getKind() == InlineAsmBuiltinType::bf16; if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) { insertHeader(HeaderType::HT_SYCL_Math); if (SrcNeedBitCast) diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h index 42cce5902f97..640fd536613b 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h @@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType { return ((K == Kind) || ...); } template bool isNot(Ks... K) { return ((K != Kind) && ...); } - bool isBit() const { return isOneOf(b8, b16, b32, b64); } - bool isSigned() const { return isOneOf(s8, s16, s32, s64); } - bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); } + bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); } + bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); } + bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); } bool isInt() const { return isSigned() || isUnsigned(); } bool isFloat() const { return isOneOf(f16, f32, f64); } bool isScalar() const { return isInt() || isFloat(); } @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType { // This class is used for device asm vector types. class InlineAsmVectorType : public InlineAsmType { public: - enum VecKind { v2, v4, v8, x1, x2, x4 }; + enum VecKind { v1, v2, v4, v8, x1, x2, x4 }; private: VecKind Kind; @@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt { /// This represents arrtibutes like: comparsion operator, rounding modifiers, /// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'. - SmallSet Attributes; + SmallVector Attributes; /// This represents types in instruction, e.g. mov.u32. SmallVector Types; @@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt { OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) { StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(), AsmStateSpaces.end()); - Attributes.insert(Attrs.begin(), Attrs.end()); + Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end()); } using attr_range = - llvm::iterator_range::const_iterator>; + llvm::iterator_range::const_iterator>; using type_range = llvm::iterator_range::const_iterator>; using op_range = @@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt { } template bool hasAttr(Ts... Attrs) const { - return (Attributes.contains(Attrs) || ...); + return (llvm::is_contained(Attributes, Attrs) || ...); } const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; } asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); } ArrayRef getTypes() const { return Types; } const InlineAsmType *getType(unsigned I) const { return Types[I]; } + InstAttr getAttr(unsigned I) const { + assert(I < Attributes.size() && "Attributes index out of range"); + return Attributes[I]; + } unsigned getNumTypes() const { return Types.size(); } const InlineAsmExpr *getOutputOperand() const { return OutputOp; } const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; } diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp index 60e49a9c0c56..efc522477206 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp @@ -756,7 +756,10 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef Vec) { } else { // Vector size must be 2, 4, or 8. switch (Vec.size()) { - case 2: + case 1: + Kind = InlineAsmVectorType::v1; + break; + case 2: Kind = InlineAsmVectorType::v2; break; case 4: diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h index 8b9a3f5f01ba..4de5db7bf658 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h @@ -498,7 +498,7 @@ class InlineAsmParser { /// .reg .sreg .const .local .param .shared .tex /// /// vector-specifier: one of - /// .v2 .v4 .v8 + /// .v1 .v2 .v4 .v8 /// /// type-specifier: one of /// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64 diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def index ea401fb0777c..c7ef34d725f3 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def @@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64) SPECIAL_REG(WARP_SZ, "WARP_SZ", s64) // Built-in type names +BUILTIN_TYPE(b1, ".b1") BUILTIN_TYPE(b8, ".b8") BUILTIN_TYPE(b16, ".b16") BUILTIN_TYPE(b32, ".b32") BUILTIN_TYPE(b64, ".b64") +BUILTIN_TYPE(u4, ".u4") BUILTIN_TYPE(u8, ".u8") BUILTIN_TYPE(u16, ".u16") BUILTIN_TYPE(u32, ".u32") BUILTIN_TYPE(u64, ".u64") +BUILTIN_TYPE(s4, ".s4") BUILTIN_TYPE(s8, ".s8") BUILTIN_TYPE(s16, ".s16") BUILTIN_TYPE(s32, ".s32") @@ -270,6 +273,7 @@ BUILTIN_TYPE(s16x2, ".s16x2") BUILTIN_TYPE(u16x2, ".u16x2") // Vector modifiers +MODIFIER(v1, ".v1") MODIFIER(v2, ".v2") MODIFIER(v4, ".v4") MODIFIER(v8, ".v8") @@ -279,8 +283,23 @@ MODIFIER(x1, ".x1") MODIFIER(x2, ".x2") MODIFIER(x4, ".x4") +// Matrix modifiers +MODIFIER(row, ".row") +MODIFIER(col, ".col") + // Matrix shape MODIFIER(m8n8, ".m8n8") +MODIFIER(m8n8k4, ".m8n8k4") +MODIFIER(m8n8k16, ".m8n8k16") +MODIFIER(m8n8k32, ".m8n8k32") +MODIFIER(m8n8k128, ".m8n8k128") +MODIFIER(m16n8k4, ".m16n8k4") +MODIFIER(m16n8k8, ".m16n8k8") +MODIFIER(m16n8k16, ".m16n8k16") +MODIFIER(m16n8k32, ".m16n8k32") +MODIFIER(m16n8k64, ".m16n8k64") +MODIFIER(m16n8k128, ".m16n8k128") +MODIFIER(m16n8k256, ".m16n8k256") STATE_SPACE(reg, ".reg") STATE_SPACE(sreg, ".sreg") @@ -376,6 +395,7 @@ MODIFIER(max, ".max") MODIFIER(op_or, ".or") MODIFIER(op_xor, ".xor") MODIFIER(op_and, ".and") +MODIFIER(op_popc, ".popc") MODIFIER(cas, ".cas") MODIFIER(exch, ".exch") MODIFIER(inc, ".inc") diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 69004f702181..65e55e69c8ba 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -1636,7 +1636,8 @@ inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) { - return detail::extend_vbinary2(a, b, c, detail::average()); + return detail::extend_vbinary2(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 2 @@ -1933,7 +1934,8 @@ inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) { - return detail::extend_vbinary4(a, b, c, detail::average()); + return detail::extend_vbinary4(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 4 @@ -2216,6 +2218,934 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { ldmatrix(addr, m4, trans, 3); } + +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16 +/// matrix (m8n8k4.row.col.f16.f16.f16.f16) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma(volatile CDType *d0, volatile CDType *d1, volatile CDType *d2, + volatile CDType *d3, ABType a0, ABType a1, ABType b0, ABType b1, + CDType c0, CDType c1, CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4); + + if (M == 8 && N == 8 && K == 4) { + ABType recv_a[2], recv_b[4]; + recv_a[0] = a0; + recv_a[1] = a1; + + MulType *ra = reinterpret_cast(recv_a); + MulType *rb = reinterpret_cast(recv_b); + + float c_f[8] = {0.0f}; + + for (int i = 0; i < 4; i++) { + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 + i); + + for (int j = 0; j < 4; j++) { + c_f[i] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[i + 4] += static_cast(ra[j]) * static_cast(rb[j + 4]); + } + } + + auto c_h = reinterpret_cast(&c0); + c_f[0] += static_cast(c_h[0]); + c_f[1] += static_cast(c_h[1]); + c_h[0] = c_f[0]; + c_h[1] = c_f[1]; + + c_h = reinterpret_cast(&c1); + c_f[2] += static_cast(c_h[0]); + c_f[3] += static_cast(c_h[1]); + c_h[0] = c_f[2]; + c_h[1] = c_f[3]; + + c_h = reinterpret_cast(&c2); + c_f[4] += static_cast(c_h[0]); + c_f[5] += static_cast(c_h[1]); + c_h[0] = c_f[4]; + c_h[1] = c_f[5]; + + c_h = reinterpret_cast(&c3); + c_f[6] += static_cast(c_h[0]); + c_f[7] += static_cast(c_h[1]); + c_h[0] = c_f[6]; + c_h[1] = c_f[7]; + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 +/// matrix (m8n8k4.row.col.f32.f32.f32.f32) +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 8, 2, 2, 8 +/// \tparam [in] ItemT The type of the sycl::nd_item index space class +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] d4 The 5th element to be written to the output D matrix +/// \param [in] d5 The 6th element to be written to the output D matrix +/// \param [in] d6 The 7th element to be written to the output D matrix +/// \param [in] d7 The 8th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +/// \param [in] c4 The 5th element from C matrix to be added with d4 +/// \param [in] c5 The 6th element from C matrix to be added with d5 +/// \param [in] c6 The 7th element from C matrix to be added with d6 +/// \param [in] c7 The 8th element from C matrix to be added with d7 +template +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, + CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1, + CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5, + CDType c6, CDType c7) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2) + (lane % 2); + short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2); + + if (M == 8 && N == 8 && K == 4) { + ABType recv_a[2 * 2], recv_b[4 * 2]; + + for (int i = 0; i < 2; i++) { + recv_a[2 * i] = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + 2 * i); + recv_a[2 * i + 1] = + dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + 2 * i); + + recv_b[4 * i] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i); + recv_b[4 * i + 1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i); + recv_b[4 * i + 2] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i + 1); + recv_b[4 * i + 3] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i + 1); + } + + MulType *ra = reinterpret_cast(recv_a); + MulType *rb = reinterpret_cast(recv_b); + for (int i = 0; i < 4; i++) { + c0 += static_cast(ra[i]) * static_cast(rb[i]); + c1 += static_cast(ra[i]) * static_cast(rb[i + 4]); + c2 += static_cast(ra[i + 4]) * static_cast(rb[i]); + c3 += static_cast(ra[i + 4]) * static_cast(rb[i + 4]); + c4 += static_cast(ra[i]) * static_cast(rb[i + 8]); + c5 += static_cast(ra[i]) * static_cast(rb[i + 12]); + c6 += static_cast(ra[i + 4]) * static_cast(rb[i + 8]); + c7 += static_cast(ra[i + 4]) * static_cast(rb[i + 12]); + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; + *d4 = c4; + *d5 = c5; + *d6 = c6; + *d7 = c7; +} + +/// Multiplies 2 8x4 & 4x8 f64 matrices and accumulates the result to a 8x8 b64 +/// matrix (m8n8k4.row.col.f64.f64.f64.f64). +/// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32). +/// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 +/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32). +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 +/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel calling this function to be 32. +/// In: 2, 1, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, + Op op = Op{}) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 8 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + c0 += recv_a * recv_b; + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + c1 += recv_a * recv_b; + } + } else if (M == 8 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + c1 += a[k] * b[k]; + } + } + } else if (M == 8 && N == 8 && K == 32) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + MulType *a = reinterpret_cast(&recv_a); + MulType *b = reinterpret_cast(&recv_b); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c0 += a0 * b0; + c0 += a1 * b1; + } + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + for (int k = 0; k < 4; k++) { + MulType a0 = a[k] >> 4; + MulType a1 = a[k] & 0x0F; + MulType b0 = b[k] >> 4; + MulType b1 = b[k] & 0x0F; + + c1 += a0 * b0; + c1 += a1 * b1; + } + } + } + } else if (M == 8 && N == 8 && K == 128) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a = + dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + ABType recv_b = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + + c0 += sycl::popcount(op(recv_a, recv_b)); + + recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c1 += sycl::popcount(op(recv_a, recv_b)); + } + } + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f16 matrix (m16n8k8.row.col.f16.f16.f16.f16) +/// Requires the sub-group size of kernel +/// calling this function to be 32 +/// In: 2, 2, 1, 2 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma(CDType *d0, CDType *d1, ABType a0, ABType a1, ABType b0, CDType c0, + CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + auto c0_h = reinterpret_cast(&c0); + auto c1_h = reinterpret_cast(&c1); + + float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 2]); + c_f[2] += static_cast(ra[j + 2]) * static_cast(rb[j]); + c_f[3] += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + + c0_h[0] = c_f[0]; + c0_h[1] = c_f[1]; + c1_h[0] = c_f[2]; + c1_h[1] = c_f[3]; + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x16 & 16x8 f16 matrices and accumulates the result to a 16x8 +/// f16 matrix (m16n8k16.row.col.f16.f16.f16.f16). +/// Requires the sub-group size of kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 2, 4, 2, 2 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +template +void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + auto c0_h = reinterpret_cast(&c0); + auto c1_h = reinterpret_cast(&c1); + + float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; + + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 4; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + c_f[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + c_f[3] += static_cast(ra[j + 4]) * static_cast(rb[j + 4]); + } + } + + c0_h[0] = c_f[0]; + c0_h[1] = c_f[1]; + c1_h[0] = c_f[2]; + c1_h[1] = c_f[3]; + } + + *d0 = c0; + *d1 = c1; +} + +/// Multiplies 2 16x8 & 8x8 u4/s4 matrices and accumulates the result to a 16x8 +/// f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 +/// matrix (m16n8k16.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x32 & 32x8 u8/s8 matrices and accumulates the result to a +/// 16x8 b32 matrix (m16n8k32.row.col.s32.u8.u8.s32 / +/// m16n8k32.row.col.s32.s8.s8.s32). +/// Multiplies 2 16x64 & 64x8 u4/s4 matrices and +/// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 +/// / m16n8k64.row.col.s32.s4.s4.s32). +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel calling this function to be 32. +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 4, 2, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, + CDType c2, CDType c3, Op op = Op{}) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } else if (M == 16 && N == 8 && K == 16) { + for (int i = 0; i < 4; i++) { + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 4 + i); + recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 4 + i); + + auto *ra0 = reinterpret_cast(recv_a); + auto *ra1 = reinterpret_cast(recv_a + 2); + auto *rb0 = reinterpret_cast(recv_b); + auto *rb1 = reinterpret_cast(recv_b + 2); + + // Iterate for k (i * j) times + for (int j = 0; j < 4; j++) { + auto a0 = static_cast(ra0[j]); + auto a1 = static_cast(ra1[j]); + auto b0 = static_cast(rb0[j]); + auto b1 = static_cast(rb1[j]); + + c0 += a0 * b0; + c1 += a0 * b1; + c2 += a1 * b0; + c3 += a1 * b1; + } + } + } else if (M == 16 && N == 8 && K == 32) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + c1 += a[k] * b[k + 4]; + c2 += a[k + 4] * b[k]; + c3 += a[k + 4] * b[k + 4]; + } + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + c0 += a[k] * b[k]; + c1 += a[k] * b[k + 4]; + c2 += a[k + 4] * b[k]; + c3 += a[k + 4] * b[k + 4]; + } + } + } else if (M == 16 && N == 8 && K == 64) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + } + } else if (M == 16 && N == 8 && K == 256) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x16 & 16x8 f64 matrices and accumulates the result to a 16x8 +/// f64 matrix (m16n8k16.row.col.f64.f64.f64.f64) Requires the sub-group size of +/// kernel calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 8, 4, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix +/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix +/// \param [in] a4 The 5th element from A matrix to be multiplied with B matrix +/// \param [in] a5 The 6th element from A matrix to be multiplied with B matrix +/// \param [in] a6 The 7th element from A matrix to be multiplied with B matrix +/// \param [in] a7 The 8th element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix +/// \param [in] b2 The 3rd element from B matrix to be multiplied with A matrix +/// \param [in] b3 The 4th element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType a2, ABType a3, ABType a4, ABType a5, ABType a6, ABType a7, + ABType b0, ABType b1, ABType b2, ABType b3, CDType c0, CDType c1, + CDType c2, CDType c3) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 16) { + ABType recv_a[16 * 2], recv_b[16 * 2]; + + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[i + 4] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); + recv_a[i + 8] = dpct::select_from_sub_group(sg, a4, ROW_LOAD_OFFSET + i); + recv_a[i + 12] = dpct::select_from_sub_group(sg, a6, ROW_LOAD_OFFSET + i); + recv_a[i + 16] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_a[i + 20] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); + recv_a[i + 24] = dpct::select_from_sub_group(sg, a5, ROW_LOAD_OFFSET + i); + recv_a[i + 28] = dpct::select_from_sub_group(sg, a7, ROW_LOAD_OFFSET + i); + + recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[i + 4] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); + recv_b[i + 8] = dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i); + recv_b[i + 12] = dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i); + recv_b[i + 16] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + recv_b[i + 20] = + dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); + recv_b[i + 24] = + dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i + 4); + recv_b[i + 28] = + dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i + 4); + } + + for (int i = 0; i < 16; i++) { + c0 += recv_a[i] * recv_b[i]; + c1 += recv_a[i] * recv_b[i + 16]; + c2 += recv_a[i + 16] * recv_b[i]; + c3 += recv_a[i + 16] * recv_b[i + 16]; + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + +/// Multiplies 2 16x4 & 4x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k4.row.col.f16.f16.f16.f16 / +/// m16n8k4.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x4 & 4x8 f64 matrices and accumulates the result to a +/// 16x8 f64 matrix (m16n8k4.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a +/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32). +/// Multiplies 2 16x8 & 8x8 f64 matrices and accumulates the result to a +/// 16x8 f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). +/// Multiplies 2 16x16 & 16x8 u8/s8 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k16.row.col.s32.u8.u8.s32 / +/// m16n8k16.row.col.s32.s8.s8.s32). +/// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a +/// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / +/// m16n8k32.row.col.s32.s4.s4.s32). +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc). +/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 +/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc). +/// Requires the sub-group size of kernel. +/// calling this function to be 32 +/// \tparam [in] M The rows of A/C/D matrix +/// \tparam [in] N The columns of B/C/D matrix +/// \tparam [in] K The columns/rows of A/B matrix +/// \tparam [in] MulType The type of the multiplication result +/// \tparam [in] ABType The type of the input matrices +/// \tparam [in] CDType The type of the output matrix +/// In: 4, 2, 1, 4 +/// \param [in] d0 The 1st element to be written to the output D matrix +/// \param [in] d1 The 2nd element to be written to the output D matrix +/// \param [in] d2 The 3rd element to be written to the output D matrix +/// \param [in] d3 The 4th element to be written to the output D matrix +/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix +/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix +/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix +/// \param [in] c0 The 1st element from C matrix to be added with d0 +/// \param [in] c1 The 2nd element from C matrix to be added with d1 +/// \param [in] c2 The 3rd element from C matrix to be added with d2 +/// \param [in] c3 The 4th element from C matrix to be added with d3 +template , + typename ABType, typename CDType> +void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, + ABType b0, CDType c0, CDType c1, CDType c2, CDType c3, Op op = Op{}) { + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + short ROW_LOAD_OFFSET = 4 * (lane >> 2); + short COL_LOAD_OFFSET = 8 * (lane % 4); + + if (M == 16 && N == 8 && K == 4) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += recv_a[0] * recv_b[0]; + c1 += recv_a[0] * recv_b[1]; + c2 += recv_a[1] * recv_b[0]; + c3 += recv_a[1] * recv_b[1]; + } + } else if (M == 16 && N == 8 && K == 8) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 2; j++) { + c0 += static_cast(ra[j]) * static_cast(rb[j]); + c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); + c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); + c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); + } + } + } else if (M == 16 && N == 8 && K == 16) { + ABType recv_a[4 * 2], recv_b[4 * 2]; + + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[i + 4] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + + recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[i + 4] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + } + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + for (int i = 0; i < 16; i++) { + c0 += a[i] * b[i]; + c1 += a[i] * b[i + 16]; + c2 += a[i + 16] * b[i]; + c3 += a[i + 16] * b[i + 16]; + } + } else if (M == 16 && N == 8 && K == 32) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + MulType *a = reinterpret_cast(recv_a); + MulType *b = reinterpret_cast(recv_b); + + for (int k = 0; k < 4; k++) { + MulType a00 = a[k] >> 4; + MulType a01 = a[k] & 0x0F; + MulType a10 = a[k + 4] >> 4; + MulType a11 = a[k + 4] & 0x0F; + MulType b00 = b[k] >> 4; + MulType b01 = b[k] & 0x0F; + MulType b10 = b[k + 4] >> 4; + MulType b11 = b[k + 4] & 0x0F; + + c0 += a00 * b00; + c0 += a01 * b01; + + c1 += a00 * b10; + c1 += a01 * b11; + + c2 += a10 * b00; + c2 += a11 * b01; + + c3 += a10 * b10; + c3 += a11 * b11; + } + } + } + } else if (M == 16 && N == 8 && K == 128) { + if constexpr (std::is_integral_v) { + for (int i = 0; i < 4; i++) { + ABType recv_a[2], recv_b[2]; + + recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); + recv_b[1] = + dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + + c0 += sycl::popcount(op(recv_a[0], recv_b[0])); + c1 += sycl::popcount(op(recv_a[0], recv_b[1])); + c2 += sycl::popcount(op(recv_a[1], recv_b[0])); + c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + } + } + } + + *d0 = c0; + *d1 = c1; + *d2 = c2; + *d3 = c3; +} + } // namespace matrix } // namespace experimental diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu new file mode 100644 index 000000000000..98508d39ae5f --- /dev/null +++ b/clang/test/dpct/asm/mma.cu @@ -0,0 +1,57 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --format-range=none -out-root %T/mma %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck %s --match-full-lines --input-file %T/mma/mma.dp.cpp +// RUN: %if build_lit %{icpx -c -DNO_BUILD_TEST -fsycl %T/mma/mma.dp.cpp -o %T/mma/mma.dp.o %} + +// clang-format off +#include +#include + +/* +As per PTX ASM 8.1, below is the status of supported configurations + +--------- --------- ---------- ----------- ------------- +| Shape | | A | | B | | C / D | | Supported | +--------- --------- ---------- ----------- ------------- +m16n8k16 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32) + .s8/.u8 .s8/.u8 .s32 Yes + +A Layout: row +B Layout: col +*/ + +__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&fc[0], &fc[1], &fc[2], &fc[3], a[0], a[1], a[2], a[3], b[0], b[1], fc[0], fc[1], fc[2], fc[3]); + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %0, %1, %2, %3 };" + : "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1])); + + // CHECK: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + + +int main () { + int *int_a, *int_b, *int_c, *int_d; + float *float_c; + + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k16<<<1, 32>>>(int_a, int_b, int_c, float_c, int_d); + + return 0; +} +// clang-format on diff --git a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv index 2f45259b90e7..1b68950a15ae 100644 --- a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv +++ b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv @@ -54,7 +54,7 @@ max,YES, mbarrier,NO, membar,YES, Partial min,YES, -mma,NO, +mma,YES, Partial mov,YES, movmatrix,NO, mul,YES, Partial From d32895e51f5b1419c64b7f54631e4cf69e6b8e3a Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 6 May 2025 15:23:52 +0800 Subject: [PATCH 2/9] Created helper function for m16n8k16 --- clang/runtime/dpct-rt/include/dpct/math.hpp | 957 ++------------------ 1 file changed, 55 insertions(+), 902 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 65e55e69c8ba..c0c692219749 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2218,932 +2218,85 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { ldmatrix(addr, m4, trans, 3); } - -/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b16 -/// matrix (m8n8k4.row.col.f16.f16.f16.f16) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 2, 2, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -void mma(volatile CDType *d0, volatile CDType *d1, volatile CDType *d2, - volatile CDType *d3, ABType a0, ABType a1, ABType b0, ABType b1, - CDType c0, CDType c1, CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4); - - if (M == 8 && N == 8 && K == 4) { - ABType recv_a[2], recv_b[4]; - recv_a[0] = a0; - recv_a[1] = a1; - - MulType *ra = reinterpret_cast(recv_a); - MulType *rb = reinterpret_cast(recv_b); - - float c_f[8] = {0.0f}; - - for (int i = 0; i < 4; i++) { - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 + i); - recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 + i); - - for (int j = 0; j < 4; j++) { - c_f[i] += static_cast(ra[j]) * static_cast(rb[j]); - c_f[i + 4] += static_cast(ra[j]) * static_cast(rb[j + 4]); - } - } - - auto c_h = reinterpret_cast(&c0); - c_f[0] += static_cast(c_h[0]); - c_f[1] += static_cast(c_h[1]); - c_h[0] = c_f[0]; - c_h[1] = c_f[1]; - - c_h = reinterpret_cast(&c1); - c_f[2] += static_cast(c_h[0]); - c_f[3] += static_cast(c_h[1]); - c_h[0] = c_f[2]; - c_h[1] = c_f[3]; - - c_h = reinterpret_cast(&c2); - c_f[4] += static_cast(c_h[0]); - c_f[5] += static_cast(c_h[1]); - c_h[0] = c_f[4]; - c_h[1] = c_f[5]; - - c_h = reinterpret_cast(&c3); - c_f[6] += static_cast(c_h[0]); - c_f[7] += static_cast(c_h[1]); - c_h[0] = c_f[6]; - c_h[1] = c_f[7]; - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 8x4 & 4x8 matrices and accumulates the result to a 8x8 b32 -/// matrix (m8n8k4.row.col.f32.f32.f32.f32) -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 8, 2, 2, 8 -/// \tparam [in] ItemT The type of the sycl::nd_item index space class -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] d4 The 5th element to be written to the output D matrix -/// \param [in] d5 The 6th element to be written to the output D matrix -/// \param [in] d6 The 7th element to be written to the output D matrix -/// \param [in] d7 The 8th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -/// \param [in] c4 The 5th element from C matrix to be added with d4 -/// \param [in] c5 The 6th element from C matrix to be added with d5 -/// \param [in] c6 The 7th element from C matrix to be added with d6 -/// \param [in] c7 The 8th element from C matrix to be added with d7 template -void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, CDType *d4, CDType *d5, - CDType *d6, CDType *d7, ABType a0, ABType a1, ABType b0, ABType b1, - CDType c0, CDType c1, CDType c2, CDType c3, CDType c4, CDType c5, - CDType c6, CDType c7) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2) + (lane % 2); - short COL_LOAD_OFFSET = 4 * ((lane % 16) / 4) + 2 * ((lane / 2) % 2); - - if (M == 8 && N == 8 && K == 4) { - ABType recv_a[2 * 2], recv_b[4 * 2]; - - for (int i = 0; i < 2; i++) { - recv_a[2 * i] = - dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + 2 * i); - recv_a[2 * i + 1] = - dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + 2 * i); - - recv_b[4 * i] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i); - recv_b[4 * i + 1] = - dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i); - recv_b[4 * i + 2] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 16 * i + 1); - recv_b[4 * i + 3] = - dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 16 * i + 1); - } - - MulType *ra = reinterpret_cast(recv_a); - MulType *rb = reinterpret_cast(recv_b); - for (int i = 0; i < 4; i++) { - c0 += static_cast(ra[i]) * static_cast(rb[i]); - c1 += static_cast(ra[i]) * static_cast(rb[i + 4]); - c2 += static_cast(ra[i + 4]) * static_cast(rb[i]); - c3 += static_cast(ra[i + 4]) * static_cast(rb[i + 4]); - c4 += static_cast(ra[i]) * static_cast(rb[i + 8]); - c5 += static_cast(ra[i]) * static_cast(rb[i + 12]); - c6 += static_cast(ra[i + 4]) * static_cast(rb[i + 8]); - c7 += static_cast(ra[i + 4]) * static_cast(rb[i + 12]); - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; - *d4 = c4; - *d5 = c5; - *d6 = c6; - *d7 = c7; -} - -/// Multiplies 2 8x4 & 4x8 f64 matrices and accumulates the result to a 8x8 b64 -/// matrix (m8n8k4.row.col.f64.f64.f64.f64). -/// Multiplies 2 8x16 & 16x8 u8/s8 matrices and accumulates the result to a 8x8 -/// s32 matrix (m8n8k16.row.col.s32.u8.u8.s32 / m8n8k16.row.col.s32.s8.s8.s32). -/// Multiplies 2 8x32 & 32x8 u4/s4 matrices and accumulates the result to a 8x8 -/// s32 matrix (m8n8k32.row.col.s32.u4.u4.s32 / m8n8k32.row.col.s32.s4.s4.s32). -/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 -/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc). -/// Multiplies 2 8x128 & 128x8 b1 matrices and accumulates the result to a 8x8 -/// s32 matrix (mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc). -/// Requires the sub-group size of kernel calling this function to be 32. -/// In: 2, 1, 1, 2 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -template , - typename ABType, typename CDType> -void mma(CDType *d0, CDType *d1, ABType a0, ABType b0, CDType c0, CDType c1, - Op op = Op{}) { + typename CDType, typename Op = sycl::bit_and<>> +void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); - if (M == 8 && N == 8 && K == 4) { - for (int i = 0; i < 4; i++) { - ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - c0 += recv_a * recv_b; - - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - c1 += recv_a * recv_b; - } - } else if (M == 8 && N == 8 && K == 16) { - for (int i = 0; i < 4; i++) { - ABType recv_a = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - - MulType *a = reinterpret_cast(&recv_a); - MulType *b = reinterpret_cast(&recv_b); - - for (int k = 0; k < 4; k++) { - c0 += a[k] * b[k]; - } + if (M == 16 && N == 8 && K == 16) { + if constexpr (std::is_same_v) { + // f32.f16.f16.f32 + auto c_h = reinterpret_cast(c); - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); + float c_f[4] = {c_h[0], c_h[1], c_h[2], c_h[3]}; - for (int k = 0; k < 4; k++) { - c1 += a[k] * b[k]; - } - } - } else if (M == 8 && N == 8 && K == 32) { - if constexpr (std::is_integral_v) { for (int i = 0; i < 4; i++) { - ABType recv_a = - dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - - MulType *a = reinterpret_cast(&recv_a); - MulType *b = reinterpret_cast(&recv_b); - - for (int k = 0; k < 4; k++) { - MulType a0 = a[k] >> 4; - MulType a1 = a[k] & 0x0F; - MulType b0 = b[k] >> 4; - MulType b1 = b[k] & 0x0F; - - c0 += a0 * b0; - c0 += a1 * b1; - } - - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - for (int k = 0; k < 4; k++) { - MulType a0 = a[k] >> 4; - MulType a1 = a[k] & 0x0F; - MulType b0 = b[k] >> 4; - MulType b1 = b[k] & 0x0F; - - c1 += a0 * b0; - c1 += a1 * b1; + ABType recv_a[4], recv_b[4]; + + recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a[2], ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a[3], ROW_LOAD_OFFSET + i); + + recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i); + recv_b[2] = + dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); + recv_b[3] = + dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + for (int j = 0; j < 4; j++) { + c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); + c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + c_f[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + c_f[3] += + static_cast(ra[j + 4]) * static_cast(rb[j + 4]); } } - } - } else if (M == 8 && N == 8 && K == 128) { - if constexpr (std::is_integral_v) { - for (int i = 0; i < 4; i++) { - ABType recv_a = - dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - ABType recv_b = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - c0 += sycl::popcount(op(recv_a, recv_b)); + c_h[0] = c_f[0]; + c_h[1] = c_f[1]; + c_h[2] = c_f[2]; + c_h[3] = c_f[3]; - recv_b = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c1 += sycl::popcount(op(recv_a, recv_b)); - } - } - } - - *d0 = c0; - *d1 = c1; -} - -/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a -/// 16x8 f16 matrix (m16n8k8.row.col.f16.f16.f16.f16) -/// Requires the sub-group size of kernel -/// calling this function to be 32 -/// In: 2, 2, 1, 2 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -template -void mma(CDType *d0, CDType *d1, ABType a0, ABType a1, ABType b0, CDType c0, - CDType c1) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 8) { - auto c0_h = reinterpret_cast(&c0); - auto c1_h = reinterpret_cast(&c1); - - float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); - - for (int j = 0; j < 2; j++) { - c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); - c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 2]); - c_f[2] += static_cast(ra[j + 2]) * static_cast(rb[j]); - c_f[3] += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); - } - } - - c0_h[0] = c_f[0]; - c0_h[1] = c_f[1]; - c1_h[0] = c_f[2]; - c1_h[1] = c_f[3]; - } - - *d0 = c0; - *d1 = c1; -} - -/// Multiplies 2 16x16 & 16x8 f16 matrices and accumulates the result to a 16x8 -/// f16 matrix (m16n8k16.row.col.f16.f16.f16.f16). -/// Requires the sub-group size of kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 2, 4, 2, 2 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -template -void mma(volatile CDType *d0, volatile CDType *d1, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 16) { - auto c0_h = reinterpret_cast(&c0); - auto c1_h = reinterpret_cast(&c1); - - float c_f[4] = {c0_h[0], c0_h[1], c1_h[0], c1_h[1]}; - - for (int i = 0; i < 4; i++) { - ABType recv_a[4], recv_b[4]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); - - for (int j = 0; j < 4; j++) { - c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); - c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); - c_f[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); - c_f[3] += static_cast(ra[j + 4]) * static_cast(rb[j + 4]); - } - } - - c0_h[0] = c_f[0]; - c0_h[1] = c_f[1]; - c1_h[0] = c_f[2]; - c1_h[1] = c_f[3]; - } + *d[0] = c[0]; + *d[1] = c[1]; + } else if constexpr (std::is_integral_v) { + // s32.s8.s8.s32 + ABType recv_a[4 * 2], recv_b[4 * 2]; - *d0 = c0; - *d1 = c1; -} - -/// Multiplies 2 16x8 & 8x8 u4/s4 matrices and accumulates the result to a 16x8 -/// f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). -/// Multiplies 2 16x16 & 16x8 matrices and accumulates the result to a 16x8 b32 -/// matrix (m16n8k16.row.col.f32.f16.f16.f32). -/// Multiplies 2 16x32 & 32x8 u8/s8 matrices and accumulates the result to a -/// 16x8 b32 matrix (m16n8k32.row.col.s32.u8.u8.s32 / -/// m16n8k32.row.col.s32.s8.s8.s32). -/// Multiplies 2 16x64 & 64x8 u4/s4 matrices and -/// accumulates the result to a 16x8 b32 matrix (m16n8k64.row.col.s32.u4.u4.s32 -/// / m16n8k64.row.col.s32.s4.s4.s32). -/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc). -/// Multiplies 2 16x256 & 256x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc). -/// Requires the sub-group size of kernel calling this function to be 32. -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 4, 2, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template , - typename ABType, typename CDType> -void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType b0, ABType b1, CDType c0, CDType c1, - CDType c2, CDType c3, Op op = Op{}) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 8) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += recv_a[0] * recv_b[0]; - c1 += recv_a[0] * recv_b[1]; - c2 += recv_a[1] * recv_b[0]; - c3 += recv_a[1] * recv_b[1]; - } - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - - c0 += recv_a[0] * recv_b[0]; - c1 += recv_a[0] * recv_b[1]; - c2 += recv_a[1] * recv_b[0]; - c3 += recv_a[1] * recv_b[1]; - } - } else if (M == 16 && N == 8 && K == 16) { - for (int i = 0; i < 4; i++) { - ABType recv_a[4], recv_b[4]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[2] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[3] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[2] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + 4 + i); - recv_b[3] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + 4 + i); - - auto *ra0 = reinterpret_cast(recv_a); - auto *ra1 = reinterpret_cast(recv_a + 2); - auto *rb0 = reinterpret_cast(recv_b); - auto *rb1 = reinterpret_cast(recv_b + 2); - - // Iterate for k (i * j) times - for (int j = 0; j < 4; j++) { - auto a0 = static_cast(ra0[j]); - auto a1 = static_cast(ra1[j]); - auto b0 = static_cast(rb0[j]); - auto b1 = static_cast(rb1[j]); - - c0 += a0 * b0; - c1 += a0 * b1; - c2 += a1 * b0; - c3 += a1 * b1; - } - } - } else if (M == 16 && N == 8 && K == 32) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); + for (int i = 0; i < 4; i++) { + recv_a[i] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + recv_a[i + 4] = + dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); - for (int k = 0; k < 4; k++) { - c0 += a[k] * b[k]; - c1 += a[k] * b[k + 4]; - c2 += a[k + 4] * b[k]; - c3 += a[k + 4] * b[k + 4]; + recv_b[i] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + recv_b[i + 4] = + dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); } - } - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); MulType *a = reinterpret_cast(recv_a); MulType *b = reinterpret_cast(recv_b); - - for (int k = 0; k < 4; k++) { - c0 += a[k] * b[k]; - c1 += a[k] * b[k + 4]; - c2 += a[k + 4] * b[k]; - c3 += a[k + 4] * b[k + 4]; - } - } - } else if (M == 16 && N == 8 && K == 64) { - if constexpr (std::is_integral_v) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); - - for (int k = 0; k < 4; k++) { - MulType a00 = a[k] >> 4; - MulType a01 = a[k] & 0x0F; - MulType a10 = a[k + 4] >> 4; - MulType a11 = a[k + 4] & 0x0F; - MulType b00 = b[k] >> 4; - MulType b01 = b[k] & 0x0F; - MulType b10 = b[k + 4] >> 4; - MulType b11 = b[k + 4] & 0x0F; - - c0 += a00 * b00; - c0 += a01 * b01; - - c1 += a00 * b10; - c1 += a01 * b11; - - c2 += a10 * b00; - c2 += a11 * b01; - - c3 += a10 * b10; - c3 += a11 * b11; - } - } - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = - dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); - - for (int k = 0; k < 4; k++) { - MulType a00 = a[k] >> 4; - MulType a01 = a[k] & 0x0F; - MulType a10 = a[k + 4] >> 4; - MulType a11 = a[k + 4] & 0x0F; - MulType b00 = b[k] >> 4; - MulType b01 = b[k] & 0x0F; - MulType b10 = b[k + 4] >> 4; - MulType b11 = b[k + 4] & 0x0F; - - c0 += a00 * b00; - c0 += a01 * b01; - - c1 += a00 * b10; - c1 += a01 * b11; - - c2 += a10 * b00; - c2 += a11 * b01; - - c3 += a10 * b10; - c3 += a11 * b11; - } - } - } - } else if (M == 16 && N == 8 && K == 256) { - if constexpr (std::is_integral_v) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(op(recv_a[0], recv_b[0])); - c1 += sycl::popcount(op(recv_a[0], recv_b[1])); - c2 += sycl::popcount(op(recv_a[1], recv_b[0])); - c3 += sycl::popcount(op(recv_a[1], recv_b[1])); - } - - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[1] = - dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(op(recv_a[0], recv_b[0])); - c1 += sycl::popcount(op(recv_a[0], recv_b[1])); - c2 += sycl::popcount(op(recv_a[1], recv_b[0])); - c3 += sycl::popcount(op(recv_a[1], recv_b[1])); + for (int i = 0; i < 16; i++) { + c[0] += a[i] * b[i]; + c[1] += a[i] * b[i + 16]; + c[2] += a[i + 16] * b[i]; + c[3] += a[i + 16] * b[i + 16]; } - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x16 & 16x8 f64 matrices and accumulates the result to a 16x8 -/// f64 matrix (m16n8k16.row.col.f64.f64.f64.f64) Requires the sub-group size of -/// kernel calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 8, 4, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] a2 The 3rd element from A matrix to be multiplied with B matrix -/// \param [in] a3 The 4th element from A matrix to be multiplied with B matrix -/// \param [in] a4 The 5th element from A matrix to be multiplied with B matrix -/// \param [in] a5 The 6th element from A matrix to be multiplied with B matrix -/// \param [in] a6 The 7th element from A matrix to be multiplied with B matrix -/// \param [in] a7 The 8th element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] b1 The 2nd element from B matrix to be multiplied with A matrix -/// \param [in] b2 The 3rd element from B matrix to be multiplied with A matrix -/// \param [in] b3 The 4th element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template -void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType a2, ABType a3, ABType a4, ABType a5, ABType a6, ABType a7, - ABType b0, ABType b1, ABType b2, ABType b3, CDType c0, CDType c1, - CDType c2, CDType c3) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 16) { - ABType recv_a[16 * 2], recv_b[16 * 2]; - - for (int i = 0; i < 4; i++) { - recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[i + 4] = dpct::select_from_sub_group(sg, a2, ROW_LOAD_OFFSET + i); - recv_a[i + 8] = dpct::select_from_sub_group(sg, a4, ROW_LOAD_OFFSET + i); - recv_a[i + 12] = dpct::select_from_sub_group(sg, a6, ROW_LOAD_OFFSET + i); - recv_a[i + 16] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_a[i + 20] = dpct::select_from_sub_group(sg, a3, ROW_LOAD_OFFSET + i); - recv_a[i + 24] = dpct::select_from_sub_group(sg, a5, ROW_LOAD_OFFSET + i); - recv_a[i + 28] = dpct::select_from_sub_group(sg, a7, ROW_LOAD_OFFSET + i); - - recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[i + 4] = dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i); - recv_b[i + 8] = dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i); - recv_b[i + 12] = dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i); - recv_b[i + 16] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - recv_b[i + 20] = - dpct::select_from_sub_group(sg, b1, COL_LOAD_OFFSET + i + 4); - recv_b[i + 24] = - dpct::select_from_sub_group(sg, b2, COL_LOAD_OFFSET + i + 4); - recv_b[i + 28] = - dpct::select_from_sub_group(sg, b3, COL_LOAD_OFFSET + i + 4); - } - - for (int i = 0; i < 16; i++) { - c0 += recv_a[i] * recv_b[i]; - c1 += recv_a[i] * recv_b[i + 16]; - c2 += recv_a[i + 16] * recv_b[i]; - c3 += recv_a[i + 16] * recv_b[i + 16]; - } - } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; -} - -/// Multiplies 2 16x4 & 4x8 f16 matrices and accumulates the result to a -/// 16x8 f32 matrix (m16n8k4.row.col.f16.f16.f16.f16 / -/// m16n8k4.row.col.f32.f16.f16.f32). -/// Multiplies 2 16x4 & 4x8 f64 matrices and accumulates the result to a -/// 16x8 f64 matrix (m16n8k4.row.col.f64.f64.f64.f64). -/// Multiplies 2 16x8 & 8x8 f16 matrices and accumulates the result to a -/// 16x8 f32 matrix (m16n8k8.row.col.f32.f16.f16.f32). -/// Multiplies 2 16x8 & 8x8 f64 matrices and accumulates the result to a -/// 16x8 f64 matrix (m16n8k8.row.col.f64.f64.f64.f64). -/// Multiplies 2 16x16 & 16x8 u8/s8 matrices and accumulates the result to a -/// 16x8 s32 matrix (m16n8k16.row.col.s32.u8.u8.s32 / -/// m16n8k16.row.col.s32.s8.s8.s32). -/// Multiplies 2 16x32 & 32x8 u4/s4 matrices and accumulates the result to a -/// 16x8 s32 matrix (m16n8k32.row.col.s32.u4.u4.s32 / -/// m16n8k32.row.col.s32.s4.s4.s32). -/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc). -/// Multiplies 2 16x128 & 128x8 b1 matrices and accumulates the result to a 16x8 -/// s32 matrix (mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc). -/// Requires the sub-group size of kernel. -/// calling this function to be 32 -/// \tparam [in] M The rows of A/C/D matrix -/// \tparam [in] N The columns of B/C/D matrix -/// \tparam [in] K The columns/rows of A/B matrix -/// \tparam [in] MulType The type of the multiplication result -/// \tparam [in] ABType The type of the input matrices -/// \tparam [in] CDType The type of the output matrix -/// In: 4, 2, 1, 4 -/// \param [in] d0 The 1st element to be written to the output D matrix -/// \param [in] d1 The 2nd element to be written to the output D matrix -/// \param [in] d2 The 3rd element to be written to the output D matrix -/// \param [in] d3 The 4th element to be written to the output D matrix -/// \param [in] a0 The 1st element from A matrix to be multiplied with B matrix -/// \param [in] a1 The 2nd element from A matrix to be multiplied with B matrix -/// \param [in] b0 The 1st element from B matrix to be multiplied with A matrix -/// \param [in] c0 The 1st element from C matrix to be added with d0 -/// \param [in] c1 The 2nd element from C matrix to be added with d1 -/// \param [in] c2 The 3rd element from C matrix to be added with d2 -/// \param [in] c3 The 4th element from C matrix to be added with d3 -template , - typename ABType, typename CDType> -void mma(CDType *d0, CDType *d1, CDType *d2, CDType *d3, ABType a0, ABType a1, - ABType b0, CDType c0, CDType c1, CDType c2, CDType c3, Op op = Op{}) { - auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); - int lane = sg.get_local_linear_id(); - - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); - - if (M == 16 && N == 8 && K == 4) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += recv_a[0] * recv_b[0]; - c1 += recv_a[0] * recv_b[1]; - c2 += recv_a[1] * recv_b[0]; - c3 += recv_a[1] * recv_b[1]; - } - } else if (M == 16 && N == 8 && K == 8) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); - - for (int j = 0; j < 2; j++) { - c0 += static_cast(ra[j]) * static_cast(rb[j]); - c1 += static_cast(ra[j]) * static_cast(rb[j + 2]); - c2 += static_cast(ra[j + 2]) * static_cast(rb[j]); - c3 += static_cast(ra[j + 2]) * static_cast(rb[j + 2]); - } - } - } else if (M == 16 && N == 8 && K == 16) { - ABType recv_a[4 * 2], recv_b[4 * 2]; - for (int i = 0; i < 4; i++) { - recv_a[i] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[i + 4] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - - recv_b[i] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[i + 4] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - } - - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); - for (int i = 0; i < 16; i++) { - c0 += a[i] * b[i]; - c1 += a[i] * b[i + 16]; - c2 += a[i + 16] * b[i]; - c3 += a[i + 16] * b[i + 16]; - } - } else if (M == 16 && N == 8 && K == 32) { - if constexpr (std::is_integral_v) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); - - for (int k = 0; k < 4; k++) { - MulType a00 = a[k] >> 4; - MulType a01 = a[k] & 0x0F; - MulType a10 = a[k + 4] >> 4; - MulType a11 = a[k + 4] & 0x0F; - MulType b00 = b[k] >> 4; - MulType b01 = b[k] & 0x0F; - MulType b10 = b[k + 4] >> 4; - MulType b11 = b[k + 4] & 0x0F; - - c0 += a00 * b00; - c0 += a01 * b01; - - c1 += a00 * b10; - c1 += a01 * b11; - - c2 += a10 * b00; - c2 += a11 * b01; - - c3 += a10 * b10; - c3 += a11 * b11; - } - } - } - } else if (M == 16 && N == 8 && K == 128) { - if constexpr (std::is_integral_v) { - for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; - - recv_a[0] = dpct::select_from_sub_group(sg, a0, ROW_LOAD_OFFSET + i); - recv_a[1] = dpct::select_from_sub_group(sg, a1, ROW_LOAD_OFFSET + i); - recv_b[0] = dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i); - recv_b[1] = - dpct::select_from_sub_group(sg, b0, COL_LOAD_OFFSET + i + 4); - - c0 += sycl::popcount(op(recv_a[0], recv_b[0])); - c1 += sycl::popcount(op(recv_a[0], recv_b[1])); - c2 += sycl::popcount(op(recv_a[1], recv_b[0])); - c3 += sycl::popcount(op(recv_a[1], recv_b[1])); - } + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; } } - - *d0 = c0; - *d1 = c1; - *d2 = c2; - *d3 = c3; } } // namespace matrix From 2a98af9416fef217ca829c0b7bdc353b8014990a Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Wed, 7 May 2025 16:20:37 +0800 Subject: [PATCH 3/9] Added new type logic for A & B matrix elements --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 51 ++++++++++----- clang/runtime/dpct-rt/include/dpct/math.hpp | 71 +++++++++++---------- clang/test/dpct/asm/mma.cu | 20 ++++-- 3 files changed, 91 insertions(+), 51 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 429192f9030c..ac0030d7deb7 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1390,7 +1390,7 @@ class SYCLGen : public SYCLGenBase { if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col) return SYCLGenError(); - // Only f16 type is supported for A and B matrix data + // Data types of D, A, B & C matrices respectively in the PTX instruction const auto *DType = dyn_cast(Inst->getType(0)); const auto *AType = dyn_cast(Inst->getType(1)); const auto *BType = dyn_cast(Inst->getType(2)); @@ -1418,15 +1418,18 @@ class SYCLGen : public SYCLGenBase { // Sizes of A & B matrices std::string M, N, K; - // Data type used to multiply A & B matrices - std::string MulType; + // Data types of A, B & C matrices respectively in the PTX arguments + std::string InMatrixType[3]; + if (Inst->hasAttr(InstAttr::m16n8k16)) { M = "16"; N = "8"; K = "16"; + // Only f16/s8 types are supported for A and B matrices of m16n8k16 if (AType->getKind() == InlineAsmBuiltinType::f16) { - MulType = "sycl::half"; + InMatrixType[0] = "int32_t"; // A type is .f16x2 + InMatrixType[1] = "int32_t"; // B type is .f16x2 // If A matrix type is f16, then C&D matrix types can only be f32 if (CType->getKind() == InlineAsmBuiltinType::f32) { @@ -1437,7 +1440,8 @@ class SYCLGen : public SYCLGenBase { } else return SYCLGenError(); } else if (AType->getKind() == InlineAsmBuiltinType::s8) { - MulType = "int8_t"; + InMatrixType[0] = "int32_t"; // A type is .s8x4 + InMatrixType[1] = "int32_t"; // B type is .s8x4 // If A matrix type is s8, then C&D matrix types can only be s32 if (CType->getKind() == InlineAsmBuiltinType::s32) { @@ -1452,6 +1456,8 @@ class SYCLGen : public SYCLGenBase { } else return SYCLGenError(); + InMatrixType[2] = CDType; + // Check the register sizes for vector elements of A, B, C & D matrices for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); InputOp++) { @@ -1465,13 +1471,9 @@ class SYCLGen : public SYCLGenBase { if (DMatVE->getNumElements() != NumVecElements[3]) return SYCLGenError(); - OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; - OS() << "<"; - OS() << M << ", " << N << ", " << K << ", "; - OS() << MulType; - OS() << ">("; - - // Add D matrix address values to store the MAD result + // Declare and init an array for storing the addresses of D matrix elements + OS() << "{\n"; + OS() << CDType << " *DMatrix_ct1[" << DMatVE->getNumElements() << "] = { "; for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { if (isa(DMatVE->getElement(Inst))) continue; @@ -1481,25 +1483,44 @@ class SYCLGen : public SYCLGenBase { if ((Inst + 1) != DMatVE->getNumElements()) OS() << ", "; } + OS() << " }"; + endstmt(); - // Add A, B & C matrix values to compute MAD + // Declare and init vectors for storing the values of A, B & C matrix elements + std::string InMatrixName[3] = {"A", "B", "C"}; for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); InputOp++) { if (auto VE = dyn_cast(Inst->getInputOperand(InputOp))) { + OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " << VE->getNumElements() << "> " << InMatrixName[InputOp] << "Matrix_ct1("; for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { if (isa(VE->getElement(Inst))) continue; - OS() << ", "; if (emitStmt(VE->getElement(Inst))) return SYCLGenError(); + if ((Inst + 1) != VE->getNumElements()) + OS() << ", "; } + OS() << ")"; + endstmt(); } else { return SYCLGenError(); } } - OS() << ");"; + OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; + OS() << "<"; + OS() << M << ", " << N << ", " << K << ", "; + OS() << ABType; + OS() << ">("; + + OS() << "DMatrix_ct1"; + for (int i = 0; i < 3; i++) + OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&" << InMatrixName[i] << "Matrix_ct1)"; + OS() << ")"; + endstmt(); + OS() << "}"; + endstmt(); const auto *KernelDecl = getImmediateOuterFuncDecl(GAS); if (KernelDecl) { diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index c0c692219749..53965d12bd4d 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2218,6 +2218,22 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { ldmatrix(addr, m4, trans, 3); } +/// Multiplies 2 matrices (A & B) and adds the result to C matrix and +/// accumulates the result to a D matrix (MAD). Requires the sub-group size of +/// kernel calling this function to be 32. +/// \tparam [in] M The rows of A, C & D matrix +/// \tparam [in] N The columns of B, C, D matrix +/// \tparam [in] K The columns & rows of A & B matrices respectively +/// \tparam [in] MulType The type used to multiply A and B matrix elements as +/// \tparam [in] ABType The type of the input matrix (A & B) elements +/// \tparam [in] CDType The type of the output matrix (C & D) elements +/// \param [in] d The elements of the output D matrix to store the result to +/// \param [in] a The elements of the input A matrix to be multiplied with B +/// matrix elements +/// \param [in] b The elements of the input B matrix to be multiplied with A +/// matrix elements +/// \param [in] c The elements of the input C matrix to be added with the result +/// of A * B template > void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) { @@ -2228,12 +2244,8 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) { short COL_LOAD_OFFSET = 8 * (lane % 4); if (M == 16 && N == 8 && K == 16) { - if constexpr (std::is_same_v) { + if constexpr (std::is_floating_point_v) { // f32.f16.f16.f32 - auto c_h = reinterpret_cast(c); - - float c_f[4] = {c_h[0], c_h[1], c_h[2], c_h[3]}; - for (int i = 0; i < 4; i++) { ABType recv_a[4], recv_b[4]; @@ -2245,50 +2257,45 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) { recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); recv_b[1] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i); recv_b[2] = - dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); + dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + 4 + i); recv_b[3] = - dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i + 4); + dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + 4 + i); auto ra = reinterpret_cast(recv_a); auto rb = reinterpret_cast(recv_b); for (int j = 0; j < 4; j++) { - c_f[0] += static_cast(ra[j]) * static_cast(rb[j]); - c_f[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); - c_f[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); - c_f[3] += - static_cast(ra[j + 4]) * static_cast(rb[j + 4]); + c[0] += static_cast(ra[j]) * static_cast(rb[j]); + c[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + c[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + c[3] += + static_cast(ra[j + 4]) * static_cast(rb[j + 4]); } } - c_h[0] = c_f[0]; - c_h[1] = c_f[1]; - c_h[2] = c_f[2]; - c_h[3] = c_f[3]; - *d[0] = c[0]; *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; } else if constexpr (std::is_integral_v) { // s32.s8.s8.s32 - ABType recv_a[4 * 2], recv_b[4 * 2]; - for (int i = 0; i < 4; i++) { - recv_a[i] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); - recv_a[i + 4] = - dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + ABType recv_a[2], recv_b[2]; - recv_b[i] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); - recv_b[i + 4] = + recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); - } - MulType *a = reinterpret_cast(recv_a); - MulType *b = reinterpret_cast(recv_b); - for (int i = 0; i < 16; i++) { - c[0] += a[i] * b[i]; - c[1] += a[i] * b[i + 16]; - c[2] += a[i + 16] * b[i]; - c[3] += a[i + 16] * b[i + 16]; + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + for (int i = 0; i < 4; i++) { + c[0] += ra[i] * rb[i]; + c[1] += ra[i] * rb[i + 4]; + c[2] += ra[i + 4] * rb[i]; + c[3] += ra[i + 4] * rb[i + 4]; + } } *d[0] = c[0]; diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 98508d39ae5f..5995d2040a11 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -14,15 +14,21 @@ As per PTX ASM 8.1, below is the status of supported configurations --------- --------- ---------- ----------- ------------- | Shape | | A | | B | | C / D | | Supported | --------- --------- ---------- ----------- ------------- -m16n8k16 .f16/.bf16 .f16/.bf16 .f16/.f32 Partial (.f16.f16.f16.f16 / .f32.f16.f16.f32) - .s8/.u8 .s8/.u8 .s32 Yes +m16n8k16 .f16 .f16 .f16/.f32 Yes + .s8 .s8 .s32 Yes A Layout: row B Layout: col */ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { - // CHECK: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(&fc[0], &fc[1], &fc[2], &fc[3], a[0], a[1], a[2], a[3], b[0], b[1], fc[0], fc[1], fc[2], fc[3]); + // CHECK: { + // CHECK-NEXT: float *DMatrix_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; + // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1], a[2], a[3]); + // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0], b[1]); + // CHECK-NEXT: sycl::vec CMatrix_ct1(fc[0], fc[1], fc[2], fc[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(DMatrix_ct1, reinterpret_cast(&AMatrix_ct1), reinterpret_cast(&BMatrix_ct1), reinterpret_cast(&CMatrix_ct1)); + // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " { %0, %1, %2, %3 }, " " { %4, %5, %6, %7 }, " @@ -32,7 +38,13 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), "r"(b[0]), "r"(b[1])); - // CHECK: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(&d[0], &d[1], &d[2], &d[3], a[0], a[1], b[0], c[0], c[1], c[2], c[3]); + // CHECK: { + // CHECK-NEXT: int32_t *DMatrix_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; + // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1]); + // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0]); + // CHECK-NEXT: sycl::vec CMatrix_ct1(c[0], c[1], c[2], c[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(DMatrix_ct1, reinterpret_cast(&AMatrix_ct1), reinterpret_cast(&BMatrix_ct1), reinterpret_cast(&CMatrix_ct1)); + // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " " { %0, %1, %2, %3 }, " " { %4, %5 }, " From d385cb3a6cc70eabb069af33dd4dc9cf5198c869 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Wed, 7 May 2025 16:48:23 +0800 Subject: [PATCH 4/9] Fixed format & addressed comments --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 10 +++++++--- clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp | 6 +++--- clang/runtime/dpct-rt/include/dpct/math.hpp | 8 ++++---- 3 files changed, 14 insertions(+), 10 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index ac0030d7deb7..6fc948763755 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1486,13 +1486,16 @@ class SYCLGen : public SYCLGenBase { OS() << " }"; endstmt(); - // Declare and init vectors for storing the values of A, B & C matrix elements + // Declare and init vectors for storing the values of A, B & C matrix + // elements std::string InMatrixName[3] = {"A", "B", "C"}; for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); InputOp++) { if (auto VE = dyn_cast(Inst->getInputOperand(InputOp))) { - OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " << VE->getNumElements() << "> " << InMatrixName[InputOp] << "Matrix_ct1("; + OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " + << VE->getNumElements() << "> " << InMatrixName[InputOp] + << "Matrix_ct1("; for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { if (isa(VE->getElement(Inst))) continue; @@ -1516,7 +1519,8 @@ class SYCLGen : public SYCLGenBase { OS() << "DMatrix_ct1"; for (int i = 0; i < 3; i++) - OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&" << InMatrixName[i] << "Matrix_ct1)"; + OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&" + << InMatrixName[i] << "Matrix_ct1)"; OS() << ")"; endstmt(); OS() << "}"; diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp index efc522477206..dc3b70b8373f 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp @@ -757,9 +757,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef Vec) { // Vector size must be 2, 4, or 8. switch (Vec.size()) { case 1: - Kind = InlineAsmVectorType::v1; - break; - case 2: + Kind = InlineAsmVectorType::v1; + break; + case 2: Kind = InlineAsmVectorType::v2; break; case 4: diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 53965d12bd4d..24c48cd40228 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2227,7 +2227,7 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { /// \tparam [in] MulType The type used to multiply A and B matrix elements as /// \tparam [in] ABType The type of the input matrix (A & B) elements /// \tparam [in] CDType The type of the output matrix (C & D) elements -/// \param [in] d The elements of the output D matrix to store the result to +/// \param [out] d The elements of the output D matrix to store the result to /// \param [in] a The elements of the input A matrix to be multiplied with B /// matrix elements /// \param [in] b The elements of the input B matrix to be multiplied with A @@ -2235,15 +2235,15 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { /// \param [in] c The elements of the input C matrix to be added with the result /// of A * B template > -void mma(CDType **d, ABType *a, ABType *b, CDType *c, Op op = Op{}) { + typename CDType> +void mma(CDType **d, ABType *a, ABType *b, CDType *c) { auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); - if (M == 16 && N == 8 && K == 16) { + if constexpr (M == 16 && N == 8 && K == 16) { if constexpr (std::is_floating_point_v) { // f32.f16.f16.f32 for (int i = 0; i < 4; i++) { From e5cf7361376a6db8eaa96d2b396df32dfeef4ef5 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Wed, 7 May 2025 17:24:03 +0800 Subject: [PATCH 5/9] Changed the interface to accept void * --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 7 +++---- clang/runtime/dpct-rt/include/dpct/math.hpp | 23 ++++++++++++++------- clang/test/dpct/asm/mma.cu | 4 ++-- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 6fc948763755..0f9c724df620 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1514,13 +1514,12 @@ class SYCLGen : public SYCLGenBase { OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; OS() << "<"; OS() << M << ", " << N << ", " << K << ", "; - OS() << ABType; + OS() << ABType << ", " << InMatrixType[0] << ", " << InMatrixType[2]; OS() << ">("; - OS() << "DMatrix_ct1"; + OS() << "reinterpret_cast(DMatrix_ct1)"; for (int i = 0; i < 3; i++) - OS() << ", reinterpret_cast<" << InMatrixType[i] << " *>(&" - << InMatrixName[i] << "Matrix_ct1)"; + OS() << ", &" << InMatrixName[i] << "Matrix_ct1"; OS() << ")"; endstmt(); OS() << "}"; diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 24c48cd40228..1893dd1e87af 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2227,25 +2227,33 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { /// \tparam [in] MulType The type used to multiply A and B matrix elements as /// \tparam [in] ABType The type of the input matrix (A & B) elements /// \tparam [in] CDType The type of the output matrix (C & D) elements -/// \param [out] d The elements of the output D matrix to store the result to -/// \param [in] a The elements of the input A matrix to be multiplied with B +/// \param [out] d_mat The elements of the output D matrix to store the result +/// of A* B + C +/// \param [in] a_mat The elements of the input A matrix to be multiplied with B /// matrix elements -/// \param [in] b The elements of the input B matrix to be multiplied with A +/// \param [in] b_mat The elements of the input B matrix to be multiplied with A /// matrix elements -/// \param [in] c The elements of the input C matrix to be added with the result -/// of A * B +/// \param [in] c_mat The elements of the input C matrix to be added with the +/// result of A * B template -void mma(CDType **d, ABType *a, ABType *b, CDType *c) { +void mma(void **d_mat, void *a_mat, void *b_mat, void *c_mat) { + auto d = reinterpret_cast(d_mat); + auto a = reinterpret_cast(a_mat); + auto b = reinterpret_cast(b_mat); + auto c = reinterpret_cast(c_mat); + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); + static_assert(M == 16 && N == 8 && K == 16, + "Only m16n8k16 shape is supported!"); + short ROW_LOAD_OFFSET = 4 * (lane >> 2); short COL_LOAD_OFFSET = 8 * (lane % 4); if constexpr (M == 16 && N == 8 && K == 16) { if constexpr (std::is_floating_point_v) { - // f32.f16.f16.f32 for (int i = 0; i < 4; i++) { ABType recv_a[4], recv_b[4]; @@ -2278,7 +2286,6 @@ void mma(CDType **d, ABType *a, ABType *b, CDType *c) { *d[2] = c[2]; *d[3] = c[3]; } else if constexpr (std::is_integral_v) { - // s32.s8.s8.s32 for (int i = 0; i < 4; i++) { ABType recv_a[2], recv_b[2]; diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 5995d2040a11..c0d4ea92e049 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -27,7 +27,7 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1], a[2], a[3]); // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0], b[1]); // CHECK-NEXT: sycl::vec CMatrix_ct1(fc[0], fc[1], fc[2], fc[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half>(DMatrix_ct1, reinterpret_cast(&AMatrix_ct1), reinterpret_cast(&BMatrix_ct1), reinterpret_cast(&CMatrix_ct1)); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " { %0, %1, %2, %3 }, " @@ -43,7 +43,7 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1]); // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0]); // CHECK-NEXT: sycl::vec CMatrix_ct1(c[0], c[1], c[2], c[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t>(DMatrix_ct1, reinterpret_cast(&AMatrix_ct1), reinterpret_cast(&BMatrix_ct1), reinterpret_cast(&CMatrix_ct1)); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " " { %0, %1, %2, %3 }, " From 685a7b8a0912a4513907c2e31f803a5f221a8a58 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Thu, 8 May 2025 19:33:35 +0800 Subject: [PATCH 6/9] Refined comments --- clang/runtime/dpct-rt/include/dpct/math.hpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 1893dd1e87af..f09ebf950aef 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2221,6 +2221,8 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { /// Multiplies 2 matrices (A & B) and adds the result to C matrix and /// accumulates the result to a D matrix (MAD). Requires the sub-group size of /// kernel calling this function to be 32. +/// Current supported shapes & types: +/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) /// \tparam [in] M The rows of A, C & D matrix /// \tparam [in] N The columns of B, C, D matrix /// \tparam [in] K The columns & rows of A & B matrices respectively From e915a0ba7ad62ae33bb96a3f0bd9613b1344bcba Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Fri, 9 May 2025 18:24:53 +0800 Subject: [PATCH 7/9] Added more inline commenst for loops and added volatile type to D matrix elements --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 5 +- clang/runtime/dpct-rt/include/dpct/math.hpp | 76 ++++++++++++++++----- clang/test/dpct/asm/mma.cu | 8 +-- 3 files changed, 65 insertions(+), 24 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 0f9c724df620..2ec246ede3f9 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1473,7 +1473,8 @@ class SYCLGen : public SYCLGenBase { // Declare and init an array for storing the addresses of D matrix elements OS() << "{\n"; - OS() << CDType << " *DMatrix_ct1[" << DMatVE->getNumElements() << "] = { "; + OS() << "volatile " << CDType << " *DMatrix_ct1[" + << DMatVE->getNumElements() << "] = { "; for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { if (isa(DMatVE->getElement(Inst))) continue; @@ -1517,7 +1518,7 @@ class SYCLGen : public SYCLGenBase { OS() << ABType << ", " << InMatrixType[0] << ", " << InMatrixType[2]; OS() << ">("; - OS() << "reinterpret_cast(DMatrix_ct1)"; + OS() << "reinterpret_cast(DMatrix_ct1)"; for (int i = 0; i < 3; i++) OS() << ", &" << InMatrixName[i] << "Matrix_ct1"; OS() << ")"; diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index f09ebf950aef..7e98db8e12d7 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2226,7 +2226,7 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { /// \tparam [in] M The rows of A, C & D matrix /// \tparam [in] N The columns of B, C, D matrix /// \tparam [in] K The columns & rows of A & B matrices respectively -/// \tparam [in] MulType The type used to multiply A and B matrix elements as +/// \tparam [in] MulType The type of A and B matrices in MMA ASM instruction /// \tparam [in] ABType The type of the input matrix (A & B) elements /// \tparam [in] CDType The type of the output matrix (C & D) elements /// \param [out] d_mat The elements of the output D matrix to store the result @@ -2239,8 +2239,8 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { /// result of A * B template -void mma(void **d_mat, void *a_mat, void *b_mat, void *c_mat) { - auto d = reinterpret_cast(d_mat); +void mma(volatile void **d_mat, void *a_mat, void *b_mat, void *c_mat) { + auto d = reinterpret_cast(d_mat); auto a = reinterpret_cast(a_mat); auto b = reinterpret_cast(b_mat); auto c = reinterpret_cast(c_mat); @@ -2256,61 +2256,101 @@ void mma(void **d_mat, void *a_mat, void *b_mat, void *c_mat) { if constexpr (M == 16 && N == 8 && K == 16) { if constexpr (std::is_floating_point_v) { + // Init D matrix with elements of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Iterate through 4 neighbouring work items to gather the A & B matix + // elements of the rows and cols associated with each work item + // WI0: { row0: a0 .. a15 & row8: a0 .. a15 } and + // { col0: b0 .. b15 & col1: b0 .. b15 } for (int i = 0; i < 4; i++) { ABType recv_a[4], recv_b[4]; + // WI0 loads row0: { a0, a1 }, { a2, a3 }, { a4, a5 }, { a6, a7 } recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + // WI0 loads row0: { a8, a9 }, { a10, a11 }, { a12, a13 }, { a14, a15 } recv_a[1] = dpct::select_from_sub_group(sg, a[2], ROW_LOAD_OFFSET + i); + // WI0 loads row8: { a0, a1 }, { a2, a3 }, { a4, a5 }, { a6, a7 } recv_a[2] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + // WI0 loads row8: { a8, a9 }, { a10, a11 }, { a12, a13 }, { a14, a15 } recv_a[3] = dpct::select_from_sub_group(sg, a[3], ROW_LOAD_OFFSET + i); + // WI0 loads col0: { b0, b1 }, { b2, b3 }, { b4, b5 }, { b6, b7 } recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + // WI0 loads col0: { b8, b9 }, { b10, b11 }, { b12, b13 }, { b14, b15 } recv_b[1] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i); + // WI0 loads col1: { b0, b1 }, { b2, b3 }, { b4, b5 }, { b6, b7 } recv_b[2] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + 4 + i); + // WI0 loads col1: { b8, b9 }, { b10, b11 }, { b12, b13 }, { b14, b15 } recv_b[3] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + 4 + i); auto ra = reinterpret_cast(recv_a); auto rb = reinterpret_cast(recv_b); + // Calculate partial product of 4 A & B matrix elements + // For each iteration of i, work-item calculates D matrix values as + // below: + // d0 += row0{ a0, a1, a8, a9 } * col0{ b0, b1, b8, b9 } + // d1 += row0{ a0, a1, a8, a9 } * col1{ b0, b1, b8, b9 } + // d2 += row8{ a0, a1, a8, a9 } * col0{ b0, b1, b8, b9 } + // d3 += row8{ a1, a1, a8, a9 } * col1{ b0, b1, b8, b9 } for (int j = 0; j < 4; j++) { - c[0] += static_cast(ra[j]) * static_cast(rb[j]); - c[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); - c[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); - c[3] += + *d[0] += static_cast(ra[j]) * static_cast(rb[j]); + *d[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + *d[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + *d[3] += static_cast(ra[j + 4]) * static_cast(rb[j + 4]); } } - + } else if constexpr (std::is_integral_v) { + // Init D matrix with elements of C matrix *d[0] = c[0]; *d[1] = c[1]; *d[2] = c[2]; *d[3] = c[3]; - } else if constexpr (std::is_integral_v) { + + // Iterate through 4 neighbouring work items to gather the A & B matix + // elements of the rows and cols associated with each work item + // WI0: { row0: a0 .. a15 & row8: a0 .. a15 } and + // { col0: b0 .. b15 & col1: b0 .. b15 } for (int i = 0; i < 4; i++) { ABType recv_a[2], recv_b[2]; + // WI0 loads row0: { a0, a1, a2, a3 }, { a4, a5, a6, a7 }, + // { a8, a9, a10, a11 }, { a12, a13, a14, a15 } recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + // WI0 loads row8: { a0, a1, a2, a3 }, { a4, a5, a6, a7 }, + // { a8, a9, a10, a11 }, { a12, a13, a14, a15 } recv_a[1] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + // WI0 loads col0: { b0, b1, b2, b3 }, { b4, b5, b6, b7 }, + // { b8, b9, b10, b11 }, { b12, b13, b14, b15 } recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + // WI0 loads col1: { b0, b1, b2, b3 }, { b4, b5, b6, b7 }, + // { b8, b9, b10, b11 }, { b12, b13, b14, b15 } recv_b[1] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); auto ra = reinterpret_cast(recv_a); auto rb = reinterpret_cast(recv_b); + + // Calculate partial product of 4 A & B matrix elements + // For each iteration of i, WI0 calculates D matrix values as below: + // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + // d2 += row8{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d3 += row8{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } for (int i = 0; i < 4; i++) { - c[0] += ra[i] * rb[i]; - c[1] += ra[i] * rb[i + 4]; - c[2] += ra[i + 4] * rb[i]; - c[3] += ra[i + 4] * rb[i + 4]; + *d[0] += ra[i] * rb[i]; + *d[1] += ra[i] * rb[i + 4]; + *d[2] += ra[i + 4] * rb[i]; + *d[3] += ra[i + 4] * rb[i + 4]; } } - - *d[0] = c[0]; - *d[1] = c[1]; - *d[2] = c[2]; - *d[3] = c[3]; } } } diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index c0d4ea92e049..7f3f1fd1492e 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -23,11 +23,11 @@ B Layout: col __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK: { - // CHECK-NEXT: float *DMatrix_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; + // CHECK-NEXT: volatile float *DMatrix_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1], a[2], a[3]); // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0], b[1]); // CHECK-NEXT: sycl::vec CMatrix_ct1(fc[0], fc[1], fc[2], fc[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " { %0, %1, %2, %3 }, " @@ -39,11 +39,11 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { "r"(b[0]), "r"(b[1])); // CHECK: { - // CHECK-NEXT: int32_t *DMatrix_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; + // CHECK-NEXT: volatile int32_t *DMatrix_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1]); // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0]); // CHECK-NEXT: sycl::vec CMatrix_ct1(c[0], c[1], c[2], c[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " " { %0, %1, %2, %3 }, " From 7527809da0b2088be929570faec1941af60372d5 Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Mon, 12 May 2025 14:48:46 +0800 Subject: [PATCH 8/9] Merged MulType and ABType into 1 --- clang/lib/DPCT/RulesAsm/AsmMigration.cpp | 20 +-- clang/runtime/dpct-rt/include/dpct/math.hpp | 141 +++++++++++--------- clang/test/dpct/asm/mma.cu | 20 +-- 3 files changed, 95 insertions(+), 86 deletions(-) diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 2ec246ede3f9..32d94724fe0b 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -1428,8 +1428,8 @@ class SYCLGen : public SYCLGenBase { // Only f16/s8 types are supported for A and B matrices of m16n8k16 if (AType->getKind() == InlineAsmBuiltinType::f16) { - InMatrixType[0] = "int32_t"; // A type is .f16x2 - InMatrixType[1] = "int32_t"; // B type is .f16x2 + InMatrixType[0] = "uint32_t"; // A type is .f16x2 + InMatrixType[1] = "uint32_t"; // B type is .f16x2 // If A matrix type is f16, then C&D matrix types can only be f32 if (CType->getKind() == InlineAsmBuiltinType::f32) { @@ -1440,8 +1440,8 @@ class SYCLGen : public SYCLGenBase { } else return SYCLGenError(); } else if (AType->getKind() == InlineAsmBuiltinType::s8) { - InMatrixType[0] = "int32_t"; // A type is .s8x4 - InMatrixType[1] = "int32_t"; // B type is .s8x4 + InMatrixType[0] = "uint32_t"; // A type is .f16x2 + InMatrixType[1] = "uint32_t"; // B type is .f16x2 // If A matrix type is s8, then C&D matrix types can only be s32 if (CType->getKind() == InlineAsmBuiltinType::s32) { @@ -1473,7 +1473,7 @@ class SYCLGen : public SYCLGenBase { // Declare and init an array for storing the addresses of D matrix elements OS() << "{\n"; - OS() << "volatile " << CDType << " *DMatrix_ct1[" + OS() << "volatile " << CDType << " *d_mat_frag_ct1[" << DMatVE->getNumElements() << "] = { "; for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { if (isa(DMatVE->getElement(Inst))) @@ -1489,14 +1489,14 @@ class SYCLGen : public SYCLGenBase { // Declare and init vectors for storing the values of A, B & C matrix // elements - std::string InMatrixName[3] = {"A", "B", "C"}; + std::string InMatrixName[3] = {"a", "b", "c"}; for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); InputOp++) { if (auto VE = dyn_cast(Inst->getInputOperand(InputOp))) { OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " << VE->getNumElements() << "> " << InMatrixName[InputOp] - << "Matrix_ct1("; + << "_mat_frag_ct1("; for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { if (isa(VE->getElement(Inst))) continue; @@ -1515,12 +1515,12 @@ class SYCLGen : public SYCLGenBase { OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; OS() << "<"; OS() << M << ", " << N << ", " << K << ", "; - OS() << ABType << ", " << InMatrixType[0] << ", " << InMatrixType[2]; + OS() << ABType << ", " << CDType; OS() << ">("; - OS() << "reinterpret_cast(DMatrix_ct1)"; + OS() << "reinterpret_cast(d_mat_frag_ct1)"; for (int i = 0; i < 3; i++) - OS() << ", &" << InMatrixName[i] << "Matrix_ct1"; + OS() << ", &" << InMatrixName[i] << "_mat_frag_ct1"; OS() << ")"; endstmt(); OS() << "}"; diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 7e98db8e12d7..0bbe40cee90f 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2218,32 +2218,42 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { ldmatrix(addr, m4, trans, 3); } -/// Multiplies 2 matrices (A & B) and adds the result to C matrix and -/// accumulates the result to a D matrix (MAD). Requires the sub-group size of -/// kernel calling this function to be 32. +/// A helper struct that defines the pack type for the input matrix fragments +/// of mma() function based on the type of input matrix fragments. +/// The MMAType struct is specialized for different types of input matrices. +/// Currently, the specialization for f16 and s8 types is defined below. +/// \tparam [in] T The type of the input matrix fragments +template struct MMAType { + using PackType = uint32_t; +}; + +/// Each work item of a sub-group (limited to size 32) calling this function +/// calculates a subset fragment for the output D matrix using MAD operation on +/// A, B & C matrix fragments (D = A * B + C). /// Current supported shapes & types: /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) +/// Here, m, n & k define the shapes of A, B & C matrices respectively +/// (A = [m x k], B = [k x n], C = [m x n]). /// \tparam [in] M The rows of A, C & D matrix /// \tparam [in] N The columns of B, C, D matrix /// \tparam [in] K The columns & rows of A & B matrices respectively -/// \tparam [in] MulType The type of A and B matrices in MMA ASM instruction -/// \tparam [in] ABType The type of the input matrix (A & B) elements -/// \tparam [in] CDType The type of the output matrix (C & D) elements -/// \param [out] d_mat The elements of the output D matrix to store the result -/// of A* B + C -/// \param [in] a_mat The elements of the input A matrix to be multiplied with B -/// matrix elements -/// \param [in] b_mat The elements of the input B matrix to be multiplied with A -/// matrix elements -/// \param [in] c_mat The elements of the input C matrix to be added with the -/// result of A * B -template -void mma(volatile void **d_mat, void *a_mat, void *b_mat, void *c_mat) { - auto d = reinterpret_cast(d_mat); - auto a = reinterpret_cast(a_mat); - auto b = reinterpret_cast(b_mat); - auto c = reinterpret_cast(c_mat); +/// \tparam [in] ABType The type of the input matrix (A & B) fragment +/// \tparam [in] CDType The type of the output matrix (C & D) fragment +/// \param [out] d_mat_frag The fragment of the output D matrix to store the +/// result of A * B + C +/// \param [in] a_mat_frag The fragment of the input A matrix to be multiplied +/// with B matrix fragment +/// \param [in] b_mat_frag The fragment of the input B matrix to be multiplied +/// with A matrix fragment +/// \param [in] c_mat_frag The fragment of the input C matrix to be added with +/// the result of A * B fragments +template +void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, + void *c_mat_frag) { + auto d = reinterpret_cast(d_mat_frag); + auto a = reinterpret_cast::PackType *>(a_mat_frag); + auto b = reinterpret_cast::PackType *>(b_mat_frag); + auto c = reinterpret_cast(c_mat_frag); auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); int lane = sg.get_local_linear_id(); @@ -2256,49 +2266,50 @@ void mma(volatile void **d_mat, void *a_mat, void *b_mat, void *c_mat) { if constexpr (M == 16 && N == 8 && K == 16) { if constexpr (std::is_floating_point_v) { - // Init D matrix with elements of C matrix + // Init D matrix fragment with C matrix fragment *d[0] = c[0]; *d[1] = c[1]; *d[2] = c[2]; *d[3] = c[3]; - // Iterate through 4 neighbouring work items to gather the A & B matix - // elements of the rows and cols associated with each work item - // WI0: { row0: a0 .. a15 & row8: a0 .. a15 } and - // { col0: b0 .. b15 & col1: b0 .. b15 } + // Each work item Wi (i=0...31) gathers 2 row & 2 col matrix fragments + // of length k (8) from A & B matrices respectively into recv_a & recv_b + // across 4 iterations using 4 neighboring work items with below mapping + // logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 for (int i = 0; i < 4; i++) { - ABType recv_a[4], recv_b[4]; + typename MMAType::PackType recv_a[4], recv_b[4]; - // WI0 loads row0: { a0, a1 }, { a2, a3 }, { a4, a5 }, { a6, a7 } + // Load partial fragment from row0 of matrix A ({a0, a1}) recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); - // WI0 loads row0: { a8, a9 }, { a10, a11 }, { a12, a13 }, { a14, a15 } + // Load partial fragment from row0 of matrix A ({a2, a3}) recv_a[1] = dpct::select_from_sub_group(sg, a[2], ROW_LOAD_OFFSET + i); - // WI0 loads row8: { a0, a1 }, { a2, a3 }, { a4, a5 }, { a6, a7 } + // Load partial fragment from row1 of matrix A ({a0, a1}) recv_a[2] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); - // WI0 loads row8: { a8, a9 }, { a10, a11 }, { a12, a13 }, { a14, a15 } + // Load partial fragment from row1 of matrix A ({a2, a3}) recv_a[3] = dpct::select_from_sub_group(sg, a[3], ROW_LOAD_OFFSET + i); - // WI0 loads col0: { b0, b1 }, { b2, b3 }, { b4, b5 }, { b6, b7 } + // Load partial fragment from col0 of matrix B ({b0, b1}) recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); - // WI0 loads col0: { b8, b9 }, { b10, b11 }, { b12, b13 }, { b14, b15 } + // Load partial fragment from col0 of matrix B ({b2, b3}) recv_b[1] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i); - // WI0 loads col1: { b0, b1 }, { b2, b3 }, { b4, b5 }, { b6, b7 } + // Load partial fragment from col1 of matrix B ({b0, b1}) recv_b[2] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + 4 + i); - // WI0 loads col1: { b8, b9 }, { b10, b11 }, { b12, b13 }, { b14, b15 } + // Load partial fragment from col1 of matrix B ({b2, b3}) recv_b[3] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + 4 + i); - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); - // Calculate partial product of 4 A & B matrix elements - // For each iteration of i, work-item calculates D matrix values as - // below: - // d0 += row0{ a0, a1, a8, a9 } * col0{ b0, b1, b8, b9 } - // d1 += row0{ a0, a1, a8, a9 } * col1{ b0, b1, b8, b9 } - // d2 += row8{ a0, a1, a8, a9 } * col0{ b0, b1, b8, b9 } - // d3 += row8{ a1, a1, a8, a9 } * col1{ b0, b1, b8, b9 } + // Each work item calculates a partial product of A & B matrix fragments + // and adds it to the corresponding D matrix fragment + // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + // d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d3 += row1{ a1, a1, a2, a3 } * col1{ b0, b1, b2, b3 } for (int j = 0; j < 4; j++) { *d[0] += static_cast(ra[j]) * static_cast(rb[j]); *d[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); @@ -2307,43 +2318,41 @@ void mma(volatile void **d_mat, void *a_mat, void *b_mat, void *c_mat) { static_cast(ra[j + 4]) * static_cast(rb[j + 4]); } } - } else if constexpr (std::is_integral_v) { - // Init D matrix with elements of C matrix + } else if constexpr (std::is_integral_v) { + // Init D matrix with fragments of C matrix *d[0] = c[0]; *d[1] = c[1]; *d[2] = c[2]; *d[3] = c[3]; - // Iterate through 4 neighbouring work items to gather the A & B matix - // elements of the rows and cols associated with each work item - // WI0: { row0: a0 .. a15 & row8: a0 .. a15 } and - // { col0: b0 .. b15 & col1: b0 .. b15 } + // Each work item Wi (i=0...31) gathers 2 row & 2 col matrix fragments + // of length k (8) from A & B matrices respectively into recv_a & recv_b + // across 4 iterations using 4 neighboring work items with below mapping + // logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 for (int i = 0; i < 4; i++) { - ABType recv_a[2], recv_b[2]; + typename MMAType::PackType recv_a[2], recv_b[2]; - // WI0 loads row0: { a0, a1, a2, a3 }, { a4, a5, a6, a7 }, - // { a8, a9, a10, a11 }, { a12, a13, a14, a15 } + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); - // WI0 loads row8: { a0, a1, a2, a3 }, { a4, a5, a6, a7 }, - // { a8, a9, a10, a11 }, { a12, a13, a14, a15 } + // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) recv_a[1] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); - // WI0 loads col0: { b0, b1, b2, b3 }, { b4, b5, b6, b7 }, - // { b8, b9, b10, b11 }, { b12, b13, b14, b15 } + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); - // WI0 loads col1: { b0, b1, b2, b3 }, { b4, b5, b6, b7 }, - // { b8, b9, b10, b11 }, { b12, b13, b14, b15 } + // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) recv_b[1] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); - auto ra = reinterpret_cast(recv_a); - auto rb = reinterpret_cast(recv_b); + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); - // Calculate partial product of 4 A & B matrix elements - // For each iteration of i, WI0 calculates D matrix values as below: + // Each work item calculates a partial product of A & B matrix fragments + // and adds it to the corresponding D matrix fragment // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } - // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } - // d2 += row8{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } - // d3 += row8{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + // d1 += row0{ a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } + // d2 += row1{ a4, a5, a6, a7 } * col0{ b0, b1, b2, b3 } + // d3 += row1{ a4, a5, a6, a7 } * col1{ b4, b5, b6, b7 } for (int i = 0; i < 4; i++) { *d[0] += ra[i] * rb[i]; *d[1] += ra[i] * rb[i + 4]; diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index 7f3f1fd1492e..dc7ab8d9942d 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -23,11 +23,11 @@ B Layout: col __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK: { - // CHECK-NEXT: volatile float *DMatrix_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; - // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1], a[2], a[3]); - // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0], b[1]); - // CHECK-NEXT: sycl::vec CMatrix_ct1(fc[0], fc[1], fc[2], fc[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); + // CHECK-NEXT: volatile float *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1], fc[2], fc[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " { %0, %1, %2, %3 }, " @@ -39,11 +39,11 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { "r"(b[0]), "r"(b[1])); // CHECK: { - // CHECK-NEXT: volatile int32_t *DMatrix_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; - // CHECK-NEXT: sycl::vec AMatrix_ct1(a[0], a[1]); - // CHECK-NEXT: sycl::vec BMatrix_ct1(b[0]); - // CHECK-NEXT: sycl::vec CMatrix_ct1(c[0], c[1], c[2], c[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast(DMatrix_ct1), &AMatrix_ct1, &BMatrix_ct1, &CMatrix_ct1); + // CHECK-NEXT: volatile int32_t *d_mat_frag_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(c[0], c[1], c[2], c[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " " { %0, %1, %2, %3 }, " From eced50688548053b4a71cdf06f909f489908c4dd Mon Sep 17 00:00:00 2001 From: Teja Alaghari Date: Tue, 13 May 2025 18:18:01 +0800 Subject: [PATCH 9/9] Added comments to describe the algo better --- clang/runtime/dpct-rt/include/dpct/math.hpp | 68 ++++++++++++--------- clang/test/dpct/asm/mma.cu | 12 ++-- 2 files changed, 44 insertions(+), 36 deletions(-) diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 0bbe40cee90f..67859e68ef69 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -2228,24 +2228,24 @@ template struct MMAType { }; /// Each work item of a sub-group (limited to size 32) calling this function -/// calculates a subset fragment for the output D matrix using MAD operation on +/// calculates a subset fragment for the output matrix D using MAD operation on /// A, B & C matrix fragments (D = A * B + C). /// Current supported shapes & types: /// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) /// Here, m, n & k define the shapes of A, B & C matrices respectively /// (A = [m x k], B = [k x n], C = [m x n]). -/// \tparam [in] M The rows of A, C & D matrix -/// \tparam [in] N The columns of B, C, D matrix +/// \tparam [in] M The rows of A, C & D matrices +/// \tparam [in] N The columns of B, C, D matrices /// \tparam [in] K The columns & rows of A & B matrices respectively /// \tparam [in] ABType The type of the input matrix (A & B) fragment /// \tparam [in] CDType The type of the output matrix (C & D) fragment -/// \param [out] d_mat_frag The fragment of the output D matrix to store the +/// \param [out] d_mat_frag The fragment of the output matrix D to store the /// result of A * B + C -/// \param [in] a_mat_frag The fragment of the input A matrix to be multiplied +/// \param [in] a_mat_frag The fragment of the input matrix A to be multiplied /// with B matrix fragment -/// \param [in] b_mat_frag The fragment of the input B matrix to be multiplied +/// \param [in] b_mat_frag The fragment of the input matrix B to be multiplied /// with A matrix fragment -/// \param [in] c_mat_frag The fragment of the input C matrix to be added with +/// \param [in] c_mat_frag The fragment of the input matrix C to be added with /// the result of A * B fragments template void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, @@ -2261,8 +2261,8 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, static_assert(M == 16 && N == 8 && K == 16, "Only m16n8k16 shape is supported!"); - short ROW_LOAD_OFFSET = 4 * (lane >> 2); - short COL_LOAD_OFFSET = 8 * (lane % 4); + short row_load_offset = 4 * (lane >> 2); + short col_load_offset = 8 * (lane % 4); if constexpr (M == 16 && N == 8 && K == 16) { if constexpr (std::is_floating_point_v) { @@ -2272,34 +2272,38 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, *d[2] = c[2]; *d[3] = c[3]; - // Each work item Wi (i=0...31) gathers 2 row & 2 col matrix fragments - // of length k (8) from A & B matrices respectively into recv_a & recv_b - // across 4 iterations using 4 neighboring work items with below mapping - // logic: + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment of + // matrix A (row) and matrix B (col) using the row & col offsets. for (int i = 0; i < 4; i++) { typename MMAType::PackType recv_a[4], recv_b[4]; // Load partial fragment from row0 of matrix A ({a0, a1}) - recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); // Load partial fragment from row0 of matrix A ({a2, a3}) - recv_a[1] = dpct::select_from_sub_group(sg, a[2], ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a[2], row_load_offset + i); // Load partial fragment from row1 of matrix A ({a0, a1}) - recv_a[2] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + recv_a[2] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i); // Load partial fragment from row1 of matrix A ({a2, a3}) - recv_a[3] = dpct::select_from_sub_group(sg, a[3], ROW_LOAD_OFFSET + i); + recv_a[3] = dpct::select_from_sub_group(sg, a[3], row_load_offset + i); // Load partial fragment from col0 of matrix B ({b0, b1}) - recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i); // Load partial fragment from col0 of matrix B ({b2, b3}) - recv_b[1] = dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + i); + recv_b[1] = dpct::select_from_sub_group(sg, b[1], col_load_offset + i); // Load partial fragment from col1 of matrix B ({b0, b1}) recv_b[2] = - dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + 4 + i); + dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i); // Load partial fragment from col1 of matrix B ({b2, b3}) recv_b[3] = - dpct::select_from_sub_group(sg, b[1], COL_LOAD_OFFSET + 4 + i); + dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i); auto ra = reinterpret_cast(recv_a); auto rb = reinterpret_cast(recv_b); @@ -2309,7 +2313,7 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } // d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } - // d3 += row1{ a1, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + // d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } for (int j = 0; j < 4; j++) { *d[0] += static_cast(ra[j]) * static_cast(rb[j]); *d[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); @@ -2325,24 +2329,28 @@ void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, *d[2] = c[2]; *d[3] = c[3]; - // Each work item Wi (i=0...31) gathers 2 row & 2 col matrix fragments - // of length k (8) from A & B matrices respectively into recv_a & recv_b - // across 4 iterations using 4 neighboring work items with below mapping - // logic: + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment of + // matrix A (row) and matrix B (col) using the row & col offsets. for (int i = 0; i < 4; i++) { typename MMAType::PackType recv_a[2], recv_b[2]; // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) - recv_a[0] = dpct::select_from_sub_group(sg, a[0], ROW_LOAD_OFFSET + i); + recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) - recv_a[1] = dpct::select_from_sub_group(sg, a[1], ROW_LOAD_OFFSET + i); + recv_a[1] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i); // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) - recv_b[0] = dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i); + recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i); // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) recv_b[1] = - dpct::select_from_sub_group(sg, b[0], COL_LOAD_OFFSET + i + 4); + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); auto ra = reinterpret_cast(recv_a); auto rb = reinterpret_cast(recv_b); diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu index dc7ab8d9942d..3186623b3ec6 100644 --- a/clang/test/dpct/asm/mma.cu +++ b/clang/test/dpct/asm/mma.cu @@ -24,10 +24,10 @@ B Layout: col __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK: { // CHECK-NEXT: volatile float *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; - // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); - // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1], fc[2], fc[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, int32_t, float>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, float>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " " { %0, %1, %2, %3 }, " @@ -40,10 +40,10 @@ __global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { // CHECK: { // CHECK-NEXT: volatile int32_t *d_mat_frag_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; - // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1]); - // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0]); + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0]); // CHECK-NEXT: sycl::vec c_mat_frag_ct1(c[0], c[1], c[2], c[3]); - // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t, int32_t>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); // CHECK-NEXT: } asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " " { %0, %1, %2, %3 }, "