Skip to content

Commit

Permalink
Code improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuponitskiy-duality committed Mar 3, 2025
1 parent 28e7a1a commit e54a469
Showing 1 changed file with 53 additions and 56 deletions.
109 changes: 53 additions & 56 deletions src/pke/lib/scheme/ckksrns/ckksrns-fhe.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,19 @@
#include <iostream>
#endif

namespace {
// GetBigModulus() calculates the big modulus as the product of
// the "compositeDegree" number of parameter modulus
double GetBigModulus(const std::shared_ptr<lbcrypto::CryptoParametersCKKSRNS> cryptoParams) {
double qDouble = 1.0;
uint32_t compositeDegree = cryptoParams->GetCompositeDegree();
for (uint32_t j = 0; j < compositeDegree; ++j) {
qDouble *= cryptoParams->GetElementParams()->GetParams()[j]->GetModulus().ConvertToDouble();
}

return qDouble;
}
} // namespace
namespace lbcrypto {

//------------------------------------------------------------------------------
Expand Down Expand Up @@ -167,17 +180,12 @@ void FHECKKSRNS::EvalBootstrapSetup(const CryptoContextImpl<DCRTPoly>& cc, std::
uint32_t compositeDegree = cryptoParams->GetCompositeDegree();

// Extract the modulus prior to bootstrapping
NativeInteger q = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus().ConvertToInt();
double qDouble = q.ConvertToDouble();
for (uint32_t j = 1; j < compositeDegree; ++j) {
NativeInteger qj = cryptoParams->GetElementParams()->GetParams()[j]->GetModulus().ConvertToInt();
qDouble *= qj.ConvertToDouble();
}
double qDouble = GetBigModulus(cryptoParams);

uint128_t factor = ((uint128_t)1 << (static_cast<uint32_t>(std::round(std::log2(qDouble)))));
uint128_t factor = (static_cast<uint128_t>(1) << (static_cast<uint32_t>(std::round(std::log2(qDouble)))));
double pre = (compositeDegree > 1) ? 1.0 : qDouble / factor;
double k = (cryptoParams->GetSecretKeyDist() == SPARSE_TERNARY) ? K_SPARSE : 1.0;
double scaleEnc = (compositeDegree > 1) ? 1.0 / k : pre / k;
double scaleEnc = pre / k;
double scaleDec = (compositeDegree > 1) ? qDouble / cryptoParams->GetScalingFactorReal(0) : 1 / pre;

uint32_t approxModDepth = GetModDepthInternal(cryptoParams->GetSecretKeyDist());
Expand Down Expand Up @@ -299,12 +307,7 @@ void FHECKKSRNS::EvalBootstrapPrecompute(const CryptoContextImpl<DCRTPoly>& cc,
uint32_t compositeDegree = cryptoParams->GetCompositeDegree();

// Extract the modulus prior to bootstrapping
NativeInteger q = cryptoParams->GetElementParams()->GetParams()[0]->GetModulus().ConvertToInt();
double qDouble = q.ConvertToDouble();
for (size_t j = 1; j < compositeDegree; ++j) {
NativeInteger qj = cryptoParams->GetElementParams()->GetParams()[j]->GetModulus().ConvertToInt();
qDouble *= qj.ConvertToDouble();
}
double qDouble = GetBigModulus(cryptoParams);

uint128_t factor = (static_cast<uint128_t>(1) << (static_cast<uint32_t>(std::round(std::log2(qDouble)))));
double pre = qDouble / factor;
Expand Down Expand Up @@ -466,12 +469,7 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
}
auto elementParamsRaisedPtr = std::make_shared<ILDCRTParams<DCRTPoly::Integer>>(M, moduli, roots);

NativeInteger q = elementParamsRaisedPtr->GetParams()[0]->GetModulus().ConvertToInt();
double qDouble = q.ConvertToDouble();
for (uint32_t j = 1; j < compositeDegree; ++j) {
NativeInteger qj = elementParamsRaisedPtr->GetParams()[j]->GetModulus().ConvertToInt();
qDouble *= qj.ConvertToDouble();
}
double qDouble = GetBigModulus(cryptoParams);

const auto p = cryptoParams->GetPlaintextModulus();
double powP = pow(2, p);
Expand Down Expand Up @@ -536,24 +534,24 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
qhat_inv_modqj[j] = qhat_modqj[j].ModInverse(qj[j]);
}

NativeInteger qjProduct =
std::accumulate(qj.begin(), qj.end(), NativeInteger{1}, std::multiplies<NativeInteger>());
uint32_t init_element_index = compositeDegree;
for (size_t i = 0; i < ctxtDCRT.size(); i++) {
std::vector<DCRTPoly> temp(compositeDegree + 1, DCRTPoly(elementParamsRaisedPtr, COEFFICIENT));
std::vector<DCRTPoly> ctxtDCRT_modq(compositeDegree, DCRTPoly(elementParamsRaisedPtr, COEFFICIENT));

ctxtDCRT[i].SetFormat(COEFFICIENT);

for (size_t j = 0; j < ctxtDCRT[i].GetNumOfElements(); j++) {
for (size_t k = 0; k < compositeDegree; k++)
ctxtDCRT_modq[k].SetElementAtIndex(j, ctxtDCRT[i].GetElementAtIndex(j) * qhat_inv_modqj[k]);
}

//=========================================================================================================
temp[0] = ctxtDCRT_modq[0].GetElementAtIndex(0);
for (size_t j = 0; j < elementParamsRaisedPtr->GetParams().size(); j++) {
for (size_t k = 1; k < compositeDegree; k++)
temp[0].SetElementAtIndex(j, temp[0].GetElementAtIndex(j) * qj[k]);
for (auto& el : temp[0].GetAllElements()) {
el *= qjProduct;
}

//=========================================================================================================
for (size_t d = 1; d < compositeDegree; d++) {
temp[init_element_index] = ctxtDCRT_modq[d].GetElementAtIndex(d);

Expand All @@ -562,22 +560,23 @@ Ciphertext<DCRTPoly> FHECKKSRNS::EvalBootstrap(ConstCiphertext<DCRTPoly> ciphert
temp[d].SetElementAtIndex(k, temp[0].GetElementAtIndex(k) * qj[k]);
}
}
//=========================================================================================================
NativeInteger qjProductD{1};
for (size_t k = 0; k < compositeDegree; k++) {
if (k != d)
qjProductD *= qj[k];
}

for (size_t j = compositeDegree; j < elementParamsRaisedPtr->GetParams().size(); j++) {
temp[d].SetElementAtIndex(j, temp[init_element_index].GetElementAtIndex(j) * qj[0]);
for (size_t k = 1; k < compositeDegree; k++) {
if (k != d) {
temp[d].SetElementAtIndex(j, temp[d].GetElementAtIndex(j) * qj[k]);
}
}
auto value = temp[init_element_index].GetElementAtIndex(j) * qjProductD;
temp[d].SetElementAtIndex(j, value);
}
temp[d].SetElementAtIndex(d, temp[init_element_index].GetElementAtIndex(d) * qj[0]);
for (size_t k = 1; k < compositeDegree; k++) {
if (k != d) {
temp[d].SetElementAtIndex(d, temp[d].GetElementAtIndex(d) * qj[k]);
}
//=========================================================================================================
{
auto value = temp[init_element_index].GetElementAtIndex(d) * qjProductD;
temp[d].SetElementAtIndex(d, value);
}

//=========================================================================================================
temp[0] += temp[d];
}

Expand Down Expand Up @@ -2557,34 +2556,19 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
moduli[i] = nativeParams[i]->GetModulus();
}

DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP))};
std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);

std::vector<DCRTPoly::Integer> crtPowP;
if (cryptoParams->GetScalingTechnique() == COMPOSITESCALINGAUTO ||
cryptoParams->GetScalingTechnique() == COMPOSITESCALINGMANUAL) {
// Duhyeong: Support the case powP > 2^64
// Later we might need to use the NATIVE_INT=128 version of FHECKKSRNS::MakeAuxPlaintext for higher precision
int32_t logPowP = static_cast<int32_t>(ceil(log2(fabs(powP))));
// DCRTPoly::Integer intPowP;
int32_t logApprox_PowP;

if (logPowP > 64) {
// Compute approxFactor, a value to scale down by, in case the value exceeds a 64-bit integer.
logValid = (logPowP <= LargeScalingFactorConstants::MAX_BITS_IN_WORD) ?
logPowP :
LargeScalingFactorConstants::MAX_BITS_IN_WORD;
logApprox_PowP = logPowP - logValid;
approxFactor = pow(2, logApprox_PowP);
// Multiply scFactor in two steps: powP / approxFactor and then approxFactor
intPowP = std::llround(powP / approxFactor);
}
else {
intPowP = std::llround(powP);
}

// std::vector<DCRTPoly::Integer> crtPowP(numTowers, intPowP);
crtPowP.resize(numTowers, intPowP);

if (logPowP > 64) {
int32_t logApprox_PowP = logPowP - logValid;
if (logApprox_PowP > 0) {
int32_t logStep = (logApprox <= LargeScalingFactorConstants::MAX_LOG_STEP) ?
logApprox_PowP :
Expand All @@ -2603,7 +2587,20 @@ Plaintext FHECKKSRNS::MakeAuxPlaintext(const CryptoContextImpl<DCRTPoly>& cc, co
}
crtPowP = CKKSPackedEncoding::CRTMult(crtPowP, crtApprox, moduli);
}
else {
double approxFactor = pow(2, logApprox_PowP);
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP / approxFactor))};
crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
}
}
else {
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP))};
crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
}
}
else {
DCRTPoly::Integer intPowP{static_cast<uint64_t>(std::llround(powP))};
crtPowP = std::vector<DCRTPoly::Integer>(numTowers, intPowP);
}

auto currPowP = crtPowP;
Expand Down

0 comments on commit e54a469

Please sign in to comment.