Skip to content

Commit fd3fc26

Browse files
committed
Add support for visiting op classes
Op classes are sets of ops, providing a common base class in the op class hierarchy, which can be used to group sets of ops with a common property. Before this patch, op class only support `isa<>` checks, but not visiting. This was because visiting relies on `OpDescription::get<OpT>`, which is only supported for concrete ops, but not op classes. This patch changes that by introducing a new `OpDescription::getAll<OpT>` template which returns an `ArrayRef` of OpDescriptions. The existing `get<OpT>` template is implemented in terms of the new `getAll` template. Visitor code is changed to use `getAll` instead, and handle non-trivial array refs (representing op classes) accordingly, using the already existing mechanism for op sets.
1 parent b836521 commit fd3fc26

6 files changed

Lines changed: 340 additions & 68 deletions

File tree

include/llvm-dialects/Dialect/OpDescription.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,14 @@ class OpDescription {
6565
return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads;
6666
}
6767

68+
// Only supported for concrete ops, not for op classes.
6869
template <typename OpT> static const OpDescription &get();
6970

71+
// For concrete ops, returns a 1-element array containing the result of get().
72+
// For op classes, returns an array containing the descriptions of all
73+
// concrete ops that belong to this op class.
74+
template <typename OpT> static llvm::ArrayRef<OpDescription> getAll();
75+
7076
Kind getKind() const { return m_kind; }
7177

7278
unsigned getOpcode() const;

