Skip to content

Commit d4a089d

Browse files
Merge pull request #2664 from crockeea:form_attr
PiperOrigin-RevId: 871860623
2 parents c304285 + cf04f2f commit d4a089d

14 files changed

Lines changed: 53 additions & 45 deletions

File tree

lib/Dialect/Polynomial/IR/BUILD

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ cc_library(
1515
":attributes_inc_gen",
1616
":canonicalization_inc_gen",
1717
":dialect_inc_gen",
18+
":enums_inc_gen",
1819
":ops_inc_gen",
1920
":types_inc_gen",
2021
"@heir//lib/Dialect/ModArith/IR:Dialect",
@@ -36,6 +37,7 @@ td_library(
3637
"PolynomialAttributes.td",
3738
"PolynomialCanonicalization.td",
3839
"PolynomialDialect.td",
40+
"PolynomialEnums.td",
3941
"PolynomialOps.td",
4042
"PolynomialTypes.td",
4143
],
@@ -72,6 +74,16 @@ add_heir_dialect_library(
7274
],
7375
)
7476

77+
add_heir_dialect_library(
78+
name = "enums_inc_gen",
79+
dialect = "Polynomial",
80+
kind = "enum",
81+
td_file = "PolynomialEnums.td",
82+
deps = [
83+
":td_files",
84+
],
85+
)
86+
7587
add_heir_dialect_library(
7688
name = "types_inc_gen",
7789
dialect = "Polynomial",

lib/Dialect/Polynomial/IR/PolynomialAttributes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,6 @@
88

99
#define GET_ATTRDEF_CLASSES
1010
#include "lib/Dialect/Polynomial/IR/PolynomialAttributes.h.inc"
11+
#include "lib/Dialect/Polynomial/IR/PolynomialEnums.h.inc"
1112

1213
#endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALATTRIBUTES_H_

lib/Dialect/Polynomial/IR/PolynomialAttributes.td

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ include "lib/Dialect/Polynomial/IR/PolynomialDialect.td"
55
include "mlir/IR/BuiltinAttributeInterfaces.td"
66
include "mlir/IR/OpBase.td"
77
include "mlir/IR/OpAsmInterface.td"
8+
include "mlir/IR/EnumAttr.td"
89

910
class Polynomial_Attr<string name, string attrMnemonic, list<Trait> traits = []>
1011
: AttrDef<Polynomial_Dialect, name, traits> {
@@ -232,22 +233,6 @@ def Polynomial_RingAttr : Polynomial_Attr<"Ring", "ring", [OpAsmAttrInterface]>
232233
}];
233234
}
234235

