Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions clang/lib/DPCT/ASTTraversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include "RulesSHMEM/NVSHMEMAPIMigration.h"
#include "RulesSecurity/Homoglyph.h"
#include "RulesSecurity/MisleadingBidirectional.h"
#include "RulesTensor/CUTensorAPIMigration.h"
#include "TextModification.h"
#include "Utility.h"

Expand Down Expand Up @@ -197,5 +198,7 @@ REGISTER_RULE(CuDNNAPIRule, PassKind::PK_Migration, RuleGroupKind::RK_DNN)

REGISTER_RULE(NVSHMEMRule, PassKind::PK_Migration, RuleGroupKind::RK_NVSHMEM)

REGISTER_RULE(CUTensorRule, PassKind::PK_Migration, RuleGroupKind::RK_CUTensor)

} // namespace dpct
} // namespace clang
2 changes: 2 additions & 0 deletions clang/lib/DPCT/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ add_clang_library(DPCT
RulesLang/CallExprRewriterCG.cpp
RulesLang/CallExprRewriterWmma.cpp
RulesSHMEM/CallExprRewriterNvshmem.cpp
RulesTensor/CallExprRewriterCUTensor.cpp
ErrorHandle/CrashRecovery.cpp
Diagnostics/Diagnostics.cpp
ErrorHandle/Error.cpp
Expand Down Expand Up @@ -242,6 +243,7 @@ add_clang_library(DPCT
RulesCCL/NCCLAPIMigration.cpp
RuleInfra/TypeLocRewriters.cpp
RulesSHMEM/NVSHMEMAPIMigration.cpp
RulesTensor/CUTensorAPIMigration.cpp
Linux/AutoComplete.cpp
RulesAsm/AsmMigration.cpp
QueryAPIMapping/QueryAPIMapping.cpp
Expand Down
11 changes: 7 additions & 4 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ std::optional<std::string> FuncCallExprRewriter::buildRewriteString() {

std::unique_ptr<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>
CallExprRewriterFactoryBase::RewriterMap = std::make_unique<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
CallExprRewriterFactoryBase::RewriterMap =
std::make_unique<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();

std::unique_ptr<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>
CallExprRewriterFactoryBase::MethodRewriterMap = std::make_unique<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();
CallExprRewriterFactoryBase::MethodRewriterMap =
std::make_unique<std::unordered_map<
std::string, std::shared_ptr<CallExprRewriterFactoryBase>>>();

void CallExprRewriterFactoryBase::initRewriterMap() {
if (DpctGlobalInfo::useSYCLCompat()) {
Expand All @@ -162,6 +164,7 @@ void CallExprRewriterFactoryBase::initRewriterMap() {
initRewriterMapMisc();
initRewriterMapNccl();
initRewriterMapNvshmem();
initRewriterMapCUTensor();
initRewriterMapStream();
initRewriterMapTexture();
initRewriterMapThrust();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class CallExprRewriterFactoryBase {
static void initRewriterMapMisc();
static void initRewriterMapNccl();
static void initRewriterMapNvshmem();
static void initRewriterMapCUTensor();
static void initRewriterMapStream();
static void initRewriterMapTexture();
static void initRewriterMapThrust();
Expand Down
1 change: 1 addition & 0 deletions clang/lib/DPCT/RulesInclude/InclusionHeaders.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ enum class RuleGroupKind : uint8_t {
RK_CUB,
RK_WMMA,
RK_NVSHMEM,
RK_CUTensor,
NUM
};

Expand Down
3 changes: 3 additions & 0 deletions clang/lib/DPCT/RulesInclude/InclusionHeaders.inc
Original file line number Diff line number Diff line change
Expand Up @@ -112,3 +112,6 @@ REGIST_INCLUSION("nvshmem.h", FullMatch, NVSHMEM, Replace, false,
HeaderType::HT_SHMEM)
REGIST_INCLUSION("nvshmemx.h", FullMatch, NVSHMEM, Replace, false,
HeaderType::HT_SHMEMX)

REGIST_INCLUSION("cutensor.h", FullMatch, CUTensor, Remove, true)
REGIST_INCLUSION("cutensorMg.h", FullMatch, CUTensor, Remove, true)
10 changes: 10 additions & 0 deletions clang/lib/DPCT/RulesTensor/APINamesCUTensor.inc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
//===------------------------ APINamesCUTensor.inc ------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

// Helper Functions
ENTRY_UNSUPPORTED("cutensorCreate", Diagnostics::API_NOT_MIGRATED)
77 changes: 77 additions & 0 deletions clang/lib/DPCT/RulesTensor/CUTensorAPIMigration.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
//===---------------------- CUTensorAPIMigration.cpp ----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===-----------------------------------------------------------------------===//

#include "CUTensorAPIMigration.h"
#include "RuleInfra/ExprAnalysis.h"

using namespace clang::dpct;
using namespace clang::ast_matchers;

void clang::dpct::CUTensorRule::registerMatcher(ast_matchers::MatchFinder &MF) {
auto CutensorAPIs = [&]() {
return hasAnyName(
// Helper Functions
"cutensorCreate", "cutensorDestroy", "cutensorCreateTensorDescriptor",
"cutensorDestroyTensorDescriptor", "cutensorGetErrorString",
"cutensorGetVersion", "cutensorGetCudartVersion",
// Element-wise Operations
"cutensorCreateElementwiseTrinary", "cutensorElementwiseTrinaryExecute",
"cutensorCreateElementwiseBinary", "cutensorElementwiseBinaryExecute",
"cutensorCreatePermutation", "cutensorPermute",
// Contraction Operations
"cutensorCreateContraction", "cutensorContract",
"cutensorCreateContractionTrinary", "cutensorContractTrinary",
// Reduction Operations
"cutensorCreateReduction", "cutensorReduce",
// Generic Operation Functions
"cutensorDestroyOperationDescriptor",
"cutensorOperationDescriptorGetAttribute",
"cutensorOperationDescriptorSetAttribute",
"cutensorCreatePlanPreference", "cutensorDestroyPlanPreference",
"cutensorPlanPreferenceSetAttribute", "cutensorEstimateWorkspaceSize",
"cutensorCreatePlan", "cutensorDestroyPlan", "cutensorPlanGetAttribute",
// Cache-related Operations
"cutensorHandleResizePlanCache", "cutensorHandleReadPlanCacheFromFile",
"cutensorHandleWritePlanCacheToFile", "cutensorReadKernelCacheFromFile",
"cutensorWriteKernelCacheToFile",
// Logger Functions
"cutensorLoggerSetCallback", "cutensorLoggerSetFile",
"cutensorLoggerOpenFile", "cutensorLoggerSetLevel",
"cutensorLoggerSetMask", "cutensorLoggerForceDisable",
// cuTENSORMg - General Operations
"cutensorMgCreate", "cutensorMgDestroy",
"cutensorMgCreateTensorDescriptor", "cutensorMgDestroyTensorDescriptor",
"cutensorMgCreateCopyDescriptor", "cutensorMgDestroyCopyDescriptor",
"cutensorMgCopyGetWorkspace", "cutensorMgCreateCopyPlan",
"cutensorMgDestroyCopyPlan", "cutensorMgCopy",
// cuTENSORMg - Contraction Operations
"cutensorMgCreateContractionDescriptor",
"cutensorMgDestroyContractionDescriptor",
"cutensorMgCreateContractionFind", "cutensorMgDestroyContractionFind",
"cutensorMgContractionGetWorkspace", "cutensorMgCreateContractionPlan",
"cutensorMgDestroyContractionPlan", "cutensorMgContraction");
};

MF.addMatcher(callExpr(callee(functionDecl(CutensorAPIs()))).bind("call"),
this);
}

void clang::dpct::CUTensorRule::runRule(
const ast_matchers::MatchFinder::MatchResult &Result) {
if (const CallExpr *CE = getNodeAsType<CallExpr>(Result, "call")) {
std::string FuncName = "";
const FunctionDecl *FD = CE->getDirectCallee();
if (FD) {
FuncName = FD->getNameInfo().getName().getAsString();
}

report(CE->getBeginLoc(), Diagnostics::API_NOT_MIGRATED, false, FuncName);
}

return;
}
28 changes: 28 additions & 0 deletions clang/lib/DPCT/RulesTensor/CUTensorAPIMigration.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
//===----------------------- CUTensorAPIMigration.h -----------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef CUTENSOR_API_MIGRATION_H
#define CUTENSOR_API_MIGRATION_H

#include "ASTTraversal.h"

using namespace clang::ast_matchers;

namespace clang {
namespace dpct {

class CUTensorRule : public NamedMigrationRule<CUTensorRule> {
public:
void registerMatcher(ast_matchers::MatchFinder &MF) override;
void runRule(const ast_matchers::MatchFinder::MatchResult &Result);
};

} // namespace dpct
} // namespace clang

#endif // CUTENSOR_API_MIGRATION_H
37 changes: 37 additions & 0 deletions clang/lib/DPCT/RulesTensor/CallExprRewriterCUTensor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
//===-------------------- CallExprRewriterCUTensor.cpp --------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "RuleInfra/CallExprRewriter.h"
#include "RuleInfra/CallExprRewriterCommon.h"

namespace clang {
namespace dpct {

#define REWRITER_FACTORY_ENTRY(FuncName, RewriterFactory, ...) \
{FuncName, std::make_shared<RewriterFactory>(FuncName, __VA_ARGS__)},
#define UNSUPPORTED_FACTORY_ENTRY(FuncName, MsgID) \
REWRITER_FACTORY_ENTRY(FuncName, \
UnsupportFunctionRewriterFactory<std::string>, MsgID, \
FuncName)
#define ENTRY_UNSUPPORTED(SOURCEAPINAME, MSGID) \
UNSUPPORTED_FACTORY_ENTRY(SOURCEAPINAME, MSGID)

void CallExprRewriterFactoryBase::initRewriterMapCUTensor() {
RewriterMap->merge(
std::unordered_map<std::string,
std::shared_ptr<CallExprRewriterFactoryBase>>({
#include "APINamesCUTensor.inc"
}));
}

#undef ENTRY_UNSUPPORTED
#undef UNSUPPORTED_FACTORY_ENTRY
#undef REWRITER_FACTORY_ENTRY

} // namespace dpct
} // namespace clang
Loading