Skip to content

Commit b351040

Browse files
Added init rules infra
1 parent ebff19a commit b351040

12 files changed

Lines changed: 232 additions & 62 deletions

clang/lib/DPCT/ASTTraversal.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "RulesSHMEM/NVSHMEMAPIMigration.h"
2828
#include "RulesSecurity/Homoglyph.h"
2929
#include "RulesSecurity/MisleadingBidirectional.h"
30+
#include "RulesTensor/CUTENSORAPIMigration.h"
3031
#include "TextModification.h"
3132
#include "Utility.h"
3233

@@ -197,5 +198,7 @@ REGISTER_RULE(CuDNNAPIRule, PassKind::PK_Migration, RuleGroupKind::RK_DNN)
197198

198199
REGISTER_RULE(NVSHMEMRule, PassKind::PK_Migration, RuleGroupKind::RK_NVSHMEM)
199200

201+
REGISTER_RULE(CUTENSORRule, PassKind::PK_Migration, RuleGroupKind::RK_CUTENSOR)
202+
200203
} // namespace dpct
201204
} // namespace clang

clang/lib/DPCT/CMakeLists.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,7 @@ add_clang_library(DPCT
203203
RulesLang/CallExprRewriterCG.cpp
204204
RulesLang/CallExprRewriterWmma.cpp
205205
RulesSHMEM/CallExprRewriterNvshmem.cpp
206+
RulesTensor/CallExprRewriterCUTENSOR.cpp
206207
ErrorHandle/CrashRecovery.cpp
207208
Diagnostics/Diagnostics.cpp
208209
ErrorHandle/Error.cpp
@@ -242,6 +243,7 @@ add_clang_library(DPCT
242243
RulesCCL/NCCLAPIMigration.cpp
243244
RuleInfra/TypeLocRewriters.cpp
244245
RulesSHMEM/NVSHMEMAPIMigration.cpp
246+
RulesTensor/CUTENSORAPIMigration.cpp
245247
Linux/AutoComplete.cpp
246248
RulesAsm/AsmMigration.cpp
247249
QueryAPIMapping/QueryAPIMapping.cpp

clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,15 @@ std::optional<std::string> FuncCallExprRewriter::buildRewriteString() {
134134

135135
std::unique_ptr<std::unordered_map<
136136
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>
137-
CallExprRewriterFactoryBase::RewriterMap = std::make_unique<std::unordered_map<
138-
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
137+
CallExprRewriterFactoryBase::RewriterMap =
138+
std::make_unique<std::unordered_map<
139+
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
139140

140141
std::unique_ptr<std::unordered_map<
141142
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>
142-
CallExprRewriterFactoryBase::MethodRewriterMap = std::make_unique<std::unordered_map<
143-
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
143+
CallExprRewriterFactoryBase::MethodRewriterMap =
144+
std::make_unique<std::unordered_map<
145+
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
144146

145147
void CallExprRewriterFactoryBase::initRewriterMap() {
146148
if (DpctGlobalInfo::useSYCLCompat()) {
@@ -162,6 +164,7 @@ void CallExprRewriterFactoryBase::initRewriterMap() {
162164
initRewriterMapMisc();
163165
initRewriterMapNccl();
164166
initRewriterMapNvshmem();
167+
initRewriterMapCUTENSOR();
165168
initRewriterMapStream();
166169
initRewriterMapTexture();
167170
initRewriterMapThrust();

clang/lib/DPCT/RuleInfra/CallExprRewriter.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,7 @@ class CallExprRewriterFactoryBase {
7070
static void initRewriterMapMisc();
7171
static void initRewriterMapNccl();
7272
static void initRewriterMapNvshmem();
73+
static void initRewriterMapCUTENSOR();
7374
static void initRewriterMapStream();
7475
static void initRewriterMapTexture();
7576
static void initRewriterMapThrust();

clang/lib/DPCT/RulesInclude/InclusionHeaders.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ enum class RuleGroupKind : uint8_t {
3636
RK_CUB,
3737
RK_WMMA,
3838
RK_NVSHMEM,
39+
RK_CUTENSOR,
3940
NUM
4041
};
4142

clang/lib/DPCT/RulesInclude/InclusionHeaders.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,6 @@ REGIST_INCLUSION("nvshmem.h", FullMatch, NVSHMEM, Replace, false,
112112
HeaderType::HT_SHMEM)
113113
REGIST_INCLUSION("nvshmemx.h", FullMatch, NVSHMEM, Replace, false,
114114
HeaderType::HT_SHMEMX)
115+
116+
REGIST_INCLUSION("cutensor.h", FullMatch, CUTENSOR, Remove, true)
117+
REGIST_INCLUSION("cutensorMg.h", FullMatch, CUTENSOR, Remove, true)
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
//===------------------------ APINamesCUTENSOR.inc ------------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
// Helper Functions
10+
ENTRY_UNSUPPORTED("cutensorCreate", Diagnostics::API_NOT_MIGRATED)
Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
//===---------------------- CUTENSORAPIMigration.cpp ----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===-----------------------------------------------------------------------===//
8+
9+
#include "CUTENSORAPIMigration.h"
10+
#include "RuleInfra/ExprAnalysis.h"
11+
12+
using namespace clang::dpct;
13+
using namespace clang::ast_matchers;
14+
15+
void clang::dpct::CUTENSORRule::registerMatcher(ast_matchers::MatchFinder &MF) {
16+
auto CutensorAPIs = [&]() {
17+
return hasAnyName(
18+
// Helper Functions
19+
"cutensorCreate", "cutensorDestroy", "cutensorCreateTensorDescriptor",
20+
"cutensorDestroyTensorDescriptor", "cutensorGetErrorString",
21+
"cutensorGetVersion", "cutensorGetCudartVersion",
22+
// Element-wise Operations
23+
"cutensorCreateElementwiseTrinary", "cutensorElementwiseTrinaryExecute",
24+
"cutensorCreateElementwiseBinary", "cutensorElementwiseBinaryExecute",
25+
"cutensorCreatePermutation", "cutensorPermute",
26+
// Contraction Operations
27+
"cutensorCreateContraction", "cutensorContract",
28+
"cutensorCreateContractionTrinary", "cutensorContractTrinary",
29+
// Reduction Operations
30+
"cutensorCreateReduction", "cutensorReduce",
31+
// Generic Operation Functions
32+
"cutensorDestroyOperationDescriptor",
33+
"cutensorOperationDescriptorGetAttribute",
34+
"cutensorOperationDescriptorSetAttribute",
35+
"cutensorCreatePlanPreference", "cutensorDestroyPlanPreference",
36+
"cutensorPlanPreferenceSetAttribute", "cutensorEstimateWorkspaceSize",
37+
"cutensorCreatePlan", "cutensorDestroyPlan", "cutensorPlanGetAttribute",
38+
// Cache-related Operations
39+
"cutensorHandleResizePlanCache", "cutensorHandleReadPlanCacheFromFile",
40+
"cutensorHandleWritePlanCacheToFile", "cutensorReadKernelCacheFromFile",
41+
"cutensorWriteKernelCacheToFile",
42+
// Logger Functions
43+
"cutensorLoggerSetCallback", "cutensorLoggerSetFile",
44+
"cutensorLoggerOpenFile", "cutensorLoggerSetLevel",
45+
"cutensorLoggerSetMask", "cutensorLoggerForceDisable",
46+
// cuTENSORMg - General Operations
47+
"cutensorMgCreate", "cutensorMgDestroy",
48+
"cutensorMgCreateTensorDescriptor", "cutensorMgDestroyTensorDescriptor",
49+
"cutensorMgCreateCopyDescriptor", "cutensorMgDestroyCopyDescriptor",
50+
"cutensorMgCopyGetWorkspace", "cutensorMgCreateCopyPlan",
51+
"cutensorMgDestroyCopyPlan", "cutensorMgCopy",
52+
// cuTENSORMg - Contraction Operations
53+
"cutensorMgCreateContractionDescriptor",
54+
"cutensorMgDestroyContractionDescriptor",
55+
"cutensorMgCreateContractionFind", "cutensorMgDestroyContractionFind",
56+
"cutensorMgContractionGetWorkspace", "cutensorMgCreateContractionPlan",
57+
"cutensorMgDestroyContractionPlan", "cutensorMgContraction");
58+
};
59+
60+
llvm::outs() << "[DEBUG] Inside regMatcher\n";
61+
62+
MF.addMatcher(callExpr(callee(functionDecl(CutensorAPIs()))).bind("call"),
63+
this);
64+
}
65+
66+
void clang::dpct::CUTENSORRule::runRule(
67+
const ast_matchers::MatchFinder::MatchResult &Result) {
68+
llvm::outs() << "[DEBUG] Inside runRule\n";
69+
if (const CallExpr *CE = getNodeAsType<CallExpr>(Result, "call")) {
70+
std::string FuncName = "";
71+
const FunctionDecl *FD = CE->getDirectCallee();
72+
if (FD) {
73+
FuncName = FD->getNameInfo().getName().getAsString();
74+
}
75+
76+
report(CE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false, FuncName);
77+
}
78+
79+
return;
80+
}
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
//===----------------------- CUTENSORAPIMigration.h -----------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#ifndef CUTENSOR_API_MIGRATION_H
10+
#define CUTENSOR_API_MIGRATION_H
11+
12+
#include "ASTTraversal.h"
13+
14+
using namespace clang::ast_matchers;
15+
16+
namespace clang {
17+
namespace dpct {
18+
19+
class CUTENSORRule : public NamedMigrationRule<CUTENSORRule> {
20+
public:
21+
void registerMatcher(ast_matchers::MatchFinder &MF) override;
22+
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
23+
};
24+
25+
} // namespace dpct
26+
} // namespace clang
27+
28+
#endif // CUTENSOR_API_MIGRATION_H
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//===-------------------- CallExprRewriterCUTENSOR.cpp --------------------===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "RuleInfra/CallExprRewriter.h"
10+
#include "RuleInfra/CallExprRewriterCommon.h"
11+
12+
namespace clang {
13+
namespace dpct {
14+
15+
#define REWRITER_FACTORY_ENTRY(FuncName, RewriterFactory, ...) \
16+
{FuncName, std::make_shared<RewriterFactory>(FuncName, __VA_ARGS__)},
17+
#define UNSUPPORTED_FACTORY_ENTRY(FuncName, MsgID) \
18+
REWRITER_FACTORY_ENTRY(FuncName, \
19+
UnsupportFunctionRewriterFactory<std::string>, MsgID, \
20+
FuncName)
21+
#define ENTRY_UNSUPPORTED(SOURCEAPINAME, MSGID) \
22+
UNSUPPORTED_FACTORY_ENTRY(SOURCEAPINAME, MSGID)
23+
24+
void CallExprRewriterFactoryBase::initRewriterMapCUTENSOR() {
25+
RewriterMap->merge(
26+
std::unordered_map<std::string,
27+
std::shared_ptr<CallExprRewriterFactoryBase>>({
28+
#include "APINamesCUTENSOR.inc"
29+
}));
30+
}
31+
32+
#undef ENTRY_UNSUPPORTED
33+
#undef UNSUPPORTED_FACTORY_ENTRY
34+
#undef REWRITER_FACTORY_ENTRY
35+
36+
} // namespace dpct
37+
} // namespace clang

0 commit comments

Comments
 (0)