Skip to content

Commit 25099b7

Browse files
API changes for functional CKKS bootstrapping and bug fix (#1062)
* rename convert * condense setting the depth for FBT * fix bug in multiprecision sign evaluation using CKKS FBT --------- Co-authored-by: Andreea Alexandru <aalexandru@dualitytech.com>
1 parent 907f5b5 commit 25099b7

7 files changed

Lines changed: 172 additions & 162 deletions

File tree

src/pke/examples/CKKS_FUNCTIONAL_BOOTSTRAPING.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,8 +62,8 @@ The desired number of levels that should remain after the function evaluation (f
6262
setup for the computation is called by `EvalFBTSetup` and the necessary keys are generated by calling `EvalBootstrapKeyGen`.
6363

6464
The RLWE input ciphertext needs to be converted to a CKKS ciphertext in order to commence the functional bootstrapping. This is done
65-
by calling `SchemeletRLWEMP::convert`. Then, `EvalFBT` is called to obtain a CKKS encrypting the coefficients of the function
66-
evaluation output. Finally, to return to the exact RLWE scheme, `SchemeletRLWEMP::convert` should be called again.
65+
by calling `SchemeletRLWEMP::ConvertRLWEToCKKS`. Then, `EvalFBT` is called to obtain a CKKS encrypting the coefficients of the function
66+
evaluation output. Finally, to return to the exact RLWE scheme, `SchemeletRLWEMP::ConvertCKKSToRLWE` should be called.
6767

6868
Internally, `EvalFBT` performs the following steps: modulus raising, coefficient to slots transform (equivalent to homomorphic
6969
encoding), complex exponential evaluation, computing powers of the complex exponential, power series evaluation of the

src/pke/examples/functional-bootstrapping-ckks.cpp

Lines changed: 51 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,13 @@ const BigInteger QBFVINIT(BigInteger(1) << 60);
4545
const BigInteger QBFVINITLARGE(BigInteger(1) << 80);
4646

4747
void ArbitraryLUT(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, BigInteger Q, BigInteger Bigq,
48-
uint32_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
48+
uint64_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
4949
std::function<int64_t(int64_t)> func);
5050
void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, BigInteger Q, BigInteger Bigq,
51-
uint32_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
51+
uint64_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
5252
uint32_t levelComputation);
5353
void MultiPrecisionSign(BigInteger QBFVInit, BigInteger PInput, BigInteger PDigit, BigInteger Q, BigInteger Bigq,
54-
uint32_t scaleTHI, uint32_t scaleStepTHI, size_t order, uint32_t numSlots, uint32_t ringDim);
54+
uint64_t scaleTHI, uint64_t scaleStepTHI, size_t order, uint32_t numSlots, uint32_t ringDim);
5555

5656
int main() {
5757
std::cerr << "\n*1.* Compute the function (x % PInput - POutput / 2) % POutput." << std::endl << std::endl;
@@ -97,7 +97,7 @@ int main() {
9797
}
9898

9999
void ArbitraryLUT(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, BigInteger Q, BigInteger Bigq,
100-
uint32_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
100+
uint64_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
101101
std::function<int64_t(int64_t)> func) {
102102
/* 1. Figure out whether sparse packing or full packing should be used.
103103
* numSlots represents the number of values to be encrypted in BFV.
@@ -163,12 +163,12 @@ void ArbitraryLUT(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, Bi
163163
parameters.SetNumLargeDigits(dnum);
164164
parameters.SetBatchSize(numSlotsCKKS);
165165
parameters.SetRingDim(ringDim);
166-
uint32_t depth = levelsAvailableAfterBootstrap + lvlb[0] + lvlb[1] + 2;
166+
uint32_t depth = levelsAvailableAfterBootstrap;
167167

168168
if (binaryLUT)
169-
depth += FHECKKSRNS::AdjustDepthFBT(coeffint, PInput, order, secretKeyDist);
169+
depth += FHECKKSRNS::GetFBTDepth(lvlb, coeffint, PInput, order, secretKeyDist);
170170
else
171-
depth += FHECKKSRNS::AdjustDepthFBT(coeffcomp, PInput, order, secretKeyDist);
171+
depth += FHECKKSRNS::GetFBTDepth(lvlb, coeffcomp, PInput, order, secretKeyDist);
172172

173173
parameters.SetMultiplicativeDepth(depth);
174174

@@ -209,8 +209,8 @@ void ArbitraryLUT(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, Bi
209209

210210
/* 7. Convert from the RLWE ciphertext to a CKKS ciphertext (both use the same secret key).
211211
*/
212-
auto ctxt = SchemeletRLWEMP::convert(*cc, ctxtBFV, keyPair.publicKey, Bigq, numSlotsCKKS,
213-
depth - (levelsAvailableBeforeBootstrap > 0));
212+
auto ctxt = SchemeletRLWEMP::ConvertRLWEToCKKS(*cc, ctxtBFV, keyPair.publicKey, Bigq, numSlotsCKKS,
213+
depth - (levelsAvailableBeforeBootstrap > 0));
214214

215215
/* 8. Apply the LUT over the ciphertext.
216216
*/
@@ -222,7 +222,7 @@ void ArbitraryLUT(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, Bi
222222

223223
/* 9. Convert the result back to RLWE.
224224
*/
225-
auto polys = SchemeletRLWEMP::convert(ctxtAfterFBT, Q);
225+
auto polys = SchemeletRLWEMP::ConvertCKKSToRLWE(ctxtAfterFBT, Q);
226226

227227
auto computed = SchemeletRLWEMP::DecryptCoeff(polys, Q, POutput, keyPair.secretKey, ep, numSlotsCKKS, numSlots);
228228

@@ -243,7 +243,7 @@ void ArbitraryLUT(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, Bi
243243
}
244244

245245
void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger POutput, BigInteger Q, BigInteger Bigq,
246-
uint32_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
246+
uint64_t scaleTHI, size_t order, uint32_t numSlots, uint32_t ringDim,
247247
uint32_t levelsComputation) {
248248
/* 1. Figure out whether sparse packing or full packing should be used.
249249
* numSlots represents the number of values to be encrypted in BFV.
@@ -321,12 +321,12 @@ void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger
321321
parameters.SetNumLargeDigits(dnum);
322322
parameters.SetBatchSize(numSlotsCKKS);
323323
parameters.SetRingDim(ringDim);
324-
uint32_t depth = levelsAvailableAfterBootstrap + lvlb[0] + lvlb[1] + 2 + levelsComputation;
324+
uint32_t depth = levelsAvailableAfterBootstrap + levelsComputation;
325325

326326
if (binaryLUT)
327-
depth += FHECKKSRNS::AdjustDepthFBT(coeffint1, PInput, order, secretKeyDist);
327+
depth += FHECKKSRNS::GetFBTDepth(lvlb, coeffint1, PInput, order, secretKeyDist);
328328
else
329-
depth += FHECKKSRNS::AdjustDepthFBT(coeffcomp1, PInput, order, secretKeyDist);
329+
depth += FHECKKSRNS::GetFBTDepth(lvlb, coeffcomp1, PInput, order, secretKeyDist);
330330

331331
parameters.SetMultiplicativeDepth(depth);
332332

@@ -384,8 +384,8 @@ void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger
384384

385385
/* 9. Convert from the RLWE ciphertext to a CKKS ciphertext (both use the same secret key).
386386
*/
387-
auto ctxt = SchemeletRLWEMP::convert(*cc, ctxtBFV, keyPair.publicKey, Bigq, numSlotsCKKS,
388-
depth - (levelsAvailableBeforeBootstrap > 0));
387+
auto ctxt = SchemeletRLWEMP::ConvertRLWEToCKKS(*cc, ctxtBFV, keyPair.publicKey, Bigq, numSlotsCKKS,
388+
depth - (levelsAvailableBeforeBootstrap > 0));
389389

390390
/* 10. Apply the LUTs over the ciphertext.
391391
* First, compute the complex exponential and its powers to reuse.
@@ -448,7 +448,7 @@ void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger
448448
ctxtAfterFBT2 = cc->EvalHomDecoding(ctxtAfterFBT2, scaleTHI, levelsComputation - 1);
449449
}
450450

451-
auto polys = SchemeletRLWEMP::convert(ctxtAfterFBT1, Q);
451+
auto polys = SchemeletRLWEMP::ConvertCKKSToRLWE(ctxtAfterFBT1, Q);
452452

453453
/* 11. Convert the results back to RLWE.
454454
*/
@@ -465,7 +465,7 @@ void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger
465465
auto max_error_it = std::max_element(exact.begin(), exact.end());
466466
std::cerr << "Max absolute error obtained in the first LUT: " << *max_error_it << std::endl << std::endl;
467467

468-
polys = SchemeletRLWEMP::convert(ctxtAfterFBT2, Q);
468+
polys = SchemeletRLWEMP::ConvertCKKSToRLWE(ctxtAfterFBT2, Q);
469469

470470
computed = SchemeletRLWEMP::DecryptCoeff(polys, Q, POutput, keyPair.secretKey, ep, numSlotsCKKS, numSlots, flagBR);
471471

@@ -481,7 +481,7 @@ void MultiValueBootstrapping(BigInteger QBFVInit, BigInteger PInput, BigInteger
481481
}
482482

483483
void MultiPrecisionSign(BigInteger QBFVInit, BigInteger PInput, BigInteger PDigit, BigInteger Q, BigInteger Bigq,
484-
uint32_t scaleTHI, uint32_t scaleStepTHI, size_t order, uint32_t numSlots, uint32_t ringDim) {
484+
uint64_t scaleTHI, uint64_t scaleStepTHI, size_t order, uint32_t numSlots, uint32_t ringDim) {
485485
/* 1. Figure out whether sparse packing or full packing should be used.
486486
* numSlots represents the number of values to be encrypted in BFV.
487487
* If this number is the same as the ring dimension, then the CKKS slots is half.
@@ -567,12 +567,12 @@ void MultiPrecisionSign(BigInteger QBFVInit, BigInteger PInput, BigInteger PDigi
567567
parameters.SetBatchSize(numSlotsCKKS);
568568
parameters.SetRingDim(ringDim);
569569

570-
uint32_t depth = levelsAvailableAfterBootstrap + lvlb[0] + lvlb[1] + 2;
570+
uint32_t depth = levelsAvailableAfterBootstrap;
571571

572572
if (binaryLUT)
573-
depth += FHECKKSRNS::AdjustDepthFBT(coeffintMod, PDigit, order, secretKeyDist);
573+
depth += FHECKKSRNS::GetFBTDepth(lvlb, coeffintMod, PDigit, order, secretKeyDist);
574574
else
575-
depth += FHECKKSRNS::AdjustDepthFBT(coeffcompMod, PDigit, order, secretKeyDist);
575+
depth += FHECKKSRNS::GetFBTDepth(lvlb, coeffcompMod, PDigit, order, secretKeyDist);
576576

577577
parameters.SetMultiplicativeDepth(depth);
578578

@@ -611,23 +611,27 @@ void MultiPrecisionSign(BigInteger QBFVInit, BigInteger PInput, BigInteger PDigi
611611
auto ctxtBFV = SchemeletRLWEMP::EncryptCoeff(x, QBFVInit, PInput, keyPair.secretKey, ep);
612612

613613
SchemeletRLWEMP::ModSwitch(ctxtBFV, Q, QBFVInit);
614+
uint32_t QBFVBits = Q.GetMSB() - 1;
614615

615616
/* 8. Set up the sign loop parameters. */
616-
double QBFVDouble = Q.ConvertToDouble();
617-
double pBFVDouble = PInput.ConvertToDouble();
618-
double pDigitDouble = PDigit.ConvertToDouble();
619-
double qDigitDouble = Bigq.ConvertToDouble();
620-
BigInteger pOrig = PInput;
621617
std::vector<int64_t> coeffint;
622618
std::vector<std::complex<double>> coeffcomp;
623619
if (binaryLUT)
624620
coeffint = coeffintMod;
625621
else
626622
coeffcomp = coeffcompMod;
627623

628-
bool step = false;
629-
bool go = QBFVDouble > qDigitDouble;
630-
size_t levelsToDrop = 0;
624+
const bool checkeq2 = PDigit.ConvertToInt() == 2;
625+
const bool checkgt2 = PDigit.ConvertToInt() > 2;
626+
const uint32_t pDigitBits = PDigit.GetMSB() - 1;
627+
628+
BigInteger QNew;
629+
BigInteger pOrig = PInput;
630+
631+
bool step = false;
632+
bool go = QBFVBits > dcrtBits;
633+
size_t levelsToDrop = 0;
634+
uint32_t postScalingBits = 0;
631635

632636
/* 9. Start the sign loop. For arbitrary digit size, pNew > 2, the last iteration needs
633637
* to evaluate step pNew not mod pNew.
@@ -640,50 +644,48 @@ void MultiPrecisionSign(BigInteger QBFVInit, BigInteger PInput, BigInteger PDigi
640644
encryptedDigit[0].SwitchModulus(Bigq, 1, 0, 0);
641645
encryptedDigit[1].SwitchModulus(Bigq, 1, 0, 0);
642646

643-
auto ctxt = SchemeletRLWEMP::convert(*cc, encryptedDigit, keyPair.publicKey, Bigq, numSlotsCKKS,
644-
depth - (levelsAvailableBeforeBootstrap > 0));
647+
auto ctxt = SchemeletRLWEMP::ConvertRLWEToCKKS(*cc, encryptedDigit, keyPair.publicKey, Bigq, numSlotsCKKS,
648+
depth - (levelsAvailableBeforeBootstrap > 0));
645649

646650
/* 9.2 Bootstrap the digit.*/
647651
Ciphertext<DCRTPoly> ctxtAfterFBT;
648652
if (binaryLUT)
649-
ctxtAfterFBT = cc->EvalFBT(ctxt, coeffint, PDigit.GetMSB() - 1, ep->GetModulus(),
650-
pOrig.ConvertToDouble() / pBFVDouble * scaleTHI, levelsToDrop, order);
653+
ctxtAfterFBT = cc->EvalFBT(ctxt, coeffint, pDigitBits, ep->GetModulus(), scaleTHI * (1 << postScalingBits),
654+
levelsToDrop, order);
651655
else
652-
ctxtAfterFBT = cc->EvalFBT(ctxt, coeffcomp, PDigit.GetMSB() - 1, ep->GetModulus(),
653-
pOrig.ConvertToDouble() / pBFVDouble * scaleTHI, levelsToDrop, order);
656+
ctxtAfterFBT = cc->EvalFBT(ctxt, coeffcomp, pDigitBits, ep->GetModulus(), scaleTHI * (1 << postScalingBits),
657+
levelsToDrop, order);
654658

655659
/* 9.3 Convert the result back to RLWE and update the
656660
* plaintext and ciphertext modulus of the ciphertext for the next iteration.
657661
*/
658-
auto polys = SchemeletRLWEMP::convert(ctxtAfterFBT, Q);
659-
660-
BigInteger QNew(BigInteger(1) << static_cast<uint32_t>(std::log2(QBFVDouble) - std::log2(pDigitDouble)));
661-
BigInteger PNew(BigInteger(1) << static_cast<uint32_t>(std::log2(pBFVDouble) - std::log2(pDigitDouble)));
662+
auto polys = SchemeletRLWEMP::ConvertCKKSToRLWE(ctxtAfterFBT, Q);
662663

663664
if (!step) {
664665
/* 9.4 If not in the last iteration, subtract the digit from the ciphertext. */
665666
ctxtBFV[0] = ctxtBFV[0] - polys[0];
666667
ctxtBFV[1] = ctxtBFV[1] - polys[1];
667668

668669
/* 9.5 Do modulus switching from Q to QNew for the RLWE ciphertext. */
670+
QNew = Q >> pDigitBits;
669671
ctxtBFV[0] = ctxtBFV[0].MultiplyAndRound(QNew, Q);
670672
ctxtBFV[0].SwitchModulus(QNew, 1, 0, 0);
671673
ctxtBFV[1] = ctxtBFV[1].MultiplyAndRound(QNew, Q);
672674
ctxtBFV[1].SwitchModulus(QNew, 1, 0, 0);
673-
674-
QBFVDouble /= pDigitDouble;
675-
pBFVDouble /= pDigitDouble;
676-
Q = QNew;
677-
PInput = PNew;
675+
Q >>= pDigitBits;
676+
PInput >>= pDigitBits;
677+
QBFVBits -= pDigitBits;
678+
postScalingBits += pDigitBits;
678679
}
679680
else {
680681
/* 9.6 If in the last iteration, return the digit. */
681-
ctxtBFV[0] = polys[0];
682-
ctxtBFV[1] = polys[1];
682+
ctxtBFV[0] = std::move(polys[0]);
683+
ctxtBFV[1] = std::move(polys[1]);
683684
}
684685

685686
/* 9.7 If in the last iteration, decrypt and assess correctness. */
686-
if ((PDigit.ConvertToInt() == 2 && QBFVDouble <= qDigitDouble) || step) {
687+
go = QBFVBits > dcrtBits;
688+
if (step || (checkeq2 && !go)) {
687689
auto computed =
688690
SchemeletRLWEMP::DecryptCoeff(ctxtBFV, Q, PInput, keyPair.secretKey, ep, numSlotsCKKS, numSlots);
689691

@@ -699,9 +701,7 @@ void MultiPrecisionSign(BigInteger QBFVInit, BigInteger PInput, BigInteger PDigi
699701
}
700702

701703
/* 9.8 Determine whether it is the last iteration and if not, update the parameters for the next iteration. */
702-
go = QBFVDouble > qDigitDouble;
703-
704-
if (PDigit.ConvertToInt() > 2 && !go && !step) {
704+
if (checkgt2 && !go && !step) {
705705
if (!binaryLUT)
706706
coeffcomp = coeffcompStep;
707707
scaleTHI = scaleStepTHI;

src/pke/include/scheme/ckksrns/ckksrns-fhe.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,11 @@ class FHECKKSRNS : public FHERNS {
281281

282282
static uint32_t GetBootstrapDepth(const std::vector<uint32_t>& levelBudget, SecretKeyDist secretKeyDist);
283283

284+
template <typename VectorDataType>
285+
static uint32_t GetFBTDepth(const std::vector<uint32_t>& levelBudget,
286+
const std::vector<VectorDataType>& coefficients, const BigInteger& PInput, size_t order,
287+
SecretKeyDist skd);
288+
284289
template <typename VectorDataType>
285290
static uint32_t AdjustDepthFBT(const std::vector<VectorDataType>& coefficients, const BigInteger& PInput,
286291
size_t order, SecretKeyDist skd = SPARSE_TERNARY);

src/pke/include/schemelet/rlwe-mp.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,11 +58,11 @@ class SchemeletRLWEMP {
5858

5959
static void ModSwitch(std::vector<Poly>& input, const BigInteger& Q1, const BigInteger& Q2);
6060

61-
static Ciphertext<DCRTPoly> convert(const CryptoContextImpl<DCRTPoly>& cc, const std::vector<Poly>& coeffs,
62-
const PublicKey<DCRTPoly>& pubKey, const BigInteger& Bigq, uint32_t slots,
63-
uint32_t level = 0);
61+
static Ciphertext<DCRTPoly> ConvertRLWEToCKKS(const CryptoContextImpl<DCRTPoly>& cc,
62+
const std::vector<Poly>& coeffs, const PublicKey<DCRTPoly>& pubKey,
63+
const BigInteger& Bigq, uint32_t slots, uint32_t level = 0);
6464

65-
static std::vector<Poly> convert(ConstCiphertext<DCRTPoly>& ctxt, const BigInteger& Q);
65+
static std::vector<Poly> ConvertCKKSToRLWE(ConstCiphertext<DCRTPoly>& ctxt, const BigInteger& Q);
6666

6767
static BigInteger GetQPrime(const PublicKey<DCRTPoly>& pubKey, uint32_t lvls);
6868
};

0 commit comments

Comments
 (0)