Skip to content

Commit 9dad6eb

Browse files
Added support for ldmatrix migration
1 parent e9b108d commit 9dad6eb

9 files changed

Lines changed: 257 additions & 21 deletions

File tree

clang/lib/DPCT/RulesAsm/AsmMigration.cpp

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -557,12 +557,15 @@ bool SYCLGenBase::emitVectorType(const InlineAsmVectorType *T) {
557557
OS() << ", ";
558558
switch (T->getKind()) {
559559
case InlineAsmVectorType::v2:
560+
case InlineAsmVectorType::x1:
560561
OS() << 2;
561562
break;
562563
case InlineAsmVectorType::v4:
564+
case InlineAsmVectorType::x2:
563565
OS() << 4;
564566
break;
565567
case InlineAsmVectorType::v8:
568+
case InlineAsmVectorType::x4:
566569
OS() << 8;
567570
break;
568571
}
@@ -589,9 +592,9 @@ bool SYCLGenBase::emitVariableDeclaration(const InlineAsmVarDecl *D) {
589592

590593
bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
591594
// Address expression only support ld/st/red & atom instructions.
592-
if (!CurrInst ||
593-
!CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
594-
asmtok::op_prefetch, asmtok::op_red, asmtok::op_cp)) {
595+
if (!CurrInst || !CurrInst->is(asmtok::op_st, asmtok::op_ld, asmtok::op_atom,
596+
asmtok::op_prefetch, asmtok::op_red,
597+
asmtok::op_cp, asmtok::op_ldmatrix)) {
595598
return SYCLGenError();
596599
}
597600
std::string Type;
@@ -624,6 +627,8 @@ bool SYCLGenBase::emitAddressExpr(const InlineAsmAddressExpr *Dst) {
624627
if (CurrInst->is(asmtok::op_prefetch, asmtok::op_red) ||
625628
CanSuppressCast(Dst->getSymbol()))
626629
OS() << llvm::formatv("{0}", Reg);
630+
else if (CurrInst->is(asmtok::op_ldmatrix))
631+
OS() << llvm::formatv("(uintptr_t){0}", Reg);
627632
else
628633
OS() << llvm::formatv("(({0} *)(uintptr_t){1})", Type, Reg);
629634
break;
@@ -1305,6 +1310,46 @@ class SYCLGen : public SYCLGenBase {
13051310
return SYCLGenSuccess();
13061311
}
13071312

1313+
bool handle_ldmatrix(const InlineAsmInstruction *Inst) override {
1314+
if (Inst->getNumInputOperands() != 1)
1315+
return SYCLGenError();
1316+
1317+
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
1318+
CurrInst = Inst;
1319+
const auto *Src =
1320+
dyn_cast_or_null<InlineAsmAddressExpr>(Inst->getInputOperand(0));
1321+
if (!Src)
1322+
return false;
1323+
1324+
OS() << MapNames::getDpctNamespace() << "experimental::matrix::ldmatrix(";
1325+
if (emitStmt(Src)) {
1326+
return SYCLGenError();
1327+
}
1328+
OS() << ", ";
1329+
const auto *VE = dyn_cast<InlineAsmVectorExpr>(Inst->getOutputOperand());
1330+
for (unsigned Inst = 0, E = VE->getNumElements(); Inst != E; ++Inst) {
1331+
if (isa<InlineAsmDiscardExpr>(VE->getElement(Inst)))
1332+
continue;
1333+
OS() << "&";
1334+
if (emitStmt(VE->getElement(Inst)))
1335+
return SYCLGenError();
1336+
OS() << ", ";
1337+
}
1338+
OS() << DpctGlobalInfo::getItem(GAS);
1339+
if (Inst->hasAttr(InstAttr::trans))
1340+
OS() << ", true";
1341+
OS() << ");";
1342+
const auto *KernelDecl = getImmediateOuterFuncDecl(GAS);
1343+
if (KernelDecl) {
1344+
auto FuncInfo = DeviceFunctionDecl::LinkRedecls(KernelDecl);
1345+
if (FuncInfo)
1346+
FuncInfo->addSubGroupSizeRequest(32, GAS->getBeginLoc(),
1347+
DpctGlobalInfo::getSubGroup(GAS));
1348+
}
1349+
1350+
return SYCLGenSuccess();
1351+
}
1352+
13081353
bool handle_prefetch(const InlineAsmInstruction *Inst) override {
13091354
if (!DpctGlobalInfo::useExtPrefetch() || Inst->getNumInputOperands() != 1)
13101355
return SYCLGenError();
@@ -2881,6 +2926,7 @@ class SYCLGen : public SYCLGenBase {
28812926
bool handle_ld(const InlineAsmInstruction *Inst) override {
28822927
if (Inst->getNumInputOperands() != 1)
28832928
return SYCLGenError();
2929+
28842930
llvm::SaveAndRestore<const InlineAsmInstruction *> Store(CurrInst);
28852931
CurrInst = Inst;
28862932
const auto *Src =

clang/lib/DPCT/RulesAsm/Parser/AsmNodes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ class InlineAsmBuiltinType : public InlineAsmType {
116116
// This class is used for device asm vector types.
117117
class InlineAsmVectorType : public InlineAsmType {
118118
public:
119-
enum VecKind { v2, v4, v8 };
119+
enum VecKind { v2, v4, v8, x1, x2, x4 };
120120

121121
private:
122122
VecKind Kind;
@@ -340,6 +340,8 @@ class InlineAsmInstruction : public InlineAsmStmt {
340340
/// therest are input operands.
341341
SmallVector<InlineAsmExpr *, 4> InputOps;
342342

343+
SmallVector<InlineAsmExpr *, 4> OutputOps;
344+
343345
public:
344346
InlineAsmInstruction(InlineAsmIdentifierInfo *Op,
345347
SmallVector<AsmStateSpace, 4> AsmStateSpaces,

clang/lib/DPCT/RulesAsm/Parser/AsmParser.cpp

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ InlineAsmStmtResult InlineAsmParser::ParseInstruction() {
327327
if (!Tok.getIdentifier() || !Tok.getIdentifier()->isInstruction())
328328
return AsmStmtError();
329329

330-
InlineAsmIdentifierInfo *Opcode = Tok.getIdentifier();
330+
Opcode = Tok.getIdentifier();
331331
ConsumeToken();
332332

333333
SmallVector<InstAttr, 4> Attrs;
@@ -736,20 +736,38 @@ InlineAsmExprResult InlineAsmParser::ActOnParenExpr(InlineAsmExpr *SubExpr) {
736736
InlineAsmExprResult
737737
InlineAsmParser::ActOnVectorExpr(ArrayRef<InlineAsmExpr *> Vec) {
738738

739-
// Vector size must be 2, 4, or 8.
739+
// Vector size for ldmatrix are 1, 2, 4
740+
// size(x) = 2 * sizeof(v).
740741
InlineAsmVectorType::VecKind Kind;
741-
switch (Vec.size()) {
742-
case 2:
743-
Kind = InlineAsmVectorType::v2;
744-
break;
745-
case 4:
746-
Kind = InlineAsmVectorType::v4;
747-
break;
748-
case 8:
749-
Kind = InlineAsmVectorType::v8;
750-
break;
751-
default:
752-
return AsmExprError();
742+
if (Opcode->getTokenID() == asmtok::op_ldmatrix) {
743+
switch (Vec.size()) {
744+
case 1:
745+
Kind = InlineAsmVectorType::x1;
746+
break;
747+
case 2:
748+
Kind = InlineAsmVectorType::x2;
749+
break;
750+
case 4:
751+
Kind = InlineAsmVectorType::x4;
752+
break;
753+
default:
754+
return AsmExprError();
755+
}
756+
} else {
757+
// Vector size must be 2, 4, or 8.
758+
switch (Vec.size()) {
759+
case 2:
760+
Kind = InlineAsmVectorType::v2;
761+
break;
762+
case 4:
763+
Kind = InlineAsmVectorType::v4;
764+
break;
765+
case 8:
766+
Kind = InlineAsmVectorType::v8;
767+
break;
768+
default:
769+
return AsmExprError();
770+
}
753771
}
754772

755773
InlineAsmBuiltinType *ElementType = nullptr;

clang/lib/DPCT/RulesAsm/Parser/AsmParser.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,8 @@ class InlineAsmParser {
247247
};
248248

249249
public:
250+
InlineAsmIdentifierInfo *Opcode;
251+
250252
InlineAsmParser(InlineAsmContext &Ctx, SourceMgr &Mgr)
251253
: Lexer(*Mgr.getMemoryBuffer(Mgr.getMainFileID())), Context(Ctx),
252254
SrcMgr(Mgr), CurScope(nullptr) {

clang/lib/DPCT/RulesAsm/Parser/AsmTokenKinds.def

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,14 @@ MODIFIER(v2, ".v2")
274274
MODIFIER(v4, ".v4")
275275
MODIFIER(v8, ".v8")
276276

277+
// Matrix modifiers
278+
MODIFIER(x1, ".x1")
279+
MODIFIER(x2, ".x2")
280+
MODIFIER(x4, ".x4")
281+
282+
// Matrix shape
283+
MODIFIER(m8n8, ".m8n8")
284+
277285
STATE_SPACE(reg, ".reg")
278286
STATE_SPACE(sreg, ".sreg")
279287
STATE_SPACE(const, ".const")
@@ -420,6 +428,8 @@ MODIFIER(ecr, ".ecr")
420428
MODIFIER(rc16, ".rc16")
421429
MODIFIER(cs, ".cs")
422430
MODIFIER(to, ".to")
431+
MODIFIER(aligned, ".aligned")
432+
MODIFIER(trans, ".trans")
423433

424434
#undef LINKAGE
425435
#undef TARGET

clang/lib/DPCT/SrcAPI/APINames_ASM.inc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ ENTRY("griddepcontrol", "griddepcontrol", false, NO_FLAG, P1, "Comment")
7575
ENTRY("isspacep", "isspacep", false, NO_FLAG, P1, "Comment")
7676
ENTRY("istypep", "istypep", false, NO_FLAG, P1, "Comment")
7777
ENTRY("ld", "ld", true, NO_FLAG, P1, "Partial")
78-
ENTRY("ldmatrix", "ldmatrix", false, NO_FLAG, P1, "Comment")
78+
ENTRY("ldmatrix", "ldmatrix", true, NO_FLAG, P1, "Successful")
7979
ENTRY("ldu", "ldu", false, NO_FLAG, P1, "Comment")
8080
ENTRY("lg2", "lg2", true, NO_FLAG, P1, "Successful")
8181
ENTRY("lop3", "lop3", true, NO_FLAG, P1, "Successful")

clang/runtime/dpct-rt/include/dpct/math.hpp

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,8 @@
99
#ifndef __DPCT_MATH_HPP__
1010
#define __DPCT_MATH_HPP__
1111

12-
#include <limits>
1312
#include <climits>
13+
#include <limits>
1414
#include <sycl/sycl.hpp>
1515
#include <type_traits>
1616

@@ -2055,6 +2055,64 @@ class joint_matrix {
20552055
matrix_accessor x;
20562056
const size_t num_elements;
20572057
};
2058+
2059+
template <typename T>
2060+
void ldmatrix(uintptr_t addr, T *m, const sycl::nd_item<3> &item_ct1,
2061+
bool trans = false, unsigned mat = 0) {
2062+
int lane = item_ct1.get_local_id(2);
2063+
2064+
int group = lane / 8;
2065+
int sub = lane % 8;
2066+
int src_base = group * 2;
2067+
2068+
if (!trans) {
2069+
// calculate the source lane
2070+
int src_lane = (sub / 4) ? (src_base + 1) : src_base;
2071+
2072+
// Broadcast the address from the source lane
2073+
auto recv_addr_uintp = dpct::select_from_sub_group(
2074+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2075+
auto recv_addr = reinterpret_cast<T *>(recv_addr_uintp);
2076+
2077+
// Non-transposed load
2078+
*m = recv_addr[sub % 4];
2079+
} else {
2080+
// calculate the source lane
2081+
int src_lane = (lane % 4) * 2;
2082+
2083+
// Broadcast the address from the source lane:
2084+
auto recv_addr_uintp_1 = dpct::select_from_sub_group(
2085+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane);
2086+
auto recv_addr_uintp_2 = dpct::select_from_sub_group(
2087+
item_ct1.get_sub_group(), addr, mat * 8 + src_lane + 1);
2088+
auto recv_addr_1 = reinterpret_cast<sycl::half *>(recv_addr_uintp_1);
2089+
auto recv_addr_2 = reinterpret_cast<sycl::half *>(recv_addr_uintp_2);
2090+
2091+
// Transposed load
2092+
int index = (lane / 4);
2093+
sycl::half val0 = recv_addr_1[index];
2094+
sycl::half val1 = recv_addr_2[index];
2095+
sycl::half2 val = sycl::half2(val0, val1);
2096+
*m = *reinterpret_cast<T *>(&val);
2097+
}
2098+
}
2099+
2100+
template <typename T>
2101+
void ldmatrix(uintptr_t addr, T *m1, T *m2, const sycl::nd_item<3> &item_ct1,
2102+
bool trans = false) {
2103+
ldmatrix(addr, m1, item_ct1, trans, 0);
2104+
ldmatrix(addr, m2, item_ct1, trans, 1);
2105+
}
2106+
2107+
template <typename T>
2108+
void ldmatrix(uintptr_t addr, T *m1, T *m2, T *m3, T *m4,
2109+
const sycl::nd_item<3> &item_ct1, bool trans = false) {
2110+
ldmatrix(addr, m1, item_ct1, trans, 0);
2111+
ldmatrix(addr, m2, item_ct1, trans, 1);
2112+
ldmatrix(addr, m3, item_ct1, trans, 2);
2113+
ldmatrix(addr, m4, item_ct1, trans, 3);
2114+
}
2115+
20582116
} // namespace matrix
20592117
} // namespace experimental
20602118

clang/test/dpct/asm/ldmatrix.cu

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
// UNSUPPORTED: cuda-8.0, cuda-9.0, cuda-9.1, cuda-9.2, cuda-10.0, cuda-10.1, cuda-10.2
2+
// UNSUPPORTED: v8.0, v9.0, v9.1, v9.2, v10.0, v10.1, v10.2
3+
// RUN: dpct --format-range=none -out-root %T/ldmatrix %s --cuda-include-path="%cuda-path/include" -- -std=c++14 -x cuda --cuda-host-only
4+
// RUN: FileCheck %s --match-full-lines --input-file %T/ldmatrix/ldmatrix.dp.cpp
5+
// RUN: %if build_lit %{icpx -c -fsycl %T/ldmatrix/ldmatrix.dp.cpp -o %T/ldmatrix/ldmatrix.dp.o %}
6+
7+
// clang-format off
8+
#include <cuda_runtime.h>
9+
#include <cuda_fp16.h>
10+
11+
/*
12+
ldmatrix.sync.aligned.shape.num{.trans}{.ss}.type r, [p];
13+
14+
Below are the currenly supported configurations:
15+
.shape = {.m8n8};
16+
.num = {.x1, .x2, .x4};
17+
.ss = {.shared{::cta}};
18+
.type = {.b16};
19+
*/
20+
21+
__device__ void load_matrix_x1(void *sh_r_addr, int *r) {
22+
// CHECK: auto addr = sh_r_addr;
23+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
24+
25+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], item_ct1);
26+
asm volatile("ldmatrix.sync.aligned.m8n8.x1.shared.b16 {%0}, [%1];\n"
27+
: "=r"(r[0])
28+
: "r"(addr));
29+
}
30+
31+
__device__ void load_matrix_x2(void *sh_r_addr, int *r) {
32+
// CHECK: auto addr = sh_r_addr;
33+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
34+
35+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], item_ct1);
36+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.shared.b16 {%0, %1}, [%2];\n"
37+
: "=r"(r[0]), "=r"(r[1])
38+
: "r"(addr));
39+
}
40+
41+
__device__ void load_matrix_x4(void *sh_r_addr, int *r) {
42+
// CHECK: auto addr = sh_r_addr;
43+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
44+
45+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], item_ct1);
46+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.shared.b16 {%0, %1, %2, %3}, [%4];\n"
47+
: "=r"(r[0]), "=r"(r[1]), "=r"(r[2]), "=r"(r[3])
48+
: "r"(addr));
49+
}
50+
51+
__device__ void load_matrix_x1_trans(void *sh_r_addr, int *r) {
52+
// CHECK: auto addr = sh_r_addr;
53+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
54+
55+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], item_ct1, true);
56+
asm volatile("ldmatrix.sync.aligned.m8n8.x1.trans.shared.b16 {%0}, [%1];\n"
57+
: "=r"(r[0])
58+
: "r"(addr));
59+
}
60+
61+
__device__ void load_matrix_x2_trans(void *sh_r_addr, int *r) {
62+
// CHECK: auto addr = sh_r_addr;
63+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
64+
65+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], item_ct1, true);
66+
asm volatile("ldmatrix.sync.aligned.m8n8.x2.trans.shared.b16 {%0, %1}, [%2];\n"
67+
: "=r"(r[0]), "=r"(r[1])
68+
: "r"(addr));
69+
}
70+
71+
__device__ void load_matrix_x4_trans(void *sh_r_addr, int *r) {
72+
// CHECK: auto addr = sh_r_addr;
73+
uint32_t addr = static_cast<uint32_t>(__cvta_generic_to_shared(sh_r_addr));
74+
75+
// CHECK: dpct::experimental::matrix::ldmatrix((uintptr_t)addr, &r[0], &r[1], &r[2], &r[3], item_ct1, true);
76+
asm volatile("ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16 {%0, %1, %2, %3}, [%4];\n"
77+
: "=r"(r[0]), "=r"(r[1]), "=r"(r[2]), "=r"(r[3])
78+
: "r"(addr));
79+
}
80+
81+
__global__ void load_kernel() {
82+
__shared__ half s_data[1024];
83+
int r[4];
84+
85+
load_matrix_x1(s_data, r);
86+
load_matrix_x2(s_data, r);
87+
load_matrix_x4(s_data, r);
88+
load_matrix_x1_trans(s_data, r);
89+
load_matrix_x2_trans(s_data, r);
90+
load_matrix_x4_trans(s_data, r);
91+
}
92+
93+
int main () {
94+
// CHECK: [=](sycl::nd_item<3> item_ct1) {{\[\[}}sycl::reqd_sub_group_size(32){{\]\]}} {
95+
load_kernel<<<1, 32>>>();
96+
97+
return 0;
98+
}
99+
100+
// clang-format on

docs/dev_guide/api-mapping-status/ASM_API_migration_status.csv

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ griddepcontrol,NO,
4141
isspacep,NO,
4242
istypep,NO,
4343
ld,YES, Partial
44-
ldmatrix,NO,
44+
ldmatrix,YES,Partial
4545
ldu,NO,
4646
lg2,YES,
4747
lop3,YES,

0 commit comments

Comments
 (0)