diff --git a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp index 2f940f84d7f9..32d94724fe0b 100644 --- a/clang/lib/DPCT/RulesAsm/AsmMigration.cpp +++ b/clang/lib/DPCT/RulesAsm/AsmMigration.cpp @@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) { bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) { switch (T->getKind()) { // clang-format off + case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break; case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break; case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break; case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break; + case InlineAsmBuiltinType::s4: OS() << "int8_t"; break; case InlineAsmBuiltinType::s8: OS() << "int8_t"; break; case InlineAsmBuiltinType::s16: OS() << "int16_t"; break; case InlineAsmBuiltinType::s32: OS() << "int32_t"; break; @@ -559,6 +562,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) { case InlineAsmVectorType::x1: OS() << 1; break; + case InlineAsmVectorType::v1: + OS() << 1; + break; case InlineAsmVectorType::v2: case InlineAsmVectorType::x2: OS() << 2; @@ -1370,6 +1376,167 @@ class SYCLGen : public SYCLGenBase { return SYCLGenSuccess(); } + bool handle_mma(const InlineAsmInstruction *Inst) override { + if (Inst->getNumInputOperands() != 3) + return SYCLGenError(); + + const InlineAsmVectorExpr *DMatVE = + dyn_cast(Inst->getOutputOperand()); + if (!DMatVE) + return SYCLGenError(); + + // Only row Layout is supported for of A matrix and + // only col Layout is supported for of B matrix + if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col) + return SYCLGenError(); + + // Data types of D, A, B & C matrices respectively in the PTX instruction + const auto *DType = dyn_cast(Inst->getType(0)); + const auto *AType = dyn_cast(Inst->getType(1)); + const auto *BType = dyn_cast(Inst->getType(2)); + const auto *CType = dyn_cast(Inst->getType(3)); + + if (!(AType && BType && CType && DType)) + return SYCLGenError(); + + // Data types of matrix elements for A&B and C&D matrices should be same + if ((AType->getKind() != BType->getKind()) || + (CType->getKind() != DType->getKind())) + return SYCLGenError(); + + // Check the validity of AB & CD types + std::string ABType, CDType; + if (tryEmitType(ABType, AType)) + return SYCLGenError(); + + if (tryEmitType(CDType, CType)) + return SYCLGenError(); + + // Register sizes for vector elements of A, B, C & D matrices + unsigned NumVecElements[4] = {0}; + + // Sizes of A & B matrices + std::string M, N, K; + + // Data types of A, B & C matrices respectively in the PTX arguments + std::string InMatrixType[3]; + + if (Inst->hasAttr(InstAttr::m16n8k16)) { + M = "16"; + N = "8"; + K = "16"; + + // Only f16/s8 types are supported for A and B matrices of m16n8k16 + if (AType->getKind() == InlineAsmBuiltinType::f16) { + InMatrixType[0] = "uint32_t"; // A type is .f16x2 + InMatrixType[1] = "uint32_t"; // B type is .f16x2 + + // If A matrix type is f16, then C&D matrix types can only be f32 + if (CType->getKind() == InlineAsmBuiltinType::f32) { + NumVecElements[0] = 4; // A + NumVecElements[1] = 2; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else if (AType->getKind() == InlineAsmBuiltinType::s8) { + InMatrixType[0] = "uint32_t"; // A type is .f16x2 + InMatrixType[1] = "uint32_t"; // B type is .f16x2 + + // If A matrix type is s8, then C&D matrix types can only be s32 + if (CType->getKind() == InlineAsmBuiltinType::s32) { + NumVecElements[0] = 2; // A + NumVecElements[1] = 1; // B + NumVecElements[2] = 4; // C + NumVecElements[3] = 4; // D + } else + return SYCLGenError(); + } else + return SYCLGenError(); + } else + return SYCLGenError(); + + InMatrixType[2] = CDType; + + // Check the register sizes for vector elements of A, B, C & D matrices + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + if (VE->getNumElements() != NumVecElements[InputOp]) + return SYCLGenError(); + } else + return SYCLGenError(); + } + if (DMatVE->getNumElements() != NumVecElements[3]) + return SYCLGenError(); + + // Declare and init an array for storing the addresses of D matrix elements + OS() << "{\n"; + OS() << "volatile " << CDType << " *d_mat_frag_ct1[" + << DMatVE->getNumElements() << "] = { "; + for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) { + if (isa(DMatVE->getElement(Inst))) + continue; + OS() << "&"; + if (emitStmt(DMatVE->getElement(Inst))) + return SYCLGenError(); + if ((Inst + 1) != DMatVE->getNumElements()) + OS() << ", "; + } + OS() << " }"; + endstmt(); + + // Declare and init vectors for storing the values of A, B & C matrix + // elements + std::string InMatrixName[3] = {"a", "b", "c"}; + for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands(); + InputOp++) { + if (auto VE = + dyn_cast(Inst->getInputOperand(InputOp))) { + OS() << "sycl::vec<" << InMatrixType[InputOp] << ", " + << VE->getNumElements() << "> " << InMatrixName[InputOp] + << "_mat_frag_ct1("; + for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) { + if (isa(VE->getElement(Inst))) + continue; + if (emitStmt(VE->getElement(Inst))) + return SYCLGenError(); + if ((Inst + 1) != VE->getNumElements()) + OS() << ", "; + } + OS() << ")"; + endstmt(); + } else { + return SYCLGenError(); + } + } + + OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma"; + OS() << "<"; + OS() << M << ", " << N << ", " << K << ", "; + OS() << ABType << ", " << CDType; + OS() << ">("; + + OS() << "reinterpret_cast(d_mat_frag_ct1)"; + for (int i = 0; i < 3; i++) + OS() << ", &" << InMatrixName[i] << "_mat_frag_ct1"; + OS() << ")"; + endstmt(); + OS() << "}"; + endstmt(); + + const auto *KernelDecl = getImmediateOuterFuncDecl(GAS); + if (KernelDecl) { + auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl); + if (FuncInfo) + FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(), + DpctGlobalInfo::getSubGroup(GAS)); + } + + return SYCLGenSuccess(); + } + bool handle_prefetch(const InlineAsmInstruction *Inst) override { if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1) return SYCLGenError(); @@ -2595,11 +2762,10 @@ class SYCLGen : public SYCLGenBase { Op = std::move(NewOp); } - bool HasHalfOrBfloat16 = - SrcType->getKind() == InlineAsmBuiltinType::f16 || - DesType->getKind() == InlineAsmBuiltinType::f16 || - SrcType->getKind() == InlineAsmBuiltinType::bf16 || - DesType->getKind() == InlineAsmBuiltinType::bf16; + bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 || + DesType->getKind() == InlineAsmBuiltinType::f16 || + SrcType->getKind() == InlineAsmBuiltinType::bf16 || + DesType->getKind() == InlineAsmBuiltinType::bf16; if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) { insertHeader(HeaderType::HT_SYCL_Math); if (SrcNeedBitCast) diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h index 42cce5902f97..640fd536613b 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h @@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType { return ((K == Kind) || ...); } template bool isNot(Ks... K) { return ((K != Kind) && ...); } - bool isBit() const { return isOneOf(b8, b16, b32, b64); } - bool isSigned() const { return isOneOf(s8, s16, s32, s64); } - bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); } + bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); } + bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); } + bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); } bool isInt() const { return isSigned() || isUnsigned(); } bool isFloat() const { return isOneOf(f16, f32, f64); } bool isScalar() const { return isInt() || isFloat(); } @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType { // This class is used for device asm vector types. class InlineAsmVectorType : public InlineAsmType { public: - enum VecKind { v2, v4, v8, x1, x2, x4 }; + enum VecKind { v1, v2, v4, v8, x1, x2, x4 }; private: VecKind Kind; @@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt { /// This represents arrtibutes like: comparsion operator, rounding modifiers, /// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'. - SmallSet Attributes; + SmallVector Attributes; /// This represents types in instruction, e.g. mov.u32. SmallVector Types; @@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt { OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) { StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(), AsmStateSpaces.end()); - Attributes.insert(Attrs.begin(), Attrs.end()); + Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end()); } using attr_range = - llvm::iterator_range::const_iterator>; + llvm::iterator_range::const_iterator>; using type_range = llvm::iterator_range::const_iterator>; using op_range = @@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt { } template bool hasAttr(Ts... Attrs) const { - return (Attributes.contains(Attrs) || ...); + return (llvm::is_contained(Attributes, Attrs) || ...); } const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; } asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); } ArrayRef getTypes() const { return Types; } const InlineAsmType *getType(unsigned I) const { return Types[I]; } + InstAttr getAttr(unsigned I) const { + assert(I < Attributes.size() && "Attributes index out of range"); + return Attributes[I]; + } unsigned getNumTypes() const { return Types.size(); } const InlineAsmExpr *getOutputOperand() const { return OutputOp; } const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; } diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp index 60e49a9c0c56..dc3b70b8373f 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp @@ -756,6 +756,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef Vec) { } else { // Vector size must be 2, 4, or 8. switch (Vec.size()) { + case 1: + Kind = InlineAsmVectorType::v1; + break; case 2: Kind = InlineAsmVectorType::v2; break; diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h index 8b9a3f5f01ba..4de5db7bf658 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmParser.h @@ -498,7 +498,7 @@ class InlineAsmParser { /// .reg .sreg .const .local .param .shared .tex /// /// vector-specifier: one of - /// .v2 .v4 .v8 + /// .v1 .v2 .v4 .v8 /// /// type-specifier: one of /// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64 diff --git a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def index ea401fb0777c..c7ef34d725f3 100644 --- a/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def +++ b/clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def @@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64) SPECIAL_REG(WARP_SZ, "WARP_SZ", s64) // Built-in type names +BUILTIN_TYPE(b1, ".b1") BUILTIN_TYPE(b8, ".b8") BUILTIN_TYPE(b16, ".b16") BUILTIN_TYPE(b32, ".b32") BUILTIN_TYPE(b64, ".b64") +BUILTIN_TYPE(u4, ".u4") BUILTIN_TYPE(u8, ".u8") BUILTIN_TYPE(u16, ".u16") BUILTIN_TYPE(u32, ".u32") BUILTIN_TYPE(u64, ".u64") +BUILTIN_TYPE(s4, ".s4") BUILTIN_TYPE(s8, ".s8") BUILTIN_TYPE(s16, ".s16") BUILTIN_TYPE(s32, ".s32") @@ -270,6 +273,7 @@ BUILTIN_TYPE(s16x2, ".s16x2") BUILTIN_TYPE(u16x2, ".u16x2") // Vector modifiers +MODIFIER(v1, ".v1") MODIFIER(v2, ".v2") MODIFIER(v4, ".v4") MODIFIER(v8, ".v8") @@ -279,8 +283,23 @@ MODIFIER(x1, ".x1") MODIFIER(x2, ".x2") MODIFIER(x4, ".x4") +// Matrix modifiers +MODIFIER(row, ".row") +MODIFIER(col, ".col") + // Matrix shape MODIFIER(m8n8, ".m8n8") +MODIFIER(m8n8k4, ".m8n8k4") +MODIFIER(m8n8k16, ".m8n8k16") +MODIFIER(m8n8k32, ".m8n8k32") +MODIFIER(m8n8k128, ".m8n8k128") +MODIFIER(m16n8k4, ".m16n8k4") +MODIFIER(m16n8k8, ".m16n8k8") +MODIFIER(m16n8k16, ".m16n8k16") +MODIFIER(m16n8k32, ".m16n8k32") +MODIFIER(m16n8k64, ".m16n8k64") +MODIFIER(m16n8k128, ".m16n8k128") +MODIFIER(m16n8k256, ".m16n8k256") STATE_SPACE(reg, ".reg") STATE_SPACE(sreg, ".sreg") @@ -376,6 +395,7 @@ MODIFIER(max, ".max") MODIFIER(op_or, ".or") MODIFIER(op_xor, ".xor") MODIFIER(op_and, ".and") +MODIFIER(op_popc, ".popc") MODIFIER(cas, ".cas") MODIFIER(exch, ".exch") MODIFIER(inc, ".inc") diff --git a/clang/runtime/dpct-rt/include/dpct/math.hpp b/clang/runtime/dpct-rt/include/dpct/math.hpp index 69004f702181..67859e68ef69 100644 --- a/clang/runtime/dpct-rt/include/dpct/math.hpp +++ b/clang/runtime/dpct-rt/include/dpct/math.hpp @@ -1636,7 +1636,8 @@ inline constexpr unsigned extend_vcompare2_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg2(AT a, BT b, RetT c) { - return detail::extend_vbinary2(a, b, c, detail::average()); + return detail::extend_vbinary2(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 2 @@ -1933,7 +1934,8 @@ inline constexpr unsigned extend_vcompare4_add(AT a, BT b, unsigned c, /// \returns The extend vectorized average of the two values template inline constexpr RetT extend_vavrg4(AT a, BT b, RetT c) { - return detail::extend_vbinary4(a, b, c, detail::average()); + return detail::extend_vbinary4(a, b, c, + detail::average()); } /// Compute vectorized average of \p a and \p b, with each value treated as a 4 @@ -2216,6 +2218,160 @@ void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4, bool trans = false) { ldmatrix(addr, m4, trans, 3); } +/// A helper struct that defines the pack type for the input matrix fragments +/// of mma() function based on the type of input matrix fragments. +/// The MMAType struct is specialized for different types of input matrices. +/// Currently, the specialization for f16 and s8 types is defined below. +/// \tparam [in] T The type of the input matrix fragments +template struct MMAType { + using PackType = uint32_t; +}; + +/// Each work item of a sub-group (limited to size 32) calling this function +/// calculates a subset fragment for the output matrix D using MAD operation on +/// A, B & C matrix fragments (D = A * B + C). +/// Current supported shapes & types: +/// - m16n8k16 (f32.f16.f16.f32 & s32.s8.s8.s32) +/// Here, m, n & k define the shapes of A, B & C matrices respectively +/// (A = [m x k], B = [k x n], C = [m x n]). +/// \tparam [in] M The rows of A, C & D matrices +/// \tparam [in] N The columns of B, C, D matrices +/// \tparam [in] K The columns & rows of A & B matrices respectively +/// \tparam [in] ABType The type of the input matrix (A & B) fragment +/// \tparam [in] CDType The type of the output matrix (C & D) fragment +/// \param [out] d_mat_frag The fragment of the output matrix D to store the +/// result of A * B + C +/// \param [in] a_mat_frag The fragment of the input matrix A to be multiplied +/// with B matrix fragment +/// \param [in] b_mat_frag The fragment of the input matrix B to be multiplied +/// with A matrix fragment +/// \param [in] c_mat_frag The fragment of the input matrix C to be added with +/// the result of A * B fragments +template +void mma(volatile void **d_mat_frag, void *a_mat_frag, void *b_mat_frag, + void *c_mat_frag) { + auto d = reinterpret_cast(d_mat_frag); + auto a = reinterpret_cast::PackType *>(a_mat_frag); + auto b = reinterpret_cast::PackType *>(b_mat_frag); + auto c = reinterpret_cast(c_mat_frag); + + auto sg = sycl::ext::oneapi::this_work_item::get_sub_group(); + int lane = sg.get_local_linear_id(); + + static_assert(M == 16 && N == 8 && K == 16, + "Only m16n8k16 shape is supported!"); + + short row_load_offset = 4 * (lane >> 2); + short col_load_offset = 8 * (lane % 4); + + if constexpr (M == 16 && N == 8 && K == 16) { + if constexpr (std::is_floating_point_v) { + // Init D matrix fragment with C matrix fragment + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment of + // matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[4], recv_b[4]; + + // Load partial fragment from row0 of matrix A ({a0, a1}) + recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row0 of matrix A ({a2, a3}) + recv_a[1] = dpct::select_from_sub_group(sg, a[2], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a0, a1}) + recv_a[2] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a2, a3}) + recv_a[3] = dpct::select_from_sub_group(sg, a[3], row_load_offset + i); + + // Load partial fragment from col0 of matrix B ({b0, b1}) + recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col0 of matrix B ({b2, b3}) + recv_b[1] = dpct::select_from_sub_group(sg, b[1], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b0, b1}) + recv_b[2] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + 4 + i); + // Load partial fragment from col1 of matrix B ({b2, b3}) + recv_b[3] = + dpct::select_from_sub_group(sg, b[1], col_load_offset + 4 + i); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix fragments + // and adds it to the corresponding D matrix fragment + // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d1 += row0{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + // d2 += row1{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d3 += row1{ a0, a1, a2, a3 } * col1{ b0, b1, b2, b3 } + for (int j = 0; j < 4; j++) { + *d[0] += static_cast(ra[j]) * static_cast(rb[j]); + *d[1] += static_cast(ra[j]) * static_cast(rb[j + 4]); + *d[2] += static_cast(ra[j + 4]) * static_cast(rb[j]); + *d[3] += + static_cast(ra[j + 4]) * static_cast(rb[j + 4]); + } + } + } else if constexpr (std::is_integral_v) { + // Init D matrix with fragments of C matrix + *d[0] = c[0]; + *d[1] = c[1]; + *d[2] = c[2]; + *d[3] = c[3]; + + // Each sub-group is responsible for computing a fragment size of 16*8 + // elements of matrix D. + // Each work item computes 4 elements of matrix D by gathering + // their corresponding row & col matrix fragments of length k (8) + // from A & B matrices respectively using below mapping logic: + // row0 = (lane >> 2) & row1 = (lane >> 2) + 8 + // col0 = (lane % 4) * 2 & col1 = (lane % 4) * 2 + 1 + // As each row & col fragment of A & B matrices is distributed across + // 4 work items, each iteration of below loop loads a partial fragment of + // matrix A (row) and matrix B (col) using the row & col offsets. + for (int i = 0; i < 4; i++) { + typename MMAType::PackType recv_a[2], recv_b[2]; + + // Load partial fragment from row0 of matrix A ({a0, a1, a2, a3}) + recv_a[0] = dpct::select_from_sub_group(sg, a[0], row_load_offset + i); + // Load partial fragment from row1 of matrix A ({a4, a5, a6, a7}) + recv_a[1] = dpct::select_from_sub_group(sg, a[1], row_load_offset + i); + // Load partial fragment from col0 of matrix B ({b0, b1, b2, b3}) + recv_b[0] = dpct::select_from_sub_group(sg, b[0], col_load_offset + i); + // Load partial fragment from col1 of matrix B ({b4, b5, b6, b7}) + recv_b[1] = + dpct::select_from_sub_group(sg, b[0], col_load_offset + i + 4); + + auto ra = reinterpret_cast(recv_a); + auto rb = reinterpret_cast(recv_b); + + // Each work item calculates a partial product of A & B matrix fragments + // and adds it to the corresponding D matrix fragment + // d0 += row0{ a0, a1, a2, a3 } * col0{ b0, b1, b2, b3 } + // d1 += row0{ a0, a1, a2, a3 } * col1{ b4, b5, b6, b7 } + // d2 += row1{ a4, a5, a6, a7 } * col0{ b0, b1, b2, b3 } + // d3 += row1{ a4, a5, a6, a7 } * col1{ b4, b5, b6, b7 } + for (int i = 0; i < 4; i++) { + *d[0] += ra[i] * rb[i]; + *d[1] += ra[i] * rb[i + 4]; + *d[2] += ra[i + 4] * rb[i]; + *d[3] += ra[i + 4] * rb[i + 4]; + } + } + } + } +} + } // namespace matrix } // namespace experimental diff --git a/clang/test/dpct/asm/mma.cu b/clang/test/dpct/asm/mma.cu new file mode 100644 index 000000000000..3186623b3ec6 --- /dev/null +++ b/clang/test/dpct/asm/mma.cu @@ -0,0 +1,69 @@ +// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2 +// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2 +// RUN: dpct --format-range=none -out-root %T/mma %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only +// RUN: FileCheck %s --match-full-lines --input-file %T/mma/mma.dp.cpp +// RUN: %if build_lit %{icpx -c -DNO_BUILD_TEST -fsycl %T/mma/mma.dp.cpp -o %T/mma/mma.dp.o %} + +// clang-format off +#include +#include + +/* +As per PTX ASM 8.1, below is the status of supported configurations + +--------- --------- ---------- ----------- ------------- +| Shape | | A | | B | | C / D | | Supported | +--------- --------- ---------- ----------- ------------- +m16n8k16 .f16 .f16 .f16/.f32 Yes + .s8 .s8 .s32 Yes + +A Layout: row +B Layout: col +*/ + +__global__ void mma_kernel_m16n8k16(int *a, int *b, int *c, float *fc, int *d) { + // CHECK: { + // CHECK-NEXT: volatile float *d_mat_frag_ct1[4] = { &fc[0], &fc[1], &fc[2], &fc[3] }; + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1], a[2], a[3]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0], b[1]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(fc[0], fc[1], fc[2], fc[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, sycl::half, float>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); + // CHECK-NEXT: } + asm("mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + " { %0, %1, %2, %3 }, " + " { %4, %5, %6, %7 }, " + " { %8, %9 }, " + " { %0, %1, %2, %3 };" + : "+f"(fc[0]), "+f"(fc[1]), "+f"(fc[2]), "+f"(fc[3]) + : "r"(a[0]), "r"(a[1]), "r"(a[2]), "r"(a[3]), + "r"(b[0]), "r"(b[1])); + + // CHECK: { + // CHECK-NEXT: volatile int32_t *d_mat_frag_ct1[4] = { &d[0], &d[1], &d[2], &d[3] }; + // CHECK-NEXT: sycl::vec a_mat_frag_ct1(a[0], a[1]); + // CHECK-NEXT: sycl::vec b_mat_frag_ct1(b[0]); + // CHECK-NEXT: sycl::vec c_mat_frag_ct1(c[0], c[1], c[2], c[3]); + // CHECK-NEXT: dpct::experimental::matrix::mma<16, 8, 16, int8_t, int32_t>(reinterpret_cast(d_mat_frag_ct1), &a_mat_frag_ct1, &b_mat_frag_ct1, &c_mat_frag_ct1); + // CHECK-NEXT: } + asm("mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + " { %0, %1, %2, %3 }, " + " { %4, %5 }, " + " { %6 }, " + " { %7, %8, %9, %10 };" + : "=r"(d[0]), "=r"(d[1]), "=r"(d[2]), "=r"(d[3]) + : "r"(a[0]), "r"(a[1]), + "r"(b[0]), + "r"(c[0]), "r"(c[1]), "r"(c[2]), "r"(c[3])); +} + + +int main () { + int *int_a, *int_b, *int_c, *int_d; + float *float_c; + + // CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} { + mma_kernel_m16n8k16<<<1, 32>>>(int_a, int_b, int_c, float_c, int_d); + + return 0; +} +// clang-format on diff --git a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv index 2f45259b90e7..1b68950a15ae 100644 --- a/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv +++ b/docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv @@ -54,7 +54,7 @@ max,YES, mbarrier,NO, membar,YES, Partial min,YES, -mma,NO, +mma,YES, Partial mov,YES, movmatrix,NO, mul,YES, Partial