Skip to content

Commit 1e2c019

Browse files
j2kuncopybara-github
authored andcommitted
Fix #2484 (extra guards for arith-to-cggi)
PiperOrigin-RevId: 868417961
1 parent 9099cb0 commit 1e2c019

2 files changed

Lines changed: 58 additions & 41 deletions

File tree

lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp

Lines changed: 44 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "lib/Dialect/Comb/IR/CombDialect.h"
1010
#include "lib/Dialect/Comb/IR/CombOps.h"
1111
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
12-
#include "lib/Dialect/LWE/IR/LWEDialect.h"
1312
#include "lib/Dialect/LWE/IR/LWEOps.h"
1413
#include "lib/Dialect/LWE/IR/LWETypes.h"
1514
#include "lib/Utils/ConversionUtils.h"
@@ -27,13 +26,13 @@
2726
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
2827
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
2928
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
30-
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
3129
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
3230
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
3331
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
3432
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
3533
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
3634
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
35+
#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project
3736
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project
3837

3938
namespace mlir::heir::arith {
@@ -129,10 +128,6 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
129128
assert(inputs.size() == 1);
130129
auto inputType = inputs[0].getType();
131130

132-
if (!isa<IntegerType>(getElementTypeOrSelf(inputType)))
133-
llvm_unreachable(
134-
"Non-integer types should never be the input to a materializeTarget.");
135-
136131
if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
137132
if (auto intAttr = dyn_cast<IntegerAttr>(inValue.getValueAttr())) {
138133
return cggi::CreateTrivialOp::create(builder, loc, type, intAttr);
@@ -154,17 +149,9 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
154149
}
155150

156151
// Comes from function/loop argument: Trivial encrypt through LWE
157-
lwe::LWECiphertextType ciphertextType;
158-
159-
if (auto shapedType = dyn_cast<ShapedType>(type)) {
160-
auto tensorElementSize =
161-
shapedType.getElementType().getIntOrFloatBitWidth();
162-
ciphertextType = lwe::getDefaultCGGICiphertextType(builder.getContext(),
163-
tensorElementSize);
164-
} else {
165-
ciphertextType = lwe::getDefaultCGGICiphertextType(
166-
builder.getContext(), inputType.getIntOrFloatBitWidth());
167-
}
152+
auto tensorElementSize = inputType.getIntOrFloatBitWidth();
153+
lwe::LWECiphertextType ciphertextType = lwe::getDefaultCGGICiphertextType(
154+
builder.getContext(), tensorElementSize);
168155

169156
auto plaintextBits = ciphertextType.getPlaintextSpace()
170157
.getRing()
@@ -281,6 +268,9 @@ struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
281268
LogicalResult matchAndRewrite(
282269
mlir::arith::CmpIOp op, OpAdaptor adaptor,
283270
ConversionPatternRewriter& rewriter) const override {
271+
if (isa<ShapedType>(op.getType())) {
272+
return failure();
273+
}
284274
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
285275

286276
ArithToCGGITypeConverter typeConverter(op->getContext());
@@ -321,6 +311,9 @@ struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
321311
LogicalResult matchAndRewrite(
322312
mlir::arith::SubIOp op, OpAdaptor adaptor,
323313
ConversionPatternRewriter& rewriter) const override {
314+
if (isa<ShapedType>(op.getType())) {
315+
return failure();
316+
}
324317
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
325318

326319
if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
@@ -348,6 +341,9 @@ struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
348341
LogicalResult matchAndRewrite(
349342
mlir::arith::SelectOp op, OpAdaptor adaptor,
350343
ConversionPatternRewriter& rewriter) const override {
344+
if (isa<ShapedType>(op.getType())) {
345+
return failure();
346+
}
351347
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
352348

353349
auto cmuxOp = cggi::SelectOp::create(
@@ -392,6 +388,9 @@ struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
392388
LogicalResult matchAndRewrite(
393389
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
394390
ConversionPatternRewriter& rewriter) const override {
391+
if (isa<ShapedType>(op.getType())) {
392+
return failure();
393+
}
395394
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
396395

397396
auto cteShiftSizeOp =
@@ -443,6 +442,9 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
443442
LogicalResult matchAndRewrite(
444443
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
445444
ConversionPatternRewriter& rewriter) const override {
445+
if (isa<ShapedType>(op.getType())) {
446+
return failure();
447+
}
446448
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
447449

448450
if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
@@ -517,6 +519,31 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
517519
void runOnOperation() override {
518520
MLIRContext* context = &getContext();
519521
auto* module = getOperation();
522+
523+
auto walkResult = module->walk([&](Operation* op) {
524+
if (op->getName().getDialectNamespace() == "tensor_ext") {
525+
op->emitError() << "--arith-to-cggi does not support tensor_ext "
526+
"operations. Lower them to scalars first.";
527+
return WalkResult::interrupt();
528+
}
529+
if (isa<mlir::arith::ArithDialect>(op->getDialect())) {
530+
if (llvm::any_of(op->getResultTypes(),
531+
[](Type t) { return isa<ShapedType>(t); })) {
532+
if (!isa<mlir::arith::ConstantOp>(op)) {
533+
op->emitError()
534+
<< "--arith-to-cggi does not support arith operations on "
535+
"vectors or tensors. Lower them to scalars first.";
536+
return WalkResult::interrupt();
537+
}
538+
}
539+
}
540+
return WalkResult::advance();
541+
});
542+
543+
if (walkResult.wasInterrupted()) {
544+
return signalPassFailure();
545+
}
546+
520547
ArithToCGGITypeConverter typeConverter(context);
521548

522549
RewritePatternSet patterns(context);

tests/Regression/issue_2484.mlir

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,20 @@
1-
// RUN: heir-opt --secret-insert-mgmt-ckks=slot-number=1024 --tensor-linalg-to-affine-loops --arith-to-cggi --verify-diagnostics -split-input-file %s | FileCheck %s
1+
// RUN: heir-opt --tensor-linalg-to-affine-loops --arith-to-cggi --verify-diagnostics -split-input-file %s
22

3-
// CHECK: func.func @scalar_op_combinations
4-
// CHECK-SAME: %[[ARG0:.*]]: ![[CT:.*]], %[[ARG1:.*]]: ![[CT]]
5-
func.func @scalar_op_combinations(%arg0: i32, %arg1: i32) -> i32 {
6-
// CHECK: %[[ADD:.*]] = cggi.add %[[ARG0]], %[[ARG1]]
7-
%0 = arith.addi %arg0, %arg1 : i32
8-
// CHECK: %[[MUL:.*]] = cggi.mul %[[ARG0]], %[[ARG1]]
9-
%1 = arith.muli %arg0, %arg1 : i32
10-
// CHECK: %[[CMP:.*]] = cggi.cmp %[[ADD]], %[[MUL]] {predicate = 2 : i64}
11-
%cond = arith.cmpi slt, %0, %1 : i32
12-
// CHECK: %[[SEL:.*]] = cggi.cmux %[[CMP]], %[[MUL]], %[[ARG0]]
13-
%2 = arith.select %cond, %1, %arg0 : i1, i32
14-
// CHECK: return %[[SEL]]
15-
return %2 : i32
3+
module {
4+
func.func @add_tensor(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> tensor<4xi16> {
5+
// expected-error@+1 {{--arith-to-cggi does not support arith operations on vectors or tensors. Lower them to scalars first.}}
6+
%0 = arith.addi %arg0, %arg1 : tensor<4xi16>
7+
return %0 : tensor<4xi16>
8+
}
169
}
1710

1811
// -----
1912

20-
func.func @scalar_mul(%arg0: !secret.secret<i16>) -> !secret.secret<i16> {
21-
%0 = secret.generic(%arg0 : !secret.secret<i16>) {
22-
^bb0(%arg1: i16):
23-
%c2 = arith.constant 2 : i16
24-
// expected-error@+1 {{failed to legalize unresolved materialization}}
25-
%1 = arith.muli %arg1, %c2 : i16
26-
// expected-note@+1 {{see existing live user here}}
27-
secret.yield %1 : i16
28-
} -> (!secret.secret<i16>)
29-
return %0 : !secret.secret<i16>
13+
module {
14+
func.func @rotate_tensor(%arg0: tensor<4xi16> {secret.secret}) -> tensor<4xi16> {
15+
%c1 = arith.constant 1 : index
16+
// expected-error@+1 {{--arith-to-cggi does not support tensor_ext operations. Lower them to scalars first.}}
17+
%0 = tensor_ext.rotate %arg0, %c1 : tensor<4xi16>, index
18+
return %0 : tensor<4xi16>
19+
}
3020
}

0 commit comments

Comments
 (0)