Skip to content

Commit 400ffbf

Browse files
asraacopybara-github
authored andcommitted
fix: fixes lattigo in place transform by ensuring that storage values keep level state invariant
regression test for in-place issue #2635 PiperOrigin-RevId: 868309124
1 parent 0ccc456 commit 400ffbf

21 files changed

Lines changed: 412 additions & 62 deletions

lib/Analysis/LevelAnalysis/LevelAnalysis.cpp

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -50,37 +50,23 @@ static void debugLog(StringRef opName, ArrayRef<const LevelLattice*> operands,
5050
});
5151
};
5252

53-
LevelState transferForward(mgmt::ModReduceOp op,
54-
ArrayRef<const LevelLattice*> operands) {
55-
LevelState result = std::visit(
56-
Overloaded{
57-
[](MaxLevel) -> LevelState { return LevelState(Invalid{}); },
58-
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
59-
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
60-
[](int val) -> LevelState { return LevelState(val + 1); },
61-
},
62-
operands[0]->getValue().get());
63-
LLVM_DEBUG(debugLog("mod_reduce", operands, result));
64-
return result;
65-
}
66-
67-
LevelState transferForward(mgmt::LevelReduceOp op,
53+
LevelState transferForward(ReducesLevelOpInterface op,
6854
ArrayRef<const LevelLattice*> operands) {
6955
LevelState result = std::visit(
7056
Overloaded{
7157
[](MaxLevel) -> LevelState { return LevelState(Invalid{}); },
7258
[](Uninit) -> LevelState { return LevelState(Invalid{}); },
7359
[](Invalid) -> LevelState { return LevelState(Invalid{}); },
7460
[&](int val) -> LevelState {
75-
return LevelState(val + (int)op.getLevelToDrop());
61+
return LevelState(val + op.getLevelsToDrop());
7662
},
7763
},
7864
operands[0]->getValue().get());
79-
LLVM_DEBUG(debugLog("level_reduce", operands, result));
65+
LLVM_DEBUG(debugLog("ReduceLevelOpInterface", operands, result));
8066
return result;
8167
}
8268

83-
LevelState transferForward(mgmt::LevelReduceMinOp op,
69+
LevelState transferForward(ReducesAllLevelsOpInterface op,
8470
ArrayRef<const LevelLattice*> operands) {
8571
LevelState result = std::visit(
8672
Overloaded{
@@ -92,11 +78,11 @@ LevelState transferForward(mgmt::LevelReduceMinOp op,
9278
[](int val) -> LevelState { return LevelState(MaxLevel{}); },
9379
},
9480
operands[0]->getValue().get());
95-
LLVM_DEBUG(debugLog("level_reduce_min", operands, result));
81+
LLVM_DEBUG(debugLog("ReduceAllLevelsOpInterface", operands, result));
9682
return result;
9783
}
9884

99-
LevelState transferForward(mgmt::BootstrapOp op,
85+
LevelState transferForward(ResetsLevelOpInterface op,
10086
ArrayRef<const LevelLattice*> operands) {
10187
LevelState result = std::visit(
10288
Overloaded{
@@ -106,15 +92,18 @@ LevelState transferForward(mgmt::BootstrapOp op,
10692
[](int val) -> LevelState { return LevelState(0); },
10793
},
10894
operands[0]->getValue().get());
109-
LLVM_DEBUG(debugLog("bootstrap", operands, result));
95+
LLVM_DEBUG(debugLog("ResetsLevelOpInterface", operands, result));
11096
return result;
11197
}
11298

11399
LevelState deriveResultLevel(Operation* op,
114100
ArrayRef<const LevelLattice*> operands) {
115101
return llvm::TypeSwitch<Operation&, LevelState>(*op)
116-
.Case<mgmt::ModReduceOp, mgmt::LevelReduceOp, mgmt::BootstrapOp,
117-
mgmt::LevelReduceMinOp>(
102+
.Case<ResetsLevelOpInterface>(
103+
[&](auto op) -> LevelState { return transferForward(op, operands); })
104+
.Case<ReducesAllLevelsOpInterface>(
105+
[&](auto op) -> LevelState { return transferForward(op, operands); })
106+
.Case<ReducesLevelOpInterface>(
118107
[&](auto op) -> LevelState { return transferForward(op, operands); })
119108
.Default([&](auto& op) -> LevelState {
120109
LevelState result;

lib/Dialect/HEIRInterfaces.td

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,41 @@ def ResetsMulDepthOpInterface : OpInterface<"ResetsMulDepthOpInterface"> {
3434
}];
3535
}
3636

37+
def ResetsLevelOpInterface : OpInterface<"ResetsLevelOpInterface"> {
38+
let cppNamespace = "::mlir::heir";
39+
let description = [{
40+
An interface that signals when an operation resets level
41+
among its results, such as a `mgmt.bootstrap`.
42+
}];
43+
}
44+
45+
def ReducesLevelOpInterface : OpInterface<"ReducesLevelOpInterface"> {
46+
let cppNamespace = "::mlir::heir";
47+
let description = [{
48+
An interface that signals when an operation reduces level
49+
among its results, such as a `mgmt.mod_reduce` or `ckks.rescale`.
50+
}];
51+
52+
let methods = [
53+
InterfaceMethod<
54+
/*desc=*/"Return the number of levels to reduce by.",
55+
/*retTy=*/"int",
56+
/*methodName=*/"getLevelsToDrop",
57+
/*args=*/(ins ),
58+
/*body=*/[{}],
59+
/*defaultBody=*/[{ return 1; }]
60+
>,
61+
];
62+
}
63+
64+
def ReducesAllLevelsOpInterface : OpInterface<"ReducesAllLevelsOpInterface"> {
65+
let cppNamespace = "::mlir::heir";
66+
let description = [{
67+
An interface that signals when an operation reduces all level
68+
among its results, such as a `mgmt.level_reduce_min`.
69+
}];
70+
}
71+
3772
def LUTOpInterface : OpInterface<"LUTOpInterface"> {
3873
let cppNamespace = "::mlir::heir";
3974
let description = [{

lib/Dialect/Lattigo/IR/LattigoBGVOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,8 @@ def Lattigo_BGVMulOp : Lattigo_BGVBinaryInPlaceOp<"mul", [IncreasesMulDepthOpInt
182182
}];
183183
}
184184

185-
class Lattigo_BGVUnaryOp<string mnemonic> :
186-
Lattigo_BGVOp<mnemonic> {
185+
class Lattigo_BGVUnaryOp<string mnemonic, list<Trait> traits = []> :
186+
Lattigo_BGVOp<mnemonic, traits> {
187187
let arguments = (ins
188188
Lattigo_BGVEvaluator:$evaluator,
189189
Lattigo_RLWECiphertext:$input
@@ -198,7 +198,7 @@ def Lattigo_BGVRelinearizeNewOp : Lattigo_BGVUnaryOp<"relinearize_new"> {
198198
}];
199199
}
200200

201-
def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new"> {
201+
def Lattigo_BGVRescaleNewOp : Lattigo_BGVUnaryOp<"rescale_new", [ReducesLevelOpInterface]> {
202202
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
203203
let description = [{
204204
This operation rescales a ciphertext value in the Lattigo BGV dialect.
@@ -258,7 +258,7 @@ def Lattigo_BGVRelinearizeOp : Lattigo_BGVUnaryInPlaceOp<"relinearize"> {
258258
}];
259259
}
260260

261-
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale"> {
261+
def Lattigo_BGVRescaleOp : Lattigo_BGVUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> {
262262
let summary = "Rescale a ciphertext in the Lattigo BGV dialect";
263263
let description = [{
264264
This operation rescales a ciphertext value in the Lattigo BGV dialect.

lib/Dialect/Lattigo/IR/LattigoCKKSOps.td

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,8 @@ def Lattigo_CKKSMulOp : Lattigo_CKKSBinaryInPlaceOp<"mul", [IncreasesMulDepthOpI
215215
}];
216216
}
217217

218-
class Lattigo_CKKSUnaryOp<string mnemonic> :
219-
Lattigo_CKKSOp<mnemonic> {
218+
class Lattigo_CKKSUnaryOp<string mnemonic, list<Trait> traits = []> :
219+
Lattigo_CKKSOp<mnemonic, traits> {
220220
let arguments = (ins
221221
Lattigo_CKKSEvaluator:$evaluator,
222222
Lattigo_RLWECiphertext:$input
@@ -231,7 +231,7 @@ def Lattigo_CKKSRelinearizeNewOp : Lattigo_CKKSUnaryOp<"relinearize_new"> {
231231
}];
232232
}
233233

234-
def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new"> {
234+
def Lattigo_CKKSRescaleNewOp : Lattigo_CKKSUnaryOp<"rescale_new", [ReducesLevelOpInterface]> {
235235
let summary = "Rescale a ciphertext in the Lattigo CKKS dialect";
236236
let description = [{
237237
This operation rescales a ciphertext value in the Lattigo CKKS dialect.
@@ -284,7 +284,7 @@ def Lattigo_CKKSRelinearizeOp : Lattigo_CKKSUnaryInPlaceOp<"relinearize"> {
284284
}];
285285
}
286286

287-
def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale"> {
287+
def Lattigo_CKKSRescaleOp : Lattigo_CKKSUnaryInPlaceOp<"rescale", [ReducesLevelOpInterface]> {
288288
let summary = "Rescale a ciphertext in the Lattigo CKKS dialect";
289289
let description = [{
290290
This operation rescales a ciphertext value in the Lattigo CKKS dialect.
@@ -322,7 +322,7 @@ def Lattigo_CKKSRotateOp : Lattigo_CKKSUnaryInPlaceOp<"rotate", [
322322
let hasVerifier = 1;
323323
}
324324

325-
def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap"> {
325+
def Lattigo_CKKSBootstrapOp : Lattigo_CKKSUnaryOp<"bootstrap", [ResetsLevelOpInterface]> {
326326
let summary = "Bootstrap a ciphertext in the Lattigo CKKS dialect";
327327
let description = [{
328328
Bootstraps a ciphertext value in the Lattigo CKKS dialect.

lib/Dialect/Lattigo/IR/LattigoOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,10 @@ LogicalResult RLWENewEncryptorOp::verify() {
4747
return success();
4848
}
4949

50+
int RLWEDropLevelNewOp::getLevelsToDrop() { return getLevelToDrop(); }
51+
52+
int RLWEDropLevelOp::getLevelsToDrop() { return getLevelToDrop(); }
53+
5054
LogicalResult BGVRotateColumnsNewOp::verify() {
5155
return containsExactlyOneOrEmitError(getOperation(), getDynamicShift(),
5256
getStaticShift());

lib/Dialect/Lattigo/IR/LattigoRLWEOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,8 @@ def Lattigo_RLWEDecryptOp : Lattigo_RLWEOp<"decrypt"> {
122122
let results = (outs Lattigo_RLWEPlaintext:$plaintext);
123123
}
124124

125-
def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> {
125+
def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new",
126+
[DeclareOpInterfaceMethods<ReducesLevelOpInterface, ["getLevelsToDrop"]>]> {
126127
let summary = "Drop level of a ciphertext";
127128
let arguments = (ins
128129
Lattigo_RLWEEvaluator:$evaluator,
@@ -132,7 +133,8 @@ def Lattigo_RLWEDropLevelNewOp : Lattigo_RLWEOp<"drop_level_new"> {
132133
let results = (outs Lattigo_RLWECiphertext:$output);
133134
}
134135

135-
def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level", [InPlaceOpInterface]> {
136+
def Lattigo_RLWEDropLevelOp : Lattigo_RLWEOp<"drop_level",
137+
[InPlaceOpInterface, DeclareOpInterfaceMethods<ReducesLevelOpInterface, ["getLevelsToDrop"]>]> {
136138
let summary = "Drop level of a ciphertext";
137139
let description = [{
138140
This operation drops the level of a ciphertext

0 commit comments

Comments
 (0)