include/llvm-dialects/Dialect/Visitor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,18 @@ class VisitorKey {
120120
friend class VisitorTemplate;
121121

122122
public:
123+
// OpT may be a concrete dialect op, or an op class.
123124
template <typename OpT> static VisitorKey op() {
124-
VisitorKey key{Kind::OpDescription};
125-
key.m_description = &OpDescription::get<OpT>();
125+
auto const descriptions = OpDescription::getAll<OpT>();
126+
if (descriptions.size() == 1) {
127+
VisitorKey key{Kind::OpDescription};
128+
key.m_description = &descriptions[0];
129+
return key;
130+
}
131+
// OpT is an op class. Resolve it by all concrete sub ops.
132+
static const OpSet set = OpSet::fromOpDescriptions(descriptions);
133+
VisitorKey key{Kind::OpSet};
134+
key.m_set = &set;
126135
return key;
127136
}
128137

lib/Dialect/OpDescription.cpp

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ bool OpDescription::matchIntrinsic(unsigned intrinsicId) const {
9797
// ============================================================================
9898
// Descriptions of core instructions.
9999

100-
template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
100+
template <> ArrayRef<OpDescription> OpDescription::getAll<UnaryInstruction>() {
101101
static unsigned opcodes[] = {
102102
Instruction::Alloca,
103103
Instruction::Load,
@@ -112,7 +112,11 @@ template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
112112
return desc;
113113
}
114114

115-
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
115+
template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
116+
return getAll<UnaryInstruction>()[0];
117+
}
118+
119+
template <> ArrayRef<OpDescription> OpDescription::getAll<BinaryOperator>() {
116120
static unsigned opcodes[] = {
117121
#define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode,
118122
#include "llvm/IR/Instruction.def"
@@ -121,27 +125,40 @@ template <> const OpDescription &OpDescription::get<BinaryOperator>() {
121125
return desc;
122126
}
123127

128+
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
129+
return getAll<BinaryOperator>()[0];
130+
}
131+
124132
// Generate OpDescription for all dedicate instruction classes.
125133
#define HANDLE_USER_INST(...)
126134
#define HANDLE_UNARY_INST(...)
127135
#define HANDLE_BINARY_INST(...)
128136
#define HANDLE_INST(num, opcode, Class) \
129-
template <> const OpDescription &OpDescription::get<Class>() { \
137+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
130138
static const OpDescription desc{Kind::Core, Instruction::opcode}; \
131139
return desc; \
140+
} \
141+
template <> const OpDescription &OpDescription::get<Class>() { \
142+
return getAll<Class>()[0]; \
132143
}
133144
#include "llvm/IR/Instruction.def"
134145

135146
#define HANDLE_INTRINSIC_DESC(Class, opcode) \
136-
template <> const OpDescription &OpDescription::get<Class>() { \
147+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
137148
static const OpDescription desc{Kind::Intrinsic, Intrinsic::opcode}; \
138149
return desc; \
150+
} \
151+
template <> const OpDescription &OpDescription::get<Class>() { \
152+
return getAll<Class>()[0]; \
139153
}
140154
#define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \
141-
template <> const OpDescription &OpDescription::get<Class>() { \
155+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
142156
static unsigned opcodes[] = {__VA_ARGS__}; \
143157
static const OpDescription desc{Kind::Intrinsic, opcodes}; \
144158
return desc; \
159+
} \
160+
template <> const OpDescription &OpDescription::get<Class>() { \
161+
return getAll<Class>()[0]; \
145162
}
146163

147164
// ============================================================================

lib/TableGen/GenDialect.cpp

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
271271

272272
out << R"(
273273
#include "llvm/Support/raw_ostream.h"
274+
#include <array>
274275
#endif // GET_INCLUDES
275276
276277
#ifdef GET_DIALECT_DEFS
@@ -326,10 +327,10 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
326327
}
327328
328329
bool $Dialect::isDialectOp(::llvm::CallInst& op) {
329-
::llvm::Function *calledFunc = op.getCalledFunction();
330+
::llvm::Function *calledFunc = op.getCalledFunction();
330331
if (!calledFunc)
331332
return false;
332-
333+
333334
return isDialectOp(calledFunc->getName());
334335
}
335336
@@ -448,26 +449,68 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
448449
if (!dialect->cppNamespace.empty())
449450
out << tgfmt("} // namespace $namespace\n", &fmt);
450451

451-
// Define specializations of OpDescription::get for reflection
452+
// Define specializations of OpDescription::{get, getAll} for reflection
452453
for (const auto &opPtr : dialect->operations) {
453454
Operation &op = *opPtr;
454455

455456
FmtContextScope scope{fmt};
456457
fmt.withOp(op.name);
457458
fmt.addSubst("mnemonic", op.mnemonic);
458459

460+
// We'd prefer fully qualifying the llvm_dialects namespace below with
461+
// leading "::", but this is parsed as oart of the preceding ArrayRef type
462+
// as there are just spaces in between. (gcc/clang/MSVC reject this)
463+
// The get() variant does not have this problem due to the `&` token.
459464
out << tgfmt(R"(
465+
template <>
466+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
467+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
468+
static const OpDescription desc{$0, "$dialect.$mnemonic"};
469+
return desc;
470+
}
471+
460472
template <>
461473
const ::llvm_dialects::OpDescription &
462474
::llvm_dialects::OpDescription::get<$namespace::$_op>() {
463-
static const ::llvm_dialects::OpDescription desc{$0, "$dialect.$mnemonic"};
464-
return desc;
475+
return getAll<$namespace::$_op>()[0];
465476
}
466477
467478
)",
468479
&fmt, op.haveResultOverloads() ? "true" : "false");
469480
}
470481

482+
// Define specializations of OpDescription::getAll for op classes
483+
for (const auto &opClassPtr : dialect->opClasses) {
484+
OpClass &opClass = *opClassPtr;
485+
486+
FmtContextScope scope{fmt};
487+
fmt.withOp(opClass.name);
488+
489+
// We'd prefer fully qualifying the llvm_dialects namespace below with
490+
// leading "::", but gcc/clang/MSVC reject this as they interpret the
491+
// ::llvm_dialects identifier than within the preceding ArrayRef type. The
492+
// get() variant does not have this problem as the `&` token separates the
493+
// two.
494+
out << tgfmt(R"(
495+
template <>
496+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
497+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
498+
static const std::array<::llvm_dialects::OpDescription, $0> desc{)",
499+
&fmt, opClass.operations.size());
500+
for (const auto &op : opClass.operations) {
501+
fmt.addSubst("mnemonic", op->mnemonic);
502+
out << tgfmt(R"(
503+
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"},)",
504+
&fmt, op->haveResultOverloads() ? "true" : "false");
505+
}
506+
out << tgfmt(R"(
507+
};
508+
return desc;
509+
}
510+
)",
511+
&fmt, opClass.operations.size());
512+
}
513+
471514
out << R"(
472515
#endif // GET_DIALECT_DEFS
473516
)";

0 commit comments

Comments
 (0)