Skip to content

Commit 81b6b4a

Browse files
committed
Add support for visiting op classes
Op classes are sets of ops, providing a common base class in the op class hierarchy, which can be used to group sets of ops with a common property. Before this patch, op class only support `isa<>` checks, but not visiting. This was because visiting relies on `OpDescription::get<OpT>`, which is only supported for concrete ops, but not op classes. This patch changes that by introducing a new `OpDescription::getAll<OpT>` template which returns an `ArrayRef` of OpDescriptions. The existing `get<OpT>` template is implemented in terms of the new `getAll` template. Visitor code is changed to use `getAll` instead, and handle non-trivial array refs (representing op classes) accordingly, using the already existing mechanism for op sets.
1 parent b836521 commit 81b6b4a

6 files changed

Lines changed: 400 additions & 69 deletions

File tree

include/llvm-dialects/Dialect/OpDescription.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,14 @@ class OpDescription {
6565
return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads;
6666
}
6767

68+
// Only supported for concrete ops, not for op classes.
6869
template <typename OpT> static const OpDescription &get();
6970

71+
// For concrete ops, returns a 1-element array containing the result of get().
72+
// For op classes, returns an array containing the descriptions of all
73+
// concrete ops that belong to this op class.
74+
template <typename OpT> static llvm::ArrayRef<OpDescription> getAll();
75+
7076
Kind getKind() const { return m_kind; }
7177

7278
unsigned getOpcode() const;

include/llvm-dialects/Dialect/Visitor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,18 @@ class VisitorKey {
120120
friend class VisitorTemplate;
121121

122122
public:
123+
// OpT may be a concrete dialect op, or an op class.
123124
template <typename OpT> static VisitorKey op() {
124-
VisitorKey key{Kind::OpDescription};
125-
key.m_description = &OpDescription::get<OpT>();
125+
auto const descriptions = OpDescription::getAll<OpT>();
126+
if (descriptions.size() == 1) {
127+
VisitorKey key{Kind::OpDescription};
128+
key.m_description = &descriptions[0];
129+
return key;
130+
}
131+
// OpT is an op class. Resolve it by all concrete sub ops.
132+
static const OpSet set = OpSet::fromOpDescriptions(descriptions);
133+
VisitorKey key{Kind::OpSet};
134+
key.m_set = &set;
126135
return key;
127136
}
128137

lib/Dialect/OpDescription.cpp

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ bool OpDescription::matchIntrinsic(unsigned intrinsicId) const {
9797
// ============================================================================
9898
// Descriptions of core instructions.
9999

100-
template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
100+
template <> ArrayRef<OpDescription> OpDescription::getAll<UnaryInstruction>() {
101101
static unsigned opcodes[] = {
102102
Instruction::Alloca,
103103
Instruction::Load,
@@ -109,39 +109,59 @@ template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
109109
#include "llvm/IR/Instruction.def"
110110
};
111111
static const OpDescription desc{Kind::Core, opcodes};
112-
return desc;
112+
return {&desc, 1};
113113
}
114114

115-
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
115+
template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
116+
return getAll<UnaryInstruction>()[0];
117+
}
118+
119+
template <> ArrayRef<OpDescription> OpDescription::getAll<BinaryOperator>() {
116120
static unsigned opcodes[] = {
117121
#define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode,
118122
#include "llvm/IR/Instruction.def"
119123
};
120124
static const OpDescription desc{Kind::Core, opcodes};
121-
return desc;
125+
return {&desc, 1};
126+
}
127+
128+
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
129+
return getAll<BinaryOperator>()[0];
122130
}
123131

124132
// Generate OpDescription for all dedicate instruction classes.
125133
#define HANDLE_USER_INST(...)
126134
#define HANDLE_UNARY_INST(...)
127135
#define HANDLE_BINARY_INST(...)
128136
#define HANDLE_INST(num, opcode, Class) \
129-
template <> const OpDescription &OpDescription::get<Class>() { \
130-
static const OpDescription desc{Kind::Core, Instruction::opcode}; \
137+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
138+
static const std::array<OpDescription, 1> desc{ \
139+
OpDescription{Kind::Core, Instruction::opcode}}; \
131140
return desc; \
141+
} \
142+
template <> const OpDescription &OpDescription::get<Class>() { \
143+
return getAll<Class>()[0]; \
132144
}
133145
#include "llvm/IR/Instruction.def"
134146

