Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CKKS Tensor ops #173

Merged
merged 10 commits into from
Dec 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 14 additions & 14 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -272,15 +272,9 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
// because dot doesn't have a magic function like __add__
// we prefer to overload it instead of having dot_plain functions
.def("dot", &CKKSVector::dot_product)
.def("dot",
[](shared_ptr<CKKSVector> obj, const vector<double> &other) {
return obj->dot_product_plain(other);
})
.def("dot", &CKKSVector::dot_product_plain)
.def("dot_", &CKKSVector::dot_product_inplace)
.def("dot_",
[](shared_ptr<CKKSVector> obj, const vector<double> &other) {
return obj->dot_product_plain_inplace(other);
})
.def("dot_", &CKKSVector::dot_product_plain_inplace)
.def("sum", &CKKSVector::sum, py::arg("axis") = 0)
.def("sum_", &CKKSVector::sum_inplace, py::arg("axis") = 0)
.def(
Expand Down Expand Up @@ -452,6 +446,12 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
[](const shared_ptr<TenSEALContext> &ctx, const std::string &data) {
return CKKSTensor::Create(ctx, data);
}))
.def("decrypt",
[](shared_ptr<CKKSTensor> obj) { return obj->decrypt(); })
.def("decrypt",
[](shared_ptr<CKKSTensor> obj, const shared_ptr<SecretKey> &sk) {
return obj->decrypt(sk);
})
.def("sum", &CKKSTensor::sum, py::arg("axis") = 0)
.def("sum_", &CKKSTensor::sum_inplace, py::arg("axis") = 0)
.def("sum_batch", &CKKSTensor::sum_batch)
Expand All @@ -460,12 +460,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("neg_", &CKKSTensor::negate_inplace)
.def("square", &CKKSTensor::square)
.def("square_", &CKKSTensor::square_inplace)
.def("decrypt",
[](shared_ptr<CKKSTensor> obj) { return obj->decrypt(); })
.def("decrypt",
[](shared_ptr<CKKSTensor> obj, const shared_ptr<SecretKey> &sk) {
return obj->decrypt(sk);
})
.def("pow", &CKKSTensor::power)
.def("pow_", &CKKSTensor::power_inplace)
.def("add", &CKKSTensor::add)
.def("add_", &CKKSTensor::add_inplace)
.def("sub", &CKKSTensor::sub)
Expand Down Expand Up @@ -496,6 +492,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
py::overload_cast<const double &>(&CKKSTensor::mul_plain_inplace))
.def("mul_plain_", py::overload_cast<const PlainTensor<double> &>(
&CKKSTensor::mul_plain_inplace))
.def("polyval", &CKKSTensor::polyval)
.def("polyval_", &CKKSTensor::polyval_inplace)
// python arithmetic
.def("__add__", &CKKSTensor::add)
.def("__add__", py::overload_cast<const double &>(
Expand Down Expand Up @@ -558,6 +556,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
.def("__isub__", py::overload_cast<const PlainTensor<double> &>(
&CKKSTensor::sub_plain_inplace))
.def("__neg__", &CKKSTensor::negate)
.def("__pow__", &CKKSTensor::power)
.def("__ipow__", &CKKSTensor::power_inplace)
.def("context",
[](shared_ptr<CKKSTensor> obj) { return obj->tenseal_context(); })
.def("serialize",
Expand Down
108 changes: 95 additions & 13 deletions tenseal/cpp/tensors/ckkstensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,31 @@ shared_ptr<CKKSTensor> CKKSTensor::square_inplace() {
}

shared_ptr<CKKSTensor> CKKSTensor::power_inplace(unsigned int power) {
// TODO
if (power == 0) {
auto ones = PlainTensor<double>::repeat_value(1, this->shape());
*this = CKKSTensor(this->tenseal_context(), ones, this->_init_scale,
_batch_size.has_value());
return shared_from_this();
}

if (power == 1) {
return shared_from_this();
}

if (power == 2) {
this->square_inplace();
return shared_from_this();
}

int closest_power_of_2 = 1 << static_cast<int>(floor(log2(power)));
power -= closest_power_of_2;
if (power == 0) {
this->power_inplace(closest_power_of_2 / 2)->square_inplace();
} else {
auto closest_pow2_vector = this->power(closest_power_of_2);
this->power_inplace(power)->mul_inplace(closest_pow2_vector);
}

return shared_from_this();
}

Expand Down Expand Up @@ -160,8 +184,18 @@ void CKKSTensor::perform_plain_op(seal::Ciphertext& ct, seal::Plaintext other,
this->tenseal_context()->evaluator->sub_plain_inplace(ct, other);
break;
case OP::MUL:
this->tenseal_context()->evaluator->multiply_plain_inplace(ct,
other);
try {
this->tenseal_context()->evaluator->multiply_plain_inplace(
ct, other);
} catch (const std::logic_error& e) {
if (strcmp(e.what(), "result ciphertext is transparent") == 0) {
// replace by encryption of zero
this->tenseal_context()->encryptor->encrypt_zero(ct);
ct.scale() = this->_init_scale;
} else { // Something else, need to be forwarded
throw;
}
}
this->auto_relin(ct);
this->auto_rescale(ct);
break;
Expand Down Expand Up @@ -199,16 +233,16 @@ shared_ptr<CKKSTensor> CKKSTensor::op_inplace(
std::min((i + 1) * batch_size, this->_data.size())));
}
// waiting
std::optional<std::exception> fail;
optional<string> fail;
for (size_t i = 0; i < futures.size(); i++) {
try {
futures[i].get();
} catch (std::exception& e) {
fail = e;
fail = e.what();
}
}

if (fail) throw invalid_argument(fail.value().what());
if (fail) throw invalid_argument(fail.value());
}

return shared_from_this();
Expand Down Expand Up @@ -248,16 +282,18 @@ shared_ptr<CKKSTensor> CKKSTensor::op_plain_inplace(
std::min((i + 1) * batch_size, this->_data.size())));
}
// waiting
std::optional<std::exception> fail;
optional<string> fail;
for (size_t i = 0; i < futures.size(); i++) {
try {
futures[i].get();
} catch (std::exception& e) {
fail = e;
fail = e.what();
}
}

if (fail) throw invalid_argument(fail.value().what());
if (fail) {
throw invalid_argument(fail.value());
}
}

return shared_from_this();
Expand Down Expand Up @@ -289,16 +325,16 @@ shared_ptr<CKKSTensor> CKKSTensor::op_plain_inplace(const double& operand,
std::min((i + 1) * batch_size, this->_data.size())));
}
// waiting
std::optional<std::exception> fail;
std::optional<std::string> fail;
for (size_t i = 0; i < futures.size(); i++) {
try {
futures[i].get();
} catch (std::exception& e) {
fail = e;
fail = e.what();
}
}

if (fail) throw invalid_argument(fail.value().what());
if (fail) throw invalid_argument(fail.value());
}

return shared_from_this();
Expand Down Expand Up @@ -402,6 +438,7 @@ shared_ptr<CKKSTensor> CKKSTensor::sum_inplace(size_t axis) {
// reinsert the batched axis
new_shape.insert(new_shape.begin(), *_batch_size);
}

_data = new_data;
_shape = new_shape;
return shared_from_this();
Expand All @@ -420,7 +457,52 @@ shared_ptr<CKKSTensor> CKKSTensor::sum_batch_inplace() {

shared_ptr<CKKSTensor> CKKSTensor::polyval_inplace(
const vector<double>& coefficients) {
// TODO
if (coefficients.size() == 0) {
throw invalid_argument(
"the coefficients vector need to have at least one element");
}

int degree = static_cast<int>(coefficients.size()) - 1;
while (degree >= 0) {
if (coefficients[degree] == 0.0)
degree--;
else
break;
}

if (degree == -1) {
auto zeros = PlainTensor<double>::repeat_value(0, this->shape());
*this = CKKSTensor(this->tenseal_context(), zeros, this->_init_scale,
_batch_size.has_value());
return shared_from_this();
}

// pre-compute squares of x
auto x = this->copy();

int max_square = static_cast<int>(floor(log2(degree)));
vector<shared_ptr<CKKSTensor>> x_squares;
x_squares.reserve(max_square + 1);
x_squares.push_back(x->copy()); // x
for (int i = 1; i <= max_square; i++) {
x->square_inplace();
x_squares.push_back(x->copy()); // x^(2^i)
}

auto cst_coeff =
PlainTensor<double>::repeat_value(coefficients[0], this->shape());
auto result =
CKKSTensor::Create(this->tenseal_context(), cst_coeff,
this->_init_scale, _batch_size.has_value());

// coefficients[1] * x + ... + coefficients[degree] * x^(degree)
for (int i = 1; i <= degree; i++) {
if (coefficients[i] == 0.0) continue;
x = compute_polynomial_term(i, coefficients[i], x_squares);
result->add_inplace(x);
}

this->_data = result->data();
return shared_from_this();
}

Expand Down
9 changes: 9 additions & 0 deletions tenseal/cpp/tensors/plain_tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,15 @@ class PlainTensor {
_shape = {_data.size()};
}

static PlainTensor<plain_t> repeat_value(plain_t value,
vector<size_t> shape) {
size_t size = 1;
for (auto& dim : shape) size *= dim;

vector<plain_t> repeated(size, value);
return PlainTensor<plain_t>(repeated, shape);
}

private:
vector<plain_t> _data;
vector<size_t> _shape;
Expand Down
2 changes: 1 addition & 1 deletion tenseal/deps.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def tenseal_deps():
http_archive(
name = "com_microsoft_seal",
build_file = "//third_party:seal.BUILD",
sha256 = "7751b57c0c66c1e81bb25cdddeaca6340e4475e11ab04faa27f3e0dc7526c236",
sha256 = "79c0e45bf301f4577a7633b14e8b26e37eefc89fd4f6a29d13f87e5f22a372ad",
strip_prefix = "SEAL-3.6.0",
urls = ["https://github.com/microsoft/SEAL/archive/v3.6.0.tar.gz"],
)
Expand Down
39 changes: 39 additions & 0 deletions tests/cpp/tensors/ckkstensor_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ TEST_F(CKKSTensorTest, TestCKKSSumNoBatching) {
decr = l->decrypt();
ASSERT_TRUE(are_close(decr.data(), {6, 15}));

data = PlainTensor(vector<double>({1, 2, 3, 4, 5, 6}), vector<size_t>({6}));
l = CKKSTensor::Create(ctx, data, std::pow(2, 40), false);

l->sum_inplace();
ASSERT_THAT(l->shape(), ElementsAreArray(vector<size_t>({})));
decr = l->decrypt();
ASSERT_TRUE(are_close(decr.data(), {21}));

data = PlainTensor(vector<double>({1, 2, 3, 4, 5, 6, 7, 8}),
vector<size_t>({2, 2, 2}));
l = CKKSTensor::Create(ctx, data, std::pow(2, 40), false);
Expand Down Expand Up @@ -152,6 +160,37 @@ TEST_F(CKKSTensorTest, TestCKKSSumBatching) {
ASSERT_TRUE(are_close(decr.data(), {6, 15}));
}

TEST_F(CKKSTensorTest, TestCKKSPower) {
auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
ASSERT_TRUE(ctx != nullptr);
ctx->generate_galois_keys();

auto data =
PlainTensor(vector<double>({1, 2, 3, 4, 5, 6}), vector<size_t>({2, 3}));
auto l = CKKSTensor::Create(ctx, data, std::pow(2, 40), true);

auto res = l->power(2);
ASSERT_THAT(res->shape(), ElementsAreArray({2, 3}));
auto decr = res->decrypt();
ASSERT_TRUE(are_close(decr.data(), {1, 4, 9, 16, 25, 36}));
}

TEST_F(CKKSTensorTest, TestCKKSTensorPolyval) {
auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
ASSERT_TRUE(ctx != nullptr);

auto data =
PlainTensor(vector<double>({1, 2, 3, 4, 5, 6}), vector<size_t>({2, 3}));
auto l = CKKSTensor::Create(ctx, data, std::pow(2, 40), true);

auto res = l->polyval({1, 1, 1});
ASSERT_THAT(res->shape(), ElementsAreArray({2, 3}));
auto decr = res->decrypt();
ASSERT_TRUE(are_close(decr.data(), {3, 7, 13, 21, 31, 43}));
}

TEST_F(CKKSTensorTest, TestCreateCKKSFail) {
auto ctx =
TenSEALContext::Create(scheme_type::ckks, 8192, -1, {60, 40, 40, 60});
Expand Down
Loading