@@ -172,6 +172,53 @@ struct ConvertLut3Op : public OpConversionPattern<cggi::Lut3Op> {
172172 }
173173};
174174
175+ // / Convert a Lut4Op to:
176+ // / - generate_lookup_table
177+ // / - scalar_left_shift
178+ // / - add_op
179+ // / - apply_lookup_table
180+ // /
181+ // / Note the generated lookup tables are not uniqued across applications of this
182+ // / pattern, so a separate step is required at the end to collect all the
183+ // / identical lookup tables, and this can be done with a --cse pass.
184+ struct ConvertLut4Op : public OpConversionPattern <cggi::Lut4Op> {
185+ ConvertLut4Op (mlir::MLIRContext* context)
186+ : OpConversionPattern<cggi::Lut4Op>(context) {}
187+
188+ using OpConversionPattern::OpConversionPattern;
189+
190+ LogicalResult matchAndRewrite (
191+ cggi::Lut4Op op, OpAdaptor adaptor,
192+ ConversionPatternRewriter& rewriter) const override {
193+ ImplicitLocOpBuilder b (op.getLoc (), rewriter);
194+ FailureOr<Value> result = getContextualServerKey (op.getOperation ());
195+ if (failed (result)) return result;
196+
197+ Value serverKey = result.value ();
198+ // A followup -cse pass should combine repeated LUT generation ops.
199+ auto lut = tfhe_rust::GenerateLookupTableOp::create (
200+ b, serverKey, adaptor.getLookupTable ());
201+ // Construct input = d << 3 + c << 2 + b << 1 + a
202+ auto shiftedD = tfhe_rust::ScalarLeftShiftOp::create (
203+ b, serverKey, adaptor.getD (), b.getIndexAttr (3 ));
204+ auto shiftedC = tfhe_rust::ScalarLeftShiftOp::create (
205+ b, serverKey, adaptor.getC (), b.getIndexAttr (2 ));
206+ auto shiftedB = tfhe_rust::ScalarLeftShiftOp::create (
207+ b, serverKey, adaptor.getB (), b.getIndexAttr (1 ));
208+
209+ auto summedCD = tfhe_rust::AddOp::create (b, adaptor.getB ().getType (),
210+ serverKey, shiftedC, shiftedD);
211+ auto summedBCD = tfhe_rust::AddOp::create (b, adaptor.getB ().getType (),
212+ serverKey, shiftedB, summedCD);
213+ auto summedABCD = tfhe_rust::AddOp::create (
214+ b, adaptor.getB ().getType (), serverKey, summedBCD, adaptor.getA ());
215+
216+ rewriter.replaceOp (op, tfhe_rust::ApplyLookupTableOp::create (
217+ b, serverKey, summedABCD, lut));
218+ return success ();
219+ }
220+ };
221+
175222struct ConvertLut2Op : public OpConversionPattern <cggi::Lut2Op> {
176223 ConvertLut2Op (mlir::MLIRContext* context)
177224 : OpConversionPattern<cggi::Lut2Op>(context) {}
@@ -619,9 +666,9 @@ class CGGIToTfheRust : public impl::CGGIToTfheRustBase<CGGIToTfheRust> {
619666
620667 patterns.add <
621668 AddServerKeyArg, AddServerKeyArgCall, ConvertEncodeOp, ConvertLut2Op,
622- ConvertLut3Op, ConvertNotOp, ConvertTrivialEncryptOp, ConvertTrivialOp ,
623- ConvertCGGITRBinOp<lwe::AddOp, tfhe_rust::AddOp>, ConvertScalarMulOp ,
624- ConvertCGGITRBinOp<cggi::AddOp, tfhe_rust::AddOp>,
669+ ConvertLut3Op, ConvertLut4Op, ConvertNotOp, ConvertTrivialEncryptOp ,
670+ ConvertTrivialOp, ConvertCGGITRBinOp<lwe::AddOp, tfhe_rust::AddOp>,
671+ ConvertScalarMulOp, ConvertCGGITRBinOp<cggi::AddOp, tfhe_rust::AddOp>,
625672 ConvertCGGITRBinOp<cggi::MulOp, tfhe_rust::MulOp>,
626673 ConvertCGGITRBinOp<cggi::SubOp, tfhe_rust::SubOp>,
627674 ConvertCGGITRBinOp<cggi::SubOp, tfhe_rust::SubOp>,
0 commit comments