135147
#define HANDLE_INTRINSIC_DESC(Class, opcode) \
136-
template <> const OpDescription &OpDescription::get<Class>() { \
137-
static const OpDescription desc{Kind::Intrinsic, Intrinsic::opcode}; \
148+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
149+
static const std::array<OpDescription, 1> desc{ \
150+
OpDescription{Kind::Intrinsic, Intrinsic::opcode}}; \
138151
return desc; \
152+
} \
153+
template <> const OpDescription &OpDescription::get<Class>() { \
154+
return getAll<Class>()[0]; \
139155
}
140156
#define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \
141-
template <> const OpDescription &OpDescription::get<Class>() { \
157+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
142158
static unsigned opcodes[] = {__VA_ARGS__}; \
143-
static const OpDescription desc{Kind::Intrinsic, opcodes}; \
159+
static const std::array<OpDescription, 1> desc{ \
160+
OpDescription{Kind::Intrinsic, opcodes}}; \
144161
return desc; \
162+
} \
163+
template <> const OpDescription &OpDescription::get<Class>() { \
164+
return getAll<Class>()[0]; \
145165
}
146166

147167
// ============================================================================

lib/TableGen/GenDialect.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
271271

272272
out << R"(
273273
#include "llvm/Support/raw_ostream.h"
274+
#include <array>
274275
#endif // GET_INCLUDES
275276
276277
#ifdef GET_DIALECT_DEFS
@@ -448,26 +449,70 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
448449
if (!dialect->cppNamespace.empty())
449450
out << tgfmt("} // namespace $namespace\n", &fmt);
450451

451-
// Define specializations of OpDescription::get for reflection
452+
// Define specializations of OpDescription::{get, getAll} for reflection
452453
for (const auto &opPtr : dialect->operations) {
453454
Operation &op = *opPtr;
454455

455456
FmtContextScope scope{fmt};
456457
fmt.withOp(op.name);
457458
fmt.addSubst("mnemonic", op.mnemonic);
458459

460+
// We'd prefer fully qualifying the llvm_dialects namespace below with
461+
// leading "::", but this is parsed as oart of the preceding ArrayRef type
462+
// as there are just spaces in between. (gcc/clang/MSVC reject this)
463+
// The get() variant does not have this problem due to the `&` token.
459464
out << tgfmt(R"(
465+
template <>
466+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
467+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
468+
static const std::array<::llvm_dialects::OpDescription, 1> desc{
469+
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"}
470+
};
471+
return desc;
472+
}
473+
460474
template <>
461475
const ::llvm_dialects::OpDescription &
462476
::llvm_dialects::OpDescription::get<$namespace::$_op>() {
463-
static const ::llvm_dialects::OpDescription desc{$0, "$dialect.$mnemonic"};
464-
return desc;
477+
return getAll<$namespace::$_op>()[0];
465478
}
466479
467480
)",
468481
&fmt, op.haveResultOverloads() ? "true" : "false");
469482
}
470483

484+
// Define specializations of OpDescription::getAll for op classes
485+
for (const auto &opClassPtr : dialect->opClasses) {
486+
OpClass &opClass = *opClassPtr;
487+
488+
FmtContextScope scope{fmt};
489+
fmt.withOp(opClass.name);
490+
491+
// We'd prefer fully qualifying the llvm_dialects namespace below with
492+
// leading "::", but gcc/clang/MSVC reject this as they interpret the
493+
// ::llvm_dialects identifier than within the preceding ArrayRef type. The
494+
// get() variant does not have this problem as the `&` token separates the
495+
// two.
496+
out << tgfmt(R"(
497+
template <>
498+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
499+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
500+
static const std::array<::llvm_dialects::OpDescription, $0> desc{)",
501+
&fmt, opClass.operations.size());
502+
for (const auto &op : opClass.operations) {
503+
fmt.addSubst("mnemonic", op->mnemonic);
504+
out << tgfmt(R"(
505+
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"},)",
506+
&fmt, op->haveResultOverloads() ? "true" : "false");
507+
}
508+
out << tgfmt(R"(
509+
};
510+
return desc;
511+
}
512+
)",
513+
&fmt, opClass.operations.size());
514+
}
515+
471516
out << R"(
472517
#endif // GET_DIALECT_DEFS
473518
)";

0 commit comments

Comments
 (0)