From 1e2c0198005c53c12ae2a4718c1dae88af780d3c Mon Sep 17 00:00:00 2001 From: Jeremy Kun Date: Tue, 10 Feb 2026 18:44:23 -0800 Subject: [PATCH] Fix #2484 (extra guards for arith-to-cggi) PiperOrigin-RevId: 868417961 --- .../Conversions/ArithToCGGI/ArithToCGGI.cpp | 61 +++++++++++++------ tests/Regression/issue_2484.mlir | 38 +++++------- 2 files changed, 58 insertions(+), 41 deletions(-) diff --git a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp index 77a802eb87..0561b33383 100644 --- a/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp +++ b/lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp @@ -9,7 +9,6 @@ #include "lib/Dialect/Comb/IR/CombDialect.h" #include "lib/Dialect/Comb/IR/CombOps.h" #include "lib/Dialect/LWE/IR/LWEAttributes.h" -#include "lib/Dialect/LWE/IR/LWEDialect.h" #include "lib/Dialect/LWE/IR/LWEOps.h" #include "lib/Dialect/LWE/IR/LWETypes.h" #include "lib/Utils/ConversionUtils.h" @@ -27,13 +26,13 @@ #include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project #include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/include/mlir/IR/Value.h" // from @llvm-project #include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project #include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project #include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project #include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir::heir::arith { @@ -129,10 +128,6 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs, assert(inputs.size() == 1); auto inputType = inputs[0].getType(); - if (!isa(getElementTypeOrSelf(inputType))) - llvm_unreachable( - "Non-integer types should never be the input to a materializeTarget."); - if (auto inValue = inputs.front().getDefiningOp()) { if (auto intAttr = dyn_cast(inValue.getValueAttr())) { return cggi::CreateTrivialOp::create(builder, loc, type, intAttr); @@ -154,17 +149,9 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs, } // Comes from function/loop argument: Trivial encrypt through LWE - lwe::LWECiphertextType ciphertextType; - - if (auto shapedType = dyn_cast(type)) { - auto tensorElementSize = - shapedType.getElementType().getIntOrFloatBitWidth(); - ciphertextType = lwe::getDefaultCGGICiphertextType(builder.getContext(), - tensorElementSize); - } else { - ciphertextType = lwe::getDefaultCGGICiphertextType( - builder.getContext(), inputType.getIntOrFloatBitWidth()); - } + auto tensorElementSize = inputType.getIntOrFloatBitWidth(); + lwe::LWECiphertextType ciphertextType = lwe::getDefaultCGGICiphertextType( + builder.getContext(), tensorElementSize); auto plaintextBits = ciphertextType.getPlaintextSpace() .getRing() @@ -281,6 +268,9 @@ struct ConvertCmpOp : public OpConversionPattern { LogicalResult matchAndRewrite( mlir::arith::CmpIOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (isa(op.getType())) { + return failure(); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); ArithToCGGITypeConverter typeConverter(op->getContext()); @@ -321,6 +311,9 @@ struct ConvertSubOp : public OpConversionPattern { LogicalResult matchAndRewrite( mlir::arith::SubIOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (isa(op.getType())) { + return failure(); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); if (auto rhsDefOp = op.getRhs().getDefiningOp()) { @@ -348,6 +341,9 @@ struct ConvertSelectOp : public OpConversionPattern { LogicalResult matchAndRewrite( mlir::arith::SelectOp op, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (isa(op.getType())) { + return failure(); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto cmuxOp = cggi::SelectOp::create( @@ -392,6 +388,9 @@ struct ConvertShOp : public OpConversionPattern { LogicalResult matchAndRewrite( SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (isa(op.getType())) { + return failure(); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); auto cteShiftSizeOp = @@ -443,6 +442,9 @@ struct ConvertArithBinOp : public OpConversionPattern { LogicalResult matchAndRewrite( SourceArithOp op, typename SourceArithOp::Adaptor adaptor, ConversionPatternRewriter& rewriter) const override { + if (isa(op.getType())) { + return failure(); + } ImplicitLocOpBuilder b(op.getLoc(), rewriter); if (auto lhsDefOp = op.getLhs().getDefiningOp()) { @@ -517,6 +519,31 @@ struct ArithToCGGI : public impl::ArithToCGGIBase { void runOnOperation() override { MLIRContext* context = &getContext(); auto* module = getOperation(); + + auto walkResult = module->walk([&](Operation* op) { + if (op->getName().getDialectNamespace() == "tensor_ext") { + op->emitError() << "--arith-to-cggi does not support tensor_ext " + "operations. Lower them to scalars first."; + return WalkResult::interrupt(); + } + if (isa(op->getDialect())) { + if (llvm::any_of(op->getResultTypes(), + [](Type t) { return isa(t); })) { + if (!isa(op)) { + op->emitError() + << "--arith-to-cggi does not support arith operations on " + "vectors or tensors. Lower them to scalars first."; + return WalkResult::interrupt(); + } + } + } + return WalkResult::advance(); + }); + + if (walkResult.wasInterrupted()) { + return signalPassFailure(); + } + ArithToCGGITypeConverter typeConverter(context); RewritePatternSet patterns(context); diff --git a/tests/Regression/issue_2484.mlir b/tests/Regression/issue_2484.mlir index 4f634d1ca7..10caa23c9f 100644 --- a/tests/Regression/issue_2484.mlir +++ b/tests/Regression/issue_2484.mlir @@ -1,30 +1,20 @@ -// RUN: heir-opt --secret-insert-mgmt-ckks=slot-number=1024 --tensor-linalg-to-affine-loops --arith-to-cggi --verify-diagnostics -split-input-file %s | FileCheck %s +// RUN: heir-opt --tensor-linalg-to-affine-loops --arith-to-cggi --verify-diagnostics -split-input-file %s -// CHECK: func.func @scalar_op_combinations -// CHECK-SAME: %[[ARG0:.*]]: ![[CT:.*]], %[[ARG1:.*]]: ![[CT]] -func.func @scalar_op_combinations(%arg0: i32, %arg1: i32) -> i32 { - // CHECK: %[[ADD:.*]] = cggi.add %[[ARG0]], %[[ARG1]] - %0 = arith.addi %arg0, %arg1 : i32 - // CHECK: %[[MUL:.*]] = cggi.mul %[[ARG0]], %[[ARG1]] - %1 = arith.muli %arg0, %arg1 : i32 - // CHECK: %[[CMP:.*]] = cggi.cmp %[[ADD]], %[[MUL]] {predicate = 2 : i64} - %cond = arith.cmpi slt, %0, %1 : i32 - // CHECK: %[[SEL:.*]] = cggi.cmux %[[CMP]], %[[MUL]], %[[ARG0]] - %2 = arith.select %cond, %1, %arg0 : i1, i32 - // CHECK: return %[[SEL]] - return %2 : i32 +module { + func.func @add_tensor(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> tensor<4xi16> { + // expected-error@+1 {{--arith-to-cggi does not support arith operations on vectors or tensors. Lower them to scalars first.}} + %0 = arith.addi %arg0, %arg1 : tensor<4xi16> + return %0 : tensor<4xi16> + } } // ----- -func.func @scalar_mul(%arg0: !secret.secret) -> !secret.secret { - %0 = secret.generic(%arg0 : !secret.secret) { - ^bb0(%arg1: i16): - %c2 = arith.constant 2 : i16 - // expected-error@+1 {{failed to legalize unresolved materialization}} - %1 = arith.muli %arg1, %c2 : i16 - // expected-note@+1 {{see existing live user here}} - secret.yield %1 : i16 - } -> (!secret.secret) - return %0 : !secret.secret +module { + func.func @rotate_tensor(%arg0: tensor<4xi16> {secret.secret}) -> tensor<4xi16> { + %c1 = arith.constant 1 : index + // expected-error@+1 {{--arith-to-cggi does not support tensor_ext operations. Lower them to scalars first.}} + %0 = tensor_ext.rotate %arg0, %c1 : tensor<4xi16>, index + return %0 : tensor<4xi16> + } }