Skip to content

Commit 84b0ceb

Browse files
Merge pull request #2458 from WoutLegiest:lut4
PiperOrigin-RevId: 860227954
2 parents b2c4279 + 4cd7e48 commit 84b0ceb

24 files changed

Lines changed: 960 additions & 82 deletions

File tree

.github/workflows/nightly.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,5 @@ jobs:
5151
bazel-bin/tools/heir-translate
5252
bazel-bin/tools/heir-lsp
5353
bazel-bin/external/edu_berkeley_abc/abc
54-
lib/Transforms/YosysOptimizer/yosys/techmap.v
54+
lib/Transforms/YosysOptimizer/yosys/techmap_lut3.v
55+
lib/Transforms/YosysOptimizer/yosys/techmap_lut4.v

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ scripts/jupyter/.ipynb_checkpoints/
2727
scripts/jupyter/heir-opt
2828
scripts/jupyter/heir-translate
2929
scripts/jupyter/abc
30-
scripts/jupyter/techmap.v
30+
scripts/jupyter/techmap_lut3.v
31+
scripts/jupyter/techmap_lut4.v
3132

3233
# for rust codegen tests
3334
**/Cargo.lock

lib/Dialect/Arith/Conversions/ArithToCGGI/ArithToCGGI.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,12 @@ static Value materializeTarget(OpBuilder& builder, Type type, ValueRange inputs,
160160
if (auto shapedType = dyn_cast<ShapedType>(type)) {
161161
auto tensorElementSize =
162162
shapedType.getElementType().getIntOrFloatBitWidth();
163-
ciphertextType = lwe::getDefaultCGGICiphertextType(builder.getContext(),
164-
tensorElementSize);
163+
ciphertextType = lwe::getDefaultCGGICiphertextType(
164+
builder.getContext(), tensorElementSize, tensorElementSize);
165165
} else {
166166
ciphertextType = lwe::getDefaultCGGICiphertextType(
167-
builder.getContext(), inputType.getIntOrFloatBitWidth());
167+
builder.getContext(), inputType.getIntOrFloatBitWidth(),
168+
inputType.getIntOrFloatBitWidth());
168169
}
169170

170171
auto plaintextBits = ciphertextType.getPlaintextSpace()

lib/Dialect/Arith/Conversions/ArithToCGGIQuart/ArithToCGGIQuart.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,12 +41,14 @@ static constexpr unsigned maxIntWidth = 16;
4141

4242
static lwe::LWECiphertextType convertArithToCGGIType(IntegerType type,
4343
MLIRContext* ctx) {
44-
return lwe::getDefaultCGGICiphertextType(ctx, type.getIntOrFloatBitWidth());
44+
return lwe::getDefaultCGGICiphertextType(ctx, type.getIntOrFloatBitWidth(),
45+
type.getIntOrFloatBitWidth());
4546
}
4647

4748
static std::optional<Type> convertArithToCGGIQuartType(IntegerType type,
4849
MLIRContext* ctx) {
49-
auto lweType = lwe::getDefaultCGGICiphertextType(ctx, maxIntWidth);
50+
auto lweType =
51+
lwe::getDefaultCGGICiphertextType(ctx, maxIntWidth, maxIntWidth);
5052

5153
float width = type.getWidth();
5254
float realWidth = maxIntWidth >> 1;
@@ -107,7 +109,8 @@ static Value createTrivialOpMaxWidth(ImplicitLocOpBuilder b, int value) {
107109
auto maxWideIntType = IntegerType::get(b.getContext(), maxIntWidth);
108110
auto intAttr = b.getIntegerAttr(maxWideIntType, value);
109111

110-
auto lweType = lwe::getDefaultCGGICiphertextType(b.getContext(), maxIntWidth);
112+
auto lweType = lwe::getDefaultCGGICiphertextType(b.getContext(), maxIntWidth,
113+
maxIntWidth);
111114

112115
return cggi::CreateTrivialOp::create(b, lweType, intAttr);
113116
}

lib/Dialect/CGGI/Conversions/CGGIToTfheRust/CGGIToTfheRust.cpp

Lines changed: 50 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,53 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
172172
}
173173
};
174174

