Skip to content

Commit 9771937

Browse files
jasilvanusFlakebi
authored andcommitted
Add support for visiting op classes
Op classes are sets of ops, providing a common base class in the op class hierarchy, which can be used to group sets of ops with a common property. Before this patch, op class only support `isa<>` checks, but not visiting. This was because visiting relies on `OpDescription::get<OpT>`, which is only supported for concrete ops, but not op classes. This patch changes that by introducing a new `OpDescription::getAll<OpT>` template which returns an `ArrayRef` of OpDescriptions. The existing `get<OpT>` template is implemented in terms of the new `getAll` template. Visitor code is changed to use `getAll` instead, and handle non-trivial array refs (representing op classes) accordingly, using the already existing mechanism for op sets.
1 parent 7724912 commit 9771937

6 files changed

Lines changed: 279 additions & 7 deletions

File tree

include/llvm-dialects/Dialect/OpDescription.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,14 @@ class OpDescription {
6565
return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads;
6666
}
6767

68+
// Only supported for concrete ops, not for op classes.
6869
template <typename OpT> static const OpDescription &get();
6970

71+
// For concrete ops, returns a 1-element array containing the result of get().
72+
// For op classes, returns an array containing the descriptions of all
73+
// concrete ops that belong to this op class.
74+
template <typename OpT> static llvm::ArrayRef<OpDescription> getAll();
75+
7076
Kind getKind() const { return m_kind; }
7177

7278
unsigned getOpcode() const;

include/llvm-dialects/Dialect/Visitor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,18 @@ class VisitorKey {
120120
friend class VisitorTemplate;
121121

122122
public:
123+
// OpT may be a concrete dialect op, or an op class.
123124
template <typename OpT> static VisitorKey op() {
124-
VisitorKey key{Kind::OpDescription};
125-
key.m_description = &OpDescription::get<OpT>();
125+
auto const descriptions = OpDescription::getAll<OpT>();
126+
if (descriptions.size() == 1) {
127+
VisitorKey key{Kind::OpDescription};
128+
key.m_description = &descriptions[0];
129+
return key;
130+
}
131+
// OpT is an op class. Resolve it by all concrete sub ops.
132+
static const OpSet set = OpSet::fromOpDescriptions(descriptions);
133+
VisitorKey key{Kind::OpSet};
134+
key.m_set = &set;
126135
return key;
127136
}
128137

lib/Dialect/OpDescription.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
112112
return desc;
113113
}
114114

115+
template <> ArrayRef<OpDescription> OpDescription::getAll<UnaryInstruction>() {
116+
return get<UnaryInstruction>();
117+
}
118+
115119
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
116120
static unsigned opcodes[] = {
117121
#define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode,
@@ -121,6 +125,10 @@ template <> const OpDescription &OpDescription::get<BinaryOperator>() {
121125
return desc;
122126
}
123127

128+
template <> ArrayRef<OpDescription> OpDescription::getAll<BinaryOperator>() {
129+
return get<BinaryOperator>();
130+
}
131+
124132
// Generate OpDescription for all dedicate instruction classes.
125133
#define HANDLE_USER_INST(...)
126134
#define HANDLE_UNARY_INST(...)
@@ -129,19 +137,28 @@ template <> const OpDescription &OpDescription::get<BinaryOperator>() {
129137
template <> const OpDescription &OpDescription::get<Class>() { \
130138
static const OpDescription desc{Kind::Core, Instruction::opcode}; \
131139
return desc; \
140+
} \
141+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
142+
return get<Class>(); \
132143
}
133144
#include "llvm/IR/Instruction.def"
134145

135146
#define HANDLE_INTRINSIC_DESC(Class, opcode) \
136147
template <> const OpDescription &OpDescription::get<Class>() { \
137148
static const OpDescription desc{Kind::Intrinsic, Intrinsic::opcode}; \
138149
return desc; \
150+
} \
151+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
152+
return get<Class>(); \
139153
}
140154
#define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \
141155
template <> const OpDescription &OpDescription::get<Class>() { \
142156
static unsigned opcodes[] = {__VA_ARGS__}; \
143157
static const OpDescription desc{Kind::Intrinsic, opcodes}; \
144158
return desc; \
159+
} \
160+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
161+
return get<Class>(); \
145162
}
146163

147164
// ============================================================================

lib/TableGen/GenDialect.cpp

Lines changed: 46 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
271271

272272
out << R"(
273273
#include "llvm/Support/raw_ostream.h"
274+
#include <array>
274275
#endif // GET_INCLUDES
275276
276277
#ifdef GET_DIALECT_DEFS
@@ -326,10 +327,10 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
326327
}
327328
328329
bool $Dialect::isDialectOp(::llvm::CallInst& op) {
329-
::llvm::Function *calledFunc = op.getCalledFunction();
330+
::llvm::Function *calledFunc = op.getCalledFunction();
330331
if (!calledFunc)
331332
return false;
332-
333+
333334
return isDialectOp(calledFunc->getName());
334335
}
335336
@@ -448,14 +449,18 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
448449
if (!dialect->cppNamespace.empty())
449450
out << tgfmt("} // namespace $namespace\n", &fmt);
450451

451-
// Define specializations of OpDescription::get for reflection
452+
// Define specializations of OpDescription::{get, getAll} for reflection
452453
for (const auto &opPtr : dialect->operations) {
453454
Operation &op = *opPtr;
454455

455456
FmtContextScope scope{fmt};
456457
fmt.withOp(op.name);
457458
fmt.addSubst("mnemonic", op.mnemonic);
458459

460+
// We'd prefer fully qualifying the llvm_dialects namespace below in getAll
461+
// with leading "::", but this is parsed as oart of the preceding ArrayRef
462+
// type as there are just spaces in between. (gcc/clang/MSVC reject this)
463+
// The get() variant does not have this problem due to the `&` token.
459464
out << tgfmt(R"(
460465
template <>
461466
const ::llvm_dialects::OpDescription &
@@ -464,10 +469,48 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
464469
return desc;
465470
}
466471
472+
template <>
473+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
474+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
475+
return get<$namespace::$_op>();
476+
}
477+
467478
)",
468479
&fmt, op.haveResultOverloads() ? "true" : "false");
469480
}
470481

482+
// Define specializations of OpDescription::getAll for op classes
483+
for (const auto &opClassPtr : dialect->opClasses) {
484+
OpClass &opClass = *opClassPtr;
485+
486+
FmtContextScope scope{fmt};
487+
fmt.withOp(opClass.name);
488+
489+
// We'd prefer fully qualifying the llvm_dialects namespace below with
490+
// leading "::", but gcc/clang/MSVC reject this as they interpret the
491+
// ::llvm_dialects identifier than within the preceding ArrayRef type. The
492+
// get() variant does not have this problem as the `&` token separates the
493+
// two.
494+
out << tgfmt(R"(
495+
template <>
496+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
497+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
498+
static const std::array<::llvm_dialects::OpDescription, $0> desc{)",
499+
&fmt, opClass.operations.size());
500+
for (const auto &op : opClass.operations) {
501+
fmt.addSubst("mnemonic", op->mnemonic);
502+
out << tgfmt(R"(
503+
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"},)",
504+
&fmt, op->haveResultOverloads() ? "true" : "false");
505+
}
506+
out << tgfmt(R"(
507+
};
508+
return desc;
509+
}
510+
)",
511+
&fmt, opClass.operations.size());
512+
}
513+
471514
out << R"(
472515
#endif // GET_DIALECT_DEFS
473516
)";

0 commit comments

Comments
 (0)