-
Notifications
You must be signed in to change notification settings - Fork 137
Expand file tree
/
Copy pathMulToAddPdll.cpp
More file actions
49 lines (41 loc) · 1.7 KB
/
Copy pathMulToAddPdll.cpp
File metadata and controls
49 lines (41 loc) · 1.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#include "lib/Transform/Arith/MulToAddPdll.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/PDLInterp/IR/PDLInterp.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "mlir/include/mlir/Pass/Pass.h"
namespace mlir {
namespace tutorial {
#define GEN_PASS_DEF_MULTOADDPDLL
#include "lib/Transform/Arith/Passes.h.inc"
LogicalResult halveImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
IntegerAttr cAttr = cast<IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value / 2));
return success();
}
LogicalResult minusOneImpl(PatternRewriter &rewriter, PDLResultList &results,
ArrayRef<PDLValue> args) {
Attribute attr = args[0].cast<Attribute>();
IntegerAttr cAttr = cast<IntegerAttr>(attr);
int64_t value = cAttr.getValue().getSExtValue();
results.push_back(rewriter.getIntegerAttr(cAttr.getType(), value - 1));
return success();
}
void registerNativeConstraints(RewritePatternSet &patterns) {
patterns.getPDLPatterns().registerConstraintFunction("Halve", halveImpl);
patterns.getPDLPatterns().registerConstraintFunction("MinusOne", minusOneImpl);
}
struct MulToAddPdll : impl::MulToAddPdllBase<MulToAddPdll> {
using MulToAddPdllBase::MulToAddPdllBase;
void runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
populateGeneratedPDLLPatterns(patterns);
registerNativeConstraints(patterns);
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
}
};
} // namespace tutorial
} // namespace mlir