235-
def Polynomial_FormAttr : Polynomial_Attr<"Form", "form"> {
236-
let summary = "an attribute describing the polynomial representation form";
237-
let description = [{
238-
Indicates whether a polynomial value is represented in coefficient form or
239-
evaluation (NTT) form.
240-
241-
Example:
242-
243-
```mlir
244-
#form = #polynomial.form<isCoeffForm = true>
245-
```
246-
}];
247-
let parameters = (ins "bool":$isCoeffForm);
248-
let assemblyFormat = "`<` `isCoeffForm` `=` $isCoeffForm `>`";
249-
}
250-
251236
def Polynomial_PrimitiveRootAttr: Polynomial_Attr<"PrimitiveRoot", "primitive_root"> {
252237
let summary = "an attribute containing a typed root value and its degree as a root of unity";
253238
let description = [{

lib/Dialect/Polynomial/IR/PolynomialDialect.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ using namespace mlir::heir::polynomial;
2828
#define GET_TYPEDEF_CLASSES
2929
#include "lib/Dialect/Polynomial/IR/PolynomialTypes.cpp.inc"
3030
#define GET_OP_CLASSES
31+
#include "lib/Dialect/Polynomial/IR/PolynomialEnums.cpp.inc"
3132
#include "lib/Dialect/Polynomial/IR/PolynomialOps.cpp.inc"
3233

3334
void PolynomialDialect::initialize() {
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
#ifndef LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALENUMS_TD_
2+
#define LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALENUMS_TD_
3+
4+
include "lib/Dialect/Polynomial/IR/PolynomialDialect.td"
5+
include "mlir/IR/EnumAttr.td"
6+
7+
def Polynomial_Form_Coeff : I32EnumAttrCase<"COEFF", 0, "coeff">;
8+
def Polynomial_Form_Eval : I32EnumAttrCase<"EVAL", 1, "eval">;
9+
10+
def Polynomial_FormEnum : I32EnumAttr<"Form", "Polynomial representation form",
11+
[Polynomial_Form_Coeff, Polynomial_Form_Eval]> {
12+
let cppNamespace = "::mlir::heir::polynomial";
13+
}
14+
15+
16+
#endif // LIB_DIALECT_POLYNOMIAL_IR_POLYNOMIALENUMS_TD_

lib/Dialect/Polynomial/IR/PolynomialOps.cpp

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ bool isPrimitiveNthRootOfUnity(const APInt& root, const APInt& n,
189189
static LogicalResult verifyNTTOp(Operation* op, PolynomialType input,
190190
PolynomialType output,
191191
std::optional<PrimitiveRootAttr> root,
192-
bool expectedInputForm) {
192+
Form expectedInputForm) {
193193
RingAttr inputRing = input.getRing();
194194
RingAttr outputRing = output.getRing();
195195
if (outputRing != inputRing) {
@@ -198,16 +198,16 @@ static LogicalResult verifyNTTOp(Operation* op, PolynomialType input,
198198
<< " is not equivalent to the output ring " << outputRing;
199199
}
200200

201-
FormAttr inputForm = input.getForm();
202-
FormAttr outputForm = output.getForm();
203-
if (inputForm.getIsCoeffForm() != expectedInputForm) {
201+
Form inputForm = input.getForm();
202+
Form outputForm = output.getForm();
203+
if (inputForm != expectedInputForm) {
204204
return op->emitOpError()
205205
<< "expected input with isCoeffForm=" << expectedInputForm;
206206
}
207-
if (inputForm.getIsCoeffForm() == outputForm.getIsCoeffForm()) {
207+
if (inputForm == outputForm) {
208208
return op->emitOpError() << "input and output form must be different, but "
209209
"both have isCoeffForm="
210-
<< inputForm.getIsCoeffForm();
210+
<< inputForm;
211211
}
212212

213213
if (root.has_value()) {
@@ -287,12 +287,12 @@ static LogicalResult verifyNTTOp(Operation* op, PolynomialType input,
287287

288288
LogicalResult NTTOp::verify() {
289289
return verifyNTTOp(this->getOperation(), getInput().getType(),
290-
getOutput().getType(), getRoot(), true);
290+
getOutput().getType(), getRoot(), Form::COEFF);
291291
}
292292

293293
LogicalResult INTTOp::verify() {
294294
return verifyNTTOp(this->getOperation(), getInput().getType(),
295-
getOutput().getType(), getRoot(), false);
295+
getOutput().getType(), getRoot(), Form::EVAL);
296296
}
297297

298298
LogicalResult MulScalarOp::verify() {
@@ -443,7 +443,8 @@ LogicalResult NTTOp::inferReturnTypes(MLIRContext* ctx, std::optional<Location>,
443443
if (!inputTy) {
444444
return failure();
445445
}
446-
PolynomialType outputTy = PolynomialType::get(ctx, inputTy.getRing(), false);
446+
PolynomialType outputTy =
447+
PolynomialType::get(ctx, inputTy.getRing(), Form::EVAL);
447448
results.push_back(outputTy);
448449
return success();
449450
}
@@ -457,7 +458,8 @@ LogicalResult INTTOp::inferReturnTypes(
457458
if (!inputTy) {
458459
return failure();
459460
}
460-
PolynomialType outputTy = PolynomialType::get(ctx, inputTy.getRing(), true);
461+
PolynomialType outputTy =
462+
PolynomialType::get(ctx, inputTy.getRing(), Form::COEFF);
461463
results.push_back(outputTy);
462464
return success();
463465
}

lib/Dialect/Polynomial/IR/PolynomialTypes.td

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial", [OpA
1818
let parameters = (ins
1919
Polynomial_RingAttr:$ring,
2020
DefaultValuedParameter<
21-
"::mlir::heir::polynomial::FormAttr",
22-
"FormAttr::get($_ctxt, true)">:$form
21+
"::mlir::heir::polynomial::Form",
22+
"::mlir::heir::polynomial::Form::COEFF">:$form
2323
);
2424
let assemblyFormat = "`<` struct(params) `>`";
2525

@@ -28,16 +28,7 @@ def Polynomial_PolynomialType : Polynomial_Type<"Polynomial", "polynomial", [OpA
2828
static PolynomialType get(
2929
::mlir::MLIRContext *context,
3030
::mlir::heir::polynomial::RingAttr ring) {
31-
return PolynomialType::get(context, ring, FormAttr::get(context, true));
32-
}
33-
34-
// Convenience builder from a boolean form flag.
35-
static PolynomialType get(
36-
::mlir::MLIRContext *context,
37-
::mlir::heir::polynomial::RingAttr ring,
38-
bool isCoeffForm) {
39-
return PolynomialType::get(context, ring,
40-
FormAttr::get(context, isCoeffForm));
31+
return PolynomialType::get(context, ring, ::mlir::heir::polynomial::Form::COEFF);
4132
}
4233

4334
// OpAsmTypeInterface methods.

tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/lower_intt.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#root_val = #mod_arith.value<1925:!coeff_ty>
1010
#root = #polynomial.primitive_root<value=#root_val, degree=8:i32>
1111
!poly_ty = !polynomial.polynomial<ring=#ring>
12-
!ntt_poly_ty = !polynomial.polynomial<ring=#ring, form=<isCoeffForm=false>>
12+
!ntt_poly_ty = !polynomial.polynomial<ring=#ring, form=eval>
1313

1414
// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)>
1515

tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/lower_ntt.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
#root_val = #mod_arith.value<1925:!coeff_ty>
1010
#root = #polynomial.primitive_root<value=#root_val, degree=8:i32>
1111
!poly_ty = !polynomial.polynomial<ring=#ring>
12-
!ntt_poly_ty = !polynomial.polynomial<ring=#ring, form=<isCoeffForm=false>>
12+
!ntt_poly_ty = !polynomial.polynomial<ring=#ring, form=eval>
1313

1414
// CHECK-DAG: #[[ID_MAP:.*]] = affine_map<(d0) -> (d0)>
1515

tests/Dialect/Polynomial/Conversions/heir_polynomial_to_llvm/runner/lower_intt.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
#root_val = #mod_arith.value<1925:!coeff_ty>
88
#root = #polynomial.primitive_root<value=#root_val, degree=8:i32>
99
!poly_ty = !polynomial.polynomial<ring=#ring>
10-
!ntt_poly_ty = !polynomial.polynomial<ring=#ring, form=<isCoeffForm = false>>
10+
!ntt_poly_ty = !polynomial.polynomial<ring=#ring, form=eval>
1111

1212
func.func public @test_intt() -> !poly_ty {
1313
%coeffs = arith.constant dense<[1467,2807,3471,7621]> : tensor<4xi32>

0 commit comments

Comments
 (0)