From b1a49aa3e4952acfac1e5159da0cf14727b5b1fd Mon Sep 17 00:00:00 2001 From: Nicolas Gama Date: Mon, 5 Aug 2024 23:21:51 +0200 Subject: [PATCH 01/11] vec-rnx and zn api declaration inc fix --- spqlios/CMakeLists.txt | 6 + spqlios/arithmetic/vec_rnx_arithmetic.h | 340 ++++++++++++++++++ .../arithmetic/vec_rnx_arithmetic_plugin.h | 88 +++++ .../arithmetic/vec_rnx_arithmetic_private.h | 284 +++++++++++++++ spqlios/arithmetic/zn_arithmetic.h | 135 +++++++ spqlios/arithmetic/zn_arithmetic_plugin.h | 39 ++ spqlios/arithmetic/zn_arithmetic_private.h | 150 ++++++++ 7 files changed, 1042 insertions(+) create mode 100644 spqlios/arithmetic/vec_rnx_arithmetic.h create mode 100644 spqlios/arithmetic/vec_rnx_arithmetic_plugin.h create mode 100644 spqlios/arithmetic/vec_rnx_arithmetic_private.h create mode 100644 spqlios/arithmetic/zn_arithmetic.h create mode 100644 spqlios/arithmetic/zn_arithmetic_plugin.h create mode 100644 spqlios/arithmetic/zn_arithmetic_private.h diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt index 44f9b7b..2490d1a 100644 --- a/spqlios/CMakeLists.txt +++ b/spqlios/CMakeLists.txt @@ -110,6 +110,8 @@ set(SRCS_F128 set(HEADERSPUBLIC commons.h arithmetic/vec_znx_arithmetic.h + arithmetic/vec_rnx_arithmetic.h + arithmetic/zn_arithmetic.h cplx/cplx_fft.h reim/reim_fft.h q120/q120_common.h @@ -131,6 +133,10 @@ set(HEADERSPRIVATE q120/q120_arithmetic_private.h q120/q120_ntt_private.h arithmetic/vec_znx_arithmetic.h + arithmetic/vec_rnx_arithmetic_private.h + arithmetic/vec_rnx_arithmetic_plugin.h + arithmetic/zn_arithmetic_private.h + arithmetic/zn_arithmetic_plugin.h coeffs/coeffs_arithmetic.h reim/reim_fft_core_template.h ) diff --git a/spqlios/arithmetic/vec_rnx_arithmetic.h b/spqlios/arithmetic/vec_rnx_arithmetic.h new file mode 100644 index 0000000..65ef889 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_arithmetic.h @@ -0,0 +1,340 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_H + +#include + +#include "spqlios/commons.h" + +/** + * We support the following module families: + * - FFT64: + * the overall precision should fit at all times over 52 bits. + */ +typedef enum rnx_module_type_t { FFT64 } RNX_MODULE_TYPE; + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +typedef struct rnx_module_info_t MOD_RNX; + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MOD_RNX* new_rnx_module_info(uint64_t N, RNX_MODULE_TYPE mode); +EXPORT void delete_rnx_module_info(MOD_RNX* module_info); +EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module); + +// basic arithmetic + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a . (X^p - 1) */ +EXPORT void vec_rnx_mul_xp_minus_one( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// conversions // +/////////////////////////////////////////////////////////////////// + +EXPORT void vec_rnx_to_znx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_znx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32x2( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32x2( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnxdbl( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// isolated products (n.log(n), but not particularly optimized // +/////////////////////////////////////////////////////////////////// + +/** @brief res = a * b : small polynomial product */ +EXPORT void rnx_small_single_product( // + const MOD_RNX* module, // N + double* res, // output + const double* a, // a + const double* b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t rnx_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b centermod 1: small polynomial product */ +EXPORT void tnxdbl_small_single_product( // + const MOD_RNX* module, // N + double* torus_res, // output + const double* int_a, // a + const double* torus_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t tnxdbl_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b: small polynomial product */ +EXPORT void znx32_small_single_product( // + const MOD_RNX* module, // N + int32_t* int_res, // output + const int32_t* int_a, // a + const int32_t* int_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t znx32_small_single_product_tmp_bytes(const MOD_RNX* module); + +/** @brief res = a * b centermod 1: small polynomial product */ +EXPORT void tnx32_small_single_product( // + const MOD_RNX* module, // N + int32_t* torus_res, // output + const int32_t* int_a, // a + const int32_t* torus_b, // b + uint8_t* tmp); // scratch space + +EXPORT uint64_t tnx32_small_single_product_tmp_bytes(const MOD_RNX* module); + +/////////////////////////////////////////////////////////////////// +// prepared gadget decompositions (optimized) // +/////////////////////////////////////////////////////////////////// + +// decompose from tnx32 + +typedef struct tnx32_approxdecomp_gadget_t TNX32_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnx32_approxdecomp_gadget */ +EXPORT TNX32_APPROXDECOMP_GADGET* new_tnx32_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +); +EXPORT void delete_tnx32_approxdecomp_gadget(const MOD_RNX* module); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnx32( // + const MOD_RNX* module, // N + const TNX32_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a // a +); + +// decompose from tnx32x2 + +typedef struct tnx32x2_approxdecomp_gadget_t TNX32X2_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnx32x2_approxdecomp_gadget */ +EXPORT TNX32X2_APPROXDECOMP_GADGET* new_tnx32x2_approxdecomp_gadget(const MOD_RNX* module, uint64_t ka, uint64_t ella, + uint64_t kb, uint64_t ellb); +EXPORT void delete_tnx32x2_approxdecomp_gadget(const MOD_RNX* module); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnx32x2( // + const MOD_RNX* module, // N + const TNX32X2_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a // a +); + +// decompose from tnxdbl + +typedef struct tnxdbl_approxdecomp_gadget_t TNXDBL_APPROXDECOMP_GADGET; + +/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ +EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +); +EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget); + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a + +/////////////////////////////////////////////////////////////////// +// prepared scalar-vector product (optimized) // +/////////////////////////////////////////////////////////////////// + +/** @brief opaque type that represents a polynomial of RnX prepared for a scalar-vector product */ +typedef struct rnx_svp_ppol_t RNX_SVP_PPOL; + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief allocates a prepared vector (release with delete_rnx_svp_ppol) */ +EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief frees memory for a prepared vector */ +EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* res); + +/** @brief prepares a svp polynomial */ +EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void rnx_svp_apply( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/////////////////////////////////////////////////////////////////// +// prepared vector-matrix product (optimized) // +/////////////////////////////////////////////////////////////////// + +typedef struct rnx_vmp_pmat_t RNX_VMP_PMAT; + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */ +EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); // dimensions +EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr); + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void rnx_vmp_prepare_contiguous( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module); + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void rnx_vmp_apply_tmp_a( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res size + uint64_t a_size, // a size + uint64_t nrows, uint64_t ncols // prep matrix dims +); + +/** @brief minimal size of the tmp_space */ +EXPORT void rnx_vmp_apply_dft_to_dft( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/** @brief sets res = DFT(a) */ +EXPORT void vec_rnx_dft(const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = iDFT(a_dft) -- idft is not normalized */ +EXPORT void vec_rnx_idft(const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl // a +); + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_H diff --git a/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h b/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h new file mode 100644 index 0000000..f2e07eb --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h @@ -0,0 +1,88 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H + +#include "vec_rnx_arithmetic.h" + +typedef typeof(vec_rnx_zero) VEC_RNX_ZERO_F; +typedef typeof(vec_rnx_copy) VEC_RNX_COPY_F; +typedef typeof(vec_rnx_negate) VEC_RNX_NEGATE_F; +typedef typeof(vec_rnx_add) VEC_RNX_ADD_F; +typedef typeof(vec_rnx_sub) VEC_RNX_SUB_F; +typedef typeof(vec_rnx_rotate) VEC_RNX_ROTATE_F; +typedef typeof(vec_rnx_mul_xp_minus_one) VEC_RNX_MUL_XP_MINUS_ONE_F; +typedef typeof(vec_rnx_automorphism) VEC_RNX_AUTOMORPHISM_F; +typedef typeof(vec_rnx_to_znx32) VEC_RNX_TO_ZNX32_F; +typedef typeof(vec_rnx_from_znx32) VEC_RNX_FROM_ZNX32_F; +typedef typeof(vec_rnx_to_tnx32) VEC_RNX_TO_TNX32_F; +typedef typeof(vec_rnx_from_tnx32) VEC_RNX_FROM_TNX32_F; +typedef typeof(vec_rnx_to_tnx32x2) VEC_RNX_TO_TNX32X2_F; +typedef typeof(vec_rnx_from_tnx32x2) VEC_RNX_FROM_TNX32X2_F; +typedef typeof(vec_rnx_to_tnxdbl) VEC_RNX_TO_TNXDBL_F; +// typedef typeof(vec_rnx_from_tnxdbl) VEC_RNX_FROM_TNXDBL_F; +typedef typeof(rnx_small_single_product) RNX_SMALL_SINGLE_PRODUCT_F; +typedef typeof(rnx_small_single_product_tmp_bytes) RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(tnxdbl_small_single_product) TNXDBL_SMALL_SINGLE_PRODUCT_F; +typedef typeof(tnxdbl_small_single_product_tmp_bytes) TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(znx32_small_single_product) ZNX32_SMALL_SINGLE_PRODUCT_F; +typedef typeof(znx32_small_single_product_tmp_bytes) ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(tnx32_small_single_product) TNX32_SMALL_SINGLE_PRODUCT_F; +typedef typeof(tnx32_small_single_product_tmp_bytes) TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; +typedef typeof(rnx_approxdecomp_from_tnx32) RNX_APPROXDECOMP_FROM_TNX32_F; +typedef typeof(rnx_approxdecomp_from_tnx32x2) RNX_APPROXDECOMP_FROM_TNX32X2_F; +typedef typeof(rnx_approxdecomp_from_tnxdbl) RNX_APPROXDECOMP_FROM_TNXDBL_F; +typedef typeof(bytes_of_rnx_svp_ppol) BYTES_OF_RNX_SVP_PPOL_F; +typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F; +typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F; +typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F; +typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(rnx_vmp_prepare_contiguous_tmp_bytes) RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F; +typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F; +typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F; +typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F; +typedef typeof(rnx_vmp_apply_dft_to_dft_tmp_bytes) RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F; +typedef typeof(vec_rnx_dft) VEC_RNX_DFT_F; +typedef typeof(vec_rnx_idft) VEC_RNX_IDFT_F; + +typedef struct rnx_module_vtable_t RNX_MODULE_VTABLE; +struct rnx_module_vtable_t { + VEC_RNX_ZERO_F* vec_rnx_zero; + VEC_RNX_COPY_F* vec_rnx_copy; + VEC_RNX_NEGATE_F* vec_rnx_negate; + VEC_RNX_ADD_F* vec_rnx_add; + VEC_RNX_SUB_F* vec_rnx_sub; + VEC_RNX_ROTATE_F* vec_rnx_rotate; + VEC_RNX_MUL_XP_MINUS_ONE_F* vec_rnx_mul_xp_minus_one; + VEC_RNX_AUTOMORPHISM_F* vec_rnx_automorphism; + VEC_RNX_TO_ZNX32_F* vec_rnx_to_znx32; + VEC_RNX_FROM_ZNX32_F* vec_rnx_from_znx32; + VEC_RNX_TO_TNX32_F* vec_rnx_to_tnx32; + VEC_RNX_FROM_TNX32_F* vec_rnx_from_tnx32; + VEC_RNX_TO_TNX32X2_F* vec_rnx_to_tnx32x2; + VEC_RNX_FROM_TNX32X2_F* vec_rnx_from_tnx32x2; + VEC_RNX_TO_TNXDBL_F* vec_rnx_to_tnxdbl; + RNX_SMALL_SINGLE_PRODUCT_F* rnx_small_single_product; + RNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* rnx_small_single_product_tmp_bytes; + TNXDBL_SMALL_SINGLE_PRODUCT_F* tnxdbl_small_single_product; + TNXDBL_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnxdbl_small_single_product_tmp_bytes; + ZNX32_SMALL_SINGLE_PRODUCT_F* znx32_small_single_product; + ZNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx32_small_single_product_tmp_bytes; + TNX32_SMALL_SINGLE_PRODUCT_F* tnx32_small_single_product; + TNX32_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* tnx32_small_single_product_tmp_bytes; + RNX_APPROXDECOMP_FROM_TNX32_F* rnx_approxdecomp_from_tnx32; + RNX_APPROXDECOMP_FROM_TNX32X2_F* rnx_approxdecomp_from_tnx32x2; + RNX_APPROXDECOMP_FROM_TNXDBL_F* rnx_approxdecomp_from_tnxdbl; + BYTES_OF_RNX_SVP_PPOL_F* bytes_of_rnx_svp_ppol; + RNX_SVP_PREPARE_F* rnx_svp_prepare; + RNX_SVP_APPLY_F* rnx_svp_apply; + BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat; + RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous; + RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* rnx_vmp_prepare_contiguous_tmp_bytes; + RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a; + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes; + RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft; + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* rnx_vmp_apply_dft_to_dft_tmp_bytes; + VEC_RNX_DFT_F* vec_rnx_dft; + VEC_RNX_IDFT_F* vec_rnx_idft; +}; + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PLUGIN_H diff --git a/spqlios/arithmetic/vec_rnx_arithmetic_private.h b/spqlios/arithmetic/vec_rnx_arithmetic_private.h new file mode 100644 index 0000000..b761eb7 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_arithmetic_private.h @@ -0,0 +1,284 @@ +#ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H +#define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H + +#include "spqlios/commons_private.h" +#include "spqlios/reim/reim_fft.h" +#include "vec_rnx_arithmetic.h" +#include "vec_rnx_arithmetic_plugin.h" + +typedef struct fft64_rnx_module_precomp_t FFT64_RNX_MODULE_PRECOMP; +struct fft64_rnx_module_precomp_t { + REIM_FFT_PRECOMP* p_fft; + REIM_IFFT_PRECOMP* p_ifft; + REIM_FFTVEC_MUL_PRECOMP* p_fftvec_mul; + REIM_FFTVEC_ADDMUL_PRECOMP* p_fftvec_addmul; +}; + +typedef union rnx_module_precomp_t RNX_MODULE_PRECOMP; +union rnx_module_precomp_t { + FFT64_RNX_MODULE_PRECOMP fft64; +}; + +void fft64_init_rnx_module_precomp(MOD_RNX* module); + +void fft64_finalize_rnx_module_precomp(MOD_RNX* module); + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +struct rnx_module_info_t { + uint64_t n; + uint64_t m; + RNX_MODULE_TYPE mtype; + RNX_MODULE_VTABLE vtable; + RNX_MODULE_PRECOMP precomp; + void* custom; + void (*custom_deleter)(void*); +}; + +void init_rnx_module_info(MOD_RNX* module, // + uint64_t, RNX_MODULE_TYPE mtype); + +void finalize_rnx_module_info(MOD_RNX* module); + +void fft64_init_rnx_module_vtable(MOD_RNX* module); + +/////////////////////////////////////////////////////////////////// +// prepared gadget decompositions (optimized) // +/////////////////////////////////////////////////////////////////// + +struct tnx32_approxdec_gadget_t { + uint64_t k; + uint64_t ell; + int32_t add_cst; // 1/2.(sum 2^-(i+1)K) + int32_t rshift_base; // 32 - K + int64_t and_mask; // 2^K-1 + int64_t or_mask; // double(2^52) + double sub_cst; // double(2^52 + 2^(K-1)) + uint8_t rshifts[8]; // 32 - (i+1).K +}; + +struct tnx32x2_approxdec_gadget_t { + // TODO +}; + +struct tnxdbl_approxdecomp_gadget_t { + uint64_t k; + uint64_t ell; + double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) + uint64_t and_mask; // uint64(2^(K)-1) + uint64_t or_mask; // double(2^52) + double sub_cst; // double(2^52 + 2^(K-1)) +}; + +EXPORT void vec_rnx_add_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); +EXPORT void vec_rnx_add_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +); + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +); + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism_ref( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols); + +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +); +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module); +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module); + +EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +); + +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +); + +/// gadget decompositions + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a +EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a); // a + +EXPORT void vec_rnx_mul_xp_minus_one_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_znx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_znx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_from_tnx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT void vec_rnx_to_tnxdbl_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module); // N + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +); + +/** @brief apply a svp product, result = ppol * a, presented in DFT space */ +EXPORT void fft64_rnx_svp_apply_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +); + +#endif // SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H diff --git a/spqlios/arithmetic/zn_arithmetic.h b/spqlios/arithmetic/zn_arithmetic.h new file mode 100644 index 0000000..3503e20 --- /dev/null +++ b/spqlios/arithmetic/zn_arithmetic.h @@ -0,0 +1,135 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_H +#define SPQLIOS_ZN_ARITHMETIC_H + +#include + +#include "../commons.h" + +typedef enum z_module_type_t { DEFAULT } Z_MODULE_TYPE; + +/** @brief opaque structure that describes the module and the hardware */ +typedef struct z_module_info_t MOD_Z; + +/** + * @brief obtain a module info for ring dimension N + * the module-info knows about: + * - the dimension N (or the complex dimension m=N/2) + * - any moduleuted fft or ntt items + * - the hardware (avx, arm64, x86, ...) + */ +EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mode); +EXPORT void delete_z_module_info(MOD_Z* module_info); + +typedef struct tndbl_approxdecomp_gadget_t TNDBL_APPROXDECOMP_GADGET; + +EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // + uint64_t k, + uint64_t ell); // base 2^k, and size + +EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr); + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief opaque type that represents a prepared matrix */ +typedef struct zn32_vmp_pmat_t ZN32_VMP_PMAT; + +/** @brief size in bytes of a prepared matrix (for custom allocation) */ +EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */ +EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols); // dimensions + +/** @brief deletes a prepared matrix (release with free) */ +EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr); // dimensions + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void zn32_vmp_prepare_contiguous( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols); // a + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void zn32_vmp_apply_i32( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void zn32_vmp_apply_i16( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void zn32_vmp_apply_i8( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +// explicit conversions + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in i32 space. + * WARNING: ||a||_inf must be <= 2^18 in this function + */ +EXPORT void dbl_round_to_i32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int32 space) to double + * WARNING: ||a||_inf must be <= 2^18 in this function + */ +EXPORT void i32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in int64 space + * WARNING: ||a||_inf must be <= 2^50 in this function + */ +EXPORT void dbl_round_to_i64(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int64 space, <= 2^50) to double + * WARNING: ||a||_inf must be <= 2^50 in this function + */ +EXPORT void i64_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +); + +#endif // SPQLIOS_ZN_ARITHMETIC_H diff --git a/spqlios/arithmetic/zn_arithmetic_plugin.h b/spqlios/arithmetic/zn_arithmetic_plugin.h new file mode 100644 index 0000000..d400a72 --- /dev/null +++ b/spqlios/arithmetic/zn_arithmetic_plugin.h @@ -0,0 +1,39 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_PLUGIN_H +#define SPQLIOS_ZN_ARITHMETIC_PLUGIN_H + +#include "zn_arithmetic.h" + +typedef typeof(i8_approxdecomp_from_tndbl) I8_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F; +typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F; +typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F; +typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F; +typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F; +typedef typeof(dbl_to_tn32) DBL_TO_TN32_F; +typedef typeof(tn32_to_dbl) TN32_TO_DBL_F; +typedef typeof(dbl_round_to_i32) DBL_ROUND_TO_I32_F; +typedef typeof(i32_to_dbl) I32_TO_DBL_F; +typedef typeof(dbl_round_to_i64) DBL_ROUND_TO_I64_F; +typedef typeof(i64_to_dbl) I64_TO_DBL_F; + +typedef struct z_module_vtable_t Z_MODULE_VTABLE; +struct z_module_vtable_t { + I8_APPROXDECOMP_FROM_TNDBL_F* i8_approxdecomp_from_tndbl; + I16_APPROXDECOMP_FROM_TNDBL_F* i16_approxdecomp_from_tndbl; + I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl; + BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat; + ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous; + ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32; + ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16; + ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8; + DBL_TO_TN32_F* dbl_to_tn32; + TN32_TO_DBL_F* tn32_to_dbl; + DBL_ROUND_TO_I32_F* dbl_round_to_i32; + I32_TO_DBL_F* i32_to_dbl; + DBL_ROUND_TO_I64_F* dbl_round_to_i64; + I64_TO_DBL_F* i64_to_dbl; +}; + +#endif // SPQLIOS_ZN_ARITHMETIC_PLUGIN_H diff --git a/spqlios/arithmetic/zn_arithmetic_private.h b/spqlios/arithmetic/zn_arithmetic_private.h new file mode 100644 index 0000000..3ff6c48 --- /dev/null +++ b/spqlios/arithmetic/zn_arithmetic_private.h @@ -0,0 +1,150 @@ +#ifndef SPQLIOS_ZN_ARITHMETIC_PRIVATE_H +#define SPQLIOS_ZN_ARITHMETIC_PRIVATE_H + +#include "../commons_private.h" +#include "zn_arithmetic.h" +#include "zn_arithmetic_plugin.h" + +typedef struct main_z_module_precomp_t MAIN_Z_MODULE_PRECOMP; +struct main_z_module_precomp_t { + // TODO +}; + +typedef union z_module_precomp_t Z_MODULE_PRECOMP; +union z_module_precomp_t { + MAIN_Z_MODULE_PRECOMP main; +}; + +void main_init_z_module_precomp(MOD_Z* module); + +void main_finalize_z_module_precomp(MOD_Z* module); + +/** @brief opaque structure that describes the modules (RnX,ZnX,TnX) and the hardware */ +struct z_module_info_t { + Z_MODULE_TYPE mtype; + Z_MODULE_VTABLE vtable; + Z_MODULE_PRECOMP precomp; + void* custom; + void (*custom_deleter)(void*); +}; + +void init_z_module_info(MOD_Z* module, Z_MODULE_TYPE mtype); + +void main_init_z_module_vtable(MOD_Z* module); + +struct tndbl_approxdecomp_gadget_t { + uint64_t k; + uint64_t ell; + double add_cst; // 3.2^51-(K.ell) + 1/2.(sum 2^-(i+1)K) + int64_t and_mask; // (2^K)-1 + int64_t sub_cst; // 2^(K-1) + uint8_t rshifts[64]; // 2^(ell-1-i).K for i in [0:ell-1] +}; + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, + uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, + uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size); // a + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void default_zn32_vmp_prepare_contiguous_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols // a +); + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_ref( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void default_zn32_vmp_apply_i16_ref( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void default_zn32_vmp_apply_i8_ref( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_avx( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void default_zn32_vmp_apply_i16_avx( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int16_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void default_zn32_vmp_apply_i8_avx( // + const MOD_Z* module, // N + int32_t* res, uint64_t res_size, // res + const int8_t* a, uint64_t a_size, // a + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols); // prep matrix + +// explicit conversions + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +); + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +); + +/** small int (int64 space) to double */ +EXPORT void i64_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +); + +#endif // SPQLIOS_ZN_ARITHMETIC_PRIVATE_H From a0ff8ff9ecc2adf7f4fa52c55a3b375e02c00132 Mon Sep 17 00:00:00 2001 From: Nicolas Gama Date: Mon, 5 Aug 2024 23:31:31 +0200 Subject: [PATCH 02/11] zn api implementation --- spqlios/CMakeLists.txt | 10 + spqlios/arithmetic/zn_api.c | 169 +++++++++++++++++ spqlios/arithmetic/zn_approxdecomp_ref.c | 81 ++++++++ spqlios/arithmetic/zn_conversions_ref.c | 108 +++++++++++ spqlios/arithmetic/zn_vmp_int16_avx.c | 4 + spqlios/arithmetic/zn_vmp_int16_ref.c | 4 + spqlios/arithmetic/zn_vmp_int32_avx.c | 223 +++++++++++++++++++++++ spqlios/arithmetic/zn_vmp_int32_ref.c | 88 +++++++++ spqlios/arithmetic/zn_vmp_int8_avx.c | 4 + spqlios/arithmetic/zn_vmp_int8_ref.c | 4 + spqlios/arithmetic/zn_vmp_ref.c | 138 ++++++++++++++ 11 files changed, 833 insertions(+) create mode 100644 spqlios/arithmetic/zn_api.c create mode 100644 spqlios/arithmetic/zn_approxdecomp_ref.c create mode 100644 spqlios/arithmetic/zn_conversions_ref.c create mode 100644 spqlios/arithmetic/zn_vmp_int16_avx.c create mode 100644 spqlios/arithmetic/zn_vmp_int16_ref.c create mode 100644 spqlios/arithmetic/zn_vmp_int32_avx.c create mode 100644 spqlios/arithmetic/zn_vmp_int32_ref.c create mode 100644 spqlios/arithmetic/zn_vmp_int8_avx.c create mode 100644 spqlios/arithmetic/zn_vmp_int8_ref.c create mode 100644 spqlios/arithmetic/zn_vmp_ref.c diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt index 2490d1a..a4c0c6d 100644 --- a/spqlios/CMakeLists.txt +++ b/spqlios/CMakeLists.txt @@ -32,6 +32,13 @@ set(SRCS_GENERIC arithmetic/vec_znx_big.c arithmetic/znx_small.c arithmetic/module_api.c + arithmetic/zn_vmp_int8_ref.c + arithmetic/zn_vmp_int16_ref.c + arithmetic/zn_vmp_int32_ref.c + arithmetic/zn_vmp_ref.c + arithmetic/zn_api.c + arithmetic/zn_conversions_ref.c + arithmetic/zn_approxdecomp_ref.c reim/reim_execute.c cplx/cplx_execute.c reim4/reim4_execute.c @@ -95,6 +102,9 @@ set(SRCS_AVX2 arithmetic/vec_znx_avx.c coeffs/coeffs_arithmetic_avx.c arithmetic/vec_znx_dft_avx2.c + arithmetic/zn_vmp_int8_avx.c + arithmetic/zn_vmp_int16_avx.c + arithmetic/zn_vmp_int32_avx.c q120/q120_arithmetic_avx2.c q120/q120_ntt_avx2.c ) diff --git a/spqlios/arithmetic/zn_api.c b/spqlios/arithmetic/zn_api.c new file mode 100644 index 0000000..28d5c8d --- /dev/null +++ b/spqlios/arithmetic/zn_api.c @@ -0,0 +1,169 @@ +#include + +#include "zn_arithmetic_private.h" + +void default_init_z_module_precomp(MOD_Z* module) { + // Add here initialization of items that are in the precomp +} + +void default_finalize_z_module_precomp(MOD_Z* module) { + // Add here deleters for items that are in the precomp +} + +void default_init_z_module_vtable(MOD_Z* module) { + // Add function pointers here + module->vtable.i8_approxdecomp_from_tndbl = default_i8_approxdecomp_from_tndbl_ref; + module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref; + module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref; + module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref; + module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref; + module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref; + module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref; + module->vtable.dbl_to_tn32 = dbl_to_tn32_ref; + module->vtable.tn32_to_dbl = tn32_to_dbl_ref; + module->vtable.dbl_round_to_i32 = dbl_round_to_i32_ref; + module->vtable.i32_to_dbl = i32_to_dbl_ref; + module->vtable.dbl_round_to_i64 = dbl_round_to_i64_ref; + module->vtable.i64_to_dbl = i64_to_dbl_ref; + + // Add optimized function pointers here + if (CPU_SUPPORTS("avx")) { + module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_avx; + module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_avx; + module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_avx; + } +} + +void init_z_module_info(MOD_Z* module, // + Z_MODULE_TYPE mtype) { + memset(module, 0, sizeof(MOD_Z)); + module->mtype = mtype; + switch (mtype) { + case DEFAULT: + default_init_z_module_precomp(module); + default_init_z_module_vtable(module); + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +void finalize_z_module_info(MOD_Z* module) { + if (module->custom) module->custom_deleter(module->custom); + switch (module->mtype) { + case DEFAULT: + default_finalize_z_module_precomp(module); + // fft64_finalize_rnx_module_vtable(module); // nothing to finalize + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +EXPORT MOD_Z* new_z_module_info(Z_MODULE_TYPE mtype) { + MOD_Z* res = (MOD_Z*)malloc(sizeof(MOD_Z)); + init_z_module_info(res, mtype); + return res; +} + +EXPORT void delete_z_module_info(MOD_Z* module_info) { + finalize_z_module_info(module_info); + free(module_info); +} + +//////////////// wrappers ////////////////// + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void i8_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i8_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void i16_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i16_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size) { // a + module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); +} + +EXPORT void zn32_vmp_prepare_contiguous( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols); +} + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i32(module, res, res_size, a, a_size, pmat, nrows, ncols); +} +/** @brief applies a vmp product (int16_t* input) */ +EXPORT void zn32_vmp_apply_i16(const MOD_Z* module, int32_t* res, uint64_t res_size, const int16_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i16(module, res, res_size, a, a_size, pmat, nrows, ncols); +} + +/** @brief applies a vmp product (int8_t* input) */ +EXPORT void zn32_vmp_apply_i8(const MOD_Z* module, int32_t* res, uint64_t res_size, const int8_t* a, uint64_t a_size, + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + module->vtable.zn32_vmp_apply_i8(module, res, res_size, a, a_size, pmat, nrows, ncols); +} + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_to_tn32(module, res, res_size, a, a_size); +} + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + module->vtable.tn32_to_dbl(module, res, res_size, a, a_size); +} + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_round_to_i32(module, res, res_size, a, a_size); +} + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + module->vtable.i32_to_dbl(module, res, res_size, a, a_size); +} + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + module->vtable.dbl_round_to_i64(module, res, res_size, a, a_size); +} + +/** small int (int64 space, <= 2^50) to double */ +EXPORT void i64_to_dbl(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +) { + module->vtable.i64_to_dbl(module, res, res_size, a, a_size); +} diff --git a/spqlios/arithmetic/zn_approxdecomp_ref.c b/spqlios/arithmetic/zn_approxdecomp_ref.c new file mode 100644 index 0000000..616b9a3 --- /dev/null +++ b/spqlios/arithmetic/zn_approxdecomp_ref.c @@ -0,0 +1,81 @@ +#include + +#include "zn_arithmetic_private.h" + +EXPORT TNDBL_APPROXDECOMP_GADGET* new_tndbl_approxdecomp_gadget(const MOD_Z* module, // + uint64_t k, uint64_t ell) { + if (k * ell > 50) { + return spqlios_error("approx decomposition requested is too precise for doubles"); + } + if (k < 1) { + return spqlios_error("approx decomposition supports k>=1"); + } + TNDBL_APPROXDECOMP_GADGET* res = malloc(sizeof(TNDBL_APPROXDECOMP_GADGET)); + memset(res, 0, sizeof(TNDBL_APPROXDECOMP_GADGET)); + res->k = k; + res->ell = ell; + double add_cst = INT64_C(3) << (51 - k * ell); + for (uint64_t i = 0; i < ell; ++i) { + add_cst += pow(2., -(double)(i * k + 1)); + } + res->add_cst = add_cst; + res->and_mask = (UINT64_C(1) << k) - 1; + res->sub_cst = UINT64_C(1) << (k - 1); + for (uint64_t i = 0; i < ell; ++i) res->rshifts[i] = (ell - 1 - i) * k; + return res; +} +EXPORT void delete_tndbl_approxdecomp_gadget(TNDBL_APPROXDECOMP_GADGET* ptr) { free(ptr); } + +EXPORT int default_init_tndbl_approxdecomp_gadget(const MOD_Z* module, // + TNDBL_APPROXDECOMP_GADGET* res, // + uint64_t k, uint64_t ell) { + return 0; +} + +typedef union { + double dv; + uint64_t uv; +} du_t; + +#define IMPL_ixx_approxdecomp_from_tndbl_ref(ITYPE) \ + if (res_size != a_size * gadget->ell) NOT_IMPLEMENTED(); \ + const uint64_t ell = gadget->ell; \ + const double add_cst = gadget->add_cst; \ + const uint8_t* const rshifts = gadget->rshifts; \ + const ITYPE and_mask = gadget->and_mask; \ + const ITYPE sub_cst = gadget->sub_cst; \ + ITYPE* rr = res; \ + const double* aa = a; \ + const double* aaend = a + a_size; \ + while (aa < aaend) { \ + du_t t = {.dv = *aa + add_cst}; \ + for (uint64_t i = 0; i < ell; ++i) { \ + ITYPE v = (ITYPE)(t.uv >> rshifts[i]); \ + *rr = (v & and_mask) - sub_cst; \ + ++rr; \ + } \ + ++aa; \ + } + +/** @brief sets res = gadget_decompose(a) (int8_t* output) */ +EXPORT void default_i8_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int8_t* res, uint64_t res_size, // res (in general, size ell.a_size) + const double* a, uint64_t a_size // +){IMPL_ixx_approxdecomp_from_tndbl_ref(int8_t)} + +/** @brief sets res = gadget_decompose(a) (int16_t* output) */ +EXPORT void default_i16_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int16_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +){IMPL_ixx_approxdecomp_from_tndbl_ref(int16_t)} + +/** @brief sets res = gadget_decompose(a) (int32_t* output) */ +EXPORT void default_i32_approxdecomp_from_tndbl_ref(const MOD_Z* module, // N + const TNDBL_APPROXDECOMP_GADGET* gadget, // gadget + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + IMPL_ixx_approxdecomp_from_tndbl_ref(int32_t) +} diff --git a/spqlios/arithmetic/zn_conversions_ref.c b/spqlios/arithmetic/zn_conversions_ref.c new file mode 100644 index 0000000..f016a71 --- /dev/null +++ b/spqlios/arithmetic/zn_conversions_ref.c @@ -0,0 +1,108 @@ +#include + +#include "zn_arithmetic_private.h" + +typedef union { + double dv; + int64_t s64v; + int32_t s32v; + uint64_t u64v; + uint32_t u32v; +} di_t; + +/** reduction mod 1, output in torus32 space */ +EXPORT void dbl_to_tn32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = 0.5 + (double)(INT64_C(3) << (51 - 32)); + static const int32_t XOR_CST = (INT32_C(1) << 31); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = t.s32v ^ XOR_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int32_t)); +} + +/** real centerlift mod 1, output in double space */ +EXPORT void tn32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + static const uint32_t XOR_CST = (UINT32_C(1) << 31); + static const di_t OR_CST = {.dv = (double)(INT64_C(1) << (52 - 32))}; + static const double SUB_CST = 0.5 + (double)(INT64_C(1) << (52 - 32)); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + uint32_t ai = a[i] ^ XOR_CST; + di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +/** round to the nearest int, output in i32 space */ +EXPORT void dbl_round_to_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = (double)((INT64_C(3) << (51)) + (INT64_C(1) << (31))); + static const int32_t XOR_CST = INT32_C(1) << 31; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = t.s32v ^ XOR_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int32_t)); +} + +/** small int (int32 space) to double */ +EXPORT void i32_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int32_t* a, uint64_t a_size // a +) { + static const uint32_t XOR_CST = (UINT32_C(1) << 31); + static const di_t OR_CST = {.dv = (double)(INT64_C(1) << 52)}; + static const double SUB_CST = (double)((INT64_C(1) << 52) + (INT64_C(1) << 31)); + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + uint32_t ai = a[i] ^ XOR_CST; + di_t t = {.u64v = OR_CST.u64v | (uint64_t)ai}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +/** round to the nearest int, output in int64 space */ +EXPORT void dbl_round_to_i64_ref(const MOD_Z* module, // + int64_t* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double ADD_CST = (double)(INT64_C(3) << (51)); + static const int64_t AND_CST = (INT64_C(1) << 52) - 1; + static const int64_t SUB_CST = INT64_C(1) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.dv = a[i] + ADD_CST}; + res[i] = (t.s64v & AND_CST) - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(int64_t)); +} + +/** small int (int64 space) to double */ +EXPORT void i64_to_dbl_ref(const MOD_Z* module, // + double* res, uint64_t res_size, // res + const int64_t* a, uint64_t a_size // a +) { + static const uint64_t ADD_CST = UINT64_C(1) << 51; + static const uint64_t AND_CST = (UINT64_C(1) << 52) - 1; + static const di_t OR_CST = {.dv = (INT64_C(1) << 52)}; + static const double SUB_CST = INT64_C(3) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + di_t t = {.u64v = ((a[i] + ADD_CST) & AND_CST) | OR_CST.u64v}; + res[i] = t.dv - SUB_CST; + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} diff --git a/spqlios/arithmetic/zn_vmp_int16_avx.c b/spqlios/arithmetic/zn_vmp_int16_avx.c new file mode 100644 index 0000000..563f199 --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_int16_avx.c @@ -0,0 +1,4 @@ +#define INTTYPE int16_t +#define INTSN i16 + +#include "zn_vmp_int32_avx.c" diff --git a/spqlios/arithmetic/zn_vmp_int16_ref.c b/spqlios/arithmetic/zn_vmp_int16_ref.c new file mode 100644 index 0000000..0626c9b --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_int16_ref.c @@ -0,0 +1,4 @@ +#define INTTYPE int16_t +#define INTSN i16 + +#include "zn_vmp_int32_ref.c" diff --git a/spqlios/arithmetic/zn_vmp_int32_avx.c b/spqlios/arithmetic/zn_vmp_int32_avx.c new file mode 100644 index 0000000..3fbc8fb --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_int32_avx.c @@ -0,0 +1,223 @@ +// This file is actually a template: it will be compiled multiple times with +// different INTTYPES +#ifndef INTTYPE +#define INTTYPE int32_t +#define INTSN i32 +#endif + +#include +#include + +#include "zn_arithmetic_private.h" + +#define concat_inner(aa, bb, cc) aa##_##bb##_##cc +#define concat(aa, bb, cc) concat_inner(aa, bb, cc) +#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) + +static void zn32_vec_mat32cols_avx_prefetch(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b) { + if (nrows == 0) { + memset(res, 0, 32 * sizeof(int32_t)); + return; + } + const int32_t* bb = b; + const int32_t* pref_bb = b; + const uint64_t pref_iters = 128; + const uint64_t pref_start = pref_iters < nrows ? pref_iters : nrows; + const uint64_t pref_last = pref_iters > nrows ? 0 : nrows - pref_iters; + // let's do some prefetching of the GSW key, since on some cpus, + // it helps + for (uint64_t i = 0; i < pref_start; ++i) { + __builtin_prefetch(pref_bb, 0, _MM_HINT_T0); + __builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0); + pref_bb += 32; + } + // we do the first iteration + __m256i x = _mm256_set1_epi32(a[0]); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + __m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))); + bb += 32; + uint64_t row = 1; + for (; // + row < pref_last; // + ++row, bb += 32) { + // prefetch the next iteration + __builtin_prefetch(pref_bb, 0, _MM_HINT_T0); + __builtin_prefetch(pref_bb + 16, 0, _MM_HINT_T0); + pref_bb += 32; + INTTYPE ai = a[row]; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + for (; // + row < nrows; // + ++row, bb += 32) { + INTTYPE ai = a[row]; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); + _mm256_storeu_si256((__m256i*)(res + 24), r3); +} + +void zn32_vec_fn(mat32cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 32 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + __m256i r3 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + r3 = _mm256_add_epi32(r3, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 24)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); + _mm256_storeu_si256((__m256i*)(res + 24), r3); +} + +void zn32_vec_fn(mat24cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 24 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + __m256i r2 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + r2 = _mm256_add_epi32(r2, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 16)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); + _mm256_storeu_si256((__m256i*)(res + 16), r2); +} +void zn32_vec_fn(mat16cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 16 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + __m256i r1 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + r1 = _mm256_add_epi32(r1, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb + 8)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); + _mm256_storeu_si256((__m256i*)(res + 8), r1); +} + +void zn32_vec_fn(mat8cols_avx)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + if (nrows == 0) { + memset(res, 0, 8 * sizeof(int32_t)); + return; + } + const INTTYPE* aa = a; + const INTTYPE* const aaend = a + nrows; + const int32_t* bb = b; + __m256i x = _mm256_set1_epi32(*aa); + __m256i r0 = _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb))); + bb += b_sl; + ++aa; + for (; // + aa < aaend; // + bb += b_sl, ++aa) { + INTTYPE ai = *aa; + if (ai == 0) continue; + x = _mm256_set1_epi32(ai); + r0 = _mm256_add_epi32(r0, _mm256_mullo_epi32(x, _mm256_loadu_si256((__m256i*)(bb)))); + } + _mm256_storeu_si256((__m256i*)(res), r0); +} + +typedef void (*vm_f)(uint64_t nrows, // + int32_t* res, // + const INTTYPE* a, // + const int32_t* b, uint64_t b_sl // +); +static const vm_f zn32_vec_mat8kcols_avx[4] = { // + zn32_vec_fn(mat8cols_avx), // + zn32_vec_fn(mat16cols_avx), // + zn32_vec_fn(mat24cols_avx), // + zn32_vec_fn(mat32cols_avx)}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void concat(default_zn32_vmp_apply, INTSN, avx)( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const INTTYPE* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint64_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_mat32cols_avx_prefetch(rows, rr, a, mat); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_mat8kcols_avx[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} diff --git a/spqlios/arithmetic/zn_vmp_int32_ref.c b/spqlios/arithmetic/zn_vmp_int32_ref.c new file mode 100644 index 0000000..c3d0bc9 --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_int32_ref.c @@ -0,0 +1,88 @@ +// This file is actually a template: it will be compiled multiple times with +// different INTTYPES +#ifndef INTTYPE +#define INTTYPE int32_t +#define INTSN i32 +#endif + +#include + +#include "zn_arithmetic_private.h" + +#define concat_inner(aa, bb, cc) aa##_##bb##_##cc +#define concat(aa, bb, cc) concat_inner(aa, bb, cc) +#define zn32_vec_fn(cc) concat(zn32_vec, INTSN, cc) + +// the ref version shares the same implementation for each fixed column size +// optimized implementations may do something different. +static __always_inline void IMPL_zn32_vec_matcols_ref( + const uint64_t NCOLS, // fixed number of columns + uint64_t nrows, // nrows of b + int32_t* res, // result: size NCOLS, only the first min(b_sl, NCOLS) are relevant + const INTTYPE* a, // a: nrows-sized vector + const int32_t* b, uint64_t b_sl // b: nrows * min(b_sl, NCOLS) matrix +) { + memset(res, 0, NCOLS * sizeof(int32_t)); + for (uint64_t row = 0; row < nrows; ++row) { + int32_t ai = a[row]; + const int32_t* bb = b + row * b_sl; + for (uint64_t i = 0; i < NCOLS; ++i) { + res[i] += ai * bb[i]; + } + } +} + +void zn32_vec_fn(mat32cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(32, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat24cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(24, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat16cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(16, nrows, res, a, b, b_sl); +} +void zn32_vec_fn(mat8cols_ref)(uint64_t nrows, int32_t* res, const INTTYPE* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_matcols_ref(8, nrows, res, a, b, b_sl); +} + +typedef void (*vm_f)(uint64_t nrows, // + int32_t* res, // + const INTTYPE* a, // + const int32_t* b, uint64_t b_sl // +); +static const vm_f zn32_vec_mat8kcols_ref[4] = { // + zn32_vec_fn(mat8cols_ref), // + zn32_vec_fn(mat16cols_ref), // + zn32_vec_fn(mat24cols_ref), // + zn32_vec_fn(mat32cols_ref)}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void concat(default_zn32_vmp_apply, INTSN, ref)( // + const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const INTTYPE* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint32_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_fn(mat32cols_ref)(rows, rr, a, mat, 32); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} diff --git a/spqlios/arithmetic/zn_vmp_int8_avx.c b/spqlios/arithmetic/zn_vmp_int8_avx.c new file mode 100644 index 0000000..74480aa --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_int8_avx.c @@ -0,0 +1,4 @@ +#define INTTYPE int8_t +#define INTSN i8 + +#include "zn_vmp_int32_avx.c" diff --git a/spqlios/arithmetic/zn_vmp_int8_ref.c b/spqlios/arithmetic/zn_vmp_int8_ref.c new file mode 100644 index 0000000..d1de571 --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_int8_ref.c @@ -0,0 +1,4 @@ +#define INTTYPE int8_t +#define INTSN i8 + +#include "zn_vmp_int32_ref.c" diff --git a/spqlios/arithmetic/zn_vmp_ref.c b/spqlios/arithmetic/zn_vmp_ref.c new file mode 100644 index 0000000..d75dca2 --- /dev/null +++ b/spqlios/arithmetic/zn_vmp_ref.c @@ -0,0 +1,138 @@ +#include + +#include "zn_arithmetic_private.h" + +/** @brief size in bytes of a prepared matrix (for custom allocation) */ +EXPORT uint64_t bytes_of_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols // dimensions +) { + return (nrows * ncols + 7) * sizeof(int32_t); +} + +/** @brief allocates a prepared matrix (release with delete_zn32_vmp_pmat) */ +EXPORT ZN32_VMP_PMAT* new_zn32_vmp_pmat(const MOD_Z* module, // N + uint64_t nrows, uint64_t ncols) { + return (ZN32_VMP_PMAT*)spqlios_alloc(bytes_of_zn32_vmp_pmat(module, nrows, ncols)); +} + +/** @brief deletes a prepared matrix (release with free) */ +EXPORT void delete_zn32_vmp_pmat(ZN32_VMP_PMAT* ptr) { spqlios_free(ptr); } + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void default_zn32_vmp_prepare_contiguous_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols // a +) { + int32_t* const out = (int32_t*)pmat; + const uint64_t nblk = ncols >> 5; + const uint64_t ncols_rem = ncols & 31; + const uint64_t final_elems = (8 - nrows * ncols) & 7; + for (uint64_t blk = 0; blk < nblk; ++blk) { + int32_t* outblk = out + blk * nrows * 32; + const int32_t* srcblk = mat + blk * 32; + for (uint64_t row = 0; row < nrows; ++row) { + int32_t* dest = outblk + row * 32; + const int32_t* src = srcblk + row * ncols; + for (uint64_t i = 0; i < 32; ++i) { + dest[i] = src[i]; + } + } + } + // copy the last block if any + if (ncols_rem) { + int32_t* outblk = out + nblk * nrows * 32; + const int32_t* srcblk = mat + nblk * 32; + for (uint64_t row = 0; row < nrows; ++row) { + int32_t* dest = outblk + row * ncols_rem; + const int32_t* src = srcblk + row * ncols; + for (uint64_t i = 0; i < ncols_rem; ++i) { + dest[i] = src[i]; + } + } + } + // zero-out the final elements that may be accessed + if (final_elems) { + int32_t* f = out + nrows * ncols; + for (uint64_t i = 0; i < final_elems; ++i) { + f[i] = 0; + } + } +} + +#if 0 + +#define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \ + memset(res, 0, NCOLS * sizeof(int32_t)); \ + for (uint64_t row = 0; row < nrows; ++row) { \ + int32_t ai = a[row]; \ + const int32_t* bb = b + row * b_sl; \ + for (uint64_t i = 0; i < NCOLS; ++i) { \ + res[i] += ai * bb[i]; \ + } \ + } + +#define IMPL_zn32_vec_ixxx_mat8cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(8) +#define IMPL_zn32_vec_ixxx_mat16cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(16) +#define IMPL_zn32_vec_ixxx_mat24cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(24) +#define IMPL_zn32_vec_ixxx_mat32cols_ref() IMPL_zn32_vec_ixxx_matyyycols_ref(32) + +void zn32_vec_i8_mat32cols_ref(uint64_t nrows, int32_t* res, const int8_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} +void zn32_vec_i16_mat32cols_ref(uint64_t nrows, int32_t* res, const int16_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} + +void zn32_vec_i32_mat32cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat32cols_ref() +} +void zn32_vec_i32_mat24cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat24cols_ref() +} +void zn32_vec_i32_mat16cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat16cols_ref() +} +void zn32_vec_i32_mat8cols_ref(uint64_t nrows, int32_t* res, const int32_t* a, const int32_t* b, uint64_t b_sl) { + IMPL_zn32_vec_ixxx_mat8cols_ref() +} +typedef void (*zn32_vec_i32_mat8kcols_ref_f)(uint64_t nrows, // + int32_t* res, // + const int32_t* a, // + const int32_t* b, uint64_t b_sl // +); +zn32_vec_i32_mat8kcols_ref_f zn32_vec_i32_mat8kcols_ref[4] = { // + zn32_vec_i32_mat8cols_ref, zn32_vec_i32_mat16cols_ref, // + zn32_vec_i32_mat24cols_ref, zn32_vec_i32_mat32cols_ref}; + +/** @brief applies a vmp product (int32_t* input) */ +EXPORT void default_zn32_vmp_apply_i32_ref(const MOD_Z* module, // + int32_t* res, uint64_t res_size, // + const int32_t* a, uint64_t a_size, // + const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { + const uint64_t rows = a_size < nrows ? a_size : nrows; + const uint64_t cols = res_size < ncols ? res_size : ncols; + const uint64_t ncolblk = cols >> 5; + const uint64_t ncolrem = cols & 31; + // copy the first full blocks + const uint32_t full_blk_size = nrows * 32; + const int32_t* mat = (int32_t*)pmat; + int32_t* rr = res; + for (uint64_t blk = 0; // + blk < ncolblk; // + ++blk, mat += full_blk_size, rr += 32) { + zn32_vec_i32_mat32cols_ref(rows, rr, a, mat, 32); + } + // last block + if (ncolrem) { + uint64_t orig_rem = ncols - (ncolblk << 5); + uint64_t b_sl = orig_rem >= 32 ? 32 : orig_rem; + int32_t tmp[32]; + zn32_vec_i32_mat8kcols_ref[(ncolrem - 1) >> 3](rows, tmp, a, mat, b_sl); + memcpy(rr, tmp, ncolrem * sizeof(int32_t)); + } + // trailing bytes + memset(res + cols, 0, (res_size - cols) * sizeof(int32_t)); +} + +#endif From 4d9721df06a031c7828e942fe366228d9aeebcb9 Mon Sep 17 00:00:00 2001 From: georgiev Date: Tue, 6 Aug 2024 13:38:03 +0200 Subject: [PATCH 03/11] Add vec rnx arithmetic and approximate decomposition for both ref and avx --- spqlios/CMakeLists.txt | 6 +- spqlios/arithmetic/vec_rnx_approxdecomp_avx.c | 59 +++++ spqlios/arithmetic/vec_rnx_approxdecomp_ref.c | 75 ++++++ spqlios/arithmetic/vec_rnx_arithmetic.c | 223 ++++++++++++++++++ spqlios/arithmetic/vec_rnx_arithmetic.h | 2 +- spqlios/arithmetic/vec_rnx_arithmetic_avx.c | 189 +++++++++++++++ .../arithmetic/vec_rnx_arithmetic_private.h | 4 +- 7 files changed, 554 insertions(+), 4 deletions(-) create mode 100644 spqlios/arithmetic/vec_rnx_approxdecomp_avx.c create mode 100644 spqlios/arithmetic/vec_rnx_approxdecomp_ref.c create mode 100644 spqlios/arithmetic/vec_rnx_arithmetic.c create mode 100644 spqlios/arithmetic/vec_rnx_arithmetic_avx.c diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt index a4c0c6d..c5b14c2 100644 --- a/spqlios/CMakeLists.txt +++ b/spqlios/CMakeLists.txt @@ -42,6 +42,8 @@ set(SRCS_GENERIC reim/reim_execute.c cplx/cplx_execute.c reim4/reim4_execute.c + arithmetic/vec_rnx_arithmetic.c + arithmetic/vec_rnx_approxdecomp_ref.c ) # C or assembly source files compiled only on x86 targets set(SRCS_X86 @@ -107,7 +109,9 @@ set(SRCS_AVX2 arithmetic/zn_vmp_int32_avx.c q120/q120_arithmetic_avx2.c q120/q120_ntt_avx2.c - ) + arithmetic/vec_rnx_arithmetic_avx.c + arithmetic/vec_rnx_approxdecomp_avx.c +) set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2") # C source files on float128 via libquadmath on x86 targets targets diff --git a/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c b/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c new file mode 100644 index 0000000..2acda14 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_approxdecomp_avx.c @@ -0,0 +1,59 @@ +#include + +#include "immintrin.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_avx( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a // a +) { + const uint64_t nn = module->n; + if (nn < 4) return rnx_approxdecomp_from_tnxdbl_ref(module, gadget, res, res_size, res_sl, a); + const uint64_t ell = gadget->ell; + const __m256i k = _mm256_set1_epi64x(gadget->k); + const __m256d add_cst = _mm256_set1_pd(gadget->add_cst); + const __m256i and_mask = _mm256_set1_epi64x(gadget->and_mask); + const __m256i or_mask = _mm256_set1_epi64x(gadget->or_mask); + const __m256d sub_cst = _mm256_set1_pd(gadget->sub_cst); + const uint64_t msize = res_size <= ell ? res_size : ell; + // gadget decompose column by column + if (msize == ell) { + // this is the main scenario when msize == ell + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; j += 4) { + double* rr = last_r + j; + const double* aa = a + j; + __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); + __m256i t_int = _mm256_castpd_si256(t_dbl); + do { + __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); + _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); + t_int = _mm256_srlv_epi64(t_int, k); + rr -= res_sl; + } while (rr >= res); + } + } else if (msize > 0) { + // otherwise, if msize < ell: there is one additional rshift + const __m256i first_rsh = _mm256_set1_epi64x((ell - msize) * gadget->k); + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; j += 4) { + double* rr = last_r + j; + const double* aa = a + j; + __m256d t_dbl = _mm256_add_pd(_mm256_loadu_pd(aa), add_cst); + __m256i t_int = _mm256_srlv_epi64(_mm256_castpd_si256(t_dbl), first_rsh); + do { + __m256i u_int = _mm256_or_si256(_mm256_and_si256(t_int, and_mask), or_mask); + _mm256_storeu_pd(rr, _mm256_sub_pd(_mm256_castsi256_pd(u_int), sub_cst)); + t_int = _mm256_srlv_epi64(t_int, k); + rr -= res_sl; + } while (rr >= res); + } + } + // zero-out the last slices (if any) + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c b/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c new file mode 100644 index 0000000..eab2d12 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_approxdecomp_ref.c @@ -0,0 +1,75 @@ +#include + +#include "vec_rnx_arithmetic_private.h" + +typedef union di { + double dv; + uint64_t uv; +} di_t; + +/** @brief new gadget: delete with delete_tnxdbl_approxdecomp_gadget */ +EXPORT TNXDBL_APPROXDECOMP_GADGET* new_tnxdbl_approxdecomp_gadget( // + const MOD_RNX* module, // N + uint64_t k, uint64_t ell // base 2^K and size +) { + if (k * ell > 50) return spqlios_error("gadget requires a too large fp precision"); + TNXDBL_APPROXDECOMP_GADGET* res = spqlios_alloc(sizeof(TNXDBL_APPROXDECOMP_GADGET)); + res->k = k; + res->ell = ell; + // double add_cst; // double(3.2^(51-ell.K) + 1/2.(sum 2^(-iK)) for i=[0,ell[) + union di add_cst; + add_cst.dv = UINT64_C(3) << (51 - ell * k); + for (uint64_t i = 0; i < ell; ++i) { + add_cst.uv |= UINT64_C(1) << ((i + 1) * k - 1); + } + res->add_cst = add_cst.dv; + // uint64_t and_mask; // uint64(2^(K)-1) + res->and_mask = (UINT64_C(1) << k) - 1; + // uint64_t or_mask; // double(2^52) + union di or_mask; + or_mask.dv = (UINT64_C(1) << 52); + res->or_mask = or_mask.uv; + // double sub_cst; // double(2^52 + 2^(K-1)) + res->sub_cst = ((UINT64_C(1) << 52) + (UINT64_C(1) << (k - 1))); + return res; +} + +EXPORT void delete_tnxdbl_approxdecomp_gadget(TNXDBL_APPROXDECOMP_GADGET* gadget) { spqlios_free(gadget); } + +/** @brief sets res = gadget_decompose(a) */ +EXPORT void rnx_approxdecomp_from_tnxdbl_ref( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a // a +) { + const uint64_t nn = module->n; + const uint64_t k = gadget->k; + const uint64_t ell = gadget->ell; + const double add_cst = gadget->add_cst; + const uint64_t and_mask = gadget->and_mask; + const uint64_t or_mask = gadget->or_mask; + const double sub_cst = gadget->sub_cst; + const uint64_t msize = res_size <= ell ? res_size : ell; + const uint64_t first_rsh = (ell - msize) * k; + // gadget decompose column by column + if (msize > 0) { + double* const last_r = res + (msize - 1) * res_sl; + for (uint64_t j = 0; j < nn; ++j) { + double* rr = last_r + j; + di_t t = {.dv = a[j] + add_cst}; + if (msize < ell) t.uv >>= first_rsh; + do { + di_t u; + u.uv = (t.uv & and_mask) | or_mask; + *rr = u.dv - sub_cst; + t.uv >>= k; + rr -= res_sl; + } while (rr >= res); + } + } + // zero-out the last slices (if any) + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/arithmetic/vec_rnx_arithmetic.c b/spqlios/arithmetic/vec_rnx_arithmetic.c new file mode 100644 index 0000000..eb56899 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_arithmetic.c @@ -0,0 +1,223 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +void rnx_add_ref(uint64_t nn, double* res, const double* a, const double* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] + b[i]; + } +} + +void rnx_sub_ref(uint64_t nn, double* res, const double* a, const double* b) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = a[i] - b[i]; + } +} + +void rnx_negate_ref(uint64_t nn, double* res, const double* a) { + for (uint64_t i = 0; i < nn; ++i) { + res[i] = -a[i]; + } +} + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +) { + const uint64_t nn = module->n; + for (uint64_t i = 0; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + memcpy(res_ptr, a_ptr, nn * sizeof(double)); + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + rnx_negate_ref(nn, res_ptr, a_ptr); + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + rnx_negate_ref(nn, res + i * res_sl, b + i * b_sl); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_rotate_inplace_f64(nn, p, res_ptr); + } else { + rnx_rotate_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism_ref( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_automorphism_inplace_f64(nn, p, res_ptr); + } else { + rnx_automorphism_f64(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a . (X^p - 1) */ +EXPORT void vec_rnx_mul_xp_minus_one_ref( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + + const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size; + // rotate up to the smallest dimension + for (uint64_t i = 0; i < rot_end_idx; ++i) { + double* res_ptr = res + i * res_sl; + const double* a_ptr = a + i * a_sl; + if (res_ptr == a_ptr) { + rnx_mul_xp_minus_one_inplace(nn, p, res_ptr); + } else { + rnx_mul_xp_minus_one(nn, p, res_ptr, a_ptr); + } + } + // then extend with zeros + for (uint64_t i = rot_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/arithmetic/vec_rnx_arithmetic.h b/spqlios/arithmetic/vec_rnx_arithmetic.h index 65ef889..16a5e6d 100644 --- a/spqlios/arithmetic/vec_rnx_arithmetic.h +++ b/spqlios/arithmetic/vec_rnx_arithmetic.h @@ -3,7 +3,7 @@ #include -#include "spqlios/commons.h" +#include "../commons.h" /** * We support the following module families: diff --git a/spqlios/arithmetic/vec_rnx_arithmetic_avx.c b/spqlios/arithmetic/vec_rnx_arithmetic_avx.c new file mode 100644 index 0000000..04b3ec0 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_arithmetic_avx.c @@ -0,0 +1,189 @@ +#include +#include + +#include "vec_rnx_arithmetic_private.h" + +void rnx_add_avx(uint64_t nn, double* res, const double* a, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_add_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_add_pd(_mm_loadu_pd(a), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = *a + *b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x0, x1, x2, x3, x4, x5; + const double* aa = a; + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x0 = _mm256_loadu_pd(aa); + x1 = _mm256_loadu_pd(aa + 4); + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_add_pd(x0, x2); + x5 = _mm256_add_pd(x1, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + aa += 8; + bb += 8; + rr += 8; + } while (rr < rrend); +} + +void rnx_sub_avx(uint64_t nn, double* res, const double* a, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_sub_pd(_mm256_loadu_pd(a), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_sub_pd(_mm_loadu_pd(a), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = *a - *b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x0, x1, x2, x3, x4, x5; + const double* aa = a; + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x0 = _mm256_loadu_pd(aa); + x1 = _mm256_loadu_pd(aa + 4); + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_sub_pd(x0, x2); + x5 = _mm256_sub_pd(x1, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + aa += 8; + bb += 8; + rr += 8; + } while (rr < rrend); +} + +void rnx_negate_avx(uint64_t nn, double* res, const double* b) { + if (nn < 8) { + if (nn == 4) { + _mm256_storeu_pd(res, _mm256_sub_pd(_mm256_set1_pd(0), _mm256_loadu_pd(b))); + } else if (nn == 2) { + _mm_storeu_pd(res, _mm_sub_pd(_mm_set1_pd(0), _mm_loadu_pd(b))); + } else if (nn == 1) { + *res = -*b; + } else { + NOT_SUPPORTED(); // not a power of 2 + } + return; + } + // general case: nn >= 8 + __m256d x2, x3, x4, x5; + const __m256d ZERO = _mm256_set1_pd(0); + const double* bb = b; + double* rr = res; + double* const rrend = res + nn; + do { + x2 = _mm256_loadu_pd(bb); + x3 = _mm256_loadu_pd(bb + 4); + x4 = _mm256_sub_pd(ZERO, x2); + x5 = _mm256_sub_pd(ZERO, x3); + _mm256_storeu_pd(rr, x4); + _mm256_storeu_pd(rr + 4, x5); + bb += 8; + rr += 8; + } while (rr < rrend); +} + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, b + i * b_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_add_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_negate_avx(nn, res + i * res_sl, a + i * a_sl); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + const uint64_t nn = module->n; + if (a_size < b_size) { + const uint64_t msize = res_size < a_size ? res_size : a_size; + const uint64_t nsize = res_size < b_size ? res_size : b_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + rnx_negate_avx(nn, res + i * res_sl, b + i * b_sl); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } else { + const uint64_t msize = res_size < b_size ? res_size : b_size; + const uint64_t nsize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + rnx_sub_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl); + } + for (uint64_t i = msize; i < nsize; ++i) { + memcpy(res + i * res_sl, a + i * a_sl, nn * sizeof(double)); + } + for (uint64_t i = nsize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } + } +} diff --git a/spqlios/arithmetic/vec_rnx_arithmetic_private.h b/spqlios/arithmetic/vec_rnx_arithmetic_private.h index b761eb7..59a4cf8 100644 --- a/spqlios/arithmetic/vec_rnx_arithmetic_private.h +++ b/spqlios/arithmetic/vec_rnx_arithmetic_private.h @@ -1,8 +1,8 @@ #ifndef SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H #define SPQLIOS_VEC_RNX_ARITHMETIC_PRIVATE_H -#include "spqlios/commons_private.h" -#include "spqlios/reim/reim_fft.h" +#include "../commons_private.h" +#include "../reim/reim_fft.h" #include "vec_rnx_arithmetic.h" #include "vec_rnx_arithmetic_plugin.h" From 571a6d92bf075599003891be37f80992e51d4e12 Mon Sep 17 00:00:00 2001 From: Sandra Guasch Date: Wed, 7 Aug 2024 13:52:59 +0000 Subject: [PATCH 04/11] add vmp --- spqlios/CMakeLists.txt | 3 + spqlios/arithmetic/vec_rnx_vmp_avx.c | 196 +++++++++++++++++++++ spqlios/arithmetic/vec_rnx_vmp_ref.c | 251 +++++++++++++++++++++++++++ 3 files changed, 450 insertions(+) create mode 100644 spqlios/arithmetic/vec_rnx_vmp_avx.c create mode 100644 spqlios/arithmetic/vec_rnx_vmp_ref.c diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt index c5b14c2..1738c44 100644 --- a/spqlios/CMakeLists.txt +++ b/spqlios/CMakeLists.txt @@ -44,6 +44,7 @@ set(SRCS_GENERIC reim4/reim4_execute.c arithmetic/vec_rnx_arithmetic.c arithmetic/vec_rnx_approxdecomp_ref.c + arithmetic/vec_rnx_vmp_ref.c ) # C or assembly source files compiled only on x86 targets set(SRCS_X86 @@ -111,6 +112,8 @@ set(SRCS_AVX2 q120/q120_ntt_avx2.c arithmetic/vec_rnx_arithmetic_avx.c arithmetic/vec_rnx_approxdecomp_avx.c + arithmetic/vec_rnx_vmp_avx.c + ) set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2") diff --git a/spqlios/arithmetic/vec_rnx_vmp_avx.c b/spqlios/arithmetic/vec_rnx_vmp_avx.c new file mode 100644 index 0000000..4c1b23d --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_vmp_avx.c @@ -0,0 +1,196 @@ +#include +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim/reim_fft.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_avx(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_avx(nn, m, res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } + } +} + +/** @brief minimal size of the tmp_space */ +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->n; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (row_max > 0 && col_max > 0) { + if (nn >= 8) { + // let's do some prefetching of the GSW key, since on some cpus, + // it helps + const uint64_t ms4 = m >> 2; // m/4 + const uint64_t gsw_iter_doubles = 8 * nrows * ncols; + const uint64_t pref_doubles = 1200; + const double* gsw_pref_ptr = mat_input; + const double* const gsw_ptr_end = mat_input + ms4 * gsw_iter_doubles; + const double* gsw_pref_ptr_target = mat_input + pref_doubles; + for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) { + __builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0); + } + const double* mat_blk_start; + uint64_t blk_i; + for (blk_i = 0, mat_blk_start = mat_input; blk_i < ms4; blk_i++, mat_blk_start += gsw_iter_doubles) { + // prefetch the next iteration + if (gsw_pref_ptr_target < gsw_ptr_end) { + gsw_pref_ptr_target += gsw_iter_doubles; + if (gsw_pref_ptr_target > gsw_ptr_end) gsw_pref_ptr_target = gsw_ptr_end; + for (; gsw_pref_ptr < gsw_pref_ptr_target; gsw_pref_ptr += 8) { + __builtin_prefetch(gsw_pref_ptr, 0, _MM_HINT_T0); + } + } + reim4_extract_1blk_from_contiguous_reim_sl_avx(m, a_sl, row_max, blk_i, extracted_blk, a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_avx(m, blk_i, res + col_i * res_sl, mat2cols_output); + reim4_save_1blk_to_reim_avx(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_avx2(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_avx(m, blk_i, res + last_col * res_sl, mat2cols_output); + } + } + } else { + const double* in; + uint64_t in_sl; + if (res == a_dft) { + // it is in place: copy the input vector + in = (double*)tmp_space; + in_sl = nn; + // vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl); + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double)); + } + } else { + // it is out of place: do the product directly + in = a_dft; + in_sl = a_sl; + } + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + { + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, // + res + col_i * res_sl, // + in, // + pmat_col); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, // + res + col_i * res_sl, // + in + row_i * in_sl, // + pmat_col + row_i * nn); + } + } + } + } + // zero out remaining bytes (if any) + for (uint64_t i = col_max; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void fft64_rnx_vmp_apply_tmp_a_avx( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->n; + const uint64_t rows = nrows < a_size ? nrows : a_size; + const uint64_t cols = ncols < res_size ? ncols : res_size; + + // fft is done in place on the input (tmpa is destroyed) + for (uint64_t i = 0; i < rows; ++i) { + reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl); + } + fft64_rnx_vmp_apply_dft_to_dft_avx(module, // + res, cols, res_sl, // + tmpa, rows, a_sl, // + pmat, nrows, ncols, // + tmp_space); + // ifft is done in place on the output + for (uint64_t i = 0; i < cols; ++i) { + reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl); + } + // zero out the remaining positions + for (uint64_t i = cols; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} diff --git a/spqlios/arithmetic/vec_rnx_vmp_ref.c b/spqlios/arithmetic/vec_rnx_vmp_ref.c new file mode 100644 index 0000000..de14ba8 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_vmp_ref.c @@ -0,0 +1,251 @@ +#include +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "../reim/reim_fft.h" +#include "../reim4/reim4_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t fft64_bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return nrows * ncols * module->n * sizeof(double); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_ref(nn, m, dtmp, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } + } else { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_ref(nn, m, res, mat + (row_i * ncols + col_i) * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } + } +} + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module) { + const uint64_t nn = module->n; + return nn * sizeof(int64_t); +} + +/** @brief minimal size of the tmp_space */ +EXPORT void fft64_rnx_vmp_apply_dft_to_dft_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + const uint64_t m = module->m; + const uint64_t nn = module->n; + + double* mat2cols_output = (double*)tmp_space; // 128 bytes + double* extracted_blk = (double*)tmp_space + 16; // 64*min(nrows,a_size) bytes + + double* mat_input = (double*)pmat; + + const uint64_t row_max = nrows < a_size ? nrows : a_size; + const uint64_t col_max = ncols < res_size ? ncols : res_size; + + if (row_max > 0 && col_max > 0) { + if (nn >= 8) { + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + double* mat_blk_start = mat_input + blk_i * (8 * nrows * ncols); + + reim4_extract_1blk_from_contiguous_reim_sl_ref(m, a_sl, row_max, blk_i, extracted_blk, a_dft); + // apply mat2cols + for (uint64_t col_i = 0; col_i < col_max - 1; col_i += 2) { + uint64_t col_offset = col_i * (8 * nrows); + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + + reim4_save_1blk_to_reim_ref(m, blk_i, res + col_i * res_sl, mat2cols_output); + reim4_save_1blk_to_reim_ref(m, blk_i, res + (col_i + 1) * res_sl, mat2cols_output + 8); + } + + // check if col_max is odd, then special case + if (col_max % 2 == 1) { + uint64_t last_col = col_max - 1; + uint64_t col_offset = last_col * (8 * nrows); + + // the last column is alone in the pmat: vec_mat1col + if (ncols == col_max) { + reim4_vec_mat1col_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } else { + // the last column is part of a colpair in the pmat: vec_mat2cols and ignore the second position + reim4_vec_mat2cols_product_ref(row_max, mat2cols_output, extracted_blk, mat_blk_start + col_offset); + } + reim4_save_1blk_to_reim_ref(m, blk_i, res + last_col * res_sl, mat2cols_output); + } + } + } else { + const double* in; + uint64_t in_sl; + if (res == a_dft) { + // it is in place: copy the input vector + in = (double*)tmp_space; + in_sl = nn; + // vec_rnx_copy(module, (double*)tmp_space, row_max, nn, a_dft, row_max, a_sl); + for (uint64_t row_i = 0; row_i < row_max; row_i++) { + memcpy((double*)tmp_space + row_i * nn, a_dft + row_i * a_sl, nn * sizeof(double)); + } + } else { + // it is out of place: do the product directly + in = a_dft; + in_sl = a_sl; + } + for (uint64_t col_i = 0; col_i < col_max; col_i++) { + double* pmat_col = mat_input + col_i * nrows * nn; + { + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, // + res + col_i * res_sl, // + in, // + pmat_col); + } + for (uint64_t row_i = 1; row_i < row_max; row_i++) { + reim_fftvec_addmul(module->precomp.fft64.p_fftvec_addmul, // + res + col_i * res_sl, // + in + row_i * in_sl, // + pmat_col + row_i * nn); + } + } + } + } + // zero out remaining bytes (if any) + for (uint64_t i = col_max; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res (addr must be != a) + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + const uint64_t nn = module->n; + const uint64_t rows = nrows < a_size ? nrows : a_size; + const uint64_t cols = ncols < res_size ? ncols : res_size; + + // fft is done in place on the input (tmpa is destroyed) + for (uint64_t i = 0; i < rows; ++i) { + reim_fft(module->precomp.fft64.p_fft, tmpa + i * a_sl); + } + fft64_rnx_vmp_apply_dft_to_dft_ref(module, // + res, cols, res_sl, // + tmpa, rows, a_sl, // + pmat, nrows, ncols, // + tmp_space); + // ifft is done in place on the output + for (uint64_t i = 0; i < cols; ++i) { + reim_ifft(module->precomp.fft64.p_ifft, res + i * res_sl); + } + // zero out the remaining positions + for (uint64_t i = cols; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + const uint64_t row_max = nrows < a_size ? nrows : a_size; + + return (128) + (64 * row_max); +} + +#ifdef __APPLE__ +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref(module, res_size, a_size, nrows, ncols); +} +#else +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif +// avx aliases that need to be defined in the same .c file + +/** @brief number of scratch bytes necessary to prepare a matrix */ +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module) + __attribute((alias("fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref"))); +#endif + +/** @brief minimal size of the tmp_space */ +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif + +#ifdef __APPLE__ +#pragma weak fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref +#else +EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix + ) __attribute((alias("fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref"))); +#endif +// wrappers From 5ab2272b02c2cef74703b285c42f1871a4e0ad37 Mon Sep 17 00:00:00 2001 From: Maurice Shih Date: Wed, 7 Aug 2024 12:20:08 -0700 Subject: [PATCH 05/11] added rnx_conversions_ref and rnx_svp_ref --- spqlios/arithmetic/vec_rnx_api.c | 318 +++++++++++++++++++ spqlios/arithmetic/vec_rnx_conversions_ref.c | 91 ++++++ spqlios/arithmetic/vec_rnx_svp_ref.c | 47 +++ 3 files changed, 456 insertions(+) create mode 100644 spqlios/arithmetic/vec_rnx_api.c create mode 100644 spqlios/arithmetic/vec_rnx_conversions_ref.c create mode 100644 spqlios/arithmetic/vec_rnx_svp_ref.c diff --git a/spqlios/arithmetic/vec_rnx_api.c b/spqlios/arithmetic/vec_rnx_api.c new file mode 100644 index 0000000..0f396fb --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_api.c @@ -0,0 +1,318 @@ +#include + +#include "vec_rnx_arithmetic_private.h" + +void fft64_init_rnx_module_precomp(MOD_RNX* module) { + // Add here initialization of items that are in the precomp + const uint64_t m = module->m; + module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0); + module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0); + module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m); + module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m); +} + +void fft64_finalize_rnx_module_precomp(MOD_RNX* module) { + // Add here deleters for items that are in the precomp + delete_reim_fft_precomp(module->precomp.fft64.p_fft); + delete_reim_ifft_precomp(module->precomp.fft64.p_ifft); + delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul); + delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul); +} + +void fft64_init_rnx_module_vtable(MOD_RNX* module) { + // Add function pointers here + module->vtable.vec_rnx_add = vec_rnx_add_ref; + module->vtable.vec_rnx_zero = vec_rnx_zero_ref; + module->vtable.vec_rnx_copy = vec_rnx_copy_ref; + module->vtable.vec_rnx_negate = vec_rnx_negate_ref; + module->vtable.vec_rnx_sub = vec_rnx_sub_ref; + module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref; + module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref; + module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref; + module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref; + module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref; + module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref; + module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref; + module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref; + module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref; + module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat; + module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref; + module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref; + module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref; + module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref; + module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref; + module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref; + module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol; + module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref; + module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref; + + // Add optimized function pointers here + if (CPU_SUPPORTS("avx")) { + module->vtable.vec_rnx_add = vec_rnx_add_avx; + module->vtable.vec_rnx_sub = vec_rnx_sub_avx; + module->vtable.vec_rnx_negate = vec_rnx_negate_avx; + module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx; + module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx; + module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx; + module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx; + module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx; + module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx; + module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx; + } +} + +void init_rnx_module_info(MOD_RNX* module, // + uint64_t n, RNX_MODULE_TYPE mtype) { + memset(module, 0, sizeof(MOD_RNX)); + module->n = n; + module->m = n >> 1; + module->mtype = mtype; + switch (mtype) { + case FFT64: + fft64_init_rnx_module_precomp(module); + fft64_init_rnx_module_vtable(module); + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +void finalize_rnx_module_info(MOD_RNX* module) { + if (module->custom) module->custom_deleter(module->custom); + switch (module->mtype) { + case FFT64: + fft64_finalize_rnx_module_precomp(module); + // fft64_finalize_rnx_module_vtable(module); // nothing to finalize + break; + default: + NOT_SUPPORTED(); // unknown mtype + } +} + +EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) { + MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX)); + init_rnx_module_info(res, nn, mtype); + return res; +} + +EXPORT void delete_rnx_module_info(MOD_RNX* module_info) { + finalize_rnx_module_info(module_info); + free(module_info); +} + +EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; } + +/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */ +EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols)); +} +EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); } + +//////////////// wrappers ////////////////// + +/** @brief sets res = a + b */ +EXPORT void vec_rnx_add( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = 0 */ +EXPORT void vec_rnx_zero( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl // res +) { + module->vtable.vec_rnx_zero(module, res, res_size, res_sl); +} + +/** @brief sets res = a */ +EXPORT void vec_rnx_copy( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = -a */ +EXPORT void vec_rnx_negate( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a - b */ +EXPORT void vec_rnx_sub( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl, // a + const double* b, uint64_t b_size, uint64_t b_sl // b +) { + module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl); +} + +/** @brief sets res = a . X^p */ +EXPORT void vec_rnx_rotate( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl); +} + +/** @brief sets res = a(X^p) */ +EXPORT void vec_rnx_automorphism( // + const MOD_RNX* module, // N + int64_t p, // X -> X^p + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_mul_xp_minus_one( // + const MOD_RNX* module, // N + const int64_t p, // rotation value + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl); +} +/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */ +EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N + uint64_t nrows, uint64_t ncols) { // dimensions + return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols); +} + +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void rnx_vmp_prepare_contiguous( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space); +} + +/** @brief number of scratch bytes necessary to prepare a matrix */ +EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) { + return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module); +} + +/** @brief applies a vmp product res = a x pmat */ +EXPORT void rnx_vmp_apply_tmp_a( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten) + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space); +} + +EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res size + uint64_t a_size, // a size + uint64_t nrows, uint64_t ncols // prep matrix dims +) { + return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +/** @brief minimal size of the tmp_space */ +EXPORT void rnx_vmp_apply_dft_to_dft( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a_dft, uint64_t a_size, uint64_t a_sl, // a + const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix + uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes) +) { + module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols, + tmp_space); +} + +/** @brief minimal size of the tmp_space */ +EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( // + const MOD_RNX* module, // N + uint64_t res_size, // res + uint64_t a_size, // a + uint64_t nrows, uint64_t ncols // prep matrix +) { + return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols); +} + +EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); } + +EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +) { + module->vtable.rnx_svp_prepare(module, ppol, pol); +} + +EXPORT void rnx_svp_apply( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.rnx_svp_apply(module, // N + res, res_size, res_sl, // output + ppol, // prepared pol + a, a_size, a_sl); +} + +EXPORT void rnx_approxdecomp_from_tnxdbl( // + const MOD_RNX* module, // N + const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a) { // a + module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a); +} + +EXPORT void vec_rnx_to_znx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_from_znx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_to_tnx32( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_from_tnx32( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl); +} + +EXPORT void vec_rnx_to_tnxdbl( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl); +} diff --git a/spqlios/arithmetic/vec_rnx_conversions_ref.c b/spqlios/arithmetic/vec_rnx_conversions_ref.c new file mode 100644 index 0000000..2a1b296 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_conversions_ref.c @@ -0,0 +1,91 @@ +#include + +#include "vec_rnx_arithmetic_private.h" +#include "zn_arithmetic_private.h" + +EXPORT void vec_rnx_to_znx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_round_to_i32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} + +EXPORT void vec_rnx_from_znx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + i32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} +EXPORT void vec_rnx_to_tnx32_ref( // + const MOD_RNX* module, // N + int32_t* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_to_tn32_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} +EXPORT void vec_rnx_from_tnx32_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const int32_t* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + tn32_to_dbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} + +static void dbl_to_tndbl_ref( // + const void* UNUSED, // N + double* res, uint64_t res_size, // res + const double* a, uint64_t a_size // a +) { + static const double OFF_CST = INT64_C(3) << 51; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + double ai = a[i] + OFF_CST; + res[i] = a[i] - (ai - OFF_CST); + } + memset(res + msize, 0, (res_size - msize) * sizeof(double)); +} + +EXPORT void vec_rnx_to_tnxdbl_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // res + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + const uint64_t msize = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < msize; ++i) { + dbl_to_tndbl_ref(NULL, res + i * res_sl, nn, a + i * a_sl, nn); + } + for (uint64_t i = msize; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(int32_t)); + } +} diff --git a/spqlios/arithmetic/vec_rnx_svp_ref.c b/spqlios/arithmetic/vec_rnx_svp_ref.c new file mode 100644 index 0000000..f811148 --- /dev/null +++ b/spqlios/arithmetic/vec_rnx_svp_ref.c @@ -0,0 +1,47 @@ +#include + +#include "../coeffs/coeffs_arithmetic.h" +#include "vec_rnx_arithmetic_private.h" + +EXPORT uint64_t fft64_bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->n * sizeof(double); } + +EXPORT RNX_SVP_PPOL* new_rnx_svp_ppol(const MOD_RNX* module) { return spqlios_alloc(bytes_of_rnx_svp_ppol(module)); } + +EXPORT void delete_rnx_svp_ppol(RNX_SVP_PPOL* ppol) { spqlios_free(ppol); } + +/** @brief prepares a svp polynomial */ +EXPORT void fft64_rnx_svp_prepare_ref(const MOD_RNX* module, // N + RNX_SVP_PPOL* ppol, // output + const double* pol // a +) { + double* const dppol = (double*)ppol; + rnx_divide_by_m_ref(module->n, module->m, dppol, pol); + reim_fft(module->precomp.fft64.p_fft, dppol); +} + +EXPORT void fft64_rnx_svp_apply_ref( // + const MOD_RNX* module, // N + double* res, uint64_t res_size, uint64_t res_sl, // output + const RNX_SVP_PPOL* ppol, // prepared pol + const double* a, uint64_t a_size, uint64_t a_sl // a +) { + const uint64_t nn = module->n; + double* const dppol = (double*)ppol; + + const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size; + for (uint64_t i = 0; i < auto_end_idx; ++i) { + const double* a_ptr = a + i * a_sl; + double* const res_ptr = res + i * res_sl; + // copy the polynomial to res, apply fft in place, call fftvec + // _mul, apply ifft in place. + memcpy(res_ptr, a_ptr, nn * sizeof(double)); + reim_fft(module->precomp.fft64.p_fft, (double*)res_ptr); + reim_fftvec_mul(module->precomp.fft64.p_fftvec_mul, res_ptr, res_ptr, dppol); + reim_ifft(module->precomp.fft64.p_ifft, res_ptr); + } + + // then extend with zeros + for (uint64_t i = auto_end_idx; i < res_size; ++i) { + memset(res + i * res_sl, 0, nn * sizeof(double)); + } +} From 8b07fe908ad359f6338478a039dd3358aa9aa458 Mon Sep 17 00:00:00 2001 From: georgiev Date: Fri, 9 Aug 2024 11:06:21 +0200 Subject: [PATCH 06/11] Add zn tests for approxdecomps, conversions, vmp and layouts --- test/spqlios_zn_approxdecomp_test.cpp | 46 ++++++++++++ test/spqlios_zn_conversions_test.cpp | 104 ++++++++++++++++++++++++++ test/spqlios_zn_vmp_test.cpp | 67 +++++++++++++++++ test/testlib/zn_layouts.cpp | 55 ++++++++++++++ test/testlib/zn_layouts.h | 29 +++++++ 5 files changed, 301 insertions(+) create mode 100644 test/spqlios_zn_approxdecomp_test.cpp create mode 100644 test/spqlios_zn_conversions_test.cpp create mode 100644 test/spqlios_zn_vmp_test.cpp create mode 100644 test/testlib/zn_layouts.cpp create mode 100644 test/testlib/zn_layouts.h diff --git a/test/spqlios_zn_approxdecomp_test.cpp b/test/spqlios_zn_approxdecomp_test.cpp new file mode 100644 index 0000000..d21f420 --- /dev/null +++ b/test/spqlios_zn_approxdecomp_test.cpp @@ -0,0 +1,46 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/zn_arithmetic_private.h" +#include "testlib/test_commons.h" + +template +static void test_tndbl_approxdecomp( // + void (*approxdec)(const MOD_Z*, const TNDBL_APPROXDECOMP_GADGET*, INTTYPE*, uint64_t, const double*, uint64_t) // +) { + for (const uint64_t nn : {1, 3, 8, 51}) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (const uint64_t ell : {1, 2, 7}) { + for (const uint64_t k : {2, 5}) { + TNDBL_APPROXDECOMP_GADGET* gadget = new_tndbl_approxdecomp_gadget(module, k, ell); + for (const uint64_t res_size : {ell * nn}) { + std::vector in(nn); + std::vector out(res_size); + for (double& x : in) x = uniform_f64_bounds(-10, 10); + approxdec(module, gadget, out.data(), res_size, in.data(), nn); + // reconstruct the output + double err_bnd = pow(2., -double(ell * k) - 1); + for (uint64_t j = 0; j < nn; ++j) { + double in_j = in[j]; + double out_j = 0; + for (uint64_t i = 0; i < ell; ++i) { + out_j += out[ell * j + i] * pow(2., -double((i + 1) * k)); + } + double err = out_j - in_j; + double err_abs = fabs(err - rint(err)); + ASSERT_LE(err_abs, err_bnd); + } + } + delete_tndbl_approxdecomp_gadget(gadget); + } + } + delete_z_module_info(module); + } +} + +TEST(vec_rnx, i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i8_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i8_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i8_approxdecomp_from_tndbl_ref); } + +TEST(vec_rnx, i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i16_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i16_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i16_approxdecomp_from_tndbl_ref); } + +TEST(vec_rnx, i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(i32_approxdecomp_from_tndbl); } +TEST(vec_rnx, default_i32_tndbl_rnx_approxdecomp) { test_tndbl_approxdecomp(default_i32_approxdecomp_from_tndbl_ref); } diff --git a/test/spqlios_zn_conversions_test.cpp b/test/spqlios_zn_conversions_test.cpp new file mode 100644 index 0000000..da2b94b --- /dev/null +++ b/test/spqlios_zn_conversions_test.cpp @@ -0,0 +1,104 @@ +#include +#include + +#include "testlib/test_commons.h" + +template +static void test_conv(void (*conv_f)(const MOD_Z*, DST_T* res, uint64_t res_size, const SRC_T* a, uint64_t a_size), + DST_T (*ideal_conv_f)(SRC_T x), SRC_T (*random_f)()) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t a_size : {0, 1, 2, 42}) { + for (uint64_t res_size : {0, 1, 2, 42}) { + for (uint64_t trials = 0; trials < 100; ++trials) { + std::vector a(a_size); + std::vector res(res_size); + uint64_t msize = std::min(a_size, res_size); + for (SRC_T& x : a) x = random_f(); + conv_f(module, res.data(), res_size, a.data(), a_size); + for (uint64_t i = 0; i < msize; ++i) { + DST_T expect = ideal_conv_f(a[i]); + DST_T actual = res[i]; + ASSERT_EQ(expect, actual); + } + for (uint64_t i = msize; i < res_size; ++i) { + DST_T expect = 0; + SRC_T actual = res[i]; + ASSERT_EQ(expect, actual); + } + } + } + } + delete_z_module_info(module); +} + +static int32_t ideal_dbl_to_tn32(double a) { + double _2p32 = INT64_C(1) << 32; + double a_mod_1 = a - rint(a); + int64_t t = rint(a_mod_1 * _2p32); + return int32_t(t); +} + +static double random_f64_10() { return uniform_f64_bounds(-10, 10); } + +static void test_dbl_to_tn32(DBL_TO_TN32_F dbl_to_tn32_f) { + test_conv(dbl_to_tn32_f, ideal_dbl_to_tn32, random_f64_10); +} + +TEST(zn_arithmetic, dbl_to_tn32) { test_dbl_to_tn32(dbl_to_tn32); } +TEST(zn_arithmetic, dbl_to_tn32_ref) { test_dbl_to_tn32(dbl_to_tn32_ref); } + +static double ideal_tn32_to_dbl(int32_t a) { + const double _2p32 = INT64_C(1) << 32; + return double(a) / _2p32; +} + +static int32_t random_t32() { return uniform_i64_bits(32); } + +static void test_tn32_to_dbl(TN32_TO_DBL_F tn32_to_dbl_f) { test_conv(tn32_to_dbl_f, ideal_tn32_to_dbl, random_t32); } + +TEST(zn_arithmetic, tn32_to_dbl) { test_tn32_to_dbl(tn32_to_dbl); } +TEST(zn_arithmetic, tn32_to_dbl_ref) { test_tn32_to_dbl(tn32_to_dbl_ref); } + +static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } + +static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } + +static void test_dbl_round_to_i32(DBL_ROUND_TO_I32_F dbl_round_to_i32_f) { + test_conv(dbl_round_to_i32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); +} + +TEST(zn_arithmetic, dbl_round_to_i32) { test_dbl_round_to_i32(dbl_round_to_i32); } +TEST(zn_arithmetic, dbl_round_to_i32_ref) { test_dbl_round_to_i32(dbl_round_to_i32_ref); } + +static double ideal_i32_to_dbl(int32_t a) { return double(a); } + +static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } + +static void test_i32_to_dbl(I32_TO_DBL_F i32_to_dbl_f) { + test_conv(i32_to_dbl_f, ideal_i32_to_dbl, random_i32_explaw_18); +} + +TEST(zn_arithmetic, i32_to_dbl) { test_i32_to_dbl(i32_to_dbl); } +TEST(zn_arithmetic, i32_to_dbl_ref) { test_i32_to_dbl(i32_to_dbl_ref); } + +static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } + +static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } + +static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { + test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); +} + +TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } +TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } + +static double ideal_i64_to_dbl(int64_t a) { return double(a); } + +static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } + +static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { + test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); +} + +TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } +TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } diff --git a/test/spqlios_zn_vmp_test.cpp b/test/spqlios_zn_vmp_test.cpp new file mode 100644 index 0000000..8f6fa25 --- /dev/null +++ b/test/spqlios_zn_vmp_test.cpp @@ -0,0 +1,67 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/zn_arithmetic_private.h" +#include "testlib/zn_layouts.h" + +static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prep) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + std::vector src(nrows * ncols); + zn32_pmat_layout out(nrows, ncols); + for (int32_t& x : src) x = uniform_i64_bits(32); + prep(module, out.data, src.data(), nrows, ncols); + for (uint64_t i = 0; i < nrows; ++i) { + for (uint64_t j = 0; j < ncols; ++j) { + int32_t in = src[i * ncols + j]; + int32_t actual = out.get(i, j); + ASSERT_EQ(actual, in); + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_prepare_contiguous) { test_zn_vmp_prepare(zn32_vmp_prepare_contiguous); } +TEST(zn, default_zn32_vmp_prepare_contiguous_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_contiguous_ref); } + +template +static void test_zn_vmp_apply(void (*apply)(const MOD_Z*, int32_t*, uint64_t, const INTTYPE*, uint64_t, + const ZN32_VMP_PMAT*, uint64_t, uint64_t)) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + for (uint64_t a_size : {1, 2, 5, 15}) { + for (uint64_t res_size : {1, 2, 32, 42, 67}) { + std::vector a(a_size); + zn32_pmat_layout out(nrows, ncols); + std::vector res(res_size); + for (INTTYPE& x : a) x = uniform_i64_bits(32); + out.fill_random(); + std::vector expect = vmp_product(a.data(), a_size, res_size, out); + apply(module, res.data(), res_size, a.data(), a_size, out.data, nrows, ncols); + for (uint64_t i = 0; i < res_size; ++i) { + int32_t exp = expect[i]; + int32_t actual = res[i]; + ASSERT_EQ(actual, exp); + } + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_apply_i32) { test_zn_vmp_apply(zn32_vmp_apply_i32); } +TEST(zn, zn32_vmp_apply_i16) { test_zn_vmp_apply(zn32_vmp_apply_i16); } +TEST(zn, zn32_vmp_apply_i8) { test_zn_vmp_apply(zn32_vmp_apply_i8); } + +TEST(zn, default_zn32_vmp_apply_i32_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_ref); } +TEST(zn, default_zn32_vmp_apply_i16_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_ref); } +TEST(zn, default_zn32_vmp_apply_i8_ref) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_ref); } + +#ifdef __x86_64__ +TEST(zn, default_zn32_vmp_apply_i32_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i32_avx); } +TEST(zn, default_zn32_vmp_apply_i16_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i16_avx); } +TEST(zn, default_zn32_vmp_apply_i8_avx) { test_zn_vmp_apply(default_zn32_vmp_apply_i8_avx); } +#endif diff --git a/test/testlib/zn_layouts.cpp b/test/testlib/zn_layouts.cpp new file mode 100644 index 0000000..185c9b6 --- /dev/null +++ b/test/testlib/zn_layouts.cpp @@ -0,0 +1,55 @@ +#include "zn_layouts.h" + +zn32_pmat_layout::zn32_pmat_layout(uint64_t nrows, uint64_t ncols) + : nrows(nrows), // + ncols(ncols), // + data((ZN32_VMP_PMAT*)malloc((nrows * ncols + 7) * sizeof(int32_t))) {} + +zn32_pmat_layout::~zn32_pmat_layout() { free(data); } + +int32_t* zn32_pmat_layout::get_addr(uint64_t row, uint64_t col) const { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow" << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "col overflow" << col << " / " << ncols); + const uint64_t nblk = ncols >> 5; + const uint64_t rem_ncols = ncols & 31; + uint64_t blk = col >> 5; + uint64_t col_rem = col & 31; + if (blk < nblk) { + // column is part of a full block + return (int32_t*)data + blk * nrows * 32 + row * 32 + col_rem; + } else { + // column is part of the last block + return (int32_t*)data + blk * nrows * 32 + row * rem_ncols + col_rem; + } +} +int32_t zn32_pmat_layout::get(uint64_t row, uint64_t col) const { return *get_addr(row, col); } +int32_t zn32_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) return 0; + return *get_addr(row, col); +} +void zn32_pmat_layout::set(uint64_t row, uint64_t col, int32_t value) { *get_addr(row, col) = value; } +void zn32_pmat_layout::fill_random() { + int32_t* d = (int32_t*)data; + for (uint64_t i = 0; i < nrows * ncols; ++i) d[i] = uniform_i64_bits(32); +} +thash zn32_pmat_layout::content_hash() const { return test_hash(data, nrows * ncols * sizeof(int32_t)); } + +template +std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat) { + uint64_t rows = std::min(vec_size, mat.nrows); + uint64_t cols = std::min(out_size, mat.ncols); + std::vector res(out_size, 0); + for (uint64_t j = 0; j < cols; ++j) { + for (uint64_t i = 0; i < rows; ++i) { + res[j] += vec[i] * mat.get(i, j); + } + } + return res; +} + +template std::vector vmp_product(const int8_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); +template std::vector vmp_product(const int16_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); +template std::vector vmp_product(const int32_t* vec, uint64_t vec_size, uint64_t out_size, + const zn32_pmat_layout& mat); diff --git a/test/testlib/zn_layouts.h b/test/testlib/zn_layouts.h new file mode 100644 index 0000000..6a89173 --- /dev/null +++ b/test/testlib/zn_layouts.h @@ -0,0 +1,29 @@ +#ifndef SPQLIOS_EXT_ZN_LAYOUTS_H +#define SPQLIOS_EXT_ZN_LAYOUTS_H + +#include "spqlios/arithmetic/zn_arithmetic.h" +#include "testlib/test_commons.h" + +class zn32_pmat_layout { + public: + const uint64_t nrows; + const uint64_t ncols; + ZN32_VMP_PMAT* const data; + zn32_pmat_layout(uint64_t nrows, uint64_t ncols); + + private: + int32_t* get_addr(uint64_t row, uint64_t col) const; + + public: + int32_t get(uint64_t row, uint64_t col) const; + int32_t get_zext(uint64_t row, uint64_t col) const; + void set(uint64_t row, uint64_t col, int32_t value); + void fill_random(); + thash content_hash() const; + ~zn32_pmat_layout(); +}; + +template +std::vector vmp_product(const T* vec, uint64_t vec_size, uint64_t out_size, const zn32_pmat_layout& mat); + +#endif // SPQLIOS_EXT_ZN_LAYOUTS_H From 4fe204dc66e0dac0cbd6f9d41e1502f1fdaa40fa Mon Sep 17 00:00:00 2001 From: Maurice Shih Date: Fri, 9 Aug 2024 15:52:21 -0700 Subject: [PATCH 07/11] added some vec_rnx test files --- ...qlios_vec_rnx_approxdecomp_tnxdbl_test.cpp | 42 ++ test/spqlios_vec_rnx_test.cpp | 417 ++++++++++++++++++ test/testlib/vec_rnx_layout.cpp | 182 ++++++++ test/testlib/vec_rnx_layout.h | 85 ++++ 4 files changed, 726 insertions(+) create mode 100644 test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp create mode 100644 test/spqlios_vec_rnx_test.cpp create mode 100644 test/testlib/vec_rnx_layout.cpp create mode 100644 test/testlib/vec_rnx_layout.h diff --git a/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp b/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp new file mode 100644 index 0000000..5b36ed0 --- /dev/null +++ b/test/spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp @@ -0,0 +1,42 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "testlib/vec_rnx_layout.h" + +static void test_rnx_approxdecomp(RNX_APPROXDECOMP_FROM_TNXDBL_F approxdec) { + for (const uint64_t nn : {2, 4, 8, 32}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (const uint64_t ell : {1, 2, 7}) { + for (const uint64_t k : {2, 5}) { + TNXDBL_APPROXDECOMP_GADGET* gadget = new_tnxdbl_approxdecomp_gadget(module, k, ell); + for (const uint64_t res_size : {ell, ell - 1, ell + 1}) { + const uint64_t res_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, 1, nn); + in.fill_random(3); + rnx_vec_f64_layout out(nn, res_size, res_sl); + approxdec(module, gadget, out.data(), res_size, res_sl, in.data()); + // reconstruct the output + uint64_t msize = std::min(res_size, ell); + double err_bnd = msize == ell ? pow(2., -double(msize * k) - 1) : pow(2., -double(msize * k)); + for (uint64_t j = 0; j < nn; ++j) { + double in_j = in.data()[j]; + double out_j = 0; + for (uint64_t i = 0; i < res_size; ++i) { + out_j += out.get_copy(i).get_coeff(j) * pow(2., -double((i + 1) * k)); + } + double err = out_j - in_j; + double err_abs = fabs(err - rint(err)); + ASSERT_LE(err_abs, err_bnd); + } + } + delete_tnxdbl_approxdecomp_gadget(gadget); + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, rnx_approxdecomp) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl); } +TEST(vec_rnx, rnx_approxdecomp_ref) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_ref); } +#ifdef __x86_64__ +TEST(vec_rnx, rnx_approxdecomp_avx) { test_rnx_approxdecomp(rnx_approxdecomp_from_tnxdbl_avx); } +#endif diff --git a/test/spqlios_vec_rnx_test.cpp b/test/spqlios_vec_rnx_test.cpp new file mode 100644 index 0000000..0aa8723 --- /dev/null +++ b/test/spqlios_vec_rnx_test.cpp @@ -0,0 +1,417 @@ +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "test/testlib/vec_rnx_layout.h" + +// disabling this test by default, since it depicts on purpose wrong accesses +#if 0 +TEST(rnx_layout, valgrind_antipattern_test) { + uint64_t n = 4; + rnx_vec_f64_layout v(n, 7, 13); + // this should be ok + v.set(0, rnx_f64::zero(n)); + // this should abort (wrong ring dimension) + ASSERT_DEATH(v.set(3, rnx_f64::zero(2 * n)), ""); + // this should abort (out of bounds) + ASSERT_DEATH(v.set(8, rnx_f64::zero(n)), ""); + // this should be ok + ASSERT_EQ(v.get_copy_zext(0), rnx_f64::zero(n)); + // should be an uninit read + ASSERT_TRUE(!(v.get_copy_zext(2) == rnx_f64::zero(n))); // should be uninit + // should be an invalid read (inter-slice) + ASSERT_NE(v.data()[4], 0); + ASSERT_EQ(v.data()[2], 0); // should be ok + // should be an uninit read + ASSERT_NE(v.data()[13], 0); // should be uninit +} +#endif + +// test of binary operations + +// test for out of place calls +template +void test_vec_znx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 8, 128}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + for (uint64_t sc : {7, 13, 15}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + uint64_t c_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + rnx_vec_f64_layout lc(n, sc, c_sl); + std::vector expect(sc); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sc; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lc.data(), sc, c_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sc; ++i) { + ASSERT_EQ(lc.get_copy_zext(i), expect[i]); + } + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace1 calls +template +void test_vec_znx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 4, 64}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace2 calls +template +void test_vec_znx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {4, 32, 64}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {3, 9, 12}) { + for (uint64_t sb : {3, 9, 12}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + lb.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), lb.get_copy_zext(i)); + } + binop(mod, // N + lb.data(), sb, b_sl, // res + la.data(), sa, a_sl, // a + lb.data(), sb, b_sl); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]); + } + } + } + delete_rnx_module_info(mod); + } +} +// test for inplace3 calls +template +void test_vec_znx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + for (uint64_t n : {2, 16, 1024}) { + RNX_MODULE_TYPE mtype = FFT64; + MOD_RNX* mod = new_rnx_module_info(n, mtype); + for (uint64_t sa : {2, 6, 11}) { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 1.)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_binop(la.get_copy_zext(i), la.get_copy_zext(i)); + } + binop(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl, // a + la.data(), sa, a_sl); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]); + } + } + delete_rnx_module_info(mod); + } +} +template +void test_vec_znx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_znx_elemw_binop_outplace(binop, ref_binop); + test_vec_znx_elemw_binop_inplace1(binop, ref_binop); + test_vec_znx_elemw_binop_inplace2(binop, ref_binop); + test_vec_znx_elemw_binop_inplace3(binop, ref_binop); +} + +static rnx_f64 poly_add(const rnx_f64& a, const rnx_f64& b) { return a + b; } +static rnx_f64 poly_sub(const rnx_f64& a, const rnx_f64& b) { return a - b; } +TEST(vec_znx, vec_znx_add) { test_vec_znx_elemw_binop(vec_rnx_add, poly_add); } +TEST(vec_znx, vec_znx_add_ref) { test_vec_znx_elemw_binop(vec_rnx_add_ref, poly_add); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_add_avx) { test_vec_znx_elemw_binop(vec_rnx_add_avx, poly_add); } +#endif +TEST(vec_znx, vec_znx_sub) { test_vec_znx_elemw_binop(vec_rnx_sub, poly_sub); } +TEST(vec_znx, vec_znx_sub_ref) { test_vec_znx_elemw_binop(vec_rnx_sub_ref, poly_sub); } +#ifdef __x86_64__ +TEST(vec_znx, vec_znx_sub_avx) { test_vec_znx_elemw_binop(vec_rnx_sub_avx, poly_sub); } +#endif + +// test for out of place calls +template +void test_vec_rnx_elemw_unop_param_outplace(ACTUAL_FCN test_mul_xp_minus_one, EXPECT_FCN ref_mul_xp_minus_one, + int64_t (*param_gen)()) { + for (uint64_t n : {2, 4, 8, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_mul_xp_minus_one(p, la.get_copy_zext(i)); + } + test_mul_xp_minus_one(mod, // + p, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_rnx_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_rnx_elemw_unop_param_inplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function, + int64_t (*param_gen)()) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + int64_t p = param_gen(); + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_function(p, la.get_copy_zext(i)); + } + actual_function(mod, // N + p, //; + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} + +static int64_t random_mul_xp_minus_one_param() { return uniform_i64(); } +static int64_t random_automorphism_param() { return 2 * uniform_i64() + 1; } +static int64_t random_rotation_param() { return uniform_i64(); } + +template +void test_vec_rnx_elemw_mul_xp_minus_one(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_mul_xp_minus_one_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_mul_xp_minus_one_param); +} +template +void test_vec_rnx_elemw_rotate(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_rotation_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_rotation_param); +} +template +void test_vec_rnx_elemw_automorphism(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_unop_param_outplace(binop, ref_binop, random_automorphism_param); + test_vec_rnx_elemw_unop_param_inplace(binop, ref_binop, random_automorphism_param); +} + +static rnx_f64 poly_mul_xp_minus_one(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p) - a.get_coeff(i)); + } + return res; +} +static rnx_f64 poly_rotate(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i, a.get_coeff(i - p)); + } + return res; +} +static rnx_f64 poly_automorphism(const int64_t p, const rnx_f64& a) { + uint64_t n = a.nn(); + rnx_f64 res(n); + for (uint64_t i = 0; i < n; ++i) { + res.set_coeff(i * p, a.get_coeff(i)); + } + return res; +} + +TEST(vec_rnx, vec_rnx_mul_xp_minus_one) { + test_vec_rnx_elemw_mul_xp_minus_one(vec_rnx_mul_xp_minus_one, poly_mul_xp_minus_one); +} +TEST(vec_rnx, vec_rnx_mul_xp_minus_one_ref) { + test_vec_rnx_elemw_mul_xp_minus_one(vec_rnx_mul_xp_minus_one_ref, poly_mul_xp_minus_one); +} + +TEST(vec_rnx, vec_rnx_rotate) { test_vec_rnx_elemw_rotate(vec_rnx_rotate, poly_rotate); } +TEST(vec_rnx, vec_rnx_rotate_ref) { test_vec_rnx_elemw_rotate(vec_rnx_rotate_ref, poly_rotate); } +TEST(vec_rnx, vec_rnx_automorphism) { test_vec_rnx_elemw_automorphism(vec_rnx_automorphism, poly_automorphism); } +TEST(vec_rnx, vec_rnx_automorphism_ref) { + test_vec_rnx_elemw_automorphism(vec_rnx_automorphism_ref, poly_automorphism); +} + +// test for out of place calls +template +void test_vec_rnx_elemw_unop_outplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function) { + for (uint64_t n : {2, 4, 8, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {7, 13, 15}) { + for (uint64_t sb : {7, 13, 15}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + uint64_t b_sl = uniform_u64_bits(3) * 4 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + rnx_vec_f64_layout lb(n, sb, b_sl); + std::vector expect(sb); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sb; ++i) { + expect[i] = ref_function(la.get_copy_zext(i)); + } + actual_function(mod, // + lb.data(), sb, b_sl, // + la.data(), sa, a_sl // + ); + for (uint64_t i = 0; i < sb; ++i) { + ASSERT_EQ(lb.get_copy_zext(i), expect[i]) << n << " " << sa << " " << sb << " " << i; + } + } + } + } + delete_rnx_module_info(mod); + } +} + +// test for inplace calls +template +void test_vec_rnx_elemw_unop_inplace(ACTUAL_FCN actual_function, EXPECT_FCN ref_function) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + std::vector expect(sa); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + for (uint64_t i = 0; i < sa; ++i) { + expect[i] = ref_function(la.get_copy_zext(i)); + } + actual_function(mod, // N + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} +template +void test_vec_rnx_elemw_unop(ACTUAL_FCN unnop, EXPECT_FCN ref_unnop) { + test_vec_rnx_elemw_unop_outplace(unnop, ref_unnop); + test_vec_rnx_elemw_unop_inplace(unnop, ref_unnop); +} + +static rnx_f64 poly_copy(const rnx_f64& a) { return a; } +static rnx_f64 poly_negate(const rnx_f64& a) { return -a; } + +TEST(vec_rnx, vec_rnx_copy) { test_vec_rnx_elemw_unop(vec_rnx_copy, poly_copy); } +TEST(vec_rnx, vec_rnx_copy_ref) { test_vec_rnx_elemw_unop(vec_rnx_copy_ref, poly_copy); } +TEST(vec_rnx, vec_rnx_negate) { test_vec_rnx_elemw_unop(vec_rnx_negate, poly_negate); } +TEST(vec_rnx, vec_rnx_negate_ref) { test_vec_rnx_elemw_unop(vec_rnx_negate_ref, poly_negate); } +#ifdef __x86_64__ +TEST(vec_rnx, vec_rnx_negate_avx) { test_vec_rnx_elemw_unop(vec_rnx_negate_avx, poly_negate); } +#endif + +// test for inplace calls +void test_vec_rnx_zero(VEC_RNX_ZERO_F actual_function) { + for (uint64_t n : {2, 16, 1024}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + for (uint64_t sa : {2, 6, 11}) { + { + uint64_t a_sl = uniform_u64_bits(3) * 5 + n; + rnx_vec_f64_layout la(n, sa, a_sl); + const rnx_f64 ZERO = rnx_f64::zero(n); + for (uint64_t i = 0; i < sa; ++i) { + la.set(i, rnx_f64::random_log2bound(n, 0)); + } + actual_function(mod, // N + la.data(), sa, a_sl // res + ); + for (uint64_t i = 0; i < sa; ++i) { + ASSERT_EQ(la.get_copy_zext(i), ZERO) << n << " " << sa << " " << i; + } + } + } + delete_rnx_module_info(mod); + } +} + +TEST(vec_rnx, vec_rnx_zero) { test_vec_rnx_zero(vec_rnx_zero); } + +TEST(vec_rnx, vec_rnx_zero_ref) { test_vec_rnx_zero(vec_rnx_zero_ref); } diff --git a/test/testlib/vec_rnx_layout.cpp b/test/testlib/vec_rnx_layout.cpp new file mode 100644 index 0000000..5fe9121 --- /dev/null +++ b/test/testlib/vec_rnx_layout.cpp @@ -0,0 +1,182 @@ +#include "vec_rnx_layout.h" + +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic.h" + +#ifdef VALGRIND_MEM_TESTS +#include "valgrind/memcheck.h" +#endif + +#define CANARY_PADDING (1024) +#define GARBAGE_VALUE (242) + +rnx_vec_f64_layout::rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice) : n(n), size(size), slice(slice) { + REQUIRE_DRAMATICALLY(is_pow2(n), "not a power of 2" << n); + REQUIRE_DRAMATICALLY(slice >= n, "slice too small" << slice << " < " << n); + this->region = (uint8_t*)malloc(size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + this->data_start = (double*)(region + CANARY_PADDING); + // ensure that any invalid value is kind-of garbage + memset(region, GARBAGE_VALUE, size * slice * sizeof(int64_t) + 2 * CANARY_PADDING); + // mark inter-slice memory as not accessible +#ifdef VALGRIND_MEM_TESTS + VALGRIND_MAKE_MEM_NOACCESS(region, CANARY_PADDING); + VALGRIND_MAKE_MEM_NOACCESS(region + size * slice * sizeof(int64_t) + CANARY_PADDING, CANARY_PADDING); + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_UNDEFINED(data_start + i * slice, n * sizeof(int64_t)); + } + if (size != slice) { + for (uint64_t i = 0; i < size; ++i) { + VALGRIND_MAKE_MEM_NOACCESS(data_start + i * slice + n, (slice - n) * sizeof(int64_t)); + } + } +#endif +} + +rnx_vec_f64_layout::~rnx_vec_f64_layout() { free(region); } + +rnx_f64 rnx_vec_f64_layout::get_copy_zext(uint64_t index) const { + if (index < size) { + return rnx_f64(n, data_start + index * slice); + } else { + return rnx_f64::zero(n); + } +} + +rnx_f64 rnx_vec_f64_layout::get_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return rnx_f64(n, data_start + index * slice); +} + +reim_fft64vec rnx_vec_f64_layout::get_dft_copy_zext(uint64_t index) const { + if (index < size) { + return reim_fft64vec(n, data_start + index * slice); + } else { + return reim_fft64vec::zero(n); + } +} + +reim_fft64vec rnx_vec_f64_layout::get_dft_copy(uint64_t index) const { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + return reim_fft64vec(n, data_start + index * slice); +} + +void rnx_vec_f64_layout::set(uint64_t index, const rnx_f64& elem) { + REQUIRE_DRAMATICALLY(index < size, "index overflow: " << index << " / " << size); + REQUIRE_DRAMATICALLY(elem.nn() == n, "incompatible ring dimensions: " << elem.nn() << " / " << n); + elem.save_as(data_start + index * slice); +} + +double* rnx_vec_f64_layout::data() { return data_start; } +const double* rnx_vec_f64_layout::data() const { return data_start; } + +void rnx_vec_f64_layout::fill_random(double log2bound) { + for (uint64_t i = 0; i < size; ++i) { + set(i, rnx_f64::random_log2bound(n, log2bound)); + } +} + +thash rnx_vec_f64_layout::content_hash() const { + test_hasher hasher; + for (uint64_t i = 0; i < size; ++i) { + hasher.update(data() + i * slice, n * sizeof(int64_t)); + } + return hasher.hash(); +} + +fft64_rnx_vmp_pmat_layout::fft64_rnx_vmp_pmat_layout(uint64_t n, uint64_t nrows, uint64_t ncols) + : nn(n), + nrows(nrows), + ncols(ncols), // + data((RNX_VMP_PMAT*)alloc64(nrows * ncols * nn * 8)) {} + +double* fft64_rnx_vmp_pmat_layout::get_addr(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "col overflow: " << col << " / " << ncols); + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + double* d = (double*)data; + if (col == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + return d + blk * nrows * ncols * 8 // major: blk + + col * nrows * 8 // col == ncols-1 + + row * 8; + } else { + // general case: columns go by pair + return d + blk * nrows * ncols * 8 // major: blk + + (col / 2) * (2 * nrows) * 8 // second: col pair index + + row * 2 * 8 // third: row index + + (col % 2) * 8; // minor: col in colpair + } +} + +reim4_elem fft64_rnx_vmp_pmat_layout::get(uint64_t row, uint64_t col, uint64_t blk) const { + return reim4_elem(get_addr(row, col, blk)); +} +reim4_elem fft64_rnx_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col, uint64_t blk) const { + REQUIRE_DRAMATICALLY(blk < nn / 8, "block overflow: " << blk << " / " << (nn / 8)); + if (row < nrows && col < ncols) { + return reim4_elem(get_addr(row, col, blk)); + } else { + return reim4_elem::zero(); + } +} +void fft64_rnx_vmp_pmat_layout::set(uint64_t row, uint64_t col, uint64_t blk, const reim4_elem& value) const { + value.save_as(get_addr(row, col, blk)); +} + +fft64_rnx_vmp_pmat_layout::~fft64_rnx_vmp_pmat_layout() { spqlios_free(data); } + +reim_fft64vec fft64_rnx_vmp_pmat_layout::get_zext(uint64_t row, uint64_t col) const { + if (row >= nrows || col >= ncols) { + return reim_fft64vec::zero(nn); + } + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + return reim_fft64vec(nn, addr); + } + // otherwise, reconstruct it block by block + reim_fft64vec res(nn); + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = get(row, col, blk); + res.set_blk(blk, v); + } + return res; +} +void fft64_rnx_vmp_pmat_layout::set(uint64_t row, uint64_t col, const reim_fft64vec& value) { + REQUIRE_DRAMATICALLY(row < nrows, "row overflow: " << row << " / " << nrows); + REQUIRE_DRAMATICALLY(col < ncols, "row overflow: " << col << " / " << ncols); + if (nn < 8) { + // the pmat is just col major + double* addr = (double*)data + (row + col * nrows) * nn; + value.save_as(addr); + return; + } + // otherwise, reconstruct it block by block + for (uint64_t blk = 0; blk < nn / 8; ++blk) { + reim4_elem v = value.get_blk(blk); + set(row, col, blk, v); + } +} +void fft64_rnx_vmp_pmat_layout::fill_random(double log2bound) { + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + set(row, col, reim_fft64vec::random(nn, log2bound)); + } + } +} + +fft64_rnx_svp_ppol_layout::fft64_rnx_svp_ppol_layout(uint64_t n) + : nn(n), // + data((RNX_SVP_PPOL*)alloc64(nn * 8)) {} + +reim_fft64vec fft64_rnx_svp_ppol_layout::get_copy() const { return reim_fft64vec(nn, (double*)data); } + +void fft64_rnx_svp_ppol_layout::set(const reim_fft64vec& value) { value.save_as((double*)data); } + +void fft64_rnx_svp_ppol_layout::fill_dft_random(uint64_t log2bound) { set(reim_fft64vec::dft_random(nn, log2bound)); } + +void fft64_rnx_svp_ppol_layout::fill_random(double log2bound) { set(reim_fft64vec::random(nn, log2bound)); } + +fft64_rnx_svp_ppol_layout::~fft64_rnx_svp_ppol_layout() { spqlios_free(data); } +thash fft64_rnx_svp_ppol_layout::content_hash() const { return test_hash(data, nn * sizeof(double)); } \ No newline at end of file diff --git a/test/testlib/vec_rnx_layout.h b/test/testlib/vec_rnx_layout.h new file mode 100644 index 0000000..6b2415b --- /dev/null +++ b/test/testlib/vec_rnx_layout.h @@ -0,0 +1,85 @@ +#ifndef SPQLIOS_EXT_VEC_RNX_LAYOUT_H +#define SPQLIOS_EXT_VEC_RNX_LAYOUT_H + +#include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" +#include "testlib/fft64_dft.h" +#include "testlib/negacyclic_polynomial.h" +#include "testlib/reim4_elem.h" +#include "testlib/test_commons.h" + +/** @brief a test memory layout for rnx i64 polynomials vectors */ +class rnx_vec_f64_layout { + uint64_t n; + uint64_t size; + uint64_t slice; + double* data_start; + uint8_t* region; + + public: + // NO-COPY structure + rnx_vec_f64_layout(const rnx_vec_f64_layout&) = delete; + void operator=(const rnx_vec_f64_layout&) = delete; + rnx_vec_f64_layout(rnx_vec_f64_layout&&) = delete; + void operator=(rnx_vec_f64_layout&&) = delete; + /** @brief initialises a memory layout */ + rnx_vec_f64_layout(uint64_t n, uint64_t size, uint64_t slice); + /** @brief destructor */ + ~rnx_vec_f64_layout(); + + /** @brief get a copy of item index index (extended with zeros) */ + rnx_f64 get_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + rnx_f64 get_copy(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + reim_fft64vec get_dft_copy_zext(uint64_t index) const; + /** @brief get a copy of item index index (extended with zeros) */ + reim_fft64vec get_dft_copy(uint64_t index) const; + + /** @brief get a copy of item index index (index Date: Wed, 14 Aug 2024 14:13:04 +0000 Subject: [PATCH 08/11] add rnx tests --- test/spqlios_vec_rnx_conversions_test.cpp | 134 ++++++++++ test/spqlios_vec_rnx_ppol_test.cpp | 73 ++++++ test/spqlios_vec_rnx_vmp_test.cpp | 291 ++++++++++++++++++++++ 3 files changed, 498 insertions(+) create mode 100644 test/spqlios_vec_rnx_conversions_test.cpp create mode 100644 test/spqlios_vec_rnx_ppol_test.cpp create mode 100644 test/spqlios_vec_rnx_vmp_test.cpp diff --git a/test/spqlios_vec_rnx_conversions_test.cpp b/test/spqlios_vec_rnx_conversions_test.cpp new file mode 100644 index 0000000..3b629e6 --- /dev/null +++ b/test/spqlios_vec_rnx_conversions_test.cpp @@ -0,0 +1,134 @@ +#include +#include + +#include "testlib/test_commons.h" + +template +static void test_conv(void (*conv_f)(const MOD_RNX*, // + DST_T* res, uint64_t res_size, uint64_t res_sl, // + const SRC_T* a, uint64_t a_size, uint64_t a_sl), // + DST_T (*ideal_conv_f)(SRC_T x), // + SRC_T (*random_f)() // +) { + for (uint64_t nn : {2, 4, 16, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t a_size : {0, 1, 2, 5}) { + for (uint64_t res_size : {0, 1, 3, 5}) { + for (uint64_t trials = 0; trials < 20; ++trials) { + uint64_t a_sl = nn + uniform_u64_bits(2); + uint64_t res_sl = nn + uniform_u64_bits(2); + std::vector a(a_sl * a_size); + std::vector res(res_sl * res_size); + uint64_t msize = std::min(a_size, res_size); + for (uint64_t i = 0; i < a_size; ++i) { + for (uint64_t j = 0; j < nn; ++j) { + a[i * a_sl + j] = random_f(); + } + } + conv_f(module, res.data(), res_size, res_sl, a.data(), a_size, a_sl); + for (uint64_t i = 0; i < msize; ++i) { + for (uint64_t j = 0; j < nn; ++j) { + SRC_T aij = a[i * a_sl + j]; + DST_T expect = ideal_conv_f(aij); + DST_T actual = res[i * res_sl + j]; + ASSERT_EQ(expect, actual); + } + } + for (uint64_t i = msize; i < res_size; ++i) { + DST_T expect = 0; + for (uint64_t j = 0; j < nn; ++j) { + SRC_T actual = res[i * res_sl + j]; + ASSERT_EQ(expect, actual); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static int32_t ideal_dbl_to_tn32(double a) { + double _2p32 = INT64_C(1) << 32; + double a_mod_1 = a - rint(a); + int64_t t = rint(a_mod_1 * _2p32); + return int32_t(t); +} + +static double random_f64_10() { return uniform_f64_bounds(-10, 10); } + +static void test_vec_rnx_to_tnx32(VEC_RNX_TO_TNX32_F vec_rnx_to_tnx32_f) { + test_conv(vec_rnx_to_tnx32_f, ideal_dbl_to_tn32, random_f64_10); +} + +TEST(vec_rnx_arithmetic, vec_rnx_to_tnx32) { test_vec_rnx_to_tnx32(vec_rnx_to_tnx32); } +TEST(vec_rnx_arithmetic, vec_rnx_to_tnx32_ref) { test_vec_rnx_to_tnx32(vec_rnx_to_tnx32_ref); } + +static double ideal_tn32_to_dbl(int32_t a) { + const double _2p32 = INT64_C(1) << 32; + return double(a) / _2p32; +} + +static int32_t random_t32() { return uniform_i64_bits(32); } + +static void test_vec_rnx_from_tnx32(VEC_RNX_FROM_TNX32_F vec_rnx_from_tnx32_f) { + test_conv(vec_rnx_from_tnx32_f, ideal_tn32_to_dbl, random_t32); +} + +TEST(vec_rnx_arithmetic, vec_rnx_from_tnx32) { test_vec_rnx_from_tnx32(vec_rnx_from_tnx32); } +TEST(vec_rnx_arithmetic, vec_rnx_from_tnx32_ref) { test_vec_rnx_from_tnx32(vec_rnx_from_tnx32_ref); } + +static int32_t ideal_dbl_round_to_i32(double a) { return int32_t(rint(a)); } + +static double random_dbl_explaw_18() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(6) % 19); } + +static void test_vec_rnx_to_znx32(VEC_RNX_TO_ZNX32_F vec_rnx_to_znx32_f) { + test_conv(vec_rnx_to_znx32_f, ideal_dbl_round_to_i32, random_dbl_explaw_18); +} + +TEST(zn_arithmetic, vec_rnx_to_znx32) { test_vec_rnx_to_znx32(vec_rnx_to_znx32); } +TEST(zn_arithmetic, vec_rnx_to_znx32_ref) { test_vec_rnx_to_znx32(vec_rnx_to_znx32_ref); } + +static double ideal_i32_to_dbl(int32_t a) { return double(a); } + +static int32_t random_i32_explaw_18() { return uniform_i64_bits(uniform_u64_bits(6) % 19); } + +static void test_vec_rnx_from_znx32(VEC_RNX_FROM_ZNX32_F vec_rnx_from_znx32_f) { + test_conv(vec_rnx_from_znx32_f, ideal_i32_to_dbl, random_i32_explaw_18); +} + +TEST(zn_arithmetic, vec_rnx_from_znx32) { test_vec_rnx_from_znx32(vec_rnx_from_znx32); } +TEST(zn_arithmetic, vec_rnx_from_znx32_ref) { test_vec_rnx_from_znx32(vec_rnx_from_znx32_ref); } + +static double ideal_dbl_to_tndbl(double a) { return a - rint(a); } + +static void test_vec_rnx_to_tnxdbl(VEC_RNX_TO_TNXDBL_F vec_rnx_to_tnxdbl_f) { + test_conv(vec_rnx_to_tnxdbl_f, ideal_dbl_to_tndbl, random_f64_10); +} + +TEST(zn_arithmetic, vec_rnx_to_tnxdbl) { test_vec_rnx_to_tnxdbl(vec_rnx_to_tnxdbl); } +TEST(zn_arithmetic, vec_rnx_to_tnxdbl_ref) { test_vec_rnx_to_tnxdbl(vec_rnx_to_tnxdbl_ref); } + +#if 0 +static int64_t ideal_dbl_round_to_i64(double a) { return rint(a); } + +static double random_dbl_explaw_50() { return uniform_f64_bounds(-1., 1.) * pow(2., uniform_u64_bits(7) % 51); } + +static void test_dbl_round_to_i64(DBL_ROUND_TO_I64_F dbl_round_to_i64_f) { + test_conv(dbl_round_to_i64_f, ideal_dbl_round_to_i64, random_dbl_explaw_50); +} + +TEST(zn_arithmetic, dbl_round_to_i64) { test_dbl_round_to_i64(dbl_round_to_i64); } +TEST(zn_arithmetic, dbl_round_to_i64_ref) { test_dbl_round_to_i64(dbl_round_to_i64_ref); } + +static double ideal_i64_to_dbl(int64_t a) { return double(a); } + +static int64_t random_i64_explaw_50() { return uniform_i64_bits(uniform_u64_bits(7) % 51); } + +static void test_i64_to_dbl(I64_TO_DBL_F i64_to_dbl_f) { + test_conv(i64_to_dbl_f, ideal_i64_to_dbl, random_i64_explaw_50); +} + +TEST(zn_arithmetic, i64_to_dbl) { test_i64_to_dbl(i64_to_dbl); } +TEST(zn_arithmetic, i64_to_dbl_ref) { test_i64_to_dbl(i64_to_dbl_ref); } +#endif diff --git a/test/spqlios_vec_rnx_ppol_test.cpp b/test/spqlios_vec_rnx_ppol_test.cpp new file mode 100644 index 0000000..e5e0cbd --- /dev/null +++ b/test/spqlios_vec_rnx_ppol_test.cpp @@ -0,0 +1,73 @@ +#include + +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "test/testlib/vec_rnx_layout.h" + +static void test_vec_rnx_svp_prepare(RNX_SVP_PREPARE_F* rnx_svp_prepare, BYTES_OF_RNX_SVP_PPOL_F* tmp_bytes) { + for (uint64_t n : {2, 4, 8, 64}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + const double invm = 1. / mod->m; + + rnx_f64 in = rnx_f64::random_log2bound(n, 40); + rnx_f64 in_divide_by_m = rnx_f64::zero(n); + for (uint64_t i = 0; i < n; ++i) { + in_divide_by_m.set_coeff(i, in.get_coeff(i) * invm); + } + fft64_rnx_svp_ppol_layout out(n); + reim_fft64vec expect = simple_fft64(in_divide_by_m); + rnx_svp_prepare(mod, out.data, in.data()); + const double* ed = (double*)expect.data(); + const double* ac = (double*)out.data; + for (uint64_t i = 0; i < n; ++i) { + ASSERT_LE(abs(ed[i] - ac[i]), 1e-10) << i << n; + } + delete_rnx_module_info(mod); + } +} +TEST(vec_rnx, vec_rnx_svp_prepare) { test_vec_rnx_svp_prepare(rnx_svp_prepare, bytes_of_rnx_svp_ppol); } +TEST(vec_rnx, vec_rnx_svp_prepare_ref) { + test_vec_rnx_svp_prepare(fft64_rnx_svp_prepare_ref, fft64_bytes_of_rnx_svp_ppol); +} + +static void test_vec_rnx_svp_apply(RNX_SVP_APPLY_F* apply) { + for (uint64_t n : {2, 4, 8, 64, 128}) { + MOD_RNX* mod = new_rnx_module_info(n, FFT64); + + // poly 1 to multiply - create and prepare + fft64_rnx_svp_ppol_layout ppol(n); + ppol.fill_random(1.); + for (uint64_t sa : {3, 5, 8}) { + for (uint64_t sr : {3, 5, 8}) { + uint64_t a_sl = n + uniform_u64_bits(2); + uint64_t r_sl = n + uniform_u64_bits(2); + // poly 2 to multiply + rnx_vec_f64_layout a(n, sa, a_sl); + a.fill_random(19); + + // original operation result + rnx_vec_f64_layout res(n, sr, r_sl); + thash hash_a_before = a.content_hash(); + thash hash_ppol_before = ppol.content_hash(); + apply(mod, res.data(), sr, r_sl, ppol.data, a.data(), sa, a_sl); + ASSERT_EQ(a.content_hash(), hash_a_before); + ASSERT_EQ(ppol.content_hash(), hash_ppol_before); + // create expected value + reim_fft64vec ppo = ppol.get_copy(); + std::vector expect(sr); + for (uint64_t i = 0; i < sr; ++i) { + expect[i] = simple_ifft64(ppo * simple_fft64(a.get_copy_zext(i))); + } + // this is the largest precision we can safely expect + double prec_expect = n * pow(2., 19 - 50); + for (uint64_t i = 0; i < sr; ++i) { + rnx_f64 actual = res.get_copy_zext(i); + ASSERT_LE(infty_dist(actual, expect[i]), prec_expect); + } + } + } + delete_rnx_module_info(mod); + } +} +TEST(vec_rnx, vec_rnx_svp_apply) { test_vec_rnx_svp_apply(rnx_svp_apply); } +TEST(vec_rnx, vec_rnx_svp_apply_ref) { test_vec_rnx_svp_apply(fft64_rnx_svp_apply_ref); } diff --git a/test/spqlios_vec_rnx_vmp_test.cpp b/test/spqlios_vec_rnx_vmp_test.cpp new file mode 100644 index 0000000..9e09a2b --- /dev/null +++ b/test/spqlios_vec_rnx_vmp_test.cpp @@ -0,0 +1,291 @@ +#include "gtest/gtest.h" +#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "spqlios/reim/reim_fft.h" +#include "test/testlib/vec_rnx_layout.h" + +static void test_vmp_apply_dft_to_dft_outplace( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_sl = nn + uniform_u64_bits(2); + const uint64_t out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, in_size, in_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + rnx_vec_f64_layout out(nn, out_size, out_sl); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in.get_dft_copy(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + out.data(), out_size, out_sl, // + in.data(), in_size, in_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = out.get_dft_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_dft_to_dft_inplace( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 2, 6}) { + for (uint64_t mat_ncols : {1, 2, 7, 8}) { + for (uint64_t in_size : {1, 3, 6}) { + for (uint64_t out_size : {1, 3, 6}) { + const uint64_t in_out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in_out(nn, std::max(in_size, out_size), in_out_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + in_out.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size, reim_fft64vec(nn)); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * in_out.get_dft_copy(row); + } + expect[col] = ex; + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + in_out.data(), out_size, in_out_sl, // + in_out.data(), in_size, in_out_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec actual = in_out.get_dft_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_dft_to_dft( // + RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // + RNX_VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* tmp_bytes) { + test_vmp_apply_dft_to_dft_outplace(apply, tmp_bytes); + test_vmp_apply_dft_to_dft_inplace(apply, tmp_bytes); +} + +TEST(vec_rnx, vmp_apply_to_dft) { + test_vmp_apply_dft_to_dft(rnx_vmp_apply_dft_to_dft, rnx_vmp_apply_dft_to_dft_tmp_bytes); +} +TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_ref) { + test_vmp_apply_dft_to_dft(fft64_rnx_vmp_apply_dft_to_dft_ref, fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_avx) { + test_vmp_apply_dft_to_dft(fft64_rnx_vmp_apply_dft_to_dft_avx, fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_prepare + +static void test_vmp_prepare_contiguous(RNX_VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, + RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_znx, vmp_prepare_contiguous) { + test_vmp_prepare_contiguous(rnx_vmp_prepare_contiguous, rnx_vmp_prepare_contiguous_tmp_bytes); +} +TEST(vec_znx, fft64_vmp_prepare_contiguous_ref) { + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_ref, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_prepare_contiguous_avx) { + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_avx, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_apply_dft_to_dft + +static void test_vmp_apply_tmp_a_outplace( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_sl = nn + uniform_u64_bits(2); + const uint64_t out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in(nn, in_size, in_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + rnx_vec_f64_layout out(nn, out_size, out_sl); + in.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * simple_fft64(in.get_copy(row)); + } + expect[col] = simple_ifft64(ex); + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + out.data(), out_size, out_sl, // + in.data(), in_size, in_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + rnx_f64 actual = out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_tmp_a_inplace( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + for (uint64_t nn : {2, 4, 8, 64}) { + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t mat_nrows : {1, 4, 7}) { + for (uint64_t mat_ncols : {1, 2, 5}) { + for (uint64_t in_size : {1, 4, 7}) { + for (uint64_t out_size : {1, 2, 5}) { + const uint64_t in_out_sl = nn + uniform_u64_bits(2); + rnx_vec_f64_layout in_out(nn, std::max(in_size, out_size), in_out_sl); + fft64_rnx_vmp_pmat_layout pmat(nn, mat_nrows, mat_ncols); + in_out.fill_random(0); + pmat.fill_random(0); + // naive computation of the product + std::vector expect(out_size); + for (uint64_t col = 0; col < out_size; ++col) { + reim_fft64vec ex = reim_fft64vec::zero(nn); + for (uint64_t row = 0; row < std::min(mat_nrows, in_size); ++row) { + ex += pmat.get_zext(row, col) * simple_fft64(in_out.get_copy(row)); + } + expect[col] = simple_ifft64(ex); + } + // apply the product + std::vector tmp(tmp_bytes(module, out_size, in_size, mat_nrows, mat_ncols)); + apply(module, // + in_out.data(), out_size, in_out_sl, // + in_out.data(), in_size, in_out_sl, // + pmat.data, mat_nrows, mat_ncols, // + tmp.data()); + // check that the output is close from the expectation + for (uint64_t col = 0; col < out_size; ++col) { + rnx_f64 actual = in_out.get_copy_zext(col); + ASSERT_LE(infty_dist(actual, expect[col]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +static void test_vmp_apply_tmp_a( // + RNX_VMP_APPLY_TMP_A_F* apply, // + RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* tmp_bytes) { + test_vmp_apply_tmp_a_outplace(apply, tmp_bytes); + test_vmp_apply_tmp_a_inplace(apply, tmp_bytes); +} + +TEST(vec_znx, fft64_vmp_apply_tmp_a) { test_vmp_apply_tmp_a(rnx_vmp_apply_tmp_a, rnx_vmp_apply_tmp_a_tmp_bytes); } +TEST(vec_znx, fft64_vmp_apply_tmp_a_ref) { + test_vmp_apply_tmp_a(fft64_rnx_vmp_apply_tmp_a_ref, fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_apply_tmp_a_avx) { + test_vmp_apply_tmp_a(fft64_rnx_vmp_apply_tmp_a_avx, fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx); +} +#endif From b21ebaa2f9cd69a73b04ba34dc86105a28f1a797 Mon Sep 17 00:00:00 2001 From: Sandra Guasch Date: Wed, 14 Aug 2024 15:58:59 +0000 Subject: [PATCH 09/11] add classes to build --- spqlios/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt index 1738c44..4326576 100644 --- a/spqlios/CMakeLists.txt +++ b/spqlios/CMakeLists.txt @@ -39,6 +39,9 @@ set(SRCS_GENERIC arithmetic/zn_api.c arithmetic/zn_conversions_ref.c arithmetic/zn_approxdecomp_ref.c + arithmetic/vec_rnx_api.c + arithmetic/vec_rnx_conversions_ref.c + arithmetic/vec_rnx_svp_ref.c reim/reim_execute.c cplx/cplx_execute.c reim4/reim4_execute.c From 0b4f2de2f38d72ad57d25653e19da143515c0150 Mon Sep 17 00:00:00 2001 From: Sandra Guasch Date: Tue, 20 Aug 2024 07:28:26 +0000 Subject: [PATCH 10/11] fix tests --- test/CMakeLists.txt | 6 ++++++ test/spqlios_vec_rnx_test.cpp | 32 +++++++++++++++---------------- test/spqlios_vec_rnx_vmp_test.cpp | 12 ++++++------ test/testlib/vec_rnx_layout.cpp | 2 +- test/testlib/vec_rnx_layout.h | 8 ++++---- test/testlib/zn_layouts.h | 4 ++-- 6 files changed, 35 insertions(+), 29 deletions(-) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index bac4bd3..874477e 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -56,6 +56,8 @@ add_library(spqlios-testlib SHARED testlib/sha3.c testlib/polynomial_vector.h testlib/polynomial_vector.cpp + testlib/vec_rnx_layout.h + testlib/vec_rnx_layout.cpp ) if (VALGRIND_DIR) target_include_directories(spqlios-testlib PRIVATE ${VALGRIND_DIR}) @@ -84,6 +86,10 @@ set(UNITTEST_FILES spqlios_svp_test.cpp spqlios_svp_product_test.cpp spqlios_vec_znx_test.cpp + spqlios_vec_rnx_vmp_test.cpp + spqlios_vec_rnx_conversions_test.cpp + spqlios_vec_rnx_ppol_test.cpp + ) add_executable(spqlios-test ${UNITTEST_FILES}) diff --git a/test/spqlios_vec_rnx_test.cpp b/test/spqlios_vec_rnx_test.cpp index 0aa8723..2990299 100644 --- a/test/spqlios_vec_rnx_test.cpp +++ b/test/spqlios_vec_rnx_test.cpp @@ -2,7 +2,7 @@ #include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" #include "spqlios/reim/reim_fft.h" -#include "test/testlib/vec_rnx_layout.h" +#include "testlib/vec_rnx_layout.h" // disabling this test by default, since it depicts on purpose wrong accesses #if 0 @@ -31,7 +31,7 @@ TEST(rnx_layout, valgrind_antipattern_test) { // test for out of place calls template -void test_vec_znx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { +void test_vec_rnx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { for (uint64_t n : {2, 4, 8, 128}) { RNX_MODULE_TYPE mtype = FFT64; MOD_RNX* mod = new_rnx_module_info(n, mtype); @@ -69,7 +69,7 @@ void test_vec_znx_elemw_binop_outplace(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { } // test for inplace1 calls template -void test_vec_znx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { +void test_vec_rnx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { for (uint64_t n : {2, 4, 64}) { RNX_MODULE_TYPE mtype = FFT64; MOD_RNX* mod = new_rnx_module_info(n, mtype); @@ -103,7 +103,7 @@ void test_vec_znx_elemw_binop_inplace1(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { } // test for inplace2 calls template -void test_vec_znx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { +void test_vec_rnx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { for (uint64_t n : {4, 32, 64}) { RNX_MODULE_TYPE mtype = FFT64; MOD_RNX* mod = new_rnx_module_info(n, mtype); @@ -137,7 +137,7 @@ void test_vec_znx_elemw_binop_inplace2(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { } // test for inplace3 calls template -void test_vec_znx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { +void test_vec_rnx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { for (uint64_t n : {2, 16, 1024}) { RNX_MODULE_TYPE mtype = FFT64; MOD_RNX* mod = new_rnx_module_info(n, mtype); @@ -163,24 +163,24 @@ void test_vec_znx_elemw_binop_inplace3(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { } } template -void test_vec_znx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { - test_vec_znx_elemw_binop_outplace(binop, ref_binop); - test_vec_znx_elemw_binop_inplace1(binop, ref_binop); - test_vec_znx_elemw_binop_inplace2(binop, ref_binop); - test_vec_znx_elemw_binop_inplace3(binop, ref_binop); +void test_vec_rnx_elemw_binop(ACTUAL_FCN binop, EXPECT_FCN ref_binop) { + test_vec_rnx_elemw_binop_outplace(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace1(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace2(binop, ref_binop); + test_vec_rnx_elemw_binop_inplace3(binop, ref_binop); } static rnx_f64 poly_add(const rnx_f64& a, const rnx_f64& b) { return a + b; } static rnx_f64 poly_sub(const rnx_f64& a, const rnx_f64& b) { return a - b; } -TEST(vec_znx, vec_znx_add) { test_vec_znx_elemw_binop(vec_rnx_add, poly_add); } -TEST(vec_znx, vec_znx_add_ref) { test_vec_znx_elemw_binop(vec_rnx_add_ref, poly_add); } +TEST(vec_rnx, vec_rnx_add) { test_vec_rnx_elemw_binop(vec_rnx_add, poly_add); } +TEST(vec_rnx, vec_rnx_add_ref) { test_vec_rnx_elemw_binop(vec_rnx_add_ref, poly_add); } #ifdef __x86_64__ -TEST(vec_znx, vec_znx_add_avx) { test_vec_znx_elemw_binop(vec_rnx_add_avx, poly_add); } +TEST(vec_rnx, vec_rnx_add_avx) { test_vec_rnx_elemw_binop(vec_rnx_add_avx, poly_add); } #endif -TEST(vec_znx, vec_znx_sub) { test_vec_znx_elemw_binop(vec_rnx_sub, poly_sub); } -TEST(vec_znx, vec_znx_sub_ref) { test_vec_znx_elemw_binop(vec_rnx_sub_ref, poly_sub); } +TEST(vec_rnx, vec_rnx_sub) { test_vec_rnx_elemw_binop(vec_rnx_sub, poly_sub); } +TEST(vec_rnx, vec_rnx_sub_ref) { test_vec_rnx_elemw_binop(vec_rnx_sub_ref, poly_sub); } #ifdef __x86_64__ -TEST(vec_znx, vec_znx_sub_avx) { test_vec_znx_elemw_binop(vec_rnx_sub_avx, poly_sub); } +TEST(vec_rnx, vec_rnx_sub_avx) { test_vec_rnx_elemw_binop(vec_rnx_sub_avx, poly_sub); } #endif // test for out of place calls diff --git a/test/spqlios_vec_rnx_vmp_test.cpp b/test/spqlios_vec_rnx_vmp_test.cpp index 9e09a2b..9bbb9d7 100644 --- a/test/spqlios_vec_rnx_vmp_test.cpp +++ b/test/spqlios_vec_rnx_vmp_test.cpp @@ -1,7 +1,7 @@ #include "gtest/gtest.h" -#include "spqlios/arithmetic/vec_rnx_arithmetic_private.h" -#include "spqlios/reim/reim_fft.h" -#include "test/testlib/vec_rnx_layout.h" +#include "../spqlios/arithmetic/vec_rnx_arithmetic_private.h" +#include "../spqlios/reim/reim_fft.h" +#include "testlib/vec_rnx_layout.h" static void test_vmp_apply_dft_to_dft_outplace( // RNX_VMP_APPLY_DFT_TO_DFT_F* apply, // @@ -171,14 +171,14 @@ static void test_vmp_prepare_contiguous(RNX_VMP_PREPARE_CONTIGUOUS_F* prepare_co } } -TEST(vec_znx, vmp_prepare_contiguous) { +TEST(vec_rnx, vmp_prepare_contiguous) { test_vmp_prepare_contiguous(rnx_vmp_prepare_contiguous, rnx_vmp_prepare_contiguous_tmp_bytes); } -TEST(vec_znx, fft64_vmp_prepare_contiguous_ref) { +TEST(vec_rnx, fft64_vmp_prepare_contiguous_ref) { test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_ref, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref); } #ifdef __x86_64__ -TEST(vec_znx, fft64_vmp_prepare_contiguous_avx) { +TEST(vec_rnx, fft64_vmp_prepare_contiguous_avx) { test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_avx, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx); } #endif diff --git a/test/testlib/vec_rnx_layout.cpp b/test/testlib/vec_rnx_layout.cpp index 5fe9121..2a61e81 100644 --- a/test/testlib/vec_rnx_layout.cpp +++ b/test/testlib/vec_rnx_layout.cpp @@ -2,7 +2,7 @@ #include -#include "spqlios/arithmetic/vec_rnx_arithmetic.h" +#include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" #ifdef VALGRIND_MEM_TESTS #include "valgrind/memcheck.h" diff --git a/test/testlib/vec_rnx_layout.h b/test/testlib/vec_rnx_layout.h index 6b2415b..a92bc04 100644 --- a/test/testlib/vec_rnx_layout.h +++ b/test/testlib/vec_rnx_layout.h @@ -2,10 +2,10 @@ #define SPQLIOS_EXT_VEC_RNX_LAYOUT_H #include "../../spqlios/arithmetic/vec_rnx_arithmetic.h" -#include "testlib/fft64_dft.h" -#include "testlib/negacyclic_polynomial.h" -#include "testlib/reim4_elem.h" -#include "testlib/test_commons.h" +#include "fft64_dft.h" +#include "negacyclic_polynomial.h" +#include "reim4_elem.h" +#include "test_commons.h" /** @brief a test memory layout for rnx i64 polynomials vectors */ class rnx_vec_f64_layout { diff --git a/test/testlib/zn_layouts.h b/test/testlib/zn_layouts.h index 6a89173..b36ce3e 100644 --- a/test/testlib/zn_layouts.h +++ b/test/testlib/zn_layouts.h @@ -1,8 +1,8 @@ #ifndef SPQLIOS_EXT_ZN_LAYOUTS_H #define SPQLIOS_EXT_ZN_LAYOUTS_H -#include "spqlios/arithmetic/zn_arithmetic.h" -#include "testlib/test_commons.h" +#include "../../spqlios/arithmetic/zn_arithmetic.h" +#include "test_commons.h" class zn32_pmat_layout { public: From 42a2343dc720f7aa3f37a1d44007340b42c8ad5e Mon Sep 17 00:00:00 2001 From: Sandra Guasch Date: Tue, 20 Aug 2024 07:36:22 +0000 Subject: [PATCH 11/11] add remaining tests --- test/CMakeLists.txt | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 874477e..132ba71 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -58,6 +58,8 @@ add_library(spqlios-testlib SHARED testlib/polynomial_vector.cpp testlib/vec_rnx_layout.h testlib/vec_rnx_layout.cpp + testlib/zn_layouts.h + testlib/zn_layouts.cpp ) if (VALGRIND_DIR) target_include_directories(spqlios-testlib PRIVATE ${VALGRIND_DIR}) @@ -86,9 +88,15 @@ set(UNITTEST_FILES spqlios_svp_test.cpp spqlios_svp_product_test.cpp spqlios_vec_znx_test.cpp + spqlios_vec_rnx_test.cpp spqlios_vec_rnx_vmp_test.cpp spqlios_vec_rnx_conversions_test.cpp spqlios_vec_rnx_ppol_test.cpp + spqlios_vec_rnx_approxdecomp_tnxdbl_test.cpp + spqlios_zn_approxdecomp_test.cpp + spqlios_zn_conversions_test.cpp + spqlios_zn_vmp_test.cpp + )