Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 171 additions & 5 deletions clang/lib/DPCT/RulesAsm/AsmMigration.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,17 @@ bool SYCLGenBase::emitType(const InlineAsmType *T) {
bool SYCLGenBase::emitBuiltinType(const InlineAsmBuiltinType *T) {
switch (T->getKind()) {
// clang-format off
case InlineAsmBuiltinType::b1: OS() << "uint8_t"; break;
case InlineAsmBuiltinType::b8: OS() << "uint8_t"; break;
case InlineAsmBuiltinType::b16: OS() << "uint16_t"; break;
case InlineAsmBuiltinType::b32: OS() << "uint32_t"; break;
case InlineAsmBuiltinType::b64: OS() << "uint64_t"; break;
case InlineAsmBuiltinType::u4: OS() << "uint8_t"; break;
case InlineAsmBuiltinType::u8: OS() << "uint8_t"; break;
case InlineAsmBuiltinType::u16: OS() << "uint16_t"; break;
case InlineAsmBuiltinType::u32: OS() << "uint32_t"; break;
case InlineAsmBuiltinType::u64: OS() << "uint64_t"; break;
case InlineAsmBuiltinType::s4: OS() << "int8_t"; break;
case InlineAsmBuiltinType::s8: OS() << "int8_t"; break;
case InlineAsmBuiltinType::s16: OS() << "int16_t"; break;
case InlineAsmBuiltinType::s32: OS() << "int32_t"; break;
Expand Down Expand Up @@ -559,6 +562,9 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
case InlineAsmVectorType::x1:
OS() << 1;
break;
case InlineAsmVectorType::v1:
OS() << 1;
break;
case InlineAsmVectorType::v2:
case InlineAsmVectorType::x2:
OS() << 2;
Expand Down Expand Up @@ -1370,6 +1376,167 @@ class SYCLGen : public SYCLGenBase {
return SYCLGenSuccess();
}

bool handle_mma(const InlineAsmInstruction *Inst) override {
if (Inst->getNumInputOperands() != 3)
return SYCLGenError();

const InlineAsmVectorExpr *DMatVE =
dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
if (!DMatVE)
return SYCLGenError();

// Only row Layout is supported for of A matrix and
// only col Layout is supported for of B matrix
if (Inst->getAttr(3) != InstAttr::row || Inst->getAttr(4) != InstAttr::col)
return SYCLGenError();

// Data types of D, A, B & C matrices respectively in the PTX instruction
const auto *DType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(0));
const auto *AType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(1));
const auto *BType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(2));
const auto *CType = dyn_cast<InlineAsmBuiltinType>(Inst->getType(3));

if (!(AType && BType && CType && DType))
return SYCLGenError();

// Data types of matrix elements for A&B and C&D matrices should be same
if ((AType->getKind() != BType->getKind()) ||
(CType->getKind() != DType->getKind()))
return SYCLGenError();

// Check the validity of AB & CD types
std::string ABType, CDType;
if (tryEmitType(ABType, AType))
return SYCLGenError();

if (tryEmitType(CDType, CType))
return SYCLGenError();

// Register sizes for vector elements of A, B, C & D matrices
unsigned NumVecElements[4] = {0};

// Sizes of A & B matrices
std::string M, N, K;

// Data types of A, B & C matrices respectively in the PTX arguments
std::string InMatrixType[3];

if (Inst->hasAttr(InstAttr::m16n8k16)) {
M = "16";
N = "8";
K = "16";

// Only f16/s8 types are supported for A and B matrices of m16n8k16
if (AType->getKind() == InlineAsmBuiltinType::f16) {
InMatrixType[0] = "uint32_t"; // A type is .f16x2
InMatrixType[1] = "uint32_t"; // B type is .f16x2

// If A matrix type is f16, then C&D matrix types can only be f32
if (CType->getKind() == InlineAsmBuiltinType::f32) {
NumVecElements[0] = 4; // A
NumVecElements[1] = 2; // B
NumVecElements[2] = 4; // C
NumVecElements[3] = 4; // D
} else
return SYCLGenError();
} else if (AType->getKind() == InlineAsmBuiltinType::s8) {
InMatrixType[0] = "uint32_t"; // A type is .f16x2
InMatrixType[1] = "uint32_t"; // B type is .f16x2

// If A matrix type is s8, then C&D matrix types can only be s32
if (CType->getKind() == InlineAsmBuiltinType::s32) {
NumVecElements[0] = 2; // A
NumVecElements[1] = 1; // B
NumVecElements[2] = 4; // C
NumVecElements[3] = 4; // D
} else
return SYCLGenError();
} else
return SYCLGenError();
} else
return SYCLGenError();

InMatrixType[2] = CDType;

// Check the register sizes for vector elements of A, B, C & D matrices
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
InputOp++) {
if (auto VE =
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
if (VE->getNumElements() != NumVecElements[InputOp])
return SYCLGenError();
} else
return SYCLGenError();
}
if (DMatVE->getNumElements() != NumVecElements[3])
return SYCLGenError();

