Skip to content

Commit b03c620

Browse files
Added support for ldmatrix migration
1 parent 16c5bc3 commit b03c620

12 files changed

Lines changed: 373 additions & 38 deletions

File tree

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,15 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
557557
OS() << ", ";
558558
switch (T->getKind()) {
559559
case InlineAsmVectorType::v2:
560+
case InlineAsmVectorType::x1:
560561
OS() << 2;
561562
break;
562563
case InlineAsmVectorType::v4:
564+
case InlineAsmVectorType::x2:
563565
OS() << 4;
564566
break;
565567
case InlineAsmVectorType::v8:
568+
case InlineAsmVectorType::x4:
566569
OS() << 8;
567570
break;
568571
}
@@ -591,7 +594,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
591594
// Address expression only support ld/st/red & atom instructions.
592595
if (!CurrInst ||
593596
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
594-
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
597+
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp,
598+
asmtok::op_ldmatrix)) {
595599
return SYCLGenError();
596600
}
597601
std::string Type;
@@ -624,6 +628,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
624628
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
625629
CanSuppressCast(Dst->getSymbol()))
626630
OS() << llvm::formatv("{0}", Reg);
631+
else if (CurrInst->is(asmtok::op_ldmatrix))
632+
OS() << llvm::formatv("(uintptr_t){0}", Reg);
627633
else
628634
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
629635
break;
@@ -1290,6 +1296,39 @@ class SYCLGen : public SYCLGenBase {
12901296
return SYCLGenSuccess();
12911297
}
12921298

