|
2 | 2 |
|
3 | 3 | #include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" |
4 | 4 | #include "lib/Dialect/Polynomial/IR/PolynomialOps.h" |
| 5 | +#include "lib/Dialect/Polynomial/IR/PolynomialTraits.h" |
5 | 6 | #include "lib/Dialect/Polynomial/IR/PolynomialTypes.h" |
6 | 7 | #include "lib/Dialect/Polynomial/Transforms/NTTSolver.h" |
7 | 8 | #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project |
@@ -57,21 +58,30 @@ enum class OpFormClass { |
57 | 58 | }; |
58 | 59 |
|
59 | 60 | OpFormClass opFormClass(Operation* op) { |
| 61 | + // This special case must come first because it overrides the more general |
| 62 | + // SameOperandsAndResultForm trait for MulOp. |
| 63 | + if (op->hasTrait<FixedFormEval>() || isa<MulOp>(op)) { |
| 64 | + return OpFormClass::EVAL; |
| 65 | + } |
| 66 | + |
60 | 67 | if (isa<func::ReturnOp>(op)) { |
61 | 68 | return OpFormClass::RETURN; |
62 | | - } else if (isa<ToTensorOp, LeadingTermOp, EvalOp, ConvertBasisOp, |
63 | | - MonicMonomialMulOp, FromTensorOp, ApplyCoefficientwiseOp>( |
64 | | - op)) { |
| 69 | + } |
| 70 | + |
| 71 | + if (op->hasTrait<FixedFormCoeff>()) { |
65 | 72 | return OpFormClass::COEFF; |
66 | | - } else if (isa<MulOp>(op)) { |
67 | | - return OpFormClass::EVAL; |
68 | | - } else if (isa<AddOp, SubOp, MulScalarOp, ModSwitchOp, ExtractSliceOp, |
69 | | - tensor::ExtractSliceOp, tensor::ExtractOp, |
70 | | - tensor::FromElementsOp>(op)) { |
| 73 | + } |
| 74 | + |
| 75 | + if (op->hasTrait<SameOperandsAndResultForm>() || |
| 76 | + isa<ToTensorOp, FromTensorOp, tensor::ExtractSliceOp, tensor::ExtractOp, |
| 77 | + tensor::FromElementsOp>(op)) { |
71 | 78 | return OpFormClass::EITHER; |
72 | | - } else if (isa<MonomialOp, ConstantOp>(op)) { |
| 79 | + } |
| 80 | + |
| 81 | + if (isa<MonomialOp, ConstantOp>(op)) { |
73 | 82 | return OpFormClass::CONST; |
74 | 83 | } |
| 84 | + |
75 | 85 | return OpFormClass::UNKNOWN; |
76 | 86 | } |
77 | 87 |
|
|
0 commit comments