// Declare and init an array for storing the addresses of D matrix elements
OS() << "{\n";
OS() << "volatile " << CDType << " *d_mat_frag_ct1["
<< DMatVE->getNumElements() << "] = { ";
for (unsigned Inst = 0; Inst != DMatVE->getNumElements(); ++Inst) {
if (isa<InlineAsmDiscardExpr>(DMatVE->getElement(Inst)))
continue;
OS() << "&";
if (emitStmt(DMatVE->getElement(Inst)))
return SYCLGenError();
if ((Inst + 1) != DMatVE->getNumElements())
OS() << ", ";
}
OS() << " }";
endstmt();

// Declare and init vectors for storing the values of A, B & C matrix
// elements
std::string InMatrixName[3] = {"a", "b", "c"};
for (unsigned InputOp = 0; InputOp < Inst->getNumInputOperands();
InputOp++) {
if (auto VE =
dyn_cast<InlineAsmVectorExpr>(Inst->getInputOperand(InputOp))) {
OS() << "sycl::vec<" << InMatrixType[InputOp] << ", "
<< VE->getNumElements() << "> " << InMatrixName[InputOp]
<< "_mat_frag_ct1(";
for (unsigned Inst = 0; Inst != VE->getNumElements(); ++Inst) {
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
continue;
if (emitStmt(VE->getElement(Inst)))
return SYCLGenError();
if ((Inst + 1) != VE->getNumElements())
OS() << ", ";
}
OS() << ")";
endstmt();
} else {
return SYCLGenError();
}
}

OS() << MapNames::getDpctNamespace() << "experimental::matrix::mma";
OS() << "<";
OS() << M << ", " << N << ", " << K << ", ";
OS() << ABType << ", " << CDType;
OS() << ">(";

OS() << "reinterpret_cast<volatile void **>(d_mat_frag_ct1)";
for (int i = 0; i < 3; i++)
OS() << ", &" << InMatrixName[i] << "_mat_frag_ct1";
OS() << ")";
endstmt();
OS() << "}";
endstmt();

const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
if (KernelDecl) {
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
if (FuncInfo)
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
DpctGlobalInfo::getSubGroup(GAS));
}

return SYCLGenSuccess();
}

