99#include " lib/Dialect/Comb/IR/CombDialect.h"
1010#include " lib/Dialect/Comb/IR/CombOps.h"
1111#include " lib/Dialect/LWE/IR/LWEAttributes.h"
12- #include " lib/Dialect/LWE/IR/LWEDialect.h"
1312#include " lib/Dialect/LWE/IR/LWEOps.h"
1413#include " lib/Dialect/LWE/IR/LWETypes.h"
1514#include " lib/Utils/ConversionUtils.h"
2726#include " mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
2827#include " mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
2928#include " mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
30- #include " mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
3129#include " mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
3230#include " mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
3331#include " mlir/include/mlir/IR/Value.h" // from @llvm-project
3432#include " mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
3533#include " mlir/include/mlir/Support/LLVM.h" // from @llvm-project
3634#include " mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
35+ #include " mlir/include/mlir/Support/WalkResult.h" // from @llvm-project
3736#include " mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
3837
3938namespace mlir ::heir::arith {
@@ -129,10 +128,6 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
129128 assert (inputs.size () == 1 );
130129 auto inputType = inputs[0 ].getType ();
131130
132- if (!isa<IntegerType>(getElementTypeOrSelf (inputType)))
133- llvm_unreachable (
134- " Non-integer types should never be the input to a materializeTarget." );
135-
136131 if (auto inValue = inputs.front ().getDefiningOp <mlir::arith::ConstantOp>()) {
137132 if (auto intAttr = dyn_cast<IntegerAttr>(inValue.getValueAttr ())) {
138133 return cggi::CreateTrivialOp::create (builder, loc, type, intAttr);
@@ -154,17 +149,9 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
154149 }
155150
156151 // Comes from function/loop argument: Trivial encrypt through LWE
157- lwe::LWECiphertextType ciphertextType;
158-
159- if (auto shapedType = dyn_cast<ShapedType>(type)) {
160- auto tensorElementSize =
161- shapedType.getElementType ().getIntOrFloatBitWidth ();
162- ciphertextType = lwe::getDefaultCGGICiphertextType (builder.getContext (),
163- tensorElementSize);
164- } else {
165- ciphertextType = lwe::getDefaultCGGICiphertextType (
166- builder.getContext (), inputType.getIntOrFloatBitWidth ());
167- }
152+ auto tensorElementSize = inputType.getIntOrFloatBitWidth ();
153+ lwe::LWECiphertextType ciphertextType = lwe::getDefaultCGGICiphertextType (
154+ builder.getContext (), tensorElementSize);
168155
169156 auto plaintextBits = ciphertextType.getPlaintextSpace ()
170157 .getRing ()
@@ -281,6 +268,9 @@ struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
281268 LogicalResult matchAndRewrite (
282269 mlir::arith::CmpIOp op, OpAdaptor adaptor,
283270 ConversionPatternRewriter& rewriter) const override {
271+ if (isa<ShapedType>(op.getType ())) {
272+ return failure ();
273+ }
284274 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
285275
286276 ArithToCGGITypeConverter typeConverter (op->getContext ());
@@ -321,6 +311,9 @@ struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
321311 LogicalResult matchAndRewrite (
322312 mlir::arith::SubIOp op, OpAdaptor adaptor,
323313 ConversionPatternRewriter& rewriter) const override {
314+ if (isa<ShapedType>(op.getType ())) {
315+ return failure ();
316+ }
324317 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
325318
326319 if (auto rhsDefOp = op.getRhs ().getDefiningOp ()) {
@@ -348,6 +341,9 @@ struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
348341 LogicalResult matchAndRewrite (
349342 mlir::arith::SelectOp op, OpAdaptor adaptor,
350343 ConversionPatternRewriter& rewriter) const override {
344+ if (isa<ShapedType>(op.getType ())) {
345+ return failure ();
346+ }
351347 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
352348
353349 auto cmuxOp = cggi::SelectOp::create (
@@ -392,6 +388,9 @@ struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
392388 LogicalResult matchAndRewrite (
393389 SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
394390 ConversionPatternRewriter& rewriter) const override {
391+ if (isa<ShapedType>(op.getType ())) {
392+ return failure ();
393+ }
395394 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
396395
397396 auto cteShiftSizeOp =
@@ -443,6 +442,9 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
443442 LogicalResult matchAndRewrite (
444443 SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
445444 ConversionPatternRewriter& rewriter) const override {
445+ if (isa<ShapedType>(op.getType ())) {
446+ return failure ();
447+ }
446448 ImplicitLocOpBuilder b (op.getLoc (), rewriter);
447449
448450 if (auto lhsDefOp = op.getLhs ().getDefiningOp ()) {
@@ -517,6 +519,31 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
517519 void runOnOperation () override {
518520 MLIRContext* context = &getContext ();
519521 auto * module = getOperation ();
522+
523+ auto walkResult = module ->walk ([&](Operation* op) {
524+ if (op->getName ().getDialectNamespace () == " tensor_ext" ) {
525+ op->emitError () << " --arith-to-cggi does not support tensor_ext "
526+ " operations. Lower them to scalars first." ;
527+ return WalkResult::interrupt ();
528+ }
529+ if (isa<mlir::arith::ArithDialect>(op->getDialect ())) {
530+ if (llvm::any_of (op->getResultTypes (),
531+ [](Type t) { return isa<ShapedType>(t); })) {
532+ if (!isa<mlir::arith::ConstantOp>(op)) {
533+ op->emitError ()
534+ << " --arith-to-cggi does not support arith operations on "
535+ " vectors or tensors. Lower them to scalars first." ;
536+ return WalkResult::interrupt ();
537+ }
538+ }
539+ }
540+ return WalkResult::advance ();
541+ });
542+
543+ if (walkResult.wasInterrupted ()) {
544+ return signalPassFailure ();
545+ }
546+
520547 ArithToCGGITypeConverter typeConverter (context);
521548
522549 RewritePatternSet patterns (context);
0 commit comments