175+
/// Convert a Lut4Op to:
176+
/// - generate_lookup_table
177+
/// - scalar_left_shift
178+
/// - add_op
179+
/// - apply_lookup_table
180+
///
181+
/// Note the generated lookup tables are not uniqued across applications of this
182+
/// pattern, so a separate step is required at the end to collect all the
183+
/// identical lookup tables, and this can be done with a --cse pass.
184+
struct ConvertLut4Op : public OpConversionPattern<cggi::Lut4Op> {
185+
ConvertLut4Op(mlir::MLIRContext* context)
186+
: OpConversionPattern<cggi::Lut4Op>(context) {}
187+
188+
using OpConversionPattern::OpConversionPattern;
189+
190+
LogicalResult matchAndRewrite(
191+
cggi::Lut4Op op, OpAdaptor adaptor,
192+
ConversionPatternRewriter& rewriter) const override {
193+
ImplicitLocOpBuilder b(op.getLoc(), rewriter);
194+
FailureOr<Value> result = getContextualServerKey(op.getOperation());
195+
if (failed(result)) return result;
196+
197+
Value serverKey = result.value();
198+
// A followup -cse pass should combine repeated LUT generation ops.
199+
auto lut = tfhe_rust::GenerateLookupTableOp::create(
200+
b, serverKey, adaptor.getLookupTable());
201+
// Construct input = d << 3 + c << 2 + b << 1 + a
202+
auto shiftedD = tfhe_rust::ScalarLeftShiftOp::create(
203+
b, serverKey, adaptor.getD(), b.getIndexAttr(3));
204+
auto shiftedC = tfhe_rust::ScalarLeftShiftOp::create(
205+
b, serverKey, adaptor.getC(), b.getIndexAttr(2));
206+
auto shiftedB = tfhe_rust::ScalarLeftShiftOp::create(
207+
b, serverKey, adaptor.getB(), b.getIndexAttr(1));
208+
209+
auto summedCD = tfhe_rust::AddOp::create(b, adaptor.getB().getType(),
210+
serverKey, shiftedC, shiftedD);
211+
auto summedBCD = tfhe_rust::AddOp::create(b, adaptor.getB().getType(),
212+
serverKey, shiftedB, summedCD);
213+
auto summedABCD = tfhe_rust::AddOp::create(
214+
b, adaptor.getB().getType(), serverKey, summedBCD, adaptor.getA());
215+
216+
rewriter.replaceOp(op, tfhe_rust::ApplyLookupTableOp::create(
217+
b, serverKey, summedABCD, lut));
218+
return success();
219+
}
220+
};
221+
175222
struct ConvertLut2Op : public OpConversionPattern<cggi::Lut2Op> {
176223
ConvertLut2Op(mlir::MLIRContext* context)
177224
: OpConversionPattern<cggi::Lut2Op>(context) {}
@@ -619,9 +666,9 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
619666

620667
patterns.add<
621668
AddServerKeyArg, AddServerKeyArgCall, ConvertEncodeOp, ConvertLut2Op,
622-
ConvertLut3Op, ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp,
623-
ConvertCGGITRBinOp<lwe::AddOp, tfhe_rust::AddOp>, ConvertScalarMulOp,
624-
ConvertCGGITRBinOp<cggi::AddOp, tfhe_rust::AddOp>,
669+
ConvertLut3Op, ConvertLut4Op, ConvertNotOp, ConvertTrivialEncryptOp,
670+
ConvertTrivialOp, ConvertCGGITRBinOp<lwe::AddOp, tfhe_rust::AddOp>,
671+
ConvertScalarMulOp, ConvertCGGITRBinOp<cggi::AddOp, tfhe_rust::AddOp>,
625672
ConvertCGGITRBinOp<cggi::MulOp, tfhe_rust::MulOp>,
626673
ConvertCGGITRBinOp<cggi::SubOp, tfhe_rust::SubOp>,
627674
ConvertCGGITRBinOp<cggi::SubOp, tfhe_rust::SubOp>,

lib/Dialect/CGGI/IR/CGGIOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,10 @@ std::optional<ValueRange> Lut3Op::getLookupTableInputs() {
2828
return ValueRange{getC(), getB(), getA()};
2929
}
3030

31+
std::optional<ValueRange> Lut4Op::getLookupTableInputs() {
32+
return ValueRange{getD(), getC(), getB(), getA()};
33+
}
34+
3135
std::optional<ValueRange> LutLinCombOp::getLookupTableInputs() {
3236
return ValueRange{getInputs()};
3337
}

lib/Dialect/CGGI/IR/CGGIPBSOps.td

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,12 @@ def CGGI_Lut3Op : CGGI_LutOp<"lut3",
6969
let results = (outs LWECiphertextLike:$output);
7070
}
7171