bool handle_prefetch(const InlineAsmInstruction *Inst) override {
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
return SYCLGenError();
Expand Down Expand Up @@ -2595,11 +2762,10 @@ class SYCLGen : public SYCLGenBase {
Op = std::move(NewOp);
}

bool HasHalfOrBfloat16 =
SrcType->getKind() == InlineAsmBuiltinType::f16 ||
DesType->getKind() == InlineAsmBuiltinType::f16 ||
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
DesType->getKind() == InlineAsmBuiltinType::bf16;
bool HasHalfOrBfloat16 = SrcType->getKind() == InlineAsmBuiltinType::f16 ||
DesType->getKind() == InlineAsmBuiltinType::f16 ||
SrcType->getKind() == InlineAsmBuiltinType::bf16 ||
DesType->getKind() == InlineAsmBuiltinType::bf16;
if (DpctGlobalInfo::useIntelDeviceMath() && HasHalfOrBfloat16) {
insertHeader(HeaderType::HT_SYCL_Math);
if (SrcNeedBitCast)
Expand Down
20 changes: 12 additions & 8 deletions clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ class InlineAsmBuiltinType : public InlineAsmType {
return ((K == Kind) || ...);
}
template <class... Ks> bool isNot(Ks... K) { return ((K != Kind) && ...); }
bool isBit() const { return isOneOf(b8, b16, b32, b64); }
bool isSigned() const { return isOneOf(s8, s16, s32, s64); }
bool isUnsigned() const { return isOneOf(u8, u16, u32, u64); }
bool isBit() const { return isOneOf(b1, b8, b16, b32, b64); }
bool isSigned() const { return isOneOf(s4, s8, s16, s32, s64); }
bool isUnsigned() const { return isOneOf(u4, u8, u16, u32, u64); }
bool isInt() const { return isSigned() || isUnsigned(); }
bool isFloat() const { return isOneOf(f16, f32, f64); }
bool isScalar() const { return isInt() || isFloat(); }
Expand All @@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
// This class is used for device asm vector types.
class InlineAsmVectorType : public InlineAsmType {
public:
enum VecKind { v2, v4, v8, x1, x2, x4 };
enum VecKind { v1, v2, v4, v8, x1, x2, x4 };

private:
VecKind Kind;
Expand Down Expand Up @@ -322,7 +322,7 @@ class InlineAsmInstruction : public InlineAsmStmt {

/// This represents arrtibutes like: comparsion operator, rounding modifiers,
/// ... e.g. instruction setp.eq.s32 has a comparsion operator 'eq'.
SmallSet<InstAttr, 4> Attributes;
SmallVector<InstAttr, 4> Attributes;

/// This represents types in instruction, e.g. mov.u32.
SmallVector<InlineAsmType *, 4> Types;
Expand Down Expand Up @@ -350,11 +350,11 @@ class InlineAsmInstruction : public InlineAsmStmt {
OutputOp(Out), PredOutputOp(Pred), InputOps(InOps) {
StateSpaces.insert(StateSpaces.begin(), AsmStateSpaces.begin(),
AsmStateSpaces.end());
Attributes.insert(Attrs.begin(), Attrs.end());
Attributes.insert(Attributes.begin(), Attrs.begin(), Attrs.end());
}

using attr_range =
llvm::iterator_range<SmallSet<InstAttr, 4>::const_iterator>;
llvm::iterator_range<SmallVector<InstAttr, 4>::const_iterator>;
using type_range =
llvm::iterator_range<SmallVector<InlineAsmType *, 4>::const_iterator>;
using op_range =
Expand All @@ -369,12 +369,16 @@ class InlineAsmInstruction : public InlineAsmStmt {
}

template <typename... Ts> bool hasAttr(Ts... Attrs) const {
return (Attributes.contains(Attrs) || ...);
return (llvm::is_contained(Attributes, Attrs) || ...);
}
const InlineAsmIdentifierInfo *getOpcodeID() const { return Opcode; }
asmtok::TokenKind getOpcode() const { return Opcode->getTokenID(); }
ArrayRef<InlineAsmType *> getTypes() const { return Types; }
const InlineAsmType *getType(unsigned I) const { return Types[I]; }
InstAttr getAttr(unsigned I) const {
assert(I < Attributes.size() && "Attributes index out of range");
return Attributes[I];
}
unsigned getNumTypes() const { return Types.size(); }
const InlineAsmExpr *getOutputOperand() const { return OutputOp; }
const InlineAsmExpr *getPredOutputOperand() const { return PredOutputOp; }
Expand Down
3 changes: 3 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,9 @@ InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
} else {
// Vector size must be 2, 4, or 8.
switch (Vec.size()) {
case 1:
Kind = InlineAsmVectorType::v1;
break;
case 2:
Kind = InlineAsmVectorType::v2;
break;
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RulesAsm/Parser/AsmParser.h
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ class InlineAsmParser {
/// .reg .sreg .const .local .param .shared .tex
///
/// vector-specifier: one of
/// .v2 .v4 .v8
/// .v1 .v2 .v4 .v8
///
/// type-specifier: one of
/// .b8 .b16 .b32 .b64 .s8 .s16 .s32 .s64
Expand Down
20 changes: 20 additions & 0 deletions clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def
Original file line number Diff line number Diff line change
Expand Up @@ -240,14 +240,17 @@ SPECIAL_REG(warpid, "%warpid", s64)
SPECIAL_REG(WARP_SZ, "WARP_SZ", s64)

// Built-in type names
BUILTIN_TYPE(b1, ".b1")
BUILTIN_TYPE(b8, ".b8")
BUILTIN_TYPE(b16, ".b16")
BUILTIN_TYPE(b32, ".b32")
BUILTIN_TYPE(b64, ".b64")
BUILTIN_TYPE(u4, ".u4")
BUILTIN_TYPE(u8, ".u8")
BUILTIN_TYPE(u16, ".u16")
BUILTIN_TYPE(u32, ".u32")
BUILTIN_TYPE(u64, ".u64")
BUILTIN_TYPE(s4, ".s4")
BUILTIN_TYPE(s8, ".s8")
BUILTIN_TYPE(s16, ".s16")
BUILTIN_TYPE(s32, ".s32")
Expand All @@ -270,6 +273,7 @@ BUILTIN_TYPE(s16x2, ".s16x2")
BUILTIN_TYPE(u16x2, ".u16x2")

// Vector modifiers
MODIFIER(v1, ".v1")
MODIFIER(v2, ".v2")
MODIFIER(v4, ".v4")
MODIFIER(v8, ".v8")
Expand All @@ -279,8 +283,23 @@ MODIFIER(x1, ".x1")
MODIFIER(x2, ".x2")
MODIFIER(x4, ".x4")

// Matrix modifiers
MODIFIER(row, ".row")
MODIFIER(col, ".col")

// Matrix shape
MODIFIER(m8n8, ".m8n8")
MODIFIER(m8n8k4, ".m8n8k4")
MODIFIER(m8n8k16, ".m8n8k16")
MODIFIER(m8n8k32, ".m8n8k32")
MODIFIER(m8n8k128, ".m8n8k128")
MODIFIER(m16n8k4, ".m16n8k4")
MODIFIER(m16n8k8, ".m16n8k8")
MODIFIER(m16n8k16, ".m16n8k16")
MODIFIER(m16n8k32, ".m16n8k32")
MODIFIER(m16n8k64, ".m16n8k64")
MODIFIER(m16n8k128, ".m16n8k128")
MODIFIER(m16n8k256, ".m16n8k256")

STATE_SPACE(reg, ".reg")
STATE_SPACE(sreg, ".sreg")
Expand Down Expand Up @@ -376,6 +395,7 @@ MODIFIER(max, ".max")
MODIFIER(op_or, ".or")
MODIFIER(op_xor, ".xor")
MODIFIER(op_and, ".and")
MODIFIER(op_popc, ".popc")
MODIFIER(cas, ".cas")
MODIFIER(exch, ".exch")
MODIFIER(inc, ".inc")
Expand Down
Loading