@@ -189,7 +189,7 @@ bool isPrimitiveNthRootOfUnity(const APInt& root, const APInt& n,
189189static LogicalResult verifyNTTOp (Operation* op, PolynomialType input,
190190 PolynomialType output,
191191 std::optional<PrimitiveRootAttr> root,
192- bool expectedInputForm) {
192+ Form expectedInputForm) {
193193 RingAttr inputRing = input.getRing ();
194194 RingAttr outputRing = output.getRing ();
195195 if (outputRing != inputRing) {
@@ -198,16 +198,16 @@ static LogicalResult verifyNTTOp(Operation* op, PolynomialType input,
198198 << " is not equivalent to the output ring " << outputRing;
199199 }
200200
201- FormAttr inputForm = input.getForm ();
202- FormAttr outputForm = output.getForm ();
203- if (inputForm. getIsCoeffForm () != expectedInputForm) {
201+ Form inputForm = input.getForm ();
202+ Form outputForm = output.getForm ();
203+ if (inputForm != expectedInputForm) {
204204 return op->emitOpError ()
205205 << " expected input with isCoeffForm=" << expectedInputForm;
206206 }
207- if (inputForm. getIsCoeffForm () == outputForm. getIsCoeffForm () ) {
207+ if (inputForm == outputForm) {
208208 return op->emitOpError () << " input and output form must be different, but "
209209 " both have isCoeffForm="
210- << inputForm. getIsCoeffForm () ;
210+ << inputForm;
211211 }
212212
213213 if (root.has_value ()) {
@@ -287,12 +287,12 @@ static LogicalResult verifyNTTOp(Operation* op, PolynomialType input,
287287
288288LogicalResult NTTOp::verify () {
289289 return verifyNTTOp (this ->getOperation (), getInput ().getType (),
290- getOutput ().getType (), getRoot (), true );
290+ getOutput ().getType (), getRoot (), Form::COEFF );
291291}
292292
293293LogicalResult INTTOp::verify () {
294294 return verifyNTTOp (this ->getOperation (), getInput ().getType (),
295- getOutput ().getType (), getRoot (), false );
295+ getOutput ().getType (), getRoot (), Form::EVAL );
296296}
297297
298298LogicalResult MulScalarOp::verify () {
@@ -443,7 +443,8 @@ LogicalResult NTTOp::inferReturnTypes(MLIRContext* ctx, std::optional<Location>,
443443 if (!inputTy) {
444444 return failure ();
445445 }
446- PolynomialType outputTy = PolynomialType::get (ctx, inputTy.getRing (), false );
446+ PolynomialType outputTy =
447+ PolynomialType::get (ctx, inputTy.getRing (), Form::EVAL);
447448 results.push_back (outputTy);
448449 return success ();
449450}
@@ -457,7 +458,8 @@ LogicalResult INTTOp::inferReturnTypes(
457458 if (!inputTy) {
458459 return failure ();
459460 }
460- PolynomialType outputTy = PolynomialType::get (ctx, inputTy.getRing (), true );
461+ PolynomialType outputTy =
462+ PolynomialType::get (ctx, inputTy.getRing (), Form::COEFF);
461463 results.push_back (outputTy);
462464 return success ();
463465}
0 commit comments