Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 15 additions & 23 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -1038,24 +1038,18 @@ template <class NameT> class TypeNamePrinter {
}
};

template <class BaseT, class MemberT, bool HasExplicitTemplateArg>
template <class BaseT, class MemberT, bool NeedDisambiguator>
class MemberExprPrinter {
BaseT Base;
bool IsArrow;
MemberT MemberName;
bool IsBaseDependentType = false;

public:
MemberExprPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName)
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {
if constexpr (std::is_same_v<BaseT, const Expr *>) {
IsBaseDependentType = Base->getType()->isDependentType();
}
}
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {}

template <class StreamT> void print(StreamT &Stream) const {
printBase(Stream, Base, IsArrow,
HasExplicitTemplateArg && IsBaseDependentType);
printBase(Stream, Base, IsArrow, NeedDisambiguator);
dpct::print(Stream, MemberName);
}
};
Expand All @@ -1074,19 +1068,17 @@ template <class BaseT, class MemberT> class StaticMemberExprPrinter {
}
};

template <class BaseT, class MemberT, bool HasExplicitTemplateArg,
template <class BaseT, class MemberT, bool NeedDisambiguator,
class... CallArgsT>
class MemberCallPrinter
: public CallExprPrinter<
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
CallArgsT...> {
MemberExprPrinter<BaseT, MemberT, NeedDisambiguator>, CallArgsT...> {
public:
MemberCallPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName,
CallArgsT &&...Args)
: CallExprPrinter<
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
CallArgsT...>(
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>(
: CallExprPrinter<MemberExprPrinter<BaseT, MemberT, NeedDisambiguator>,
CallArgsT...>(
MemberExprPrinter<BaseT, MemberT, NeedDisambiguator>(
std::move(Base), IsArrow, std::move(MemberName)),
std::forward<CallArgsT>(Args)...) {}
};
Expand Down Expand Up @@ -1451,25 +1443,25 @@ class MemberExprRewriter
C, Source, BaseCreator(C), IsArrow, MemberCreator(C)) {}
};