1299+
bool handle_ldmatrix(const InlineAsmInstruction *Inst) override {
1300+
if (Inst->getNumInputOperands() != 1)
1301+
return SYCLGenError();
1302+
1303+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1304+
CurrInst = Inst;
1305+
const auto *Src =
1306+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
1307+
if (!Src)
1308+
return false;
1309+
1310+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::ldmatrix(";
1311+
if (emitStmt(Src)) {
1312+
return SYCLGenError();
1313+
}
1314+
OS() << ", ";
1315+
const auto *VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1316+
for (unsigned Inst = 0, E = VE->getNumElements(); Inst != E; ++Inst) {
1317+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1318+
continue;
1319+
OS() << "&";
1320+
if (emitStmt(VE->getElement(Inst)))
1321+
return SYCLGenError();
1322+
OS() << ", ";
1323+
}
1324+
OS() << DpctGlobalInfo::getItem(GAS);
1325+
if (Inst->hasAttr(InstAttr::trans))
1326+
OS() << ", true";
1327+
OS() << ");";
1328+
1329+
return SYCLGenSuccess();
1330+
}
1331+
12931332
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
12941333
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
12951334
return SYCLGenError();
@@ -2667,6 +2706,16 @@ class SYCLGen : public SYCLGenBase {
26672706
bool handle_ld(const InlineAsmInstruction *Inst) override {
26682707
if (Inst->getNumInputOperands() != 1)
26692708
return SYCLGenError();
2709+
2710+
OS() << "Size of input ops: " << Inst->getNumInputOperands() << "\n";
2711+
OS() << "Input op(0/1): " << Inst->getInputOperand(0) << "\n";
2712+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store2(CurrInst);
2713+
CurrInst = Inst;
2714+
const auto *Src2 =
2715+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
2716+
OS() << emitStmt(Src2) << "\n";
2717+
OS() << "Output op: " << emitStmt(Inst->getOutputOperand()) << "\n";
2718+
26702719
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
26712720
CurrInst = Inst;
26722721
const auto *Src =

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v2, v4, v8, x1, x2, x4 };
120120

121121
private:
122122
VecKind Kind;
@@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
340340
/// therest are input operands.
341341
SmallVector<InlineAsmExpr *, 4> InputOps;
342342

343+
SmallVector<InlineAsmExpr *, 4> OutputOps;
344+
343345
public:
344346
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345347
SmallVector<AsmStateSpace, 4> AsmStateSpaces,

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
327327
if (!Tok.getIdentifier() || !Tok.getIdentifier()->isInstruction())
328328
return AsmStmtError();
329329

330-
InlineAsmIdentifierInfo *Opcode = Tok.getIdentifier();
330+
Opcode = Tok.getIdentifier();
331331
ConsumeToken();
332332

333333
SmallVector<InstAttr, 4> Attrs;
@@ -736,20 +736,38 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size must be 2, 4, or 8.
739+
// Vector size for ldmatrix are 1, 2, 4
740+
// size(x) = 2 * sizeof(v).
740741
InlineAsmVectorType::VecKind Kind;
741-
switch (Vec.size()) {
742-
case 2:
743-
Kind = InlineAsmVectorType::v2;
744-
break;
745-
case 4:
746-
Kind = InlineAsmVectorType::v4;
747-
break;
748-
case 8:
749-
Kind = InlineAsmVectorType::v8;
750-
break;
751-
default:
752-
return AsmExprError();
742+
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
743+
switch (Vec.size()) {
744+
case 1:
745+
Kind = InlineAsmVectorType::x1;
746+
break;
747+
case 2:
748+
Kind = InlineAsmVectorType::x2;
749+
break;
750+
case 4:
751+
Kind = InlineAsmVectorType::x4;
752+
break;
753+
default:
754+
return AsmExprError();
755+
}
756+
} else {
757+
// Vector size must be 2, 4, or 8.
758+
switch (Vec.size()) {
759+
case 2:
760+
Kind = InlineAsmVectorType::v2;
761+
break;
762+
case 4:
763+
Kind = InlineAsmVectorType::v4;
764+
break;
765+
case 8:
766+
Kind = InlineAsmVectorType::v8;
767+
break;
768+
default:
769+
return AsmExprError();
770+
}
753771
}
754772

755773
InlineAsmBuiltinType *ElementType = nullptr;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ class InlineAsmParser {
247247
};
248248

249249
public:
250+
InlineAsmIdentifierInfo *Opcode;
251+
250252
InlineAsmParser(InlineAsmContext &Ctx, SourceMgr &Mgr)
251253
: Lexer(*Mgr.getMemoryBuffer(Mgr.getMainFileID())), Context(Ctx),
252254
SrcMgr(Mgr), CurScope(nullptr) {

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ MODIFIER(v2, ".v2")
274274
MODIFIER(v4, ".v4")
275275
MODIFIER(v8, ".v8")
276276

277+
// Matrix modifiers
278+
MODIFIER(x1, ".x1")
279+
MODIFIER(x2, ".x2")
280+
MODIFIER(x4, ".x4")
281+
282+
// Matrix shape
283+
MODIFIER(m8n8, ".m8n8")
284+
277285
STATE_SPACE(reg, ".reg")
278286
STATE_SPACE(sreg, ".sreg")
279287
STATE_SPACE(const, ".const")
@@ -412,7 +420,8 @@ MODIFIER(sc, ".sc")
412420
MODIFIER(gl, ".gl")
413421
MODIFIER(L1, ".L1")
414422
MODIFIER(L2, ".L2")
415-
423+
MODIFIER(aligned, ".aligned")
424+
MODIFIER(trans, ".trans")
416425

417426
#undef LINKAGE
418427
#undef TARGET

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ENTRY("griddepcontrol", "griddepcontrol", false, NO_FLAG, P1, "Comment")
7575
ENTRY("isspacep", "isspacep", false, NO_FLAG, P1, "Comment")
7676
ENTRY("istypep", "istypep", false, NO_FLAG, P1, "Comment")
7777
ENTRY("ld", "ld", true, NO_FLAG, P1, "Partial")
78-
ENTRY("ldmatrix", "ldmatrix", false, NO_FLAG, P1, "Comment")
78+
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Successful")
7979
ENTRY("ldu", "ldu", false, NO_FLAG, P1, "Comment")
8080
ENTRY("lg2", "lg2", true, NO_FLAG, P1, "Successful")
8181
ENTRY("lop3", "lop3", true, NO_FLAG, P1, "Successful")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2047,6 +2047,66 @@ class joint_matrix {
20472047
matrix_accessor x;
20482048
const size_t num_elements;
20492049
};
2050+
2051+
template <typename T>
2052+
void ldmatrix(uintptr_t addr, T *m,
2053+
const sycl::nd_item<3> &item_ct1, bool trans = false,
2054+
unsigned mat = 0) {
2055+
int lane = item_ct1.get_local_id(2);
2056+
2057+
int group = lane / 8;
2058+
int sub = lane % 8;
2059+
int src_base = group * 2;
2060+
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2061+
2062+
if (!trans) {
2063+
// Broadcast the address from the source lane
2064+
auto recv_addr_uintp = dpct::select_from_sub_group(
2065+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2066+
auto recv_addr = reinterpret_cast<sycl::half *>(recv_addr_uintp);
2067+
2068+
// Row-major load
2069+
int index = (lane % 4) * 2;
2070+
sycl::half val0 = recv_addr[index];
2071+
sycl::half val1 = recv_addr[index + 1];
2072+
uint16_t bits0 = sycl::bit_cast<unsigned short, sycl::half>(val0);
2073+
uint16_t bits1 = sycl::bit_cast<unsigned short, sycl::half>(val1);
2074+
*m = ((uint32_t)bits1 << 16) | bits0;
2075+
} else {
2076+
// Broadcast the address from the source lane:
2077+
auto recv_addr_uintp = dpct::select_from_sub_group(
2078+
item_ct1.get_sub_group(), addr, mat * 8);
2079+
auto recv_addr = reinterpret_cast<sycl::half *>(recv_addr_uintp);
2080+
recv_addr += src_lane;
2081+
2082+
// Transposed load
2083+
int index = (lane % 4) * 8 * 2;
2084+
sycl::half val0 = recv_addr[index];
2085+
sycl::half val1 = recv_addr[index + 8];
2086+
uint16_t bits0 = sycl::bit_cast<unsigned short, sycl::half>(val0);
2087+
uint16_t bits1 = sycl::bit_cast<unsigned short, sycl::half>(val1);
2088+
*m = ((uint32_t)bits1 << 16) | bits0;
2089+
}
2090+
}
2091+
2092+
template <typename T>
2093+
void ldmatrix(uintptr_t addr, T *m1, T *m2,
2094+
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2095+
ldmatrix(addr, m1, item_ct1, trans, 0);
2096+
ldmatrix(addr, m2, item_ct1, trans, 1);
2097+
}
2098+
2099+
template <typename T>
2100+
void ldmatrix(uintptr_t addr, T *m1, T *m2,
2101+
T *m3, T *m4,
2102+
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2103+
ldmatrix(addr, m1, item_ct1, trans, 0);
2104+
ldmatrix(addr, m2, item_ct1, trans, 1);
2105+
ldmatrix(addr, m3, item_ct1, trans, 2);
2106+
ldmatrix(addr, m4, item_ct1, trans, 3);
2107+
}
2108+
2109+
20502110
} // namespace matrix
20512111
} // namespace experimental
20522112

clang/test/dpct/asm/ldmatrix.cu

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
3+
// RUN: dpct --format-range=none -out-root %T/ldmatrix %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/ldmatrix/ldmatrix.dp.cpp
5+
// RUN: %if build_lit %{icpx -c -fsycl %T/ldmatrix/ldmatrix.dp.cpp -o %T/ldmatrix/ldmatrix.dp.o %}
6+
7+
// clang-format off
8+
#include <cuda_runtime.h>
9+
10+
/*
11+
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];
12+
13+
Below are the currenly supported configurations:
14+
.shape = {.m8n8};
15+
.num = {.x1, .x2, .x4};
16+
.ss = {.shared{::cta}};
17+
.type = {.b16};
18+
*/
19+
20+
__device__ void load_matrix_x1(void *sh_r_addr, int *r) {
21+
// CHECK: auto addr = sh_r_addr;
22+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
23+
24+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], item_ct1);
25+
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
26+
: "=r"(r[0])
27+
: "r"(addr));
28+
}
29+
30+
__device__ void load_matrix_x2(void *sh_r_addr, int *r) {
31+
// CHECK: auto addr = sh_r_addr;
32+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
33+
34+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], item_ct1);
35+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
36+
: "=r"(r[0]), "=r"(r[1])
37+
: "r"(addr));
38+
}
39+
40+
__device__ void load_matrix_x4(void *sh_r_addr, int *r) {
41+
// CHECK: auto addr = sh_r_addr;
42+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
43+
44+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], item_ct1);
45+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
46+
: "=r"(r[0]), "=r"(r[1]), "=r"(r[2]), "=r"(r[3])
47+
: "r"(addr));
48+
}
49+
50+
__device__ void load_matrix_x1_trans(void *sh_r_addr, int *r) {
51+
// CHECK: auto addr = sh_r_addr;
52+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
53+
54+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], item_ct1, true);
55+
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
56+
: "=r"(r[0])
57+
: "r"(addr));
58+
}
59+
60+
__device__ void load_matrix_x2_trans(void *sh_r_addr, int *r) {
61+
// CHECK: auto addr = sh_r_addr;
62+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
63+
64+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], item_ct1, true);
65+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
66+
: "=r"(r[0]), "=r"(r[1])
67+
: "r"(addr));
68+
}
69+
70+
__device__ void load_matrix_x4_trans(void *sh_r_addr, int *r) {
71+
// CHECK: auto addr = sh_r_addr;
72+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
73+
74+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], item_ct1, true);
75+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
76+
: "=r"(r[0]), "=r"(r[1]), "=r"(r[2]), "=r"(r[3])
77+
: "r"(addr));
78+
}
79+
80+
// clang-format on
Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1 @@
1-
add_library(target
2-
foo.cpp
3-
layer.cudnn.cpp
4-
foo.h
5-
)
6-
7-
add_library(bar bar.cpp bar.h)
8-
9-
target_compile_features(${TARGET} PUBLIC cxx_std_14)
10-
set(CMAKE_CXX_STANDARD 14)
11-
target_compile_features(culib PRIVATE cxx_std_14)
12-
set_target_properties(target_one PROPERTIES CXX_STANDARD 17)
13-
add_compile_options(-std=c++17)
14-
15-
add_library(chash OBJECT deps/chash/chash.c deps/chash/chash.h)
16-
add_library(cchash OBJECT deps/cchash/cchash.cc deps/cchash/cchash.h)
17-
add_library(cxxhash OBJECT deps/cxxhash/cxxhash.cxx deps/cxxhash/cxxhash.h)
18-
add_library(cpphash OBJECT deps/cpphash/cpphash.cpp deps/cpphash/cpphash.h)
19-
add_library(chash OBJECT deps/chash/foo.c deps/chash/chash.h)
1+
add_library(target foo.cpp)

0 commit comments

Comments
 (0)