72+
def CGGI_Lut4Op : CGGI_LutOp<"lut4", [AllTypesMatch<["a", "b", "c", "d", "output"]>]> {
73+
let summary = "A lookup table on four inputs.";
74+
let arguments = (ins LWECiphertextLike:$d, LWECiphertextLike:$c, LWECiphertextLike:$b, LWECiphertextLike:$a, Builtin_IntegerAttr:$lookup_table);
75+
let results = (outs LWECiphertextLike:$output);
76+
}
77+
7278
def CGGI_PackedLut3Op : CGGI_Op<"packed_lut3", [
7379
Pure,
7480
SameOperandsAndResultType,

lib/Dialect/LWE/IR/LWETypes.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace lwe {
1919
// application data.
2020
LWECiphertextType getDefaultCGGICiphertextType(MLIRContext* ctx,
2121
int messageWidth,
22-
int plaintextBits = 3);
22+
int plaintextBits);
2323

2424
inline LWEPlaintextType getCorrespondingPlaintextType(
2525
LWECiphertextType ctType) {

lib/Dialect/Secret/Conversions/SecretToCGGI/SecretToCGGI.cpp

Lines changed: 17 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ class SecretTypeConverter : public ContextAwareTypeConverter {
255255
Type getLWECiphertextForInt(MLIRContext* ctx, Type type) const {
256256
if (IntegerType intType = dyn_cast<IntegerType>(type)) {
257257
if (intType.getWidth() == 1) {
258-
return lwe::getDefaultCGGICiphertextType(ctx, 1);
258+
return lwe::getDefaultCGGICiphertextType(ctx, 1, this->minBitWidth);
259259
}
260260
return RankedTensorType::get(
261261
{intType.getWidth()},
@@ -322,11 +322,20 @@ class SecretGenericOpLUTConversion
322322
// Assemble the lookup table.
323323
comb::TruthTableOp truthOp =
324324
cast<comb::TruthTableOp>(op.getBody()->getOperations().front());
325-
return rewriter
326-
.replaceOpWithNewOp<cggi::Lut3Op>(op, encodedInputs[0],
327-
encodedInputs[1], encodedInputs[2],
328-
truthOp.getLookupTable())
329-
.getOperation();
325+
326+
if (encodedInputs.size() == 3)
327+
return rewriter
328+
.replaceOpWithNewOp<cggi::Lut3Op>(op, encodedInputs[0],
329+
encodedInputs[1], encodedInputs[2],
330+
truthOp.getLookupTable())
331+
.getOperation();
332+
if (encodedInputs.size() == 4)
333+
return rewriter
334+
.replaceOpWithNewOp<cggi::Lut4Op>(
335+
op, encodedInputs[0], encodedInputs[1], encodedInputs[2],
336+
encodedInputs[3], truthOp.getLookupTable())
337+
.getOperation();
338+
return rewriter.notifyMatchFailure(op, "expected 3 or 4 LUT inputs");
330339
}
331340
};
332341

@@ -722,11 +731,6 @@ struct ConvertFromElementsOp
722731
overflowAttr),
723732
b.getIndexAttr(ciphertextBits));
724733

725-
// b.create<lwe::TrivialEncryptOp>(
726-
// ctTy,
727-
// b.create<lwe::EncodeOp>(ptTy, element, ctTy.getEncoding()),
728-
// lwe::LWEParamsAttr())
729-
// .getResult();
730734
values.push_back(ctElement);
731735
}
732736
}
@@ -776,8 +780,8 @@ static int findLUTSize(MLIRContext* context, Operation* module) {
776780
auto processOperation = [&](Operation* op) {
777781
if (isa<comb::CombDialect>(op->getDialect())) {
778782
int currentSize = 0;
779-
if (dyn_cast<comb::TruthTableOp>(op))
780-
currentSize = 3;
783+
if (auto ttOp = dyn_cast<comb::TruthTableOp>(op))
784+
currentSize = ttOp.getInputs().size();
781785
else
782786
currentSize = op->getResults().getTypes()[0].getIntOrFloatBitWidth();
783787

lib/Pipelines/BooleanPipelineRegistration.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,12 +124,15 @@ void mlirToCGGIPipeline(OpPassManager& pm,
124124
pm.addPass(secret::createSecretDistributeGeneric());
125125
pm.addPass(createCanonicalizerPass());
126126
pm.addPass(createSecretToCGGI());
127+
127128
break;
128129
case Integer:
129130
pm.addPass(arith::createArithToCGGI());
130131
break;
131132
}
132133
// Cleanup SecretToCGGI
134+
pm.addPass(createRemoveDeadValuesPass());
135+
pm.addPass(createSymbolDCEPass());
133136
pm.addPass(createCanonicalizerPass());
134137
pm.addPass(createLinalgCanonicalizations());
135138
pm.addPass(createForwardInsertToExtract());

0 commit comments

Comments
 (0)