template <class BaseT, bool HasExplicitTemplateArg, class... ArgsT>
template <class BaseT, bool NeedDisambiguator, class... ArgsT>
class MemberCallExprRewriter
: public PrinterRewriter<MemberCallPrinter<
BaseT, StringRef, HasExplicitTemplateArg, ArgsT...>> {
: public PrinterRewriter<
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, ArgsT...>> {
public:
MemberCallExprRewriter(
const CallExpr *C, StringRef Source,
const std::function<BaseT(const CallExpr *)> &BaseCreator, bool IsArrow,
StringRef Member,
const std::function<ArgsT(const CallExpr *)> &...ArgsCreator)
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef,
HasExplicitTemplateArg, ArgsT...>>(
: PrinterRewriter<
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, ArgsT...>>(
C, Source, BaseCreator(C), IsArrow, Member, ArgsCreator(C)...) {}
MemberCallExprRewriter(
const CallExpr *C, StringRef Source, const BaseT &BaseCreator,
bool IsArrow, StringRef Member,
const std::function<ArgsT(const CallExpr *)> &...ArgsCreator)
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef,
HasExplicitTemplateArg, ArgsT...>>(
: PrinterRewriter<
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, ArgsT...>>(
C, Source, BaseCreator, IsArrow, Member, ArgsCreator(C)...) {}
};

Expand Down
34 changes: 17 additions & 17 deletions clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,32 +439,32 @@ inline std::function<std::string(const CallExpr *)> makeDeviceStr() {
};
}

template <class BaseT, bool HasExplicitTemplateArg, class... CallArgsT>
template <class BaseT, bool NeedDisambiguator, class... CallArgsT>
using MemberCallPrinterCreator = PrinterCreator<
MemberCallPrinter<BaseT, StringRef, HasExplicitTemplateArg, CallArgsT...>,
MemberCallPrinter<BaseT, StringRef, NeedDisambiguator, CallArgsT...>,
std::function<BaseT(const CallExpr *)>, bool, std::string,
std::function<CallArgsT(const CallExpr *)>...>;

template <bool HasExplicitTemplateArg, class BaseT, class... CallArgsT>
inline std::function<MemberCallPrinter<BaseT, StringRef, HasExplicitTemplateArg,
template <bool NeedDisambiguator, class BaseT, class... CallArgsT>
inline std::function<MemberCallPrinter<BaseT, StringRef, NeedDisambiguator,
CallArgsT...>(const CallExpr *)>
makeMemberCallCreator(std::function<BaseT(const CallExpr *)> BaseFunc,
bool IsArrow, std::string Member,
std::function<CallArgsT(const CallExpr *)>... Args) {
return MemberCallPrinterCreator<BaseT, HasExplicitTemplateArg, CallArgsT...>(
return MemberCallPrinterCreator<BaseT, NeedDisambiguator, CallArgsT...>(
BaseFunc, IsArrow, Member, Args...);
}

template <bool HasExplicitTemplateArg, class BaseT, class MemberT>
template <bool NeedDisambiguator, class BaseT, class MemberT>
inline std::function<
MemberCallPrinter<BaseT, MemberT, HasExplicitTemplateArg>(const CallExpr *)>
MemberCallPrinter<BaseT, MemberT, NeedDisambiguator>(const CallExpr *)>
makeMemberCallCreator(std::function<BaseT(const CallExpr *)> BaseFunc,
bool IsArrow,
std::function<MemberT(const CallExpr *)> Member) {
return PrinterCreator<
MemberCallPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
std::function<BaseT(const CallExpr *)>, bool,
std::function<MemberT(const CallExpr *)>>(BaseFunc, IsArrow, Member);
return PrinterCreator<MemberCallPrinter<BaseT, MemberT, NeedDisambiguator>,
std::function<BaseT(const CallExpr *)>, bool,
std::function<MemberT(const CallExpr *)>>(
BaseFunc, IsArrow, Member);
}

template <class... StmtT>
Expand Down Expand Up @@ -1344,15 +1344,15 @@ createTemplatedCallExprRewriterFactory(
/// \p BaseCreator use to get base expr from original call expr.
/// \p IsArrow the member operator is arrow or dot as default.
/// \p ArgsCreator use to get call args from original call expr.
template <bool HasExplicitTemplateArg, class BaseT, class... ArgsT>
template <bool NeedDisambiguator, class BaseT, class... ArgsT>
inline std::shared_ptr<CallExprRewriterFactoryBase>
createMemberCallExprRewriterFactory(
const std::string &SourceName,
std::function<BaseT(const CallExpr *)> BaseCreator, bool IsArrow,
std::string MemberName,
std::function<ArgsT(const CallExpr *)>... ArgsCreator) {
return std::make_shared<CallExprRewriterFactory<
MemberCallExprRewriter<BaseT, HasExplicitTemplateArg, ArgsT...>,
MemberCallExprRewriter<BaseT, NeedDisambiguator, ArgsT...>,
std::function<BaseT(const CallExpr *)>, bool, std::string,
std::function<ArgsT(const CallExpr *)>...>>(
SourceName,
Expand All @@ -1361,16 +1361,16 @@ createMemberCallExprRewriterFactory(
std::forward<std::function<ArgsT(const CallExpr *)>>(ArgsCreator)...);
}

template <bool HasExplicitTemplateArg, class BaseT, class... ArgsT>
template <bool NeedDisambiguator, class BaseT, class... ArgsT>
inline std::shared_ptr<std::enable_if_t<
!std::is_invocable_v<BaseT, const CallExpr *>, CallExprRewriterFactoryBase>>
createMemberCallExprRewriterFactory(
const std::string &SourceName, BaseT BaseCreator, bool IsArrow,
std::string MemberName,
std::function<ArgsT(const CallExpr *)>... ArgsCreator) {
return std::make_shared<CallExprRewriterFactory<
MemberCallExprRewriter<BaseT, HasExplicitTemplateArg, ArgsT...>, BaseT,
bool, std::string, std::function<ArgsT(const CallExpr *)>...>>(
MemberCallExprRewriter<BaseT, NeedDisambiguator, ArgsT...>, BaseT, bool,
std::string, std::function<ArgsT(const CallExpr *)>...>>(
SourceName, BaseCreator, IsArrow, MemberName,
std::forward<std::function<ArgsT(const CallExpr *)>>(ArgsCreator)...);
}
Expand Down Expand Up @@ -2252,7 +2252,7 @@ const std::string MipmapNeedBindlessImage =
#define MEMBER_CALL_FACTORY_ENTRY(FuncName, ...) \
std::make_pair(FuncName, createMemberCallExprRewriterFactory<false>( \
FuncName, __VA_ARGS__)),
#define MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(FuncName, ...) \
#define MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(FuncName, ...) \
std::make_pair(FuncName, createMemberCallExprRewriterFactory<true>( \
FuncName, __VA_ARGS__)),
#define ARRAYSUBSCRIPT_EXPR_FACTORY_ENTRY(FuncName, ...) \
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/DPCT/RulesLang/APINamesComplex.inc
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,7 @@ BINARY_OP_FACTORY_ENTRY("cuCfmaf", BinaryOperatorKind::BO_Add,
makeCallArgCreatorWithCall(1)),
makeCallArgCreatorWithCall(2))

MEMBER_CALL_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0), false, "convert<float>")
MEMBER_CALL_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0), false, "convert<double>")
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0),
false, "convert<float>")
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0),
false, "convert<double>")
85 changes: 17 additions & 68 deletions clang/lib/DPCT/RulesLang/Math/CallExprRewriterMath.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,75 +288,24 @@ std::optional<std::string> MathTypeCastRewriter::rewrite() {
const StringRef &FuncName = SourceCalleeName;
std::string ReplStr;
llvm::raw_string_ostream OS(ReplStr);

auto MigratedArg0 = getMigratedArgWithExtraParens(0);
if (FuncName == "__float22half2_rn") {
OS << MigratedArg0
<< ".convert<" + MapNames::getClNamespace() + "half, " +
MapNames::getClNamespace() + "rounding_mode::rte>()";
} else if (FuncName == "__float2half2_rn") {
OS << MapNames::getClNamespace() + "float2{" << MigratedArg0 << ","
<< MigratedArg0
<< "}.convert<" + MapNames::getClNamespace() + "half, " +
MapNames::getClNamespace() + "rounding_mode::rte>()";
} else if (FuncName == "__floats2half2_rn") {
auto MigratedArg1 = getMigratedArg(1);
OS << MapNames::getClNamespace() + "float2{" << MigratedArg0 << ","
<< MigratedArg1
<< "}.convert<" + MapNames::getClNamespace() + "half, " +
MapNames::getClNamespace() + "rounding_mode::rte>()";
} else if (FuncName == "__half22float2") {
OS << MigratedArg0
<< ".convert<float, " + MapNames::getClNamespace() +
"rounding_mode::automatic>()";
} else if (FuncName == "__half2half2") {
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << ","
<< MigratedArg0 << "}";
} else if (FuncName == "__halves2half2") {
auto MigratedArg1 = getMigratedArg(1);
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << ","
<< MigratedArg1 << "}";
} else if (FuncName == "__high2half") {
OS << MigratedArg0 << "[0]";
} else if (FuncName == "__high2half2") {
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[0], "
<< MigratedArg0 << "[0]}";
} else if (FuncName == "__highs2half2") {
auto MigratedArg1 = getMigratedArgWithExtraParens(1);
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[0], "
<< MigratedArg1 << "[0]}";
} else if (FuncName == "__low2half") {
OS << MigratedArg0 << "[1]";
} else if (FuncName == "__low2half2") {
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], "
<< MigratedArg0 << "[1]}";
} else if (FuncName == "__lowhigh2highlow") {
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], "
<< MigratedArg0 << "[0]}";
} else if (FuncName == "__lows2half2") {
auto MigratedArg1 = getMigratedArgWithExtraParens(1);
OS << MapNames::getClNamespace() + "half2{" << MigratedArg0 << "[1], "
<< MigratedArg1 << "[1]}";
} else {
//__half2short_rd and __half2float
static SSMap TypeMap{{"ll", "long long"},
{"ull", "unsigned long long"},
{"ushort", "unsigned short"},
{"uint", "unsigned int"},
{"half", MapNames::getClNamespace() + "half"}};
std::string RoundingMode;
if (FuncName[FuncName.size() - 3] == '_')
RoundingMode = FuncName.substr(FuncName.size() - 2).str();
auto FN = FuncName.substr(2, FuncName.find('_', 2) - 2).str();
auto Types = split(FN, '2');
assert(Types.size() == 2);
MapNames::replaceName(TypeMap, Types[0]);
MapNames::replaceName(TypeMap, Types[1]);
OS << MapNames::getClNamespace() + "vec<" << Types[0] << ", 1>{"
<< MigratedArg0 << "}.convert<" << Types[1]
<< ", " + MapNames::getClNamespace() + "rounding_mode::"
<< RoundingModeMap[RoundingMode] << ">()[0]";
}
static SSMap TypeMap{{"ll", "long long"},
{"ull", "unsigned long long"},
{"ushort", "unsigned short"},
{"uint", "unsigned int"},
{"half", MapNames::getClNamespace() + "half"}};
std::string RoundingMode;
if (FuncName[FuncName.size() - 3] == '_')
RoundingMode = FuncName.substr(FuncName.size() - 2).str();
auto FN = FuncName.substr(2, FuncName.find('_', 2) - 2).str();
auto Types = split(FN, '2');
assert(Types.size() == 2);
MapNames::replaceName(TypeMap, Types[0]);
MapNames::replaceName(TypeMap, Types[1]);
OS << MapNames::getClNamespace() + "vec<" << Types[0] << ", 1>{"
<< MigratedArg0 << "}.convert<" << Types[1]
<< ", " + MapNames::getClNamespace() + "rounding_mode::"
<< RoundingModeMap[RoundingMode] << ">()[0]";
OS.flush();
return ReplStr;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
CALL(MapNames::getClNamespace() +
"ext::intel::math::float2half_rn",
MEMBER_CALL(ARG(0), false, "y"))))),
MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(
"__float22half2_rn", ARG(0), false,
"convert<" + MapNames::getClNamespace() + "half, " +
MapNames::getClNamespace() + "rounding_mode::rte>"))
Expand Down Expand Up @@ -168,7 +168,7 @@ RewriterMap dpct::createHalfPrecisionConversionAndDataMovementRewriterMap() {
CALL(MapNames::getClNamespace() +
"ext::intel::math::half2float",
MEMBER_CALL(ARG(0), false, "y"))))),
MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(
MEMBER_CALL_WITH_DISAMBIGUATOR_FACTORY_ENTRY(
"__half22float2", ARG(0), false,
"convert<float, " + MapNames::getClNamespace() +
"rounding_mode::automatic>"))
Expand Down
4 changes: 2 additions & 2 deletions clang/test/dpct/complex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -290,9 +290,9 @@ int main() {
auto a24 = COMPLEX_F_FMA(f1, f2, f3);
r = r && check(a24, expect, index);

// CHECK: f1 = d1.convert<float>();
// CHECK: f1 = d1.template convert<float>();
f1 = cuComplexDoubleToFloat(d1);
// CHECK: d1 = f1.convert<double>();
// CHECK: d1 = f1.template convert<double>();
d1 = cuComplexFloatToDouble(f1);

int *result = nullptr;
Expand Down
2 changes: 1 addition & 1 deletion clang/test/dpct/math/cuda-math-need-paren.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ using namespace std;

void __global__ kernel() {
half2 h2;
// CHECK: (h2 + h2).convert<float, sycl::rounding_mode::automatic>();
// CHECK: (h2 + h2).template convert<float, sycl::rounding_mode::automatic>();
__half22float2(__hadd2(h2, h2));
}

Expand Down
4 changes: 2 additions & 2 deletions clang/test/dpct/math/half/half.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ __global__ void kernelFuncHalfConversion() {
unsigned u;
unsigned long long ull;
unsigned short us;
// CHECK: h2 = f2.convert<sycl::half, sycl::rounding_mode::rte>();
// CHECK: h2 = f2.template convert<sycl::half, sycl::rounding_mode::rte>();
h2 = __float22half2_rn(f2);
// CHECK: h = sycl::vec<float, 1>(f).convert<sycl::half, sycl::rounding_mode::automatic>()[0];
h = __float2half(f);
Expand All @@ -31,7 +31,7 @@ __global__ void kernelFuncHalfConversion() {
h = __float2half_rz(f);
// CHECK: h2 = sycl::float2(f, f).convert<sycl::half, sycl::rounding_mode::rte>();
h2 = __floats2half2_rn(f, f);
// CHECK: f2 = h2.convert<float, sycl::rounding_mode::automatic>();
// CHECK: f2 = h2.template convert<float, sycl::rounding_mode::automatic>();
f2 = __half22float2(h2);
// CHECK: f = sycl::vec<sycl::half, 1>(h).convert<float, sycl::rounding_mode::automatic>()[0];
f = __half2float(h);
Expand Down
4 changes: 2 additions & 2 deletions clang/test/dpct/query_api_mapping/NoLib/test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,13 @@
// CUCOMPLEXDOUBLETOFLOAT: CUDA API:
// CUCOMPLEXDOUBLETOFLOAT-NEXT: cuComplexDoubleToFloat(c /*cuDoubleComplex*/);
// CUCOMPLEXDOUBLETOFLOAT-NEXT: Is migrated to:
// CUCOMPLEXDOUBLETOFLOAT-NEXT: c.convert<float>();
// CUCOMPLEXDOUBLETOFLOAT-NEXT: c.template convert<float>();

// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cuComplexFloatToDouble | FileCheck %s -check-prefix=CUCOMPLEXFLOATTODOUBLE
// CUCOMPLEXFLOATTODOUBLE: CUDA API:
// CUCOMPLEXFLOATTODOUBLE-NEXT: cuComplexFloatToDouble(c /*cuFloatComplex*/);
// CUCOMPLEXFLOATTODOUBLE-NEXT: Is migrated to:
// CUCOMPLEXFLOATTODOUBLE-NEXT: c.convert<double>();
// CUCOMPLEXFLOATTODOUBLE-NEXT: c.template convert<double>();

// RUN: dpct --cuda-include-path="%cuda-path/include" --query-api-mapping=cuConj | FileCheck %s -check-prefix=CUCONJ
// CUCONJ: CUDA API:
Expand Down