Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 17 additions & 42 deletions clang/lib/DPCT/RuleInfra/CallExprRewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -702,10 +702,6 @@ template <class StreamT> void printMemberOp(StreamT &Stream, bool IsArrow) {
Stream << ".";
}

template <class StreamT> void printDisambiguator(StreamT &Stream) {
Stream << "template ";
}

template <class StreamT>
void printCapture(StreamT &Stream, bool IsCaptureRef) {
if (IsCaptureRef)
Expand Down Expand Up @@ -941,45 +937,34 @@ class ArgsPrinter<HasPrefixArg, FirstArgT, RestArgsT...>

template <class StreamT>
void printBase(StreamT &Stream, std::pair<const CallExpr *, const Expr *> P,
bool IsArrow, bool NeedDisambiguator) {
bool IsArrow) {
{
std::unique_ptr<ParensPrinter<StreamT>> Paren;
if (needExtraParensInMemberExpr(P.second))
Paren = std::make_unique<ParensPrinter<StreamT>>(Stream);
print(Stream, P);
}
printMemberOp(Stream, IsArrow);
if (NeedDisambiguator)
printDisambiguator(Stream);
}

template <class StreamT>
void printBase(StreamT &Stream, const Expr *E, bool IsArrow,
bool NeedDisambiguator) {
void printBase(StreamT &Stream, const Expr *E, bool IsArrow) {
{
std::unique_ptr<ParensPrinter<StreamT>> Paren;
if (needExtraParensInMemberExpr(E))
Paren = std::make_unique<ParensPrinter<StreamT>>(Stream);
print(Stream, E);
}
printMemberOp(Stream, IsArrow);
if (NeedDisambiguator)
printDisambiguator(Stream);
}
template <class StreamT>
void printBase(StreamT &Stream, const DerefExpr &D, bool,
bool NeedDisambiguator) {
void printBase(StreamT &Stream, const DerefExpr &D, bool) {
D.printMemberBase(Stream);
if (NeedDisambiguator)
printDisambiguator(Stream);
}
template <class StreamT, class T>
void printBase(StreamT &Stream, const T &Val, bool IsArrow,
bool NeedDisambiguator) {
void printBase(StreamT &Stream, const T &Val, bool IsArrow) {
print(Stream, Val);
printMemberOp(Stream, IsArrow);
if (NeedDisambiguator)
printDisambiguator(Stream);
}

template <class CalleeT, class... CallArgsT> class CallExprPrinter {
Expand Down Expand Up @@ -1038,24 +1023,17 @@ template <class NameT> class TypeNamePrinter {
}
};

template <class BaseT, class MemberT, bool HasExplicitTemplateArg>
class MemberExprPrinter {
template <class BaseT, class MemberT> class MemberExprPrinter {
BaseT Base;
bool IsArrow;
MemberT MemberName;
bool IsBaseDependentType = false;

public:
MemberExprPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName)
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {
if constexpr (std::is_same_v<BaseT, const Expr *>) {
IsBaseDependentType = Base->getType()->isDependentType();
}
}
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {}

template <class StreamT> void print(StreamT &Stream) const {
printBase(Stream, Base, IsArrow,
HasExplicitTemplateArg && IsBaseDependentType);
printBase(Stream, Base, IsArrow);
dpct::print(Stream, MemberName);
}
};
Expand All @@ -1074,19 +1052,19 @@ template <class BaseT, class MemberT> class StaticMemberExprPrinter {
}
};

template <class BaseT, class MemberT, bool HasExplicitTemplateArg,
template <class BaseT, class MemberT,
class... CallArgsT>
class MemberCallPrinter
: public CallExprPrinter<
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
MemberExprPrinter<BaseT, MemberT>,
CallArgsT...> {
public:
MemberCallPrinter(const BaseT &Base, bool IsArrow, MemberT MemberName,
CallArgsT &&...Args)
: CallExprPrinter<
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
MemberExprPrinter<BaseT, MemberT>,
CallArgsT...>(
MemberExprPrinter<BaseT, MemberT, HasExplicitTemplateArg>(
MemberExprPrinter<BaseT, MemberT>(
std::move(Base), IsArrow, std::move(MemberName)),
std::forward<CallArgsT>(Args)...) {}
};
Expand Down Expand Up @@ -1441,35 +1419,32 @@ class TemplatedCallExprRewriter

template <class BaseT, class MemberT>
class MemberExprRewriter
: public PrinterRewriter<MemberExprPrinter<BaseT, MemberT, false>> {
: public PrinterRewriter<MemberExprPrinter<BaseT, MemberT>> {
public:
MemberExprRewriter(
const CallExpr *C, StringRef Source,
const std::function<BaseT(const CallExpr *)> &BaseCreator, bool IsArrow,
const std::function<MemberT(const CallExpr *)> &MemberCreator)
: PrinterRewriter<MemberExprPrinter<BaseT, MemberT, false>>(
: PrinterRewriter<MemberExprPrinter<BaseT, MemberT>>(
C, Source, BaseCreator(C), IsArrow, MemberCreator(C)) {}
};

template <class BaseT, bool HasExplicitTemplateArg, class... ArgsT>
template <class BaseT, class... ArgsT>
class MemberCallExprRewriter
: public PrinterRewriter<MemberCallPrinter<
BaseT, StringRef, HasExplicitTemplateArg, ArgsT...>> {
: public PrinterRewriter<MemberCallPrinter<BaseT, StringRef, ArgsT...>> {
public:
MemberCallExprRewriter(
const CallExpr *C, StringRef Source,
const std::function<BaseT(const CallExpr *)> &BaseCreator, bool IsArrow,
StringRef Member,
const std::function<ArgsT(const CallExpr *)> &...ArgsCreator)
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef,
HasExplicitTemplateArg, ArgsT...>>(
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef, ArgsT...>>(
C, Source, BaseCreator(C), IsArrow, Member, ArgsCreator(C)...) {}
MemberCallExprRewriter(
const CallExpr *C, StringRef Source, const BaseT &BaseCreator,
bool IsArrow, StringRef Member,
const std::function<ArgsT(const CallExpr *)> &...ArgsCreator)
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef,
HasExplicitTemplateArg, ArgsT...>>(
: PrinterRewriter<MemberCallPrinter<BaseT, StringRef, ArgsT...>>(
C, Source, BaseCreator, IsArrow, Member, ArgsCreator(C)...) {}
};

Expand Down
66 changes: 30 additions & 36 deletions clang/lib/DPCT/RuleInfra/CallExprRewriterCommon.h
Original file line number Diff line number Diff line change
Expand Up @@ -439,32 +439,31 @@ inline std::function<std::string(const CallExpr *)> makeDeviceStr() {
};
}

template <class BaseT, bool HasExplicitTemplateArg, class... CallArgsT>
using MemberCallPrinterCreator = PrinterCreator<
MemberCallPrinter<BaseT, StringRef, HasExplicitTemplateArg, CallArgsT...>,
std::function<BaseT(const CallExpr *)>, bool, std::string,
std::function<CallArgsT(const CallExpr *)>...>;

template <bool HasExplicitTemplateArg, class BaseT, class... CallArgsT>
inline std::function<MemberCallPrinter<BaseT, StringRef, HasExplicitTemplateArg,
CallArgsT...>(const CallExpr *)>
template <class BaseT, class... CallArgsT>
using MemberCallPrinterCreator =
PrinterCreator<MemberCallPrinter<BaseT, StringRef, CallArgsT...>,
std::function<BaseT(const CallExpr *)>, bool, std::string,
std::function<CallArgsT(const CallExpr *)>...>;

template <class BaseT, class... CallArgsT>
inline std::function<
MemberCallPrinter<BaseT, StringRef, CallArgsT...>(const CallExpr *)>
makeMemberCallCreator(std::function<BaseT(const CallExpr *)> BaseFunc,
bool IsArrow, std::string Member,
std::function<CallArgsT(const CallExpr *)>... Args) {
return MemberCallPrinterCreator<BaseT, HasExplicitTemplateArg, CallArgsT...>(
BaseFunc, IsArrow, Member, Args...);
return MemberCallPrinterCreator<BaseT, CallArgsT...>(BaseFunc, IsArrow,
Member, Args...);
}

template <bool HasExplicitTemplateArg, class BaseT, class MemberT>
inline std::function<
MemberCallPrinter<BaseT, MemberT, HasExplicitTemplateArg>(const CallExpr *)>
template <class BaseT, class MemberT>
inline std::function<MemberCallPrinter<BaseT, MemberT>(const CallExpr *)>
makeMemberCallCreator(std::function<BaseT(const CallExpr *)> BaseFunc,
bool IsArrow,
std::function<MemberT(const CallExpr *)> Member) {
return PrinterCreator<
MemberCallPrinter<BaseT, MemberT, HasExplicitTemplateArg>,
std::function<BaseT(const CallExpr *)>, bool,
std::function<MemberT(const CallExpr *)>>(BaseFunc, IsArrow, Member);
return PrinterCreator<MemberCallPrinter<BaseT, MemberT>,
std::function<BaseT(const CallExpr *)>, bool,
std::function<MemberT(const CallExpr *)>>(
BaseFunc, IsArrow, Member);
}

template <class... StmtT>
Expand Down Expand Up @@ -691,10 +690,10 @@ makeArgWithAddressSpaceCast(int ArgIdx) {
}

template <class BaseT, class MemberT>
inline std::function<MemberExprPrinter<BaseT, MemberT, false>(const CallExpr *)>
inline std::function<MemberExprPrinter<BaseT, MemberT>(const CallExpr *)>
makeMemberExprCreator(std::function<BaseT(const CallExpr *)> Base, bool IsArrow,
std::function<MemberT(const CallExpr *)> Member) {
return PrinterCreator<MemberExprPrinter<BaseT, MemberT, false>,
return PrinterCreator<MemberExprPrinter<BaseT, MemberT>,
std::function<BaseT(const CallExpr *)>, bool,
std::function<MemberT(const CallExpr *)>>(Base, IsArrow,
Member);
Expand Down Expand Up @@ -1344,15 +1343,15 @@ createTemplatedCallExprRewriterFactory(
/// \p BaseCreator use to get base expr from original call expr.
/// \p IsArrow the member operator is arrow or dot as default.
/// \p ArgsCreator use to get call args from original call expr.
template <bool HasExplicitTemplateArg, class BaseT, class... ArgsT>
template <class BaseT, class... ArgsT>
inline std::shared_ptr<CallExprRewriterFactoryBase>
createMemberCallExprRewriterFactory(
const std::string &SourceName,
std::function<BaseT(const CallExpr *)> BaseCreator, bool IsArrow,
std::string MemberName,
std::function<ArgsT(const CallExpr *)>... ArgsCreator) {
return std::make_shared<CallExprRewriterFactory<
MemberCallExprRewriter<BaseT, HasExplicitTemplateArg, ArgsT...>,
MemberCallExprRewriter<BaseT, ArgsT...>,
std::function<BaseT(const CallExpr *)>, bool, std::string,
std::function<ArgsT(const CallExpr *)>...>>(
SourceName,
Expand All @@ -1361,16 +1360,16 @@ createMemberCallExprRewriterFactory(
std::forward<std::function<ArgsT(const CallExpr *)>>(ArgsCreator)...);
}

template <bool HasExplicitTemplateArg, class BaseT, class... ArgsT>
template <class BaseT, class... ArgsT>
inline std::shared_ptr<std::enable_if_t<
!std::is_invocable_v<BaseT, const CallExpr *>, CallExprRewriterFactoryBase>>
createMemberCallExprRewriterFactory(
const std::string &SourceName, BaseT BaseCreator, bool IsArrow,
std::string MemberName,
std::function<ArgsT(const CallExpr *)>... ArgsCreator) {
return std::make_shared<CallExprRewriterFactory<
MemberCallExprRewriter<BaseT, HasExplicitTemplateArg, ArgsT...>, BaseT,
bool, std::string, std::function<ArgsT(const CallExpr *)>...>>(
MemberCallExprRewriter<BaseT, ArgsT...>, BaseT, bool, std::string,
std::function<ArgsT(const CallExpr *)>...>>(
SourceName, BaseCreator, IsArrow, MemberName,
std::forward<std::function<ArgsT(const CallExpr *)>>(ArgsCreator)...);
}
Expand Down Expand Up @@ -1635,19 +1634,19 @@ createBindTextureRewriterFactory(const std::string &Source) {

return std::make_shared<ConditionalRewriterFactory>(
makePointerChecker(StartIdx + 0),
createMemberCallExprRewriterFactory<false>(
createMemberCallExprRewriterFactory(
Source, makeDerefExprCreator(StartIdx + 0), true, "attach",
makeCallArgCreator(StartIdx + 1),
makeCallArgCreator(StartIdx + Idx + 1)...,
makeDerefExprCreator(StartIdx + 2)),
std::make_shared<ConditionalRewriterFactory>(
TypeChecker,
createMemberCallExprRewriterFactory<false>(
createMemberCallExprRewriterFactory(
Source, makeCallArgCreatorWithCall(StartIdx + 0), false, "attach",
makeCallArgCreatorWithCall(StartIdx + 1),
makeCallArgCreatorWithCall(StartIdx + Idx + 1)...,
makeCallArgCreatorWithCall(StartIdx + 2)),
createMemberCallExprRewriterFactory<false>(
createMemberCallExprRewriterFactory(
Source, makeCallArgCreatorWithCall(StartIdx + 0), false, "attach",
makeCallArgCreatorWithCall(StartIdx + 1),
makeCallArgCreatorWithCall(StartIdx + Idx)...)));
Expand Down Expand Up @@ -2184,9 +2183,7 @@ const std::string MipmapNeedBindlessImage =
#define UO(Op, E) makeUnaryOperatorCreator<Op>(E)
#define BO(Op, L, R) makeBinaryOperatorCreator<Op>(L, R)
#define PAREN(E) makeParenExprCreator(E)
#define MEMBER_CALL(...) makeMemberCallCreator<false>(__VA_ARGS__)
#define MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG(...) \
makeMemberCallCreator<true>(__VA_ARGS__)
#define MEMBER_CALL(...) makeMemberCallCreator(__VA_ARGS__)
#define MEMBER_EXPR(...) makeMemberExprCreator(__VA_ARGS__)
#define STATIC_MEMBER_EXPR(...) makeStaticMemberExprCreator(__VA_ARGS__)
#define LAMBDA(...) makeLambdaCreator(__VA_ARGS__)
Expand Down Expand Up @@ -2250,11 +2247,8 @@ const std::string MipmapNeedBindlessImage =
#define CALL_FACTORY_ENTRY(FuncName, C) \
std::make_pair(FuncName, createCallExprRewriterFactory(FuncName, C)),
#define MEMBER_CALL_FACTORY_ENTRY(FuncName, ...) \
std::make_pair(FuncName, createMemberCallExprRewriterFactory<false>( \
FuncName, __VA_ARGS__)),
#define MEMBER_CALL_HAS_EXPLICIT_TEMP_ARG_FACTORY_ENTRY(FuncName, ...) \
std::make_pair(FuncName, createMemberCallExprRewriterFactory<true>( \
FuncName, __VA_ARGS__)),
std::make_pair(FuncName, \
createMemberCallExprRewriterFactory(FuncName, __VA_ARGS__)),
#define ARRAYSUBSCRIPT_EXPR_FACTORY_ENTRY(FuncName, ...) \
std::make_pair(FuncName, createArraySubscriptExprRewriterFactory( \
FuncName, __VA_ARGS__)),
Expand Down
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RuleInfra/MemberExprRewriter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ template <class BaseT, class MemberT> class MEMemberExprPrinter {
: Base(Base), IsArrow(IsArrow), MemberName(MemberName) {}

template <class StreamT> void print(StreamT &Stream) const {
printBase(Stream, Base, IsArrow, false);
printBase(Stream, Base, IsArrow);
dpct::print(Stream, MemberName);
}
};
Expand Down
6 changes: 4 additions & 2 deletions clang/lib/DPCT/RulesLang/APINamesComplex.inc
Original file line number Diff line number Diff line change
Expand Up @@ -95,5 +95,7 @@ BINARY_OP_FACTORY_ENTRY("cuCfmaf", BinaryOperatorKind::BO_Add,
makeCallArgCreatorWithCall(1)),
makeCallArgCreatorWithCall(2))

MEMBER_CALL_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0), false, "convert<float>")
MEMBER_CALL_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0), false, "convert<double>")
MEMBER_CALL_FACTORY_ENTRY("cuComplexDoubleToFloat", ARG(0), false,
"template convert<float>")
MEMBER_CALL_FACTORY_ENTRY("cuComplexFloatToDouble", ARG(0), false,
"template convert<double>")
2 changes: 1 addition & 1 deletion clang/lib/DPCT/RulesLang/CallExprRewriterTexture.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class TextureReadRewriterFactory : public CallExprRewriterFactoryBase {
template <class BaseT>
std::shared_ptr<CallExprRewriter>
createRewriter(const CallExpr *C, bool RetAssign, BaseT Base) const {
using ReaderPrinter = decltype(makeMemberCallCreator<false>(
using ReaderPrinter = decltype(makeMemberCallCreator(
std::declval<std::function<BaseT(const CallExpr *)>>(), false,
TargetName, makeCallArgCreatorWithCall(Idx)...)(C));
if (RetAssign) {
Expand Down
Loading
Loading