|
2 | 2 |
|
3 | 3 | #include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h" |
4 | 4 | #include "lib/Dialect/Polynomial/IR/PolynomialOps.h" |
| 5 | +#include "lib/Dialect/Polynomial/IR/PolynomialTraits.h" |
5 | 6 | #include "lib/Dialect/Polynomial/IR/PolynomialTypes.h" |
6 | 7 | #include "lib/Dialect/Polynomial/Transforms/NTTSolver.h" |
7 | 8 | #include "llvm/include/llvm/ADT/DenseMap.h" // from @llvm-project |
@@ -57,21 +58,30 @@ enum class OpFormClass { |
57 | 58 | }; |
58 | 59 |
|
59 | 60 | OpFormClass opFormClass(Operation* op) { |
| 61 | + // This special case must come first because it overrides the more general |
| 62 | + // SameOperandsAndResultForm trait for MulOp. |
| 63 | + if (op->hasTrait<FixedFormEval>() || isa<MulOp>(op)) { |
| 64 | + return OpFormClass::EVAL; |
| 65 | + } |
| 66 | + |
60 | 67 | if (isa<func::ReturnOp>(op)) { |
61 | 68 | return OpFormClass::RETURN; |
62 | | - } else if (isa<ToTensorOp, LeadingTermOp, EvalOp, ConvertBasisOp, |
63 | | - MonicMonomialMulOp, FromTensorOp, ApplyCoefficientwiseOp>( |
64 | | - op)) { |
| 69 | + } |
| 70 | + |
| 71 | + if (op->hasTrait<FixedFormCoeff>() || isa<ToTensorOp, FromTensorOp>(op)) { |
65 | 72 | return OpFormClass::COEFF; |
66 | | - } else if (isa<MulOp>(op)) { |
67 | | - return OpFormClass::EVAL; |
68 | | - } else if (isa<AddOp, SubOp, MulScalarOp, ModSwitchOp, ExtractSliceOp, |
69 | | - tensor::ExtractSliceOp, tensor::ExtractOp, |
70 | | - tensor::FromElementsOp>(op)) { |
| 73 | + } |
| 74 | + |
| 75 | + if (op->hasTrait<SameOperandsAndResultForm>() || |
| 76 | + isa<tensor::ExtractSliceOp, tensor::ExtractOp, tensor::FromElementsOp>( |
| 77 | + op)) { |
71 | 78 | return OpFormClass::EITHER; |
72 | | - } else if (isa<MonomialOp, ConstantOp>(op)) { |
| 79 | + } |
| 80 | + |
| 81 | + if (isa<MonomialOp, ConstantOp>(op)) { |
73 | 82 | return OpFormClass::CONST; |
74 | 83 | } |
| 84 | + |
75 | 85 | return OpFormClass::UNKNOWN; |
76 | 86 | } |
77 | 87 |
|
@@ -244,34 +254,48 @@ void PolyMulToNTT::runOnOperation() { |
244 | 254 | // Eval poly inputs and outputs; this is really a mirror of the previous |
245 | 255 | // case |
246 | 256 | else if (opClass == OpFormClass::EVAL) { |
247 | | - Value y = polyResults[0]; |
248 | | - // Since this op outputs eval form, the use of coeff form implies the |
249 | | - // use of eval form |
250 | | - solver.implyForm(y, Form::COEFF, Form::EVAL); |
251 | | - // There's a conversion cost if y_c is needed |
252 | | - solver.addConversionCostForForm(y, Form::COEFF); |
253 | | - for (Value x : polyOperands) { |
254 | | - // Use of output in eval form implies use of input in eval form |
255 | | - solver.implyUse(y, x, Form::EVAL); |
| 257 | + if (polyResults.empty()) { |
| 258 | + for (Value v : polyOperands) { |
| 259 | + solver.forceDemandFixedForm(v, Form::EVAL); |
| 260 | + } |
| 261 | + } else { |
| 262 | + Value y = polyResults[0]; |
| 263 | + // Since this op outputs eval form, the use of coeff form implies the |
| 264 | + // use of eval form |
| 265 | + solver.implyForm(y, Form::COEFF, Form::EVAL); |
| 266 | + // There's a conversion cost if y_c is needed |
| 267 | + solver.addConversionCostForForm(y, Form::COEFF); |
| 268 | + for (Value x : polyOperands) { |
| 269 | + // Use of output in eval form implies use of input in eval form |
| 270 | + solver.implyUse(y, x, Form::EVAL); |
| 271 | + } |
256 | 272 | } |
257 | 273 | } |
258 | 274 | // Ops that work in either form, as long as inputs and outputs are all |
259 | 275 | // "uni-form" |
260 | 276 | else if (opClass == OpFormClass::EITHER) { |
261 | | - Value y = polyResults[0]; |
262 | | - // Since the value output by this op can be in either form, it gets a |
263 | | - // 'mode' variable. In short, if y_c is needed and y_e is not, we run the |
264 | | - // op in coeff mode, and vice versa. |
265 | | - solver.addOpMode(y); |
266 | | - for (Value x : polyOperands) { |
267 | | - // if y_mode = 0 and output (in either form) is needed, the inputs in |
268 | | - // coeff form are required if y_mode = 1 and output (in either form) is |
269 | | - // needed, the inputs in eval form are required |
270 | | - solver.implyMode(y, x); |
| 277 | + if (polyResults.empty()) { |
| 278 | + // If an op has no poly results, we don't have a mode variable to attach |
| 279 | + // to it, so we just allow each operand to be in either form. |
| 280 | + for (Value v : polyOperands) { |
| 281 | + solver.forceDemandEitherForm(v); |
| 282 | + } |
| 283 | + } else { |
| 284 | + Value y = polyResults[0]; |
| 285 | + // Since the value output by this op can be in either form, it gets a |
| 286 | + // 'mode' variable. In short, if y_c is needed and y_e is not, we run |
| 287 | + // the op in coeff mode, and vice versa. |
| 288 | + solver.addOpMode(y); |
| 289 | + for (Value x : polyOperands) { |
| 290 | + // if y_mode = 0 and output (in either form) is needed, the inputs in |
| 291 | + // coeff form are required if y_mode = 1 and output (in either form) |
| 292 | + // is needed, the inputs in eval form are required |
| 293 | + solver.implyMode(y, x); |
| 294 | + } |
| 295 | + // The only time there's a conversion cost is if both forms are needed. |
| 296 | + // If only one form is needed, the op runs in that mode. |
| 297 | + solver.addConversionCostIfBothForms(y); |
271 | 298 | } |
272 | | - // The only time there's a conversion cost is if both forms are needed. If |
273 | | - // only one form is needed, the op runs in that mode. |
274 | | - solver.addConversionCostIfBothForms(y); |
275 | 299 | } |
276 | 300 | // Ops that produce polynomials in any form. We can pre-compute these |
277 | 301 | // constants in either (or both!) form(s) |
@@ -615,10 +639,18 @@ void PolyMulToNTT::runOnOperation() { |
615 | 639 | // Ops that work in either form, as long as inputs and outputs are all |
616 | 640 | // "uni-form" |
617 | 641 | else if (opClass == OpFormClass::EITHER) { |
618 | | - Value v = polyResults[0]; |
619 | | - Form form = soln.getMode(v); |
620 | | - for (OpOperand* arg : polyOperands) { |
621 | | - arg->set(formToValue(arg->get(), form)); |
| 642 | + if (polyResults.empty()) { |
| 643 | + for (OpOperand* arg : polyOperands) { |
| 644 | + Value v = arg->get(); |
| 645 | + Form form = soln.needsForm(v, Form::COEFF) ? Form::COEFF : Form::EVAL; |
| 646 | + arg->set(formToValue(v, form)); |
| 647 | + } |
| 648 | + } else { |
| 649 | + Value v = polyResults[0]; |
| 650 | + Form form = soln.getMode(v); |
| 651 | + for (OpOperand* arg : polyOperands) { |
| 652 | + arg->set(formToValue(arg->get(), form)); |
| 653 | + } |
622 | 654 | } |
623 | 655 | } else { |
624 | 656 | op->emitOpError( |
|
0 commit comments