diff --git a/.gitignore b/.gitignore index 90058ad..d682397 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,4 @@ cmake-build-* .idea + +build diff --git a/spqlios/arithmetic/module_api.c b/spqlios/arithmetic/module_api.c index 52140a0..1c85fca 100644 --- a/spqlios/arithmetic/module_api.c +++ b/spqlios/arithmetic/module_api.c @@ -46,14 +46,14 @@ static void fill_fft64_virtual_table(MODULE* module) { module->func.znx_small_single_product = fft64_znx_small_single_product; module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes; module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref; - module->func.vmp_prepare_contiguous_tmp_bytes = fft64_vmp_prepare_contiguous_tmp_bytes; + module->func.vmp_prepare_dblptr = fft64_vmp_prepare_dblptr_ref; + module->func.vmp_prepare_row = fft64_vmp_prepare_row_ref; + module->func.vmp_prepare_tmp_bytes = fft64_vmp_prepare_tmp_bytes; module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref; module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes; module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref; module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes; module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; - module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; - module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft; module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big; module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol; module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat; @@ -61,6 +61,8 @@ static void fill_fft64_virtual_table(MODULE* module) { // TODO add avx handlers here // TODO: enable when avx implementation is done module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx; + module->func.vmp_prepare_dblptr = fft64_vmp_prepare_dblptr_avx; + module->func.vmp_prepare_row = fft64_vmp_prepare_row_avx; module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx; module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx; } diff --git a/spqlios/arithmetic/vec_rnx_api.c b/spqlios/arithmetic/vec_rnx_api.c index 0f396fb..d1664ac 100644 --- a/spqlios/arithmetic/vec_rnx_api.c +++ b/spqlios/arithmetic/vec_rnx_api.c @@ -33,8 +33,10 @@ void fft64_init_rnx_module_vtable(MOD_RNX* module) { module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref; module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref; module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref; - module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref; + module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_ref; module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref; + module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_ref; + module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_ref; module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat; module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref; module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref; @@ -55,8 +57,10 @@ void fft64_init_rnx_module_vtable(MOD_RNX* module) { module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx; module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx; module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx; - module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx; + module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_avx; module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx; + module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_avx; + module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_avx; module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx; } } @@ -201,9 +205,29 @@ EXPORT void rnx_vmp_prepare_contiguous( // module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space); } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void rnx_vmp_prepare_dblptr( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_dblptr(module, pmat, a, nrows, ncols, tmp_space); +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void rnx_vmp_prepare_row( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->vtable.rnx_vmp_prepare_row(module, pmat, a, row_i, nrows, ncols, tmp_space); +} + /** @brief number of scratch bytes necessary to prepare a matrix */ -EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) { - return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module); +EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module) { + return module->vtable.rnx_vmp_prepare_tmp_bytes(module); } /** @brief applies a vmp product res = a x pmat */ diff --git a/spqlios/arithmetic/vec_rnx_arithmetic.h b/spqlios/arithmetic/vec_rnx_arithmetic.h index 16a5e6d..c88bddb 100644 --- a/spqlios/arithmetic/vec_rnx_arithmetic.h +++ b/spqlios/arithmetic/vec_rnx_arithmetic.h @@ -289,8 +289,24 @@ EXPORT void rnx_vmp_prepare_contiguous( // uint8_t* tmp_space // scratch space ); +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void rnx_vmp_prepare_dblptr( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** a, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void rnx_vmp_prepare_row( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + /** @brief number of scratch bytes necessary to prepare a matrix */ -EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module); +EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module); /** @brief applies a vmp product res = a x pmat */ EXPORT void rnx_vmp_apply_tmp_a( // diff --git a/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h b/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h index f2e07eb..99277b4 100644 --- a/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h +++ b/spqlios/arithmetic/vec_rnx_arithmetic_plugin.h @@ -35,7 +35,9 @@ typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F; typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F; typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F; typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F; -typedef typeof(rnx_vmp_prepare_contiguous_tmp_bytes) RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F; +typedef typeof(rnx_vmp_prepare_dblptr) RNX_VMP_PREPARE_DBLPTR_F; +typedef typeof(rnx_vmp_prepare_row) RNX_VMP_PREPARE_ROW_F; +typedef typeof(rnx_vmp_prepare_tmp_bytes) RNX_VMP_PREPARE_TMP_BYTES_F; typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F; typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F; typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F; @@ -76,7 +78,9 @@ struct rnx_module_vtable_t { RNX_SVP_APPLY_F* rnx_svp_apply; BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat; RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous; - RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* rnx_vmp_prepare_contiguous_tmp_bytes; + RNX_VMP_PREPARE_DBLPTR_F* rnx_vmp_prepare_dblptr; + RNX_VMP_PREPARE_ROW_F* rnx_vmp_prepare_row; + RNX_VMP_PREPARE_TMP_BYTES_F* rnx_vmp_prepare_tmp_bytes; RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a; RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes; RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft; diff --git a/spqlios/arithmetic/vec_rnx_arithmetic_private.h b/spqlios/arithmetic/vec_rnx_arithmetic_private.h index 59a4cf8..6bf79c2 100644 --- a/spqlios/arithmetic/vec_rnx_arithmetic_private.h +++ b/spqlios/arithmetic/vec_rnx_arithmetic_private.h @@ -183,8 +183,32 @@ EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // const double* mat, uint64_t nrows, uint64_t ncols, // a uint8_t* tmp_space // scratch space ); -EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module); -EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module); +EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_row_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT void fft64_rnx_vmp_prepare_row_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module); +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module); EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( // const MOD_RNX* module, // N diff --git a/spqlios/arithmetic/vec_rnx_vmp_avx.c b/spqlios/arithmetic/vec_rnx_vmp_avx.c index 4c1b23d..7a492bc 100644 --- a/spqlios/arithmetic/vec_rnx_vmp_avx.c +++ b/spqlios/arithmetic/vec_rnx_vmp_avx.c @@ -57,6 +57,64 @@ EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( // } } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + fft64_rnx_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_rnx_vmp_prepare_row_avx( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_avx(nn, m, dtmp, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } else { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_avx(nn, m, res, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } +} + /** @brief minimal size of the tmp_space */ EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( // const MOD_RNX* module, // N diff --git a/spqlios/arithmetic/vec_rnx_vmp_ref.c b/spqlios/arithmetic/vec_rnx_vmp_ref.c index de14ba8..1a91f3c 100644 --- a/spqlios/arithmetic/vec_rnx_vmp_ref.c +++ b/spqlios/arithmetic/vec_rnx_vmp_ref.c @@ -62,8 +62,66 @@ EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( // } } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + fft64_rnx_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_rnx_vmp_prepare_row_ref( // + const MOD_RNX* module, // N + RNX_VMP_PMAT* pmat, // output + const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->n; + const uint64_t m = module->m; + + double* const dtmp = (double*)tmp_space; + double* const output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + rnx_divide_by_m_ref(nn, m, dtmp, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, dtmp); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp); + } + } + } else { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = output_mat + (col_i * nrows + row_i) * nn; + rnx_divide_by_m_ref(nn, m, res, row + col_i * nn); + reim_fft(module->precomp.fft64.p_fft, res); + } + } +} + /** @brief number of scratch bytes necessary to prepare a matrix */ -EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module) { +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module) { const uint64_t nn = module->n; return nn * sizeof(int64_t); } @@ -220,10 +278,10 @@ EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( // /** @brief number of scratch bytes necessary to prepare a matrix */ #ifdef __APPLE__ -#pragma weak fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref +#pragma weak fft64_rnx_vmp_prepare_tmp_bytes_avx = fft64_rnx_vmp_prepare_tmp_bytes_ref #else -EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module) - __attribute((alias("fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref"))); +EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module) + __attribute((alias("fft64_rnx_vmp_prepare_tmp_bytes_ref"))); #endif /** @brief minimal size of the tmp_space */ diff --git a/spqlios/arithmetic/vec_znx.c b/spqlios/arithmetic/vec_znx.c index a850bfc..4590baa 100644 --- a/spqlios/arithmetic/vec_znx.c +++ b/spqlios/arithmetic/vec_znx.c @@ -249,27 +249,26 @@ EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // return nn * sizeof(int64_t); } - // alias have to be defined in this unit: do not move #ifdef __APPLE__ EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // const MODULE* module // N - ) { +) { return vec_znx_normalize_base2k_tmp_bytes_ref(module); } EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // - const MODULE* module // N + const MODULE* module // N ) { return vec_znx_normalize_base2k_tmp_bytes_ref(module); } #else EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( // - const MODULE* module // N -) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); + const MODULE* module // N + ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( // const MODULE* module // N -) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); + ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref"))); #endif /** @brief sets res = 0 */ diff --git a/spqlios/arithmetic/vec_znx_arithmetic.h b/spqlios/arithmetic/vec_znx_arithmetic.h index b93a571..093f37c 100644 --- a/spqlios/arithmetic/vec_znx_arithmetic.h +++ b/spqlios/arithmetic/vec_znx_arithmetic.h @@ -143,20 +143,6 @@ EXPORT void vec_znx_automorphism(const MODULE* module, const int64_t* a, uint64_t a_size, uint64_t a_sl // a ); -/** @brief prepares a vmp matrix (contiguous row-major version) */ -EXPORT void vmp_prepare_contiguous(const MODULE* module, // N - VMP_PMAT* pmat, // output - const int64_t* mat, uint64_t nrows, uint64_t ncols, // a - uint8_t* tmp_space // scratch space -); - -/** @brief prepares a vmp matrix (mat[row*ncols+col] points to the item) */ -EXPORT void vmp_prepare_dblptr(const MODULE* module, // N - VMP_PMAT* pmat, // output - const int64_t** mat, uint64_t nrows, uint64_t ncols, // a - uint8_t* tmp_space // scratch space -); - /** @brief sets res = 0 */ EXPORT void vec_dft_zero(const MODULE* module, // N VEC_ZNX_DFT* res, uint64_t res_size // res @@ -312,6 +298,10 @@ EXPORT void znx_small_single_product(const MODULE* module, // N /** @brief tmp bytes required for znx_small_single_product */ EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module); +/** @brief minimal scratch space byte-size required for the vmp_prepare function */ +EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); + /** @brief prepares a vmp matrix (contiguous row-major version) */ EXPORT void vmp_prepare_contiguous(const MODULE* module, // N VMP_PMAT* pmat, // output @@ -319,9 +309,19 @@ EXPORT void vmp_prepare_contiguous(const MODULE* module, uint8_t* tmp_space // scratch space ); -/** @brief minimal scratch space byte-size required for the vmp_prepare function */ -EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N - uint64_t nrows, uint64_t ncols); +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void vmp_prepare_dblptr(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void vmp_prepare_row(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); /** @brief applies a vmp product (result in DFT space) */ EXPORT void vmp_apply_dft(const MODULE* module, // N diff --git a/spqlios/arithmetic/vec_znx_arithmetic_private.h b/spqlios/arithmetic/vec_znx_arithmetic_private.h index 528dfad..d642015 100644 --- a/spqlios/arithmetic/vec_znx_arithmetic_private.h +++ b/spqlios/arithmetic/vec_znx_arithmetic_private.h @@ -84,7 +84,9 @@ typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F; typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F; typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F; typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F; -typedef typeof(vmp_prepare_contiguous_tmp_bytes) VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F; +typedef typeof(vmp_prepare_dblptr) VMP_PREPARE_DBLPTR_F; +typedef typeof(vmp_prepare_row) VMP_PREPARE_ROW_F; +typedef typeof(vmp_prepare_tmp_bytes) VMP_PREPARE_TMP_BYTES_F; typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F; typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F; typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F; @@ -127,7 +129,9 @@ struct module_virtual_functions_t { ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product; ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes; VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous; - VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* vmp_prepare_contiguous_tmp_bytes; + VMP_PREPARE_DBLPTR_F* vmp_prepare_dblptr; + VMP_PREPARE_ROW_F* vmp_prepare_row; + VMP_PREPARE_TMP_BYTES_F* vmp_prepare_tmp_bytes; VMP_APPLY_DFT_F* vmp_apply_dft; VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes; VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft; @@ -420,6 +424,20 @@ EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, uint8_t* tmp_space // scratch space ); +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_vmp_prepare_dblptr_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_vmp_prepare_row_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + /** @brief prepares a vmp matrix (contiguous row-major version) */ EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N VMP_PMAT* pmat, // output @@ -427,9 +445,23 @@ EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, uint8_t* tmp_space // scratch space ); +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_vmp_prepare_dblptr_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_vmp_prepare_row_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +); + /** @brief minimal scratch space byte-size required for the vmp_prepare function */ -EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N - uint64_t nrows, uint64_t ncols); +EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols); /** @brief applies a vmp product (result in DFT space) */ EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N diff --git a/spqlios/arithmetic/vec_znx_big.c b/spqlios/arithmetic/vec_znx_big.c index 923703c..79bb88f 100644 --- a/spqlios/arithmetic/vec_znx_big.c +++ b/spqlios/arithmetic/vec_znx_big.c @@ -91,7 +91,7 @@ EXPORT uint64_t fft64_bytes_of_vec_znx_big(const MODULE* module, // N } EXPORT VEC_ZNX_BIG* new_vec_znx_big(const MODULE* module, // N - uint64_t size) { + uint64_t size) { return spqlios_alloc(bytes_of_vec_znx_big(module, size)); } diff --git a/spqlios/arithmetic/vec_znx_dft.c b/spqlios/arithmetic/vec_znx_dft.c index 16b3a9e..5dafee0 100644 --- a/spqlios/arithmetic/vec_znx_dft.c +++ b/spqlios/arithmetic/vec_znx_dft.c @@ -39,7 +39,7 @@ EXPORT uint64_t fft64_bytes_of_vec_znx_dft(const MODULE* module, // N } EXPORT VEC_ZNX_DFT* new_vec_znx_dft(const MODULE* module, // N - uint64_t size) { + uint64_t size) { return spqlios_alloc(bytes_of_vec_znx_dft(module, size)); } diff --git a/spqlios/arithmetic/vector_matrix_product.c b/spqlios/arithmetic/vector_matrix_product.c index 79ab40c..0429da7 100644 --- a/spqlios/arithmetic/vector_matrix_product.c +++ b/spqlios/arithmetic/vector_matrix_product.c @@ -17,7 +17,7 @@ EXPORT uint64_t fft64_bytes_of_vmp_pmat(const MODULE* module, // N } EXPORT VMP_PMAT* new_vmp_pmat(const MODULE* module, // N - uint64_t nrows, uint64_t ncols // dimensions + uint64_t nrows, uint64_t ncols // dimensions ) { return spqlios_alloc(bytes_of_vmp_pmat(module, nrows, ncols)); } @@ -33,10 +33,28 @@ EXPORT void vmp_prepare_contiguous(const MODULE* module, module->func.vmp_prepare_contiguous(module, pmat, mat, nrows, ncols, tmp_space); } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void vmp_prepare_dblptr(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->func.vmp_prepare_dblptr(module, pmat, mat, nrows, ncols, tmp_space); +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void vmp_prepare_row(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + module->func.vmp_prepare_row(module, pmat, row, row_i, nrows, ncols, tmp_space); +} + /** @brief minimal scratch space byte-size required for the vmp_prepare function */ -EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N - uint64_t nrows, uint64_t ncols) { - return module->func.vmp_prepare_contiguous_tmp_bytes(module, nrows, ncols); +EXPORT uint64_t vmp_prepare_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols) { + return module->func.vmp_prepare_tmp_bytes(module, nrows, ncols); } /** @brief prepares a vmp matrix (contiguous row-major version) */ @@ -87,9 +105,64 @@ EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, } } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_vmp_prepare_dblptr_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + fft64_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_vmp_prepare_row_ref(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->nn; + const uint64_t m = module->m; + + double* output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, row + col_i * nn); + reim_fft(module->mod.fft64.p_fft, (double*)tmp_space); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space); + } + } + } else { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = (double*)pmat + (col_i * nrows + row_i) * nn; + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, row + col_i * nn); + reim_fft(module->mod.fft64.p_fft, res); + } + } +} + /** @brief minimal scratch space byte-size required for the vmp_prepare function */ -EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N - uint64_t nrows, uint64_t ncols) { +EXPORT uint64_t fft64_vmp_prepare_tmp_bytes(const MODULE* module, // N + uint64_t nrows, uint64_t ncols) { const uint64_t nn = module->nn; return nn * sizeof(int64_t); } diff --git a/spqlios/arithmetic/vector_matrix_product_avx.c b/spqlios/arithmetic/vector_matrix_product_avx.c index f428650..ea89959 100644 --- a/spqlios/arithmetic/vector_matrix_product_avx.c +++ b/spqlios/arithmetic/vector_matrix_product_avx.c @@ -51,6 +51,60 @@ EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, } } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void fft64_vmp_prepare_dblptr_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t** mat, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + for (uint64_t row_i = 0; row_i < nrows; row_i++) { + fft64_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void fft64_vmp_prepare_row_avx(const MODULE* module, // N + VMP_PMAT* pmat, // output + const int64_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a + uint8_t* tmp_space // scratch space +) { + // there is an edge case if nn < 8 + const uint64_t nn = module->nn; + const uint64_t m = module->m; + double* output_mat = (double*)pmat; + double* start_addr = (double*)pmat; + uint64_t offset = nrows * ncols * 8; + + if (nn >= 8) { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)tmp_space, row + col_i * nn); + reim_fft(module->mod.fft64.p_fft, (double*)tmp_space); + + if (col_i == (ncols - 1) && (ncols % 2 == 1)) { + // special case: last column out of an odd column number + start_addr = output_mat + col_i * nrows * 8 // col == ncols-1 + + row_i * 8; + } else { + // general case: columns go by pair + start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index + + row_i * 2 * 8 // third: row index + + (col_i % 2) * 8; + } + + for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) { + // extract blk from tmp and save it + reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, (double*)tmp_space); + } + } + } else { + for (uint64_t col_i = 0; col_i < ncols; col_i++) { + double* res = (double*)pmat + (col_i * nrows + row_i) * nn; + reim_from_znx64(module->mod.fft64.p_conv, (SVP_PPOL*)res, row + col_i * nn); + reim_fft(module->mod.fft64.p_fft, res); + } + } +} + /** @brief applies a vmp product (result in DFT space) */ EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N VEC_ZNX_DFT* res, uint64_t res_size, // res diff --git a/spqlios/arithmetic/zn_api.c b/spqlios/arithmetic/zn_api.c index 28d5c8d..4b81750 100644 --- a/spqlios/arithmetic/zn_api.c +++ b/spqlios/arithmetic/zn_api.c @@ -16,6 +16,8 @@ void default_init_z_module_vtable(MOD_Z* module) { module->vtable.i16_approxdecomp_from_tndbl = default_i16_approxdecomp_from_tndbl_ref; module->vtable.i32_approxdecomp_from_tndbl = default_i32_approxdecomp_from_tndbl_ref; module->vtable.zn32_vmp_prepare_contiguous = default_zn32_vmp_prepare_contiguous_ref; + module->vtable.zn32_vmp_prepare_dblptr = default_zn32_vmp_prepare_dblptr_ref; + module->vtable.zn32_vmp_prepare_row = default_zn32_vmp_prepare_row_ref; module->vtable.zn32_vmp_apply_i8 = default_zn32_vmp_apply_i8_ref; module->vtable.zn32_vmp_apply_i16 = default_zn32_vmp_apply_i16_ref; module->vtable.zn32_vmp_apply_i32 = default_zn32_vmp_apply_i32_ref; @@ -96,13 +98,27 @@ EXPORT void i32_approxdecomp_from_tndbl(const MOD_Z* module, module->vtable.i32_approxdecomp_from_tndbl(module, gadget, res, res_size, a, a_size); } -EXPORT void zn32_vmp_prepare_contiguous( // - const MOD_Z* module, - ZN32_VMP_PMAT* pmat, // output - const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a +/** @brief prepares a vmp matrix (contiguous row-major version) */ +EXPORT void zn32_vmp_prepare_contiguous(const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* mat, uint64_t nrows, uint64_t ncols) { // a module->vtable.zn32_vmp_prepare_contiguous(module, pmat, mat, nrows, ncols); } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void zn32_vmp_prepare_dblptr(const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_dblptr(module, pmat, mat, nrows, ncols); +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void zn32_vmp_prepare_row(const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols) { // a + module->vtable.zn32_vmp_prepare_row(module, pmat, row, row_i, nrows, ncols); +} + /** @brief applies a vmp product (int32_t* input) */ EXPORT void zn32_vmp_apply_i32(const MOD_Z* module, int32_t* res, uint64_t res_size, const int32_t* a, uint64_t a_size, const ZN32_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols) { diff --git a/spqlios/arithmetic/zn_arithmetic.h b/spqlios/arithmetic/zn_arithmetic.h index 3503e20..7aec10a 100644 --- a/spqlios/arithmetic/zn_arithmetic.h +++ b/spqlios/arithmetic/zn_arithmetic.h @@ -65,6 +65,18 @@ EXPORT void zn32_vmp_prepare_contiguous( // ZN32_VMP_PMAT* pmat, // output const int32_t* mat, uint64_t nrows, uint64_t ncols); // a +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void zn32_vmp_prepare_dblptr( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols); // a + +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void zn32_vmp_prepare_row( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols); // a + /** @brief applies a vmp product (int32_t* input) */ EXPORT void zn32_vmp_apply_i32( // const MOD_Z* module, // diff --git a/spqlios/arithmetic/zn_arithmetic_plugin.h b/spqlios/arithmetic/zn_arithmetic_plugin.h index d400a72..eb573cc 100644 --- a/spqlios/arithmetic/zn_arithmetic_plugin.h +++ b/spqlios/arithmetic/zn_arithmetic_plugin.h @@ -8,6 +8,8 @@ typedef typeof(i16_approxdecomp_from_tndbl) I16_APPROXDECOMP_FROM_TNDBL_F; typedef typeof(i32_approxdecomp_from_tndbl) I32_APPROXDECOMP_FROM_TNDBL_F; typedef typeof(bytes_of_zn32_vmp_pmat) BYTES_OF_ZN32_VMP_PMAT_F; typedef typeof(zn32_vmp_prepare_contiguous) ZN32_VMP_PREPARE_CONTIGUOUS_F; +typedef typeof(zn32_vmp_prepare_dblptr) ZN32_VMP_PREPARE_DBLPTR_F; +typedef typeof(zn32_vmp_prepare_row) ZN32_VMP_PREPARE_ROW_F; typedef typeof(zn32_vmp_apply_i32) ZN32_VMP_APPLY_I32_F; typedef typeof(zn32_vmp_apply_i16) ZN32_VMP_APPLY_I16_F; typedef typeof(zn32_vmp_apply_i8) ZN32_VMP_APPLY_I8_F; @@ -25,6 +27,8 @@ struct z_module_vtable_t { I32_APPROXDECOMP_FROM_TNDBL_F* i32_approxdecomp_from_tndbl; BYTES_OF_ZN32_VMP_PMAT_F* bytes_of_zn32_vmp_pmat; ZN32_VMP_PREPARE_CONTIGUOUS_F* zn32_vmp_prepare_contiguous; + ZN32_VMP_PREPARE_DBLPTR_F* zn32_vmp_prepare_dblptr; + ZN32_VMP_PREPARE_ROW_F* zn32_vmp_prepare_row; ZN32_VMP_APPLY_I32_F* zn32_vmp_apply_i32; ZN32_VMP_APPLY_I16_F* zn32_vmp_apply_i16; ZN32_VMP_APPLY_I8_F* zn32_vmp_apply_i8; diff --git a/spqlios/arithmetic/zn_arithmetic_private.h b/spqlios/arithmetic/zn_arithmetic_private.h index 3ff6c48..2de8a84 100644 --- a/spqlios/arithmetic/zn_arithmetic_private.h +++ b/spqlios/arithmetic/zn_arithmetic_private.h @@ -67,6 +67,20 @@ EXPORT void default_zn32_vmp_prepare_contiguous_ref( // const int32_t* mat, uint64_t nrows, uint64_t ncols // a ); +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void default_zn32_vmp_prepare_dblptr_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols // a +); + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void default_zn32_vmp_prepare_row_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a +); + /** @brief applies a vmp product (int32_t* input) */ EXPORT void default_zn32_vmp_apply_i32_ref( // const MOD_Z* module, // diff --git a/spqlios/arithmetic/zn_vmp_ref.c b/spqlios/arithmetic/zn_vmp_ref.c index d75dca2..dd0b527 100644 --- a/spqlios/arithmetic/zn_vmp_ref.c +++ b/spqlios/arithmetic/zn_vmp_ref.c @@ -60,6 +60,53 @@ EXPORT void default_zn32_vmp_prepare_contiguous_ref( // } } +/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */ +EXPORT void default_zn32_vmp_prepare_dblptr_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t** mat, uint64_t nrows, uint64_t ncols // a +) { + for (uint64_t row_i = 0; row_i < nrows; ++row_i) { + default_zn32_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols); + } +} + +/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */ +EXPORT void default_zn32_vmp_prepare_row_ref( // + const MOD_Z* module, + ZN32_VMP_PMAT* pmat, // output + const int32_t* row, uint64_t row_i, uint64_t nrows, uint64_t ncols // a +) { + int32_t* const out = (int32_t*)pmat; + const uint64_t nblk = ncols >> 5; + const uint64_t ncols_rem = ncols & 31; + const uint64_t final_elems = (row_i == nrows - 1) && (8 - nrows * ncols) & 7; + for (uint64_t blk = 0; blk < nblk; ++blk) { + int32_t* outblk = out + blk * nrows * 32; + int32_t* dest = outblk + row_i * 32; + const int32_t* src = row + blk * 32; + for (uint64_t i = 0; i < 32; ++i) { + dest[i] = src[i]; + } + } + // copy the last block if any + if (ncols_rem) { + int32_t* outblk = out + nblk * nrows * 32; + int32_t* dest = outblk + row_i * ncols_rem; + const int32_t* src = row + nblk * 32; + for (uint64_t i = 0; i < ncols_rem; ++i) { + dest[i] = src[i]; + } + } + // zero-out the final elements that may be accessed + if (final_elems) { + int32_t* f = out + nrows * ncols; + for (uint64_t i = 0; i < final_elems; ++i) { + f[i] = 0; + } + } +} + #if 0 #define IMPL_zn32_vec_ixxx_matyyycols_ref(NCOLS) \ diff --git a/test/spqlios_reim_test.cpp b/test/spqlios_reim_test.cpp index 3432e32..8730dca 100644 --- a/test/spqlios_reim_test.cpp +++ b/test/spqlios_reim_test.cpp @@ -221,14 +221,10 @@ TEST(fft, reim_fft16_ref_vs_fma) { test_reim_fft_ref_vs_accel<16>(reim_fft16_ref #ifdef __aarch64__ static void reim_fft16_ref_neon_pom(double* dre, double* dim, const void* omega) { - const double* pom = (double*) omega; + const double* pom = (double*)omega; // put the omegas in neon order - double x_pom[] = { - pom[0], pom[1], pom[2], pom[3], - pom[4],pom[5], pom[6], pom[7], - pom[8], pom[10],pom[12], pom[14], - pom[9], pom[11],pom[13], pom[15] - }; + double x_pom[] = {pom[0], pom[1], pom[2], pom[3], pom[4], pom[5], pom[6], pom[7], + pom[8], pom[10], pom[12], pom[14], pom[9], pom[11], pom[13], pom[15]}; reim_fft16_ref(dre, dim, x_pom); } TEST(fft, reim_fft16_ref_vs_neon) { test_reim_fft_ref_vs_accel<16>(reim_fft16_ref_neon_pom, reim_fft16_neon); } diff --git a/test/spqlios_vec_rnx_test.cpp b/test/spqlios_vec_rnx_test.cpp index 2990299..1ff3389 100644 --- a/test/spqlios_vec_rnx_test.cpp +++ b/test/spqlios_vec_rnx_test.cpp @@ -239,8 +239,8 @@ void test_vec_rnx_elemw_unop_param_inplace(ACTUAL_FCN actual_function, EXPECT_FC } actual_function(mod, // N p, //; - la.data(), sa, a_sl, // res - la.data(), sa, a_sl // a + la.data(), sa, a_sl, // res + la.data(), sa, a_sl // a ); for (uint64_t i = 0; i < sa; ++i) { ASSERT_EQ(la.get_copy_zext(i), expect[i]) << n << " " << sa << " " << i; diff --git a/test/spqlios_vec_rnx_vmp_test.cpp b/test/spqlios_vec_rnx_vmp_test.cpp index 9bbb9d7..edaa12c 100644 --- a/test/spqlios_vec_rnx_vmp_test.cpp +++ b/test/spqlios_vec_rnx_vmp_test.cpp @@ -1,6 +1,6 @@ -#include "gtest/gtest.h" #include "../spqlios/arithmetic/vec_rnx_arithmetic_private.h" #include "../spqlios/reim/reim_fft.h" +#include "gtest/gtest.h" #include "testlib/vec_rnx_layout.h" static void test_vmp_apply_dft_to_dft_outplace( // @@ -113,7 +113,7 @@ TEST(vec_rnx, fft64_vmp_apply_dft_to_dft_avx) { /// rnx_vmp_prepare static void test_vmp_prepare_contiguous(RNX_VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, - RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* tmp_bytes) { + RNX_VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { // tests when n < 8 for (uint64_t nn : {2, 4}) { const double one_over_m = 2. / nn; @@ -172,14 +172,92 @@ static void test_vmp_prepare_contiguous(RNX_VMP_PREPARE_CONTIGUOUS_F* prepare_co } TEST(vec_rnx, vmp_prepare_contiguous) { - test_vmp_prepare_contiguous(rnx_vmp_prepare_contiguous, rnx_vmp_prepare_contiguous_tmp_bytes); + test_vmp_prepare_contiguous(rnx_vmp_prepare_contiguous, rnx_vmp_prepare_tmp_bytes); } TEST(vec_rnx, fft64_vmp_prepare_contiguous_ref) { - test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_ref, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref); + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_ref, fft64_rnx_vmp_prepare_tmp_bytes_ref); } #ifdef __x86_64__ TEST(vec_rnx, fft64_vmp_prepare_contiguous_avx) { - test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_avx, fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx); + test_vmp_prepare_contiguous(fft64_rnx_vmp_prepare_contiguous_avx, fft64_rnx_vmp_prepare_tmp_bytes_avx); +} +#endif + +/// rnx_vmp_prepare_dblptr + +static void test_vmp_prepare_dblptr(RNX_VMP_PREPARE_DBLPTR_F* prepare_dblptr, RNX_VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + const double** mat_dblptr = (const double**)malloc(nrows * sizeof(double*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &mat.data()[row_i * ncols * nn]; + }; + prepare_dblptr(module, pmat.data, mat_dblptr, nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + const double one_over_m = 2. / nn; + MOD_RNX* module = new_rnx_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + rnx_vec_f64_layout mat(nn, nrows * ncols, nn); + fft64_rnx_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(0); + std::vector tmp_space(tmp_bytes(module)); + thash hash_before = mat.content_hash(); + const double** mat_dblptr = (const double**)malloc(nrows * sizeof(double*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &mat.data()[row_i * ncols * nn]; + }; + prepare_dblptr(module, pmat.data, mat_dblptr, nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = one_over_m * simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_rnx_module_info(module); + } +} + +TEST(vec_rnx, vmp_prepare_dblptr) { test_vmp_prepare_dblptr(rnx_vmp_prepare_dblptr, rnx_vmp_prepare_tmp_bytes); } +TEST(vec_rnx, fft64_vmp_prepare_dblptr_ref) { + test_vmp_prepare_dblptr(fft64_rnx_vmp_prepare_dblptr_ref, fft64_rnx_vmp_prepare_tmp_bytes_ref); +} +#ifdef __x86_64__ +TEST(vec_rnx, fft64_vmp_prepare_dblptr_avx) { + test_vmp_prepare_dblptr(fft64_rnx_vmp_prepare_dblptr_avx, fft64_rnx_vmp_prepare_tmp_bytes_avx); } #endif diff --git a/test/spqlios_vmp_product_test.cpp b/test/spqlios_vmp_product_test.cpp index cb55818..277da66 100644 --- a/test/spqlios_vmp_product_test.cpp +++ b/test/spqlios_vmp_product_test.cpp @@ -5,7 +5,7 @@ #include "testlib/polynomial_vector.h" static void test_vmp_prepare_contiguous(VMP_PREPARE_CONTIGUOUS_F* prepare_contiguous, - VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* tmp_bytes) { + VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { // tests when n < 8 for (uint64_t nn : {2, 4}) { MODULE* module = new_module_info(nn, FFT64); @@ -14,7 +14,7 @@ static void test_vmp_prepare_contiguous(VMP_PREPARE_CONTIGUOUS_F* prepare_contig znx_vec_i64_layout mat(nn, nrows * ncols, nn); fft64_vmp_pmat_layout pmat(nn, nrows, ncols); mat.fill_random(30); - std::vector tmp_space(fft64_vmp_prepare_contiguous_tmp_bytes(module, nrows, ncols)); + std::vector tmp_space(fft64_vmp_prepare_tmp_bytes(module, nrows, ncols)); thash hash_before = mat.content_hash(); prepare_contiguous(module, pmat.data, mat.data(), nrows, ncols, tmp_space.data()); ASSERT_EQ(mat.content_hash(), hash_before); @@ -61,15 +61,87 @@ static void test_vmp_prepare_contiguous(VMP_PREPARE_CONTIGUOUS_F* prepare_contig } } -TEST(vec_znx, vmp_prepare_contiguous) { - test_vmp_prepare_contiguous(vmp_prepare_contiguous, vmp_prepare_contiguous_tmp_bytes); -} +TEST(vec_znx, vmp_prepare_contiguous) { test_vmp_prepare_contiguous(vmp_prepare_contiguous, vmp_prepare_tmp_bytes); } TEST(vec_znx, fft64_vmp_prepare_contiguous_ref) { - test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_ref, fft64_vmp_prepare_contiguous_tmp_bytes); + test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_ref, fft64_vmp_prepare_tmp_bytes); } #ifdef __x86_64__ TEST(vec_znx, fft64_vmp_prepare_contiguous_avx) { - test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_avx, fft64_vmp_prepare_contiguous_tmp_bytes); + test_vmp_prepare_contiguous(fft64_vmp_prepare_contiguous_avx, fft64_vmp_prepare_tmp_bytes); +} +#endif + +static void test_vmp_prepare_dblptr(VMP_PREPARE_DBLPTR_F* prepare_dblptr, VMP_PREPARE_TMP_BYTES_F* tmp_bytes) { + // tests when n < 8 + for (uint64_t nn : {2, 4}) { + MODULE* module = new_module_info(nn, FFT64); + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + znx_vec_i64_layout mat(nn, nrows * ncols, nn); + fft64_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(30); + std::vector tmp_space(fft64_vmp_prepare_tmp_bytes(module, nrows, ncols)); + thash hash_before = mat.content_hash(); + const int64_t** mat_dblptr = (const int64_t**)malloc(nrows * sizeof(int64_t*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &mat.data()[row_i * ncols * nn]; + }; + prepare_dblptr(module, pmat.data, mat_dblptr, nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + const double* pmatv = (double*)pmat.data + (col * nrows + row) * nn; + reim_fft64vec tmp = simple_fft64(mat.get_copy(row * ncols + col)); + const double* tmpv = tmp.data(); + for (uint64_t i = 0; i < nn; ++i) { + ASSERT_LE(abs(pmatv[i] - tmpv[i]), 1e-10); + } + } + } + } + } + delete_module_info(module); + } + // tests when n >= 8 + for (uint64_t nn : {8, 32}) { + MODULE* module = new_module_info(nn, FFT64); + uint64_t nblk = nn / 8; + for (uint64_t nrows : {1, 2, 5}) { + for (uint64_t ncols : {2, 6, 7}) { + znx_vec_i64_layout mat(nn, nrows * ncols, nn); + fft64_vmp_pmat_layout pmat(nn, nrows, ncols); + mat.fill_random(30); + std::vector tmp_space(tmp_bytes(module, nrows, ncols)); + thash hash_before = mat.content_hash(); + const int64_t** mat_dblptr = (const int64_t**)malloc(nrows * sizeof(int64_t*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &mat.data()[row_i * ncols * nn]; + }; + prepare_dblptr(module, pmat.data, mat_dblptr, nrows, ncols, tmp_space.data()); + ASSERT_EQ(mat.content_hash(), hash_before); + for (uint64_t row = 0; row < nrows; ++row) { + for (uint64_t col = 0; col < ncols; ++col) { + reim_fft64vec tmp = simple_fft64(mat.get_copy(row * ncols + col)); + for (uint64_t blk = 0; blk < nblk; ++blk) { + reim4_elem expect = tmp.get_blk(blk); + reim4_elem actual = pmat.get(row, col, blk); + ASSERT_LE(infty_dist(actual, expect), 1e-10); + } + } + } + } + } + delete_module_info(module); + } +} + +TEST(vec_znx, vmp_prepare_dblptr) { test_vmp_prepare_dblptr(vmp_prepare_dblptr, vmp_prepare_tmp_bytes); } +TEST(vec_znx, fft64_vmp_prepare_dblptr_ref) { + test_vmp_prepare_dblptr(fft64_vmp_prepare_dblptr_ref, fft64_vmp_prepare_tmp_bytes); +} +#ifdef __x86_64__ +TEST(vec_znx, fft64_vmp_prepare_dblptr_avx) { + test_vmp_prepare_dblptr(fft64_vmp_prepare_dblptr_avx, fft64_vmp_prepare_tmp_bytes); } #endif diff --git a/test/spqlios_zn_vmp_test.cpp b/test/spqlios_zn_vmp_test.cpp index 8f6fa25..57b0ad0 100644 --- a/test/spqlios_zn_vmp_test.cpp +++ b/test/spqlios_zn_vmp_test.cpp @@ -2,14 +2,14 @@ #include "spqlios/arithmetic/zn_arithmetic_private.h" #include "testlib/zn_layouts.h" -static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prep) { +static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prepare_contiguous) { MOD_Z* module = new_z_module_info(DEFAULT); for (uint64_t nrows : {1, 2, 5, 15}) { for (uint64_t ncols : {1, 2, 32, 42, 67}) { std::vector src(nrows * ncols); zn32_pmat_layout out(nrows, ncols); for (int32_t& x : src) x = uniform_i64_bits(32); - prep(module, out.data, src.data(), nrows, ncols); + prepare_contiguous(module, out.data, src.data(), nrows, ncols); for (uint64_t i = 0; i < nrows; ++i) { for (uint64_t j = 0; j < ncols; ++j) { int32_t in = src[i * ncols + j]; @@ -25,6 +25,33 @@ static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_CONTIGUOUS_F prep) { TEST(zn, zn32_vmp_prepare_contiguous) { test_zn_vmp_prepare(zn32_vmp_prepare_contiguous); } TEST(zn, default_zn32_vmp_prepare_contiguous_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_contiguous_ref); } +static void test_zn_vmp_prepare(ZN32_VMP_PREPARE_DBLPTR_F prepare_dblptr) { + MOD_Z* module = new_z_module_info(DEFAULT); + for (uint64_t nrows : {1, 2, 5, 15}) { + for (uint64_t ncols : {1, 2, 32, 42, 67}) { + std::vector src(nrows * ncols); + zn32_pmat_layout out(nrows, ncols); + for (int32_t& x : src) x = uniform_i64_bits(32); + const int32_t** mat_dblptr = (const int32_t**)malloc(nrows * sizeof(int32_t*)); + for (size_t row_i = 0; row_i < nrows; row_i++) { + mat_dblptr[row_i] = &src.data()[row_i * ncols]; + }; + prepare_dblptr(module, out.data, mat_dblptr, nrows, ncols); + for (uint64_t i = 0; i < nrows; ++i) { + for (uint64_t j = 0; j < ncols; ++j) { + int32_t in = src[i * ncols + j]; + int32_t actual = out.get(i, j); + ASSERT_EQ(actual, in); + } + } + } + } + delete_z_module_info(module); +} + +TEST(zn, zn32_vmp_prepare_dblptr) { test_zn_vmp_prepare(zn32_vmp_prepare_dblptr); } +TEST(zn, default_zn32_vmp_prepare_dblptr_ref) { test_zn_vmp_prepare(default_zn32_vmp_prepare_dblptr_ref); } + template static void test_zn_vmp_apply(void (*apply)(const MOD_Z*, int32_t*, uint64_t, const INTTYPE*, uint64_t, const ZN32_VMP_PMAT*, uint64_t, uint64_t)) {