Skip to content

Commit 567d890

Browse files
j2kuncopybara-github
authored andcommitted
use form traits in PolyMulToNTT where possible
PiperOrigin-RevId: 895546447
1 parent cb5bda0 commit 567d890

2 files changed

Lines changed: 20 additions & 10 deletions

File tree

lib/Dialect/Polynomial/IR/PolynomialOps.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,7 @@ def Polynomial_YieldOp : Polynomial_Op<"yield", [Terminator, HasParent<"ApplyCoe
437437
}
438438

439439
def Polynomial_ApplyCoefficientwiseOp : Polynomial_Op<"apply_coefficientwise", [
440-
Pure, SingleBlock]> {
440+
Pure, SingleBlock, FixedFormCoeff]> {
441441
let summary = "Apply a region to each coefficient of a polynomial.";
442442
let description = [{
443443
`polynomial.apply_coefficientwise` takes a polynomial and applies a series

lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
44
#include "lib/Dialect/Polynomial/IR/PolynomialOps.h"
5+
#include "lib/Dialect/Polynomial/IR/PolynomialTraits.h"
56
#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h"
67
#include "lib/Dialect/Polynomial/Transforms/NTTSolver.h"
78
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
@@ -57,21 +58,30 @@ enum class OpFormClass {
5758
};
5859

5960
OpFormClass opFormClass(Operation* op) {
61+
// This special case must come first because it overrides the more general
62+
// SameOperandsAndResultForm trait for MulOp.
63+
if (op->hasTrait<FixedFormEval>() || isa<MulOp>(op)) {
64+
return OpFormClass::EVAL;
65+
}
66+
6067
if (isa<func::ReturnOp>(op)) {
6168
return OpFormClass::RETURN;
62-
} else if (isa<ToTensorOp, LeadingTermOp, EvalOp, ConvertBasisOp,
63-
MonicMonomialMulOp, FromTensorOp, ApplyCoefficientwiseOp>(
64-
op)) {
69+
}
70+
71+
if (op->hasTrait<FixedFormCoeff>()) {
6572
return OpFormClass::COEFF;
66-
} else if (isa<MulOp>(op)) {
67-
return OpFormClass::EVAL;
68-
} else if (isa<AddOp, SubOp, MulScalarOp, ModSwitchOp, ExtractSliceOp,
69-
tensor::ExtractSliceOp, tensor::ExtractOp,
70-
tensor::FromElementsOp>(op)) {
73+
}
74+
75+
if (op->hasTrait<SameOperandsAndResultForm>() ||
76+
isa<ToTensorOp, FromTensorOp, tensor::ExtractSliceOp, tensor::ExtractOp,
77+
tensor::FromElementsOp>(op)) {
7178
return OpFormClass::EITHER;
72-
} else if (isa<MonomialOp, ConstantOp>(op)) {
79+
}
80+
81+
if (isa<MonomialOp, ConstantOp>(op)) {
7382
return OpFormClass::CONST;
7483
}
84+
7585
return OpFormClass::UNKNOWN;
7686
}
7787

0 commit comments

Comments
 (0)