Skip to content

Commit d50e420

Browse files
committed
Add support for visiting op classes
Op classes are sets of ops, providing a common base class in the op class hierarchy, which can be used to group sets of ops with a common property. Before this patch, op class only support `isa<>` checks, but not visiting. This was because visiting relies on `OpDescription::get<OpT>`, which is only supported for concrete ops, but not op classes. This patch changes that by introducing a new `OpDescription::getAll<OpT>` template which returns an `ArrayRef` of OpDescriptions. The existing `get<OpT>` template is implemented in terms of the new `getAll` template. Visitor code is changed to use `getAll` instead, and handle non-trivial array refs (representing op classes) accordingly, using the already existing mechanism for op sets.
1 parent 3fef4f3 commit d50e420

6 files changed

Lines changed: 402 additions & 69 deletions

File tree

include/llvm-dialects/Dialect/OpDescription.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,14 @@ class OpDescription {
6565
return m_kind == Kind::Dialect || m_kind == Kind::DialectWithOverloads;
6666
}
6767

68+
// Only supported for concrete, non-op-class ops.
6869
template <typename OpT> static const OpDescription &get();
6970

71+
// For concrete, non-op-class ops, returns an array of length one that
72+
// contains the result of get(). For op classes, returns an array containing
73+
// the descriptions of all concrete ops that are of this opclass type.
74+
template <typename OpT> static llvm::ArrayRef<OpDescription> getAll();
75+
7076
Kind getKind() const { return m_kind; }
7177

7278
unsigned getOpcode() const;

include/llvm-dialects/Dialect/Visitor.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,9 +120,18 @@ class VisitorKey {
120120
friend class VisitorTemplate;
121121

122122
public:
123+
// OpT may be a concrete dialect op, or an op class.
123124
template <typename OpT> static VisitorKey op() {
124-
VisitorKey key{Kind::OpDescription};
125-
key.m_description = &OpDescription::get<OpT>();
125+
auto const descriptions = OpDescription::getAll<OpT>();
126+
if (descriptions.size() == 1) {
127+
VisitorKey key{Kind::OpDescription};
128+
key.m_description = &descriptions[0];
129+
return key;
130+
}
131+
// OpT is an op class. Resolve it by all concrete sub ops.
132+
static const OpSet set = OpSet::fromOpDescriptions(descriptions);
133+
VisitorKey key{Kind::OpSet};
134+
key.m_set = &set;
126135
return key;
127136
}
128137

lib/Dialect/OpDescription.cpp

Lines changed: 32 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ bool OpDescription::matchIntrinsic(unsigned intrinsicId) const {
9797
// ============================================================================
9898
// Descriptions of core instructions.
9999

100-
template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
100+
template <> ArrayRef<OpDescription> OpDescription::getAll<UnaryInstruction>() {
101101
static unsigned opcodes[] = {
102102
Instruction::Alloca,
103103
Instruction::Load,
@@ -108,40 +108,62 @@ template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
108108
#define HANDLE_CAST_INST(num, opcode, Class) Instruction::opcode,
109109
#include "llvm/IR/Instruction.def"
110110
};
111-
static const OpDescription desc{Kind::Core, opcodes};
111+
static const std::array<OpDescription, 1> desc{
112+
OpDescription{Kind::Core, opcodes}};
112113
return desc;
113114
}
114115

115-
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
116+
template <> const OpDescription &OpDescription::get<UnaryInstruction>() {
117+
return getAll<UnaryInstruction>()[0];
118+
}
119+
120+
template <> ArrayRef<OpDescription> OpDescription::getAll<BinaryOperator>() {
116121
static unsigned opcodes[] = {
117122
#define HANDLE_BINARY_INST(num, opcode, Class) Instruction::opcode,
118123
#include "llvm/IR/Instruction.def"
119124
};
120-
static const OpDescription desc{Kind::Core, opcodes};
125+
static const std::array<OpDescription, 1> desc{
126+
OpDescription{Kind::Core, opcodes}};
121127
return desc;
122128
}
123129

130+
template <> const OpDescription &OpDescription::get<BinaryOperator>() {
131+
return getAll<BinaryOperator>()[0];
132+
}
133+
124134
// Generate OpDescription for all dedicate instruction classes.
125135
#define HANDLE_USER_INST(...)
126136
#define HANDLE_UNARY_INST(...)
127137
#define HANDLE_BINARY_INST(...)
128138
#define HANDLE_INST(num, opcode, Class) \
129-
template <> const OpDescription &OpDescription::get<Class>() { \
130-
static const OpDescription desc{Kind::Core, Instruction::opcode}; \
139+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
140+
static const std::array<OpDescription, 1> desc{ \
141+
OpDescription{Kind::Core, Instruction::opcode}}; \
131142
return desc; \
143+
} \
144+
template <> const OpDescription &OpDescription::get<Class>() { \
145+
return getAll<Class>()[0]; \
132146
}
133147
#include "llvm/IR/Instruction.def"
134148

135149
#define HANDLE_INTRINSIC_DESC(Class, opcode) \
136-
template <> const OpDescription &OpDescription::get<Class>() { \
137-
static const OpDescription desc{Kind::Intrinsic, Intrinsic::opcode}; \
150+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
151+
static const std::array<OpDescription, 1> desc{ \
152+
OpDescription{Kind::Intrinsic, Intrinsic::opcode}}; \
138153
return desc; \
154+
} \
155+
template <> const OpDescription &OpDescription::get<Class>() { \
156+
return getAll<Class>()[0]; \
139157
}
140158
#define HANDLE_INTRINSIC_DESC_OPCODE_SET(Class, ...) \
141-
template <> const OpDescription &OpDescription::get<Class>() { \
159+
template <> ArrayRef<OpDescription> OpDescription::getAll<Class>() { \
142160
static unsigned opcodes[] = {__VA_ARGS__}; \
143-
static const OpDescription desc{Kind::Intrinsic, opcodes}; \
161+
static const std::array<OpDescription, 1> desc{ \
162+
OpDescription{Kind::Intrinsic, opcodes}}; \
144163
return desc; \
164+
} \
165+
template <> const OpDescription &OpDescription::get<Class>() { \
166+
return getAll<Class>()[0]; \
145167
}
146168

147169
// ============================================================================

lib/TableGen/GenDialect.cpp

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,6 +271,7 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
271271

272272
out << R"(
273273
#include "llvm/Support/raw_ostream.h"
274+
#include <array>
274275
#endif // GET_INCLUDES
275276
276277
#ifdef GET_DIALECT_DEFS
@@ -448,26 +449,70 @@ void llvm_dialects::genDialectDefs(raw_ostream &out, RecordKeeperTy &records) {
448449
if (!dialect->cppNamespace.empty())
449450
out << tgfmt("} // namespace $namespace\n", &fmt);
450451

451-
// Define specializations of OpDescription::get for reflection
452+
// Define specializations of OpDescription::{get, getAll} for reflection
452453
for (const auto &opPtr : dialect->operations) {
453454
Operation &op = *opPtr;
454455

455456
FmtContextScope scope{fmt};
456457
fmt.withOp(op.name);
457458
fmt.addSubst("mnemonic", op.mnemonic);
458459

460+
// We'd prefer fully qualifying the llvm_dialects namespace below with
461+
// leading "::", but this is parsed as oart of the preceding ArrayRef type
462+
// as there are just spaces in between. (gcc/clang/MSVC reject this)
463+
// The get() variant does not have this problem due to the `&` token.
459464
out << tgfmt(R"(
465+
template <>
466+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
467+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
468+
static const std::array<::llvm_dialects::OpDescription, 1> desc{
469+
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"}
470+
};
471+
return desc;
472+
}
473+
460474
template <>
461475
const ::llvm_dialects::OpDescription &
462476
::llvm_dialects::OpDescription::get<$namespace::$_op>() {
463-
static const ::llvm_dialects::OpDescription desc{$0, "$dialect.$mnemonic"};
464-
return desc;
477+
return getAll<$namespace::$_op>()[0];
465478
}
466479
467480
)",
468481
&fmt, op.haveResultOverloads() ? "true" : "false");
469482
}
470483

484+
// Define specializations of OpDescription::getAll for op classes
485+
for (const auto &opClassPtr : dialect->opClasses) {
486+
OpClass &opClass = *opClassPtr;
487+
488+
FmtContextScope scope{fmt};
489+
fmt.withOp(opClass.name);
490+
491+
// We'd prefer fully qualifying the llvm_dialects namespace below with
492+
// leading "::", but gcc/clang/MSVC reject this as they interpret the
493+
// ::llvm_dialects identifier than within the preceding ArrayRef type. The
494+
// get() variant does not have this problem as the `&` token separates the
495+
// two.
496+
out << tgfmt(R"(
497+
template <>
498+
::llvm::ArrayRef<::llvm_dialects::OpDescription>
499+
llvm_dialects::OpDescription::getAll<$namespace::$_op>() {
500+
static const std::array<::llvm_dialects::OpDescription, $0> desc{)",
501+
&fmt, opClass.operations.size());
502+
for (const auto &op : opClass.operations) {
503+
fmt.addSubst("mnemonic", op->mnemonic);
504+
out << tgfmt(R"(
505+
::llvm_dialects::OpDescription{$0, "$dialect.$mnemonic"},)",
506+
&fmt, op->haveResultOverloads() ? "true" : "false");
507+
}
508+
out << tgfmt(R"(
509+
};
510+
return desc;
511+
}
512+
)",
513+
&fmt, opClass.operations.size());
514+
}
515+
471516
out << R"(
472517
#endif // GET_DIALECT_DEFS
473518
)";

0 commit comments

Comments
 (0)