diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt index 797e3f2..50fbb13 100644 --- a/spqlios/CMakeLists.txt +++ b/spqlios/CMakeLists.txt @@ -4,6 +4,8 @@ enable_language(ASM) set(SRCS_GENERIC commons.c commons_private.c + coeffs/coeffs_arithmetic.c + arithmetic/vec_znx.c arithmetic/vec_znx_dft.c cplx/cplx_common.c cplx/cplx_conversions.c @@ -74,6 +76,8 @@ set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-ma # C or assembly source files compiled only on x86: avx2 + bmi targets set(SRCS_AVX2 + arithmetic/vec_znx_avx.c + coeffs/coeffs_arithmetic_avx.c arithmetic/vec_znx_dft_avx2.c q120/q120_arithmetic_avx2.c q120/q120_ntt_avx2.c @@ -111,6 +115,7 @@ set(HEADERSPRIVATE q120/q120_arithmetic_private.h q120/q120_ntt_private.h arithmetic/vec_znx_arithmetic.h + coeffs/coeffs_arithmetic.h ) set(SPQLIOSSOURCES diff --git a/spqlios/arithmetic/vec_znx.c b/spqlios/arithmetic/vec_znx.c new file mode 100644 index 0000000..af38265 --- /dev/null +++ b/spqlios/arithmetic/vec_znx.c @@ -0,0 +1,332 @@ +#include +#include +#include +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../q120/q120_arithmetic.h" +#include "../q120/q120_ntt.h" +#include "../reim/reim_fft_internal.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +// general function (virtual dispatch) + +EXPORT void vec_znx_add(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_add(module, // N + res, res_size, res_sl, // res + a, a_size, a_sl, // a + b, b_size, b_sl // b + ); +} + +EXPORT void vec_znx_sub(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + module->func.vec_znx_sub(module, // N + res, res_size, res_sl, // res + a, a_size, a_sl, // a + b, b_size, b_sl // b + ); +} + +EXPORT void vec_znx_rotate(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_rotate(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_automorphism(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_automorphism(module, // N + p, // p + res, res_size, res_sl, // res + a, a_size, a_sl // a + ); +} + +EXPORT void vec_znx_normalize_base2k(const MODULE* module, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space of size >= N +) { + module->func.vec_znx_normalize_base2k(module, // N + log2_base2k, // log2_base2k + res, res_size, res_sl, // res + a, a_size, a_sl, // a + tmp_space); +} + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, // N + uint64_t res_size, // res size + uint64_t inp_size // inp size +) { + return module->func.vec_znx_normalize_base2k_tmp_bytes(module, // N + res_size, // res size + inp_size // inp size + ); +} + +// specialized function (ref) + +EXPORT void vec_znx_add_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sum_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } else { + const uint64_t sum_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_sub_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sub_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then negate to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } else { + const uint64_t sub_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_rotate_ref(const MODULE* module, // N + const int64_t p, // rotation value + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_rotate_inplace_i64(nn, p, res_ptr); + } else { + znx_rotate_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N + const int64_t p, // X->X^p + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->nn; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + + for (uint64_t i = 0; i < auto_end_idx; ++i) { + int64_t* res_ptr = res + i * res_sl; + const int64_t* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + znx_automorphism_inplace_i64(nn, p, res_ptr); + } else { + znx_automorphism_i64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, // N + uint64_t log2_base2k, // output base 2^K + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + uint8_t* tmp_space // scratch space of size >= N +) { + const uint64_t nn = module->nn; + + // use MSB limb of res for carry propagation + int64_t* cout = (int64_t*)tmp_space; + int64_t* cin = 0x0; + + // propagate carry until first limb of res + int64_t i = a_size - 1; + for (; i >= res_size; --i) { + znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin); + cin = cout; + } + + // propagate carry and normalize + for (; i >= 1; --i) { + znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin); + cin = cout; + } + + // normalize last limb + znx_normalize(nn, log2_base2k, res, 0x0, a, cin); + + // extend result with zeros + for (uint64_t i = a_size; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, // N + uint64_t res_size, // res size + uint64_t inp_size // inp size +) { + const uint64_t nn = module->nn; + return nn * sizeof(int64_t); +} + +// alias have to be defined in this unit: do not move +EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // + const MODULE* module, // N + uint64_t res_size, // res size + uint64_t inp_size // inp size + ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); + +// alias have to be defined in this unit: do not move +EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // + const MODULE* module, // N + uint64_t res_size, // res size + uint64_t inp_size // inp size + ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); + +EXPORT void std_free(void* addr) { free(addr); } + +/** @brief sets res = 0 */ +EXPORT void vec_znx_zero(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +) { + module->func.vec_znx_zero(module, res, res_size, res_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_znx_copy(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_znx_negate(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_znx_zero_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl // res +) { + uint64_t nn = module->nn; + for (uint64_t i = 0; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_copy_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} + +EXPORT void vec_znx_negate_ref(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} diff --git a/spqlios/arithmetic/vec_znx_avx.c b/spqlios/arithmetic/vec_znx_avx.c new file mode 100644 index 0000000..100902d --- /dev/null +++ b/spqlios/arithmetic/vec_znx_avx.c @@ -0,0 +1,103 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_znx_arithmetic_private.h" + +// specialized function (ref) + +// Note: these functions do not have an avx variant. +#define znx_copy_i64_avx znx_copy_i64_ref +#define znx_zero_i64_avx znx_zero_i64_ref + +EXPORT void vec_znx_add_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sum_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } else { + const uint64_t sum_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // add up to the smallest dimension + for (uint64_t i = 0; i < sum_idx; ++i) { + znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sum_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_sub_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl, // a + const int64_t* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->nn; + if (a_size <= b_size) { + const uint64_t sub_idx = res_size < a_size ? res_size : a_size; + const uint64_t copy_idx = res_size < b_size ? res_size : b_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then negate to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } else { + const uint64_t sub_idx = res_size < b_size ? res_size : b_size; + const uint64_t copy_idx = res_size < a_size ? res_size : a_size; + // subtract up to the smallest dimension + for (uint64_t i = 0; i < sub_idx; ++i) { + znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + // then copy to the largest dimension + for (uint64_t i = sub_idx; i < copy_idx; ++i) { + znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + // then extend with zeros + for (uint64_t i = copy_idx; i < res_size; ++i) { + znx_zero_i64_avx(nn, res + i * res_sl); + } + } +} + +EXPORT void vec_znx_negate_avx(const MODULE* module, // N + int64_t* res, uint64_t res_size, uint64_t res_sl, // res + const int64_t* a, uint64_t a_size, uint64_t a_sl // a +) { + uint64_t nn = module->nn; + uint64_t smin = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < smin; ++i) { + znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = smin; i < res_size; ++i) { + znx_zero_i64_ref(nn, res + i * res_sl); + } +} diff --git a/spqlios/coeffs/coeffs_arithmetic.c b/spqlios/coeffs/coeffs_arithmetic.c new file mode 100644 index 0000000..01d15db --- /dev/null +++ b/spqlios/coeffs/coeffs_arithmetic.c @@ -0,0 +1,461 @@ +#include "coeffs_arithmetic.h" + +#include +#include + +/** res = a + b */ +EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] + b[i]; + } +} +/** res = a - b */ +EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] - b[i]; + } +} + +EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = -a[i]; + } +} +EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); } + +EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); } + +EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma]; + } + } +} + +EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma]; + } + } +} + +EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma] - in[j]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma] - in[j]; + } + } +} + +EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + uint64_t a = (-p) & (2 * nn - 1); // a= (-p) (pos)mod (2*nn) + if (a < nn) { // rotate to the left + uint64_t nma = nn - a; + // rotate first half + for (uint64_t j = 0; j < nma; j++) { + res[j] = in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + res[j] = -in[j - nma] - in[j]; + } + } else { + a -= nn; + uint64_t nma = nn - a; + for (uint64_t j = 0; j < nma; j++) { + res[j] = -in[j + a] - in[j]; + } + for (uint64_t j = nma; j < nn; j++) { + // rotate first half + res[j] = in[j - nma] - in[j]; + } + } +} + +// 0 < p < 2nn +EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) { + res[0] = in[0]; + uint64_t a = 0; + uint64_t _2mn = 2 * nn - 1; + for (uint64_t i = 1; i < nn; i++) { + a = (a + p) & _2mn; // i*p mod 2n + if (a < nn) { + res[a] = in[i]; // res[ip mod 2n] = res[i] + } else { + res[a - nn] = -in[i]; + } + } +} + +EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) { + res[0] = in[0]; + uint64_t a = 0; + uint64_t _2mn = 2 * nn - 1; + for (uint64_t i = 1; i < nn; i++) { + a = (a + p) & _2mn; + if (a < nn) { + res[a] = in[i]; // res[ip mod 2n] = res[i] + } else { + res[a - nn] = -in[i]; + } + } +} + +EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp2 = res[new_j_n]; + res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + uint64_t nb_modif = 0; + uint64_t j_start = 0; + while (nb_modif < nn) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j + p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp2 = res[new_j_n]; + res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1; + tmp1 = tmp2; + // move to the new location, and store the number of items modified + ++nb_modif; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator. + ++j_start; + } +} + +__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) { + return (x << (64 - base_k)) >> (64 - base_k); +} + +__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) { + return (x - digit) >> base_k; +} + +EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, + const int64_t* carry_in) { + assert(in); + if (out != 0) { + if (carry_in != 0x0 && carry_out != 0x0) { + // with carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t carry = get_base_k_carry(x, digit, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k); + + out[i] = y; + carry_out[i] = cout; + } + } else if (carry_in != 0) { + // with carry in and carry out is dropped + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + + out[i] = y; + } + + } else if (carry_out != 0) { + // no carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + + int64_t y = get_base_k_digit(x, base_k); + int64_t cout = get_base_k_carry(x, y, base_k); + + out[i] = y; + carry_out[i] = cout; + } + + } else { + // no carry in and carry out is dropped + for (uint64_t i = 0; i < nn; ++i) { + out[i] = get_base_k_digit(in[i], base_k); + } + } + } else { + assert(carry_out); + if (carry_in != 0x0) { + // with carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + const int64_t cin = carry_in[i]; + + int64_t digit = get_base_k_digit(x, base_k); + int64_t carry = get_base_k_carry(x, digit, base_k); + int64_t digit_plus_cin = digit + cin; + int64_t y = get_base_k_digit(digit_plus_cin, base_k); + int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k); + + carry_out[i] = cout; + } + } else { + // no carry in and carry out is computed + for (uint64_t i = 0; i < nn; ++i) { + const int64_t x = in[i]; + + int64_t y = get_base_k_digit(x, base_k); + int64_t cout = get_base_k_carry(x, y, base_k); + + carry_out[i] = cout; + } + } + } +} + +void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + const uint64_t m = nn >> 1; + // reduce p mod 2n + p &= _2mn; + // uint64_t vp = p & _2mn; + /// uint64_t target_modifs = m >> 1; + // we proceed by increasing binary valuation + for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn; + binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) { + // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval. + // At the beginning of this loop we have: + // vp = binval * p mod 2n + // target_modif = m / binval (i.e. order of the orbit binval % 2.binval) + + // first, handle the orders 1 and 2. + // if p*binval == binval % 2n: we're done! + if (vp == binval) return; + // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit! + if (((vp + binval) & _2mn) == 0) { + for (uint64_t j = binval; j < m; j += binval) { + int64_t tmp = res[j]; + res[j] = -res[nn - j]; + res[nn - j] = -tmp; + } + res[m] = -res[m]; + return; + } + // if p*binval == binval + n % 2n: negate the orbit and exit + if (((vp - binval) & _mn) == 0) { + for (uint64_t j = binval; j < nn; j += 2 * binval) { + res[j] = -res[j]; + } + return; + } + // if p*binval == n - binval % 2n: mirror the orbit and continue! + if (((vp + binval) & _mn) == 0) { + for (uint64_t j = binval; j < m; j += 2 * binval) { + int64_t tmp = res[j]; + res[j] = res[nn - j]; + res[nn - j] = tmp; + } + continue; + } + // otherwise we will follow the orbit cycles, + // starting from binval and -binval in parallel + uint64_t j_start = binval; + uint64_t nb_modif = 0; + while (nb_modif < orb_size) { + // follow the cycle that start with j_start + uint64_t j = j_start; + int64_t tmp1 = res[j]; + int64_t tmp2 = res[nn - j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + int64_t tmp1a = res[new_j_n]; + int64_t tmp2a = res[nn - new_j_n]; + if (new_j < nn) { + res[new_j_n] = tmp1; + res[nn - new_j_n] = tmp2; + } else { + res[new_j_n] = -tmp1; + res[nn - new_j_n] = -tmp2; + } + tmp1 = tmp1a; + tmp2 = tmp2a; + // move to the new location, and store the number of items modified + nb_modif += 2; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do *5, because 5 is a generator. + j_start = (5 * j_start) & _mn; + } + } +} + +void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) { + const uint64_t _2mn = 2 * nn - 1; + const uint64_t _mn = nn - 1; + const uint64_t m = nn >> 1; + // reduce p mod 2n + p &= _2mn; + // uint64_t vp = p & _2mn; + /// uint64_t target_modifs = m >> 1; + // we proceed by increasing binary valuation + for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn; + binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) { + // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval. + // At the beginning of this loop we have: + // vp = binval * p mod 2n + // target_modif = m / binval (i.e. order of the orbit binval % 2.binval) + + // first, handle the orders 1 and 2. + // if p*binval == binval % 2n: we're done! + if (vp == binval) return; + // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit! + if (((vp + binval) & _2mn) == 0) { + for (uint64_t j = binval; j < m; j += binval) { + double tmp = res[j]; + res[j] = -res[nn - j]; + res[nn - j] = -tmp; + } + res[m] = -res[m]; + return; + } + // if p*binval == binval + n % 2n: negate the orbit and exit + if (((vp - binval) & _mn) == 0) { + for (uint64_t j = binval; j < nn; j += 2 * binval) { + res[j] = -res[j]; + } + return; + } + // if p*binval == n - binval % 2n: mirror the orbit and continue! + if (((vp + binval) & _mn) == 0) { + for (uint64_t j = binval; j < m; j += 2 * binval) { + double tmp = res[j]; + res[j] = res[nn - j]; + res[nn - j] = tmp; + } + continue; + } + // otherwise we will follow the orbit cycles, + // starting from binval and -binval in parallel + uint64_t j_start = binval; + uint64_t nb_modif = 0; + while (nb_modif < orb_size) { + // follow the cycle that start with j_start + uint64_t j = j_start; + double tmp1 = res[j]; + double tmp2 = res[nn - j]; + do { + // find where the value should go, and with which sign + uint64_t new_j = (j * p) & _2mn; // mod 2n to get the position and sign + uint64_t new_j_n = new_j & _mn; // mod n to get just the position + // exchange this position with tmp1 (and take care of the sign) + double tmp1a = res[new_j_n]; + double tmp2a = res[nn - new_j_n]; + if (new_j < nn) { + res[new_j_n] = tmp1; + res[nn - new_j_n] = tmp2; + } else { + res[new_j_n] = -tmp1; + res[nn - new_j_n] = -tmp2; + } + tmp1 = tmp1a; + tmp2 = tmp2a; + // move to the new location, and store the number of items modified + nb_modif += 2; + j = new_j_n; + } while (j != j_start); + // move to the start of the next cycle: + // we need to find an index that has not been touched yet, and pick it as next j_start. + // in practice, it is enough to do *5, because 5 is a generator. + j_start = (5 * j_start) & _mn; + } + } +} diff --git a/spqlios/coeffs/coeffs_arithmetic.h b/spqlios/coeffs/coeffs_arithmetic.h new file mode 100644 index 0000000..73a2b43 --- /dev/null +++ b/spqlios/coeffs/coeffs_arithmetic.h @@ -0,0 +1,73 @@ +#ifndef SPQLIOS_COEFFS_ARITHMETIC_H +#define SPQLIOS_COEFFS_ARITHMETIC_H + +#include "../commons.h" + +/** res = a + b */ +EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +/** res = a - b */ +EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b); +/** res = -a */ +EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); +EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a); +/** res = a */ +EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a); +/** res = 0 */ +EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res); + +/** + * @param res = X^p *in mod X^nn +1 + * @param nn the ring dimension + * @param p a power for the rotation -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief res(X) = in(X^p) + * @param nn the ring dimension + * @param p is odd integer and must be between 0 < p < 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); +EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res); +EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res); + +/** + * @brief res = (X^p-1).in + * @param nn the ring dimension + * @param p must be between -2nn <= p <= 2nn + * @param in is a rnx/znx vector of dimension nn + */ +EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in); +EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in); + +/** + * @brief Normalize input plus carry mod-2^k. The following + * equality holds @c {in + carry_in == out + carry_out . 2^k}. + * + * @c in must be in [-2^62 .. 2^62] + * + * @c out is in [ -2^(base_k-1), 2^(base_k-1) [. + * + * @c carry_in and @carry_out have at most 64+1-k bits. + * + * Null @c carry_in or @c carry_out are ignored. + * + * @param[in] nn the ring dimension + * @param[in] base_k the base k + * @param out output normalized znx + * @param carry_out output carry znx + * @param[in] in input znx + * @param[in] carry_in input carry znx + */ +EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in, + const int64_t* carry_in); + +#endif // SPQLIOS_COEFFS_ARITHMETIC_H diff --git a/spqlios/coeffs/coeffs_arithmetic_avx.c b/spqlios/coeffs/coeffs_arithmetic_avx.c new file mode 100644 index 0000000..9fea143 --- /dev/null +++ b/spqlios/coeffs/coeffs_arithmetic_avx.c @@ -0,0 +1,84 @@ +#include + +#include "coeffs_arithmetic.h" + +// res = a + b. dimension n must be a power of 2 +EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + if (nn <= 2) { + if (nn == 1) { + res[0] = a[0] + b[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_add_epi64( // + _mm_loadu_si128((__m128i*)a), // + _mm_loadu_si128((__m128i*)b))); + } + } else { + const __m256i* aa = (__m256i*)a; + const __m256i* bb = (__m256i*)b; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_add_epi64( // + _mm256_loadu_si256(aa), // + _mm256_loadu_si256(bb))); + ++rr; + ++aa; + ++bb; + } while (rr < rrend); + } +} + +// res = a - b. dimension n must be a power of 2 +EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) { + if (nn <= 2) { + if (nn == 1) { + res[0] = a[0] - b[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_sub_epi64( // + _mm_loadu_si128((__m128i*)a), // + _mm_loadu_si128((__m128i*)b))); + } + } else { + const __m256i* aa = (__m256i*)a; + const __m256i* bb = (__m256i*)b; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_sub_epi64( // + _mm256_loadu_si256(aa), // + _mm256_loadu_si256(bb))); + ++rr; + ++aa; + ++bb; + } while (rr < rrend); + } +} + +EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) { + if (nn <= 2) { + if (nn == 1) { + res[0] = -a[0]; + } else { + _mm_storeu_si128((__m128i*)res, // + _mm_sub_epi64( // + _mm_set1_epi64x(0), // + _mm_loadu_si128((__m128i*)a))); + } + } else { + const __m256i* aa = (__m256i*)a; + __m256i* rr = (__m256i*)res; + __m256i* const rrend = (__m256i*)(res + nn); + do { + _mm256_storeu_si256(rr, // + _mm256_sub_epi64( // + _mm256_set1_epi64x(0), // + _mm256_loadu_si256(aa))); + ++rr; + ++aa; + } while (rr < rrend); + } +}