Skip to content

Commit 9f8eb9b

Browse files
j2kuncopybara-github
authored andcommitted
use form traits in PolyMulToNTT where possible
PiperOrigin-RevId: 895546447
1 parent 56ab538 commit 9f8eb9b

2 files changed

Lines changed: 70 additions & 39 deletions

File tree

lib/Dialect/Polynomial/IR/PolynomialOps.td

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,7 +199,7 @@ def Polynomial_MonicMonomialMulOp: Polynomial_Op<"monic_monomial_mul", [AllTypes
199199
let results = (outs PolynomialLike:$output);
200200
}
201201

202-
def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
202+
def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure, SameOperandsAndResultForm]> {
203203
let summary = "Creates a polynomial from integer coefficients or evaluations stored in a tensor.";
204204
let description = [{
205205
`polynomial.from_tensor` creates a polynomial value from a tensor of
@@ -236,7 +236,7 @@ def Polynomial_FromTensorOp : Polynomial_Op<"from_tensor", [Pure]> {
236236
let hasVerifier = 1;
237237
}
238238

239-
def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure]> {
239+
def Polynomial_ToTensorOp : Polynomial_Op<"to_tensor", [Pure, SameOperandsAndResultForm]> {
240240
let summary = "Creates a tensor containing the coefficients or evaluations of a polynomial.";
241241
let description = [{
242242
`polynomial.to_tensor` creates a dense tensor value containing the
@@ -291,7 +291,6 @@ def Polynomial_ModSwitchOp : Polynomial_Op<"mod_switch", [Pure, SameOperandsAndR
291291
let hasVerifier = 1;
292292
}
293293

294-
295294
def Polynomial_AnyTypedPolynomialAttr : AnyAttrOf<[
296295
Polynomial_TypedFloatPolynomialAttr,
297296
Polynomial_TypedIntPolynomialAttr,
@@ -441,7 +440,7 @@ def Polynomial_YieldOp : Polynomial_Op<"yield", [Terminator, HasParent<"ApplyCoe
441440
}
442441

443442
def Polynomial_ApplyCoefficientwiseOp : Polynomial_Op<"apply_coefficientwise", [
444-
Pure, SingleBlock]> {
443+
Pure, SingleBlock, FixedFormCoeff]> {
445444
let summary = "Apply a region to each coefficient of a polynomial.";
446445
let description = [{
447446
`polynomial.apply_coefficientwise` takes a polynomial and applies a series

lib/Dialect/Polynomial/Transforms/PolyMulToNTT.cpp

Lines changed: 67 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h"
44
#include "lib/Dialect/Polynomial/IR/PolynomialOps.h"
5+
#include "lib/Dialect/Polynomial/IR/PolynomialTraits.h"
56
#include "lib/Dialect/Polynomial/IR/PolynomialTypes.h"
67
#include "lib/Dialect/Polynomial/Transforms/NTTSolver.h"
78
#include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project
@@ -57,21 +58,30 @@ enum class OpFormClass {
5758
};
5859

5960
OpFormClass opFormClass(Operation* op) {
61+
// This special case must come first because it overrides the more general
62+
// SameOperandsAndResultForm trait for MulOp.
63+
if (op->hasTrait<FixedFormEval>() || isa<MulOp>(op)) {
64+
return OpFormClass::EVAL;
65+
}
66+
6067
if (isa<func::ReturnOp>(op)) {
6168
return OpFormClass::RETURN;
62-
} else if (isa<ToTensorOp, LeadingTermOp, EvalOp, ConvertBasisOp,
63-
MonicMonomialMulOp, FromTensorOp, ApplyCoefficientwiseOp>(
64-
op)) {
69+
}
70+
71+
if (op->hasTrait<FixedFormCoeff>() || isa<ToTensorOp, FromTensorOp>(op)) {
6572
return OpFormClass::COEFF;
66-
} else if (isa<MulOp>(op)) {
67-
return OpFormClass::EVAL;
68-
} else if (isa<AddOp, SubOp, MulScalarOp, ModSwitchOp, ExtractSliceOp,
69-
tensor::ExtractSliceOp, tensor::ExtractOp,
70-
tensor::FromElementsOp>(op)) {
73+
}
74+
75+
if (op->hasTrait<SameOperandsAndResultForm>() ||
76+
isa<tensor::ExtractSliceOp, tensor::ExtractOp, tensor::FromElementsOp>(
77+
op)) {
7178
return OpFormClass::EITHER;
72-
} else if (isa<MonomialOp, ConstantOp>(op)) {
79+
}
80+
81+
if (isa<MonomialOp, ConstantOp>(op)) {
7382
return OpFormClass::CONST;
7483
}
84+
7585
return OpFormClass::UNKNOWN;
7686
}
7787

@@ -244,34 +254,48 @@ void PolyMulToNTT::runOnOperation() {
244254
// Eval poly inputs and outputs; this is really a mirror of the previous
245255
// case
246256
else if (opClass == OpFormClass::EVAL) {
247-
Value y = polyResults[0];
248-
// Since this op outputs eval form, the use of coeff form implies the
249-
// use of eval form
250-
solver.implyForm(y, Form::COEFF, Form::EVAL);
251-
// There's a conversion cost if y_c is needed
252-
solver.addConversionCostForForm(y, Form::COEFF);
253-
for (Value x : polyOperands) {
254-
// Use of output in eval form implies use of input in eval form
255-
solver.implyUse(y, x, Form::EVAL);
257+
if (polyResults.empty()) {
258+
for (Value v : polyOperands) {
259+
solver.forceDemandFixedForm(v, Form::EVAL);
260+
}
261+
} else {
262+
Value y = polyResults[0];
263+
// Since this op outputs eval form, the use of coeff form implies the
264+
// use of eval form
265+
solver.implyForm(y, Form::COEFF, Form::EVAL);
266+
// There's a conversion cost if y_c is needed
267+
solver.addConversionCostForForm(y, Form::COEFF);
268+
for (Value x : polyOperands) {
269+
// Use of output in eval form implies use of input in eval form
270+
solver.implyUse(y, x, Form::EVAL);
271+
}
256272
}
257273
}
258274
// Ops that work in either form, as long as inputs and outputs are all
259275
// "uni-form"
260276
else if (opClass == OpFormClass::EITHER) {
261-
Value y = polyResults[0];
262-
// Since the value output by this op can be in either form, it gets a
263-
// 'mode' variable. In short, if y_c is needed and y_e is not, we run the
264-
// op in coeff mode, and vice versa.
265-
solver.addOpMode(y);
266-
for (Value x : polyOperands) {
267-
// if y_mode = 0 and output (in either form) is needed, the inputs in
268-
// coeff form are required if y_mode = 1 and output (in either form) is
269-
// needed, the inputs in eval form are required
270-
solver.implyMode(y, x);
277+
if (polyResults.empty()) {
278+
// If an op has no poly results, we don't have a mode variable to attach
279+
// to it, so we just allow each operand to be in either form.
280+
for (Value v : polyOperands) {
281+
solver.forceDemandEitherForm(v);
282+
}
283+
} else {
284+
Value y = polyResults[0];
285+
// Since the value output by this op can be in either form, it gets a
286+
// 'mode' variable. In short, if y_c is needed and y_e is not, we run
287+
// the op in coeff mode, and vice versa.
288+
solver.addOpMode(y);
289+
for (Value x : polyOperands) {
290+
// if y_mode = 0 and output (in either form) is needed, the inputs in
291+
// coeff form are required if y_mode = 1 and output (in either form)
292+
// is needed, the inputs in eval form are required
293+
solver.implyMode(y, x);
294+
}
295+
// The only time there's a conversion cost is if both forms are needed.
296+
// If only one form is needed, the op runs in that mode.
297+
solver.addConversionCostIfBothForms(y);
271298
}
272-
// The only time there's a conversion cost is if both forms are needed. If
273-
// only one form is needed, the op runs in that mode.
274-
solver.addConversionCostIfBothForms(y);
275299
}
276300
// Ops that produce polynomials in any form. We can pre-compute these
277301
// constants in either (or both!) form(s)
@@ -615,10 +639,18 @@ void PolyMulToNTT::runOnOperation() {
615639
// Ops that work in either form, as long as inputs and outputs are all
616640
// "uni-form"
617641
else if (opClass == OpFormClass::EITHER) {
618-
Value v = polyResults[0];
619-
Form form = soln.getMode(v);
620-
for (OpOperand* arg : polyOperands) {
621-
arg->set(formToValue(arg->get(), form));
642+
if (polyResults.empty()) {
643+
for (OpOperand* arg : polyOperands) {
644+
Value v = arg->get();
645+
Form form = soln.needsForm(v, Form::COEFF) ? Form::COEFF : Form::EVAL;
646+
arg->set(formToValue(v, form));
647+
}
648+
} else {
649+
Value v = polyResults[0];
650+
Form form = soln.getMode(v);
651+
for (OpOperand* arg : polyOperands) {
652+
arg->set(formToValue(arg->get(), form));
653+
}
622654
}
623655
} else {
624656
op->emitOpError(

0 commit comments

Comments
 (0)