From 2f231d6f3fd073e05eeeca8f9095d570d8244156 Mon Sep 17 00:00:00 2001 From: "Jiang, Zhiwei" Date: Tue, 6 May 2025 11:03:52 +0800 Subject: [PATCH] [SYCLomatic] Always add "template" before sycl::vec::convert for 4 APIs cuComplexDoubleToFloat cuComplexDoubleToFloat __float22half2_rn __half22float2 Signed-off-by: Jiang, Zhiwei --- clang/lib/DPCT/RuleInfra/CallExprRewriter.h | 38 ++++----- .../DPCT/RuleInfra/CallExprRewriterCommon.h | 34 ++++---- clang/lib/DPCT/RulesLang/APINamesComplex.inc | 6 +- .../RulesLang/Math/CallExprRewriterMath.cpp | 85 ++++--------------- ...HalfPrecisionConversionAndDataMovement.cpp | 4 +- clang/test/dpct/complex.cu | 4 +- clang/test/dpct/math/cuda-math-need-paren.cu | 2 +- clang/test/dpct/math/half/half.cu | 4 +- .../test/dpct/query_api_mapping/NoLib/test.cu | 4 +- 9 files changed, 62 insertions(+), 119 deletions(-) diff --git a/clang/lib/DPCT/RuleInfra/CallExprRewriter.h b/clang/lib/DPCT/RuleInfra/CallExprRewriter.h index 8142ccd3def9..b8d859a0e357 100644 --- a/clang/lib/DPCT/RuleInfra/CallExprRewriter.h +++ b/clang/lib/DPCT/RuleInfra/CallExprRewriter.h @@ -1038,24 +1038,18 @@ template class TypeNamePrinter { } }; -template +template class MemberExprPrinter { BaseT Base; bool IsArrow; MemberT MemberName; - bool IsBaseDependentType = false; public: MemberExprPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName) - : Base(Base), IsArrow(IsArrow), MemberName(MemberName) { - if constexpr (std::is_same_v) { - IsBaseDependentType = Base->getType()->isDependentType(); - } - } + : Base(Base), IsArrow(IsArrow), MemberName(MemberName) {} template void print(StreamT &Stream) const { - printBase(Stream, Base, IsArrow, - HasExplicitTemplateArg && IsBaseDependentType); + printBase(Stream, Base, IsArrow, NeedDisambiguator); dpct::print(Stream, MemberName); } }; @@ -1074,19 +1068,17 @@ template class StaticMemberExprPrinter { } }; -template class MemberCallPrinter : public CallExprPrinter< - MemberExprPrinter, - CallArgsT...> { + MemberExprPrinter, CallArgsT...> { public: MemberCallPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName, CallArgsT &&...Args) - : CallExprPrinter< - MemberExprPrinter, - CallArgsT...>( - MemberExprPrinter( + : CallExprPrinter, + CallArgsT...>( + MemberExprPrinter( std::move(Base), IsArrow, std::move(MemberName)), std::forward(Args)...) {} }; @@ -1451,25 +1443,25 @@ class MemberExprRewriter C, Source, BaseCreator(C), IsArrow, MemberCreator(C)) {} }; -template +template class MemberCallExprRewriter - : public PrinterRewriter> { + : public PrinterRewriter< + MemberCallPrinter> { public: MemberCallExprRewriter( const CallExpr *C, StringRef Source, const std::function &BaseCreator, bool IsArrow, StringRef Member, const std::function &...ArgsCreator) - : PrinterRewriter>( + : PrinterRewriter< + MemberCallPrinter>( C, Source, BaseCreator(C), IsArrow, Member, ArgsCreator(C)...) {} MemberCallExprRewriter( const CallExpr *C, StringRef Source, const BaseT &BaseCreator, bool IsArrow, StringRef Member, const std::function &...ArgsCreator) - : PrinterRewriter>( + : PrinterRewriter< + MemberCallPrinter>( C, Source, BaseCreator, IsArrow, Member, ArgsCreator(C)...) {} }; diff --git a/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h b/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h index 32818f06fff7..f7a2c8e545fe 100644 --- a/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h +++ b/clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h @@ -439,32 +439,32 @@ inline std::function makeDeviceStr() { }; } -template +template using MemberCallPrinterCreator = PrinterCreator< - MemberCallPrinter, + MemberCallPrinter, std::function, bool, std::string, std::function...>; -template -inline std::function +inline std::function(const CallExpr *)> makeMemberCallCreator(std::function BaseFunc, bool IsArrow, std::string Member, std::function... Args) { - return MemberCallPrinterCreator( + return MemberCallPrinterCreator( BaseFunc, IsArrow, Member, Args...); } -template +template inline std::function< - MemberCallPrinter(const CallExpr *)> + MemberCallPrinter(const CallExpr *)> makeMemberCallCreator(std::function BaseFunc, bool IsArrow, std::function Member) { - return PrinterCreator< - MemberCallPrinter, - std::function, bool, - std::function>(BaseFunc, IsArrow, Member); + return PrinterCreator, + std::function, bool, + std::function>( + BaseFunc, IsArrow, Member); } template @@ -1344,7 +1344,7 @@ createTemplatedCallExprRewriterFactory( /// \p BaseCreator use to get base expr from original call expr. /// \p IsArrow the member operator is arrow or dot as default. /// \p ArgsCreator use to get call args from original call expr. -template +template inline std::shared_ptr createMemberCallExprRewriterFactory( const std::string &SourceName, @@ -1352,7 +1352,7 @@ createMemberCallExprRewriterFactory( std::string MemberName, std::function... ArgsCreator) { return std::make_shared, + MemberCallExprRewriter, std::function, bool, std::string, std::function...>>( SourceName, @@ -1361,7 +1361,7 @@ createMemberCallExprRewriterFactory( std::forward>(ArgsCreator)...); } -template +template inline std::shared_ptr, CallExprRewriterFactoryBase>> createMemberCallExprRewriterFactory( @@ -1369,8 +1369,8 @@ createMemberCallExprRewriterFactory( std::string MemberName, std::function... ArgsCreator) { return std::make_shared, BaseT, - bool, std::string, std::function...>>( + MemberCallExprRewriter, BaseT, bool, + std::string, std::function...>>( SourceName, BaseCreator, IsArrow, MemberName, std::forward>(ArgsCreator)...); } @@ -2252,7 +2252,7 @@ const std::string MipmapNeedBindlessImage = #define MEMBER_CALL_FACTORY_ENTRY(FuncName, ...) \ std::make_pair(FuncName, createMemberCallExprRewriterFactory( \ FuncName, __VA_ARGS__)), -#define MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(FuncName, ...) \ +#define MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(FuncName, ...) \ std::make_pair(FuncName, createMemberCallExprRewriterFactory( \ FuncName, __VA_ARGS__)), #define ARRAYSUBSCRIPT_EXPR_FACTORY_ENTRY(FuncName, ...) \ diff --git a/clang/lib/DPCT/RulesLang/APINamesComplex.inc b/clang/lib/DPCT/RulesLang/APINamesComplex.inc index 7152711b6146..df34b915848b 100644 --- a/clang/lib/DPCT/RulesLang/APINamesComplex.inc +++ b/clang/lib/DPCT/RulesLang/APINamesComplex.inc @@ -95,5 +95,7 @@ BINARY_OP_FACTORY_ENTRY("cuCfmaf", BinaryOperatorKind::BO_Add, makeCallArgCreatorWithCall(1)), makeCallArgCreatorWithCall(2)) -MEMBER_CALL_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0), false, "convert") -MEMBER_CALL_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0), false, "convert") +MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0), + false, "convert") +MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0), + false, "convert") diff --git a/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp b/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp index 94e58fc5c6df..6e33118d45c2 100644 --- a/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp +++ b/clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp @@ -288,75 +288,24 @@ std::optional MathTypeCastRewriter::rewrite() { const StringRef &FuncName = SourceCalleeName; std::string ReplStr; llvm::raw_string_ostream OS(ReplStr); - auto MigratedArg0 = getMigratedArgWithExtraParens(0); - if (FuncName == "__float22half2_rn") { - OS << MigratedArg0 - << ".convert<" + MapNames::getClNamespace() + "half, " + - MapNames::getClNamespace() + "rounding_mode::rte>()"; - } else if (FuncName == "__float2half2_rn") { - OS << MapNames::getClNamespace() + "float2{" << MigratedArg0 << "," - << MigratedArg0 - << "}.convert<" + MapNames::getClNamespace() + "half, " + - MapNames::getClNamespace() + "rounding_mode::rte>()"; - } else if (FuncName == "__floats2half2_rn") { - auto MigratedArg1 = getMigratedArg(1); - OS << MapNames::getClNamespace() + "float2{" << MigratedArg0 << "," - << MigratedArg1 - << "}.convert<" + MapNames::getClNamespace() + "half, " + - MapNames::getClNamespace() + "rounding_mode::rte>()"; - } else if (FuncName == "__half22float2") { - OS << MigratedArg0 - << ".convert()"; - } else if (FuncName == "__half2half2") { - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "," - << MigratedArg0 << "}"; - } else if (FuncName == "__halves2half2") { - auto MigratedArg1 = getMigratedArg(1); - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "," - << MigratedArg1 << "}"; - } else if (FuncName == "__high2half") { - OS << MigratedArg0 << "[0]"; - } else if (FuncName == "__high2half2") { - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[0], " - << MigratedArg0 << "[0]}"; - } else if (FuncName == "__highs2half2") { - auto MigratedArg1 = getMigratedArgWithExtraParens(1); - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[0], " - << MigratedArg1 << "[0]}"; - } else if (FuncName == "__low2half") { - OS << MigratedArg0 << "[1]"; - } else if (FuncName == "__low2half2") { - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], " - << MigratedArg0 << "[1]}"; - } else if (FuncName == "__lowhigh2highlow") { - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], " - << MigratedArg0 << "[0]}"; - } else if (FuncName == "__lows2half2") { - auto MigratedArg1 = getMigratedArgWithExtraParens(1); - OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], " - << MigratedArg1 << "[1]}"; - } else { - //__half2short_rd and __half2float - static SSMap TypeMap{{"ll", "long long"}, - {"ull", "unsigned long long"}, - {"ushort", "unsigned short"}, - {"uint", "unsigned int"}, - {"half", MapNames::getClNamespace() + "half"}}; - std::string RoundingMode; - if (FuncName[FuncName.size() - 3] == '_') - RoundingMode = FuncName.substr(FuncName.size() - 2).str(); - auto FN = FuncName.substr(2, FuncName.find('_', 2) - 2).str(); - auto Types = split(FN, '2'); - assert(Types.size() == 2); - MapNames::replaceName(TypeMap, Types[0]); - MapNames::replaceName(TypeMap, Types[1]); - OS << MapNames::getClNamespace() + "vec<" << Types[0] << ", 1>{" - << MigratedArg0 << "}.convert<" << Types[1] - << ", " + MapNames::getClNamespace() + "rounding_mode::" - << RoundingModeMap[RoundingMode] << ">()[0]"; - } + static SSMap TypeMap{{"ll", "long long"}, + {"ull", "unsigned long long"}, + {"ushort", "unsigned short"}, + {"uint", "unsigned int"}, + {"half", MapNames::getClNamespace() + "half"}}; + std::string RoundingMode; + if (FuncName[FuncName.size() - 3] == '_') + RoundingMode = FuncName.substr(FuncName.size() - 2).str(); + auto FN = FuncName.substr(2, FuncName.find('_', 2) - 2).str(); + auto Types = split(FN, '2'); + assert(Types.size() == 2); + MapNames::replaceName(TypeMap, Types[0]); + MapNames::replaceName(TypeMap, Types[1]); + OS << MapNames::getClNamespace() + "vec<" << Types[0] << ", 1>{" + << MigratedArg0 << "}.convert<" << Types[1] + << ", " + MapNames::getClNamespace() + "rounding_mode::" + << RoundingModeMap[RoundingMode] << ">()[0]"; OS.flush(); return ReplStr; } diff --git a/clang/lib/DPCT/RulesLang/Math/RewriterHalfPrecisionConversionAndDataMovement.cpp b/clang/lib/DPCT/RulesLang/Math/RewriterHalfPrecisionConversionAndDataMovement.cpp index 8d6607742957..b153b0050dbc 100644 --- a/clang/lib/DPCT/RulesLang/Math/RewriterHalfPrecisionConversionAndDataMovement.cpp +++ b/clang/lib/DPCT/RulesLang/Math/RewriterHalfPrecisionConversionAndDataMovement.cpp @@ -30,7 +30,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() { CALL(MapNames::getClNamespace() + "ext::intel::math::float2half_rn", MEMBER_CALL(ARG(0), false, "y"))))), - MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY( + MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY( "__float22half2_rn", ARG(0), false, "convert<" + MapNames::getClNamespace() + "half, " + MapNames::getClNamespace() + "rounding_mode::rte>")) @@ -168,7 +168,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() { CALL(MapNames::getClNamespace() + "ext::intel::math::half2float", MEMBER_CALL(ARG(0), false, "y"))))), - MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY( + MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY( "__half22float2", ARG(0), false, "convert")) diff --git a/clang/test/dpct/complex.cu b/clang/test/dpct/complex.cu index 2d4542e4020b..4c62aebe29e2 100644 --- a/clang/test/dpct/complex.cu +++ b/clang/test/dpct/complex.cu @@ -290,9 +290,9 @@ int main() { auto a24 = COMPLEX_F_FMA(f1, f2, f3); r = r && check(a24, expect, index); - // CHECK: f1 = d1.convert(); + // CHECK: f1 = d1.template convert(); f1 = cuComplexDoubleToFloat(d1); - // CHECK: d1 = f1.convert(); + // CHECK: d1 = f1.template convert(); d1 = cuComplexFloatToDouble(f1); int *result = nullptr; diff --git a/clang/test/dpct/math/cuda-math-need-paren.cu b/clang/test/dpct/math/cuda-math-need-paren.cu index f3342fa2bb86..bcbb8a2becb4 100644 --- a/clang/test/dpct/math/cuda-math-need-paren.cu +++ b/clang/test/dpct/math/cuda-math-need-paren.cu @@ -8,7 +8,7 @@ using namespace std; void __global__ kernel() { half2 h2; - // CHECK: (h2 + h2).convert(); + // CHECK: (h2 + h2).template convert(); __half22float2(__hadd2(h2, h2)); } diff --git a/clang/test/dpct/math/half/half.cu b/clang/test/dpct/math/half/half.cu index ab2d3f989543..6b301a5ac899 100644 --- a/clang/test/dpct/math/half/half.cu +++ b/clang/test/dpct/math/half/half.cu @@ -15,7 +15,7 @@ __global__ void kernelFuncHalfConversion() { unsigned u; unsigned long long ull; unsigned short us; - // CHECK: h2 = f2.convert(); + // CHECK: h2 = f2.template convert(); h2 = __float22half2_rn(f2); // CHECK: h = sycl::vec(f).convert()[0]; h = __float2half(f); @@ -31,7 +31,7 @@ __global__ void kernelFuncHalfConversion() { h = __float2half_rz(f); // CHECK: h2 = sycl::float2(f, f).convert(); h2 = __floats2half2_rn(f, f); - // CHECK: f2 = h2.convert(); + // CHECK: f2 = h2.template convert(); f2 = __half22float2(h2); // CHECK: f = sycl::vec(h).convert()[0]; f = __half2float(h); diff --git a/clang/test/dpct/query_api_mapping/NoLib/test.cu b/clang/test/dpct/query_api_mapping/NoLib/test.cu index a6171d9fb5ef..4df53a6da16b 100644 --- a/clang/test/dpct/query_api_mapping/NoLib/test.cu +++ b/clang/test/dpct/query_api_mapping/NoLib/test.cu @@ -89,13 +89,13 @@ // CUCOMPLEXDOUBLETOFLOAT: CUDA API: // CUCOMPLEXDOUBLETOFLOAT-NEXT: cuComplexDoubleToFloat(c /*cuDoubleComplex*/); // CUCOMPLEXDOUBLETOFLOAT-NEXT: Is migrated to: -// CUCOMPLEXDOUBLETOFLOAT-NEXT: c.convert(); +// CUCOMPLEXDOUBLETOFLOAT-NEXT: c.template convert(); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cuComplexFloatToDouble | FileCheck %s -check-prefix=CUCOMPLEXFLOATTODOUBLE // CUCOMPLEXFLOATTODOUBLE: CUDA API: // CUCOMPLEXFLOATTODOUBLE-NEXT: cuComplexFloatToDouble(c /*cuFloatComplex*/); // CUCOMPLEXFLOATTODOUBLE-NEXT: Is migrated to: -// CUCOMPLEXFLOATTODOUBLE-NEXT: c.convert(); +// CUCOMPLEXFLOATTODOUBLE-NEXT: c.template convert(); // RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cuConj | FileCheck %s -check-prefix=CUCONJ // CUCONJ: CUDA API: