Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Encrypted matrix multiplication with plain vector #137

Merged
merged 3 commits into from
Aug 15, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions tenseal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

# utils
im2col_encoding = _ts_cpp.im2col_encoding
enc_matmul_encoding = _ts_cpp.enc_matmul_encoding


def context(
Expand Down
23 changes: 23 additions & 0 deletions tenseal/binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,27 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
return make_pair(ckks_vector, windows_nb);
});

m.def("enc_matmul_encoding", [](shared_ptr<TenSEALContext> ctx,
const vector<vector<double>> &input) {
vector<double> final_vector;
vector<vector<double>> padded_matrix;
padded_matrix.reserve(input.size());
// calculate the next power of 2
size_t plain_vec_size =
1 << (static_cast<size_t>(ceil(log2(input[0].size()))));

for (size_t i = 0; i < input.size(); i++) {
// pad the row to the next power of 2
vector<double> row(plain_vec_size, 0);
copy(input[i].begin(), input[i].end(), row.begin());
padded_matrix.push_back(row);
}

vertical_scan(padded_matrix, final_vector);
CKKSVector ckks_vector = CKKSVector(ctx, final_vector);
return ckks_vector;
});

py::class_<CKKSVector>(m, "CKKSVector")
// specifying scale
.def(py::init<shared_ptr<TenSEALContext> &, vector<double>, double>())
Expand Down Expand Up @@ -159,6 +180,8 @@ PYBIND11_MODULE(_tenseal_cpp, m) {
py::arg("n_jobs") = 0)
.def("conv2d_im2col", &CKKSVector::conv2d_im2col)
.def("conv2d_im2col_inplace", &CKKSVector::conv2d_im2col_inplace)
.def("enc_matmul_plain", &CKKSVector::enc_matmul_plain)
.def("enc_matmul_plain_inplace", &CKKSVector::enc_matmul_plain)
// python arithmetic
.def("__neg__", &CKKSVector::negate)
.def("__pow__", &CKKSVector::power)
Expand Down
46 changes: 32 additions & 14 deletions tenseal/cpp/tensors/ckksvector.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -587,34 +587,52 @@ CKKSVector& CKKSVector::conv2d_im2col_inplace(
vector<double> flatten_kernel;
horizontal_scan(kernel, flatten_kernel);

this->enc_matmul_plain_inplace(flatten_kernel, windows_nb);
return *this;
}

CKKSVector CKKSVector::enc_matmul_plain(const vector<double>& plain_vec,
const size_t rows_nb) {
CKKSVector new_vec = *this;
new_vec.enc_matmul_plain_inplace(plain_vec, rows_nb);
return new_vec;
}

CKKSVector& CKKSVector::enc_matmul_plain_inplace(
const vector<double>& plain_vec, const size_t rows_nb) {
if (plain_vec.empty()) {
throw invalid_argument("Plain vector can't be empty");
}

// calculate the next power of 2
size_t kernel_size = kernel.size() * kernel[0].size();
kernel_size = 1 << (static_cast<size_t>(ceil(log2(kernel_size))));
size_t plain_vec_size =
1 << (static_cast<size_t>(ceil(log2(plain_vec.size()))));

// pad the kernel with zeros to the next power of 2
flatten_kernel.resize(kernel_size, 0);
// pad the vector with zeros to the next power of 2
vector<double> padded_plain_vec(plain_vec);
padded_plain_vec.resize(plain_vec_size, 0);

size_t chunks_nb = flatten_kernel.size();
size_t chunks_nb = padded_plain_vec.size();

if (this->_size / windows_nb != chunks_nb) {
if (this->_size / rows_nb != chunks_nb) {
throw invalid_argument("Matrix shape doesn't match with vector size");
}

vector<double> plain_vec;
plain_vec.reserve(this->_size);
vector<double> new_plain_vec;
new_plain_vec.reserve(this->_size);

for (size_t i = 0; i < chunks_nb; i++) {
vector<double> tmp(windows_nb, flatten_kernel[i]);
plain_vec.insert(plain_vec.end(), tmp.begin(), tmp.end());
vector<double> tmp(rows_nb, padded_plain_vec[i]);
new_plain_vec.insert(new_plain_vec.end(), tmp.begin(), tmp.end());
}

// replicate the vector in order to be able to do multiple matrix
// multiplications
size_t slot_count = this->context->slot_count<CKKSEncoder>();
replicate_vector(plain_vec, slot_count);
replicate_vector(new_plain_vec, slot_count);
this->_size = slot_count;

this->mul_plain_inplace(plain_vec);
this->mul_plain_inplace(new_plain_vec);

auto galois_keys = this->context->galois_keys();

Expand All @@ -625,12 +643,12 @@ CKKSVector& CKKSVector::conv2d_im2col_inplace(
chunks_nb = static_cast<int>(
1 << (static_cast<size_t>(ceil(log2(chunks_nb))) - 1));
this->context->evaluator->rotate_vector_inplace(
tmp.ciphertext, static_cast<int>(windows_nb * chunks_nb),
tmp.ciphertext, static_cast<int>(rows_nb * chunks_nb),
*galois_keys);
this->add_inplace(tmp);
}

this->_size = windows_nb;
this->_size = rows_nb;

return *this;
}
Expand Down
10 changes: 9 additions & 1 deletion tenseal/cpp/tensors/ckksvector.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,21 @@ class CKKSVector {
CKKSVector& sum_inplace();

/**
* Matrix multiplication operations.
* Encrypted Vector multiplication with plain matrix.
**/
CKKSVector matmul_plain(const vector<vector<double>>& matrix,
size_t n_jobs = 0);
CKKSVector& matmul_plain_inplace(const vector<vector<double>>& matrix,
size_t n_jobs = 0);

/**
* Encrypted Matrix multiplication with plain vector.
**/
CKKSVector enc_matmul_plain(const vector<double>& plain_vec,
size_t row_size);
CKKSVector& enc_matmul_plain_inplace(const vector<double>& plain_vec,
size_t row_size);

/**
* Polynomial evaluation with `this` as variable.
* p(x) = coefficients[0] + coefficients[1] * x + ... + coefficients[i] *
Expand Down
23 changes: 23 additions & 0 deletions tests/python/tenseal/tensors/test_ckks_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,29 @@ def test_vec_plain_matrix_mul_depth2(context, vec, matrix1, matrix2, precision):
), "Matrix multiplication is incorrect."


@pytest.mark.parametrize(
"matrix_shape, vector_size",
[((1, 1), 1), ((2, 1), 1), ((3, 2), 2), ((4, 4), 4), ((9, 7), 7), ((16, 12), 12),],
)
def test_enc_matmul_plain(context, matrix_shape, vector_size, precision):
def generate_input(matrix_shape, vector_size):
# generated random values
matrix = np.random.randn(*matrix_shape)
vector = np.random.randn(vector_size)

return matrix, vector

matrix, vector = generate_input(matrix_shape, vector_size)
expected = matrix @ vector

context.generate_galois_keys()
ckks_vector = ts.enc_matmul_encoding(context, matrix.tolist())
result = ckks_vector.enc_matmul_plain(vector.tolist(), matrix_shape[0])
assert _almost_equal(
result.decrypt(), expected, precision
), "Matrix multiplication is incorrect."


@pytest.mark.parametrize(
"data, polynom",
[
Expand Down