Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 44 additions & 17 deletions lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
#include "lib/Dialect/Comb/IR/CombDialect.h"
#include "lib/Dialect/Comb/IR/CombOps.h"
#include "lib/Dialect/LWE/IR/LWEAttributes.h"
#include "lib/Dialect/LWE/IR/LWEDialect.h"
#include "lib/Dialect/LWE/IR/LWEOps.h"
#include "lib/Dialect/LWE/IR/LWETypes.h"
#include "lib/Utils/ConversionUtils.h"
Expand All @@ -27,13 +26,13 @@
#include "mlir/include/mlir/IR/BuiltinAttributes.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project
#include "mlir/include/mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/include/mlir/IR/ImplicitLocOpBuilder.h" // from @llvm-project
#include "mlir/include/mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/include/mlir/IR/TypeUtilities.h" // from @llvm-project
#include "mlir/include/mlir/IR/Value.h" // from @llvm-project
#include "mlir/include/mlir/IR/ValueRange.h" // from @llvm-project
#include "mlir/include/mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/include/mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/include/mlir/Support/WalkResult.h" // from @llvm-project
#include "mlir/include/mlir/Transforms/DialectConversion.h" // from @llvm-project

namespace mlir::heir::arith {
Expand Down Expand Up @@ -129,10 +128,6 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
assert(inputs.size() == 1);
auto inputType = inputs[0].getType();

if (!isa<IntegerType>(getElementTypeOrSelf(inputType)))
llvm_unreachable(
"Non-integer types should never be the input to a materializeTarget.");

if (auto inValue = inputs.front().getDefiningOp<mlir::arith::ConstantOp>()) {
if (auto intAttr = dyn_cast<IntegerAttr>(inValue.getValueAttr())) {
return cggi::CreateTrivialOp::create(builder, loc, type, intAttr);
Expand All @@ -154,17 +149,9 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
}

// Comes from function/loop argument: Trivial encrypt through LWE
lwe::LWECiphertextType ciphertextType;

if (auto shapedType = dyn_cast<ShapedType>(type)) {
auto tensorElementSize =
shapedType.getElementType().getIntOrFloatBitWidth();
ciphertextType = lwe::getDefaultCGGICiphertextType(builder.getContext(),
tensorElementSize);
} else {
ciphertextType = lwe::getDefaultCGGICiphertextType(
builder.getContext(), inputType.getIntOrFloatBitWidth());
}
auto tensorElementSize = inputType.getIntOrFloatBitWidth();
lwe::LWECiphertextType ciphertextType = lwe::getDefaultCGGICiphertextType(
builder.getContext(), tensorElementSize);

auto plaintextBits = ciphertextType.getPlaintextSpace()
.getRing()
Expand Down Expand Up @@ -281,6 +268,9 @@ struct ConvertCmpOp : public OpConversionPattern<mlir::arith::CmpIOp> {
LogicalResult matchAndRewrite(
mlir::arith::CmpIOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<ShapedType>(op.getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

ArithToCGGITypeConverter typeConverter(op->getContext());
Expand Down Expand Up @@ -321,6 +311,9 @@ struct ConvertSubOp : public OpConversionPattern<mlir::arith::SubIOp> {
LogicalResult matchAndRewrite(
mlir::arith::SubIOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<ShapedType>(op.getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto rhsDefOp = op.getRhs().getDefiningOp()) {
Expand Down Expand Up @@ -348,6 +341,9 @@ struct ConvertSelectOp : public OpConversionPattern<mlir::arith::SelectOp> {
LogicalResult matchAndRewrite(
mlir::arith::SelectOp op, OpAdaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<ShapedType>(op.getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cmuxOp = cggi::SelectOp::create(
Expand Down Expand Up @@ -392,6 +388,9 @@ struct ConvertShOp : public OpConversionPattern<SourceArithShOp> {
LogicalResult matchAndRewrite(
SourceArithShOp op, typename SourceArithShOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<ShapedType>(op.getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

auto cteShiftSizeOp =
Expand Down Expand Up @@ -443,6 +442,9 @@ struct ConvertArithBinOp : public OpConversionPattern<SourceArithOp> {
LogicalResult matchAndRewrite(
SourceArithOp op, typename SourceArithOp::Adaptor adaptor,
ConversionPatternRewriter& rewriter) const override {
if (isa<ShapedType>(op.getType())) {
return failure();
}
ImplicitLocOpBuilder b(op.getLoc(), rewriter);

if (auto lhsDefOp = op.getLhs().getDefiningOp()) {
Expand Down Expand Up @@ -517,6 +519,31 @@ struct ArithToCGGI : public impl::ArithToCGGIBase<ArithToCGGI> {
void runOnOperation() override {
MLIRContext* context = &getContext();
auto* module = getOperation();

auto walkResult = module->walk([&](Operation* op) {
if (op->getName().getDialectNamespace() == "tensor_ext") {
op->emitError() << "--arith-to-cggi does not support tensor_ext "
"operations. Lower them to scalars first.";
return WalkResult::interrupt();
}
if (isa<mlir::arith::ArithDialect>(op->getDialect())) {
if (llvm::any_of(op->getResultTypes(),
[](Type t) { return isa<ShapedType>(t); })) {
if (!isa<mlir::arith::ConstantOp>(op)) {
op->emitError()
<< "--arith-to-cggi does not support arith operations on "
"vectors or tensors. Lower them to scalars first.";
return WalkResult::interrupt();
}
}
}
return WalkResult::advance();
});

if (walkResult.wasInterrupted()) {
return signalPassFailure();
}

ArithToCGGITypeConverter typeConverter(context);

RewritePatternSet patterns(context);
Expand Down
38 changes: 14 additions & 24 deletions tests/Regression/issue_2484.mlir
Original file line number Diff line number Diff line change
@@ -1,30 +1,20 @@
// RUN: heir-opt --secret-insert-mgmt-ckks=slot-number=1024 --tensor-linalg-to-affine-loops --arith-to-cggi --verify-diagnostics -split-input-file %s | FileCheck %s
// RUN: heir-opt --tensor-linalg-to-affine-loops --arith-to-cggi --verify-diagnostics -split-input-file %s

// CHECK: func.func @scalar_op_combinations
// CHECK-SAME: %[[ARG0:.*]]: ![[CT:.*]], %[[ARG1:.*]]: ![[CT]]
func.func @scalar_op_combinations(%arg0: i32, %arg1: i32) -> i32 {
// CHECK: %[[ADD:.*]] = cggi.add %[[ARG0]], %[[ARG1]]
%0 = arith.addi %arg0, %arg1 : i32
// CHECK: %[[MUL:.*]] = cggi.mul %[[ARG0]], %[[ARG1]]
%1 = arith.muli %arg0, %arg1 : i32
// CHECK: %[[CMP:.*]] = cggi.cmp %[[ADD]], %[[MUL]] {predicate = 2 : i64}
%cond = arith.cmpi slt, %0, %1 : i32
// CHECK: %[[SEL:.*]] = cggi.cmux %[[CMP]], %[[MUL]], %[[ARG0]]
%2 = arith.select %cond, %1, %arg0 : i1, i32
// CHECK: return %[[SEL]]
return %2 : i32
module {
func.func @add_tensor(%arg0: tensor<4xi16> {secret.secret}, %arg1: tensor<4xi16> {secret.secret}) -> tensor<4xi16> {
// expected-error@+1 {{--arith-to-cggi does not support arith operations on vectors or tensors. Lower them to scalars first.}}
%0 = arith.addi %arg0, %arg1 : tensor<4xi16>
return %0 : tensor<4xi16>
}
}

// -----

func.func @scalar_mul(%arg0: !secret.secret<i16>) -> !secret.secret<i16> {
%0 = secret.generic(%arg0 : !secret.secret<i16>) {
^bb0(%arg1: i16):
%c2 = arith.constant 2 : i16
// expected-error@+1 {{failed to legalize unresolved materialization}}
%1 = arith.muli %arg1, %c2 : i16
// expected-note@+1 {{see existing live user here}}
secret.yield %1 : i16
} -> (!secret.secret<i16>)
return %0 : !secret.secret<i16>
module {
func.func @rotate_tensor(%arg0: tensor<4xi16> {secret.secret}) -> tensor<4xi16> {
%c1 = arith.constant 1 : index
// expected-error@+1 {{--arith-to-cggi does not support tensor_ext operations. Lower them to scalars first.}}
%0 = tensor_ext.rotate %arg0, %c1 : tensor<4xi16>, index
return %0 : tensor<4xi16>
}
}
Loading