@@ -50,37 +50,23 @@ static void debugLog(StringRef opName, ArrayRef<const LevelLattice*> operands,
5050 });
5151};
5252
53- LevelState transferForward (mgmt::ModReduceOp op,
54- ArrayRef<const LevelLattice*> operands) {
55- LevelState result = std::visit (
56- Overloaded{
57- [](MaxLevel) -> LevelState { return LevelState (Invalid{}); },
58- [](Uninit) -> LevelState { return LevelState (Invalid{}); },
59- [](Invalid) -> LevelState { return LevelState (Invalid{}); },
60- [](int val) -> LevelState { return LevelState (val + 1 ); },
61- },
62- operands[0 ]->getValue ().get ());
63- LLVM_DEBUG (debugLog (" mod_reduce" , operands, result));
64- return result;
65- }
66-
67- LevelState transferForward (mgmt::LevelReduceOp op,
53+ LevelState transferForward (ReducesLevelOpInterface op,
6854 ArrayRef<const LevelLattice*> operands) {
6955 LevelState result = std::visit (
7056 Overloaded{
7157 [](MaxLevel) -> LevelState { return LevelState (Invalid{}); },
7258 [](Uninit) -> LevelState { return LevelState (Invalid{}); },
7359 [](Invalid) -> LevelState { return LevelState (Invalid{}); },
7460 [&](int val) -> LevelState {
75- return LevelState (val + ( int ) op.getLevelToDrop ());
61+ return LevelState (val + op.getLevelsToDrop ());
7662 },
7763 },
7864 operands[0 ]->getValue ().get ());
79- LLVM_DEBUG (debugLog (" level_reduce " , operands, result));
65+ LLVM_DEBUG (debugLog (" ReduceLevelOpInterface " , operands, result));
8066 return result;
8167}
8268
83- LevelState transferForward (mgmt::LevelReduceMinOp op,
69+ LevelState transferForward (ReducesAllLevelsOpInterface op,
8470 ArrayRef<const LevelLattice*> operands) {
8571 LevelState result = std::visit (
8672 Overloaded{
@@ -92,11 +78,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op,
9278 [](int val) -> LevelState { return LevelState (MaxLevel{}); },
9379 },
9480 operands[0 ]->getValue ().get ());
95- LLVM_DEBUG (debugLog (" level_reduce_min " , operands, result));
81+ LLVM_DEBUG (debugLog (" ReduceAllLevelsOpInterface " , operands, result));
9682 return result;
9783}
9884
99- LevelState transferForward (mgmt::BootstrapOp op,
85+ LevelState transferForward (ResetsLevelOpInterface op,
10086 ArrayRef<const LevelLattice*> operands) {
10187 LevelState result = std::visit (
10288 Overloaded{
@@ -106,15 +92,18 @@ LevelState transferForward(mgmt::BootstrapOp op,
10692 [](int val) -> LevelState { return LevelState (0 ); },
10793 },
10894 operands[0 ]->getValue ().get ());
109- LLVM_DEBUG (debugLog (" bootstrap " , operands, result));
95+ LLVM_DEBUG (debugLog (" ResetsLevelOpInterface " , operands, result));
11096 return result;
11197}
11298
11399LevelState deriveResultLevel (Operation* op,
114100 ArrayRef<const LevelLattice*> operands) {
115101 return llvm::TypeSwitch<Operation&, LevelState>(*op)
116- .Case <mgmt::ModReduceOp, mgmt::LevelReduceOp, mgmt::BootstrapOp,
117- mgmt::LevelReduceMinOp>(
102+ .Case <ResetsLevelOpInterface>(
103+ [&](auto op) -> LevelState { return transferForward (op, operands); })
104+ .Case <ReducesAllLevelsOpInterface>(
105+ [&](auto op) -> LevelState { return transferForward (op, operands); })
106+ .Case <ReducesLevelOpInterface>(
118107 [&](auto op) -> LevelState { return transferForward (op, operands); })
119108 .Default ([&](auto & op) -> LevelState {
120109 LevelState result;
0 commit comments