diff --git a/clang/lib/DPCT/RulesLang/RulesLang.cpp b/clang/lib/DPCT/RulesLang/RulesLang.cpp index d40666018a31..7c676b00b6d7 100644 --- a/clang/lib/DPCT/RulesLang/RulesLang.cpp +++ b/clang/lib/DPCT/RulesLang/RulesLang.cpp @@ -7761,6 +7761,31 @@ void SyncThreadsMigrationRule::registerMatcher(MatchFinder &MF) { this); } +bool SyncThreadsMigrationRule::noCorrespondingCEInInstantiatedTemplates( + const FunctionTemplateDecl *FTD, const CallExpr *CE, + const std::string &FuncName) { + const auto &SM = DpctGlobalInfo::getSourceManager(); + auto CEMatcher = ast_matchers::findAll( + ast_matchers::callExpr(callee(functionDecl(hasName(FuncName)))) + .bind("call")); + SourceLocation CELocation = SM.getSpellingLoc(CE->getBeginLoc()); + for (const auto &Spec : FTD->specializations()) { + if (!(Spec->hasBody())) + continue; + auto MatchedResults = ast_matchers::match(CEMatcher, *(Spec->getBody()), + DpctGlobalInfo::getContext()); + for (auto &Node : MatchedResults) { + if (const auto *MatchedCE = Node.getNodeAs("call")) { + SourceLocation MatchedCELocation = + SM.getSpellingLoc(MatchedCE->getBeginLoc()); + if (CELocation == MatchedCELocation) + return false; + } + } + } + return true; +} + void SyncThreadsMigrationRule::runRule(const MatchFinder::MatchResult &Result) { static std::map LocationResultMapForTemplate; auto emplaceReplacement = [&](BarrierFenceSpaceAnalyzerResult Res, @@ -7802,7 +7827,8 @@ void SyncThreadsMigrationRule::runRule(const MatchFinder::MatchResult &Result) { BarrierFenceSpaceAnalyzer A; const FunctionTemplateDecl *FTD = FD->getDescribedFunctionTemplate(); if (FTD) { - if (FTD->specializations().empty()) { + if (FTD->specializations().empty() || + noCorrespondingCEInInstantiatedTemplates(FTD, CE, FuncName)) { emplaceReplacement(A.analyze(CE), CE); } } else { diff --git a/clang/lib/DPCT/RulesLang/RulesLang.h b/clang/lib/DPCT/RulesLang/RulesLang.h index a9e83884103d..b3493eadabd2 100644 --- a/clang/lib/DPCT/RulesLang/RulesLang.h +++ b/clang/lib/DPCT/RulesLang/RulesLang.h @@ -21,6 +21,7 @@ #include "clang/Frontend/CompilerInstance.h" #include +#include #include namespace clang { @@ -809,6 +810,9 @@ class SyncThreadsMigrationRule public: void registerMatcher(ast_matchers::MatchFinder &MF) override; void runRule(const ast_matchers::MatchFinder::MatchResult &Result); + bool noCorrespondingCEInInstantiatedTemplates(const FunctionTemplateDecl *FTD, + const CallExpr *CE, + const std::string &FuncName); }; /// Migrate Function Attributes to Sycl kernel info, defined in diff --git a/clang/test/dpct/syncthreads.cu b/clang/test/dpct/syncthreads.cu index d4d488190fe0..c6e2ff576f09 100644 --- a/clang/test/dpct/syncthreads.cu +++ b/clang/test/dpct/syncthreads.cu @@ -462,3 +462,20 @@ __global__ void test21(float *ptr1, float *ptr2, int step1, int step2) { idx2 += step2; } } + +template __device__ void test_22_d() { + // CHECK: if constexpr (B) { + // CHECK-NEXT: /* + // CHECK-NEXT: DPCT1065:{{[0-9]+}}: Consider replacing sycl::nd_item::barrier() with sycl::nd_item::barrier(sycl::access::fence_space::local_space) for better performance if there is no access to global memory. + // CHECK-NEXT: */ + // CHECK-NEXT: sycl::ext::oneapi::this_work_item::get_nd_item<3>().barrier(); + // CHECK-NEXT: } + if constexpr (B) { + __syncthreads(); + } +} + +__global__ void test_22() { + test_22_d(); + test_22_d(); +}