Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for vmp_prepare_dblptr #49

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,4 @@
cmake-build-*
.idea

build
8 changes: 5 additions & 3 deletions spqlios/arithmetic/module_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -46,21 +46,23 @@ static void fill_fft64_virtual_table(MODULE* module) {
module->func.znx_small_single_product = fft64_znx_small_single_product;
module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes;
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref;
module->func.vmp_prepare_contiguous_tmp_bytes = fft64_vmp_prepare_contiguous_tmp_bytes;
module->func.vmp_prepare_dblptr = fft64_vmp_prepare_dblptr_ref;
module->func.vmp_prepare_row = fft64_vmp_prepare_row_ref;
module->func.vmp_prepare_tmp_bytes = fft64_vmp_prepare_tmp_bytes;
module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref;
module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes;
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref;
module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes;
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
module->func.bytes_of_vec_znx_dft = fft64_bytes_of_vec_znx_dft;
module->func.bytes_of_vec_znx_big = fft64_bytes_of_vec_znx_big;
module->func.bytes_of_svp_ppol = fft64_bytes_of_svp_ppol;
module->func.bytes_of_vmp_pmat = fft64_bytes_of_vmp_pmat;
if (CPU_SUPPORTS("avx2")) {
// TODO add avx handlers here
// TODO: enable when avx implementation is done
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx;
module->func.vmp_prepare_dblptr = fft64_vmp_prepare_dblptr_avx;
module->func.vmp_prepare_row = fft64_vmp_prepare_row_avx;
module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx;
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx;
}
Expand Down
32 changes: 28 additions & 4 deletions spqlios/arithmetic/vec_rnx_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ void fft64_init_rnx_module_vtable(MOD_RNX* module) {
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref;
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_ref;
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_ref;
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_ref;
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
Expand All @@ -55,8 +57,10 @@ void fft64_init_rnx_module_vtable(MOD_RNX* module) {
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx;
module->vtable.rnx_vmp_prepare_tmp_bytes = fft64_rnx_vmp_prepare_tmp_bytes_avx;
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
module->vtable.rnx_vmp_prepare_dblptr = fft64_rnx_vmp_prepare_dblptr_avx;
module->vtable.rnx_vmp_prepare_row = fft64_rnx_vmp_prepare_row_avx;
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
}
}
Expand Down Expand Up @@ -201,9 +205,29 @@ EXPORT void rnx_vmp_prepare_contiguous( //
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
}

/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
EXPORT void rnx_vmp_prepare_dblptr( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double** a, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
module->vtable.rnx_vmp_prepare_dblptr(module, pmat, a, nrows, ncols, tmp_space);
}

/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
EXPORT void rnx_vmp_prepare_row( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
module->vtable.rnx_vmp_prepare_row(module, pmat, a, row_i, nrows, ncols, tmp_space);
}

/** @brief number of scratch bytes necessary to prepare a matrix */
EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) {
return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module);
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module) {
return module->vtable.rnx_vmp_prepare_tmp_bytes(module);
}

/** @brief applies a vmp product res = a x pmat */
Expand Down
18 changes: 17 additions & 1 deletion spqlios/arithmetic/vec_rnx_arithmetic.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,8 +289,24 @@ EXPORT void rnx_vmp_prepare_contiguous( //
uint8_t* tmp_space // scratch space
);

/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
EXPORT void rnx_vmp_prepare_dblptr( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double** a, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);

/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
EXPORT void rnx_vmp_prepare_row( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* a, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);

/** @brief number of scratch bytes necessary to prepare a matrix */
EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module);
EXPORT uint64_t rnx_vmp_prepare_tmp_bytes(const MOD_RNX* module);

/** @brief applies a vmp product res = a x pmat */
EXPORT void rnx_vmp_apply_tmp_a( //
Expand Down
8 changes: 6 additions & 2 deletions spqlios/arithmetic/vec_rnx_arithmetic_plugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ typedef typeof(rnx_svp_prepare) RNX_SVP_PREPARE_F;
typedef typeof(rnx_svp_apply) RNX_SVP_APPLY_F;
typedef typeof(bytes_of_rnx_vmp_pmat) BYTES_OF_RNX_VMP_PMAT_F;
typedef typeof(rnx_vmp_prepare_contiguous) RNX_VMP_PREPARE_CONTIGUOUS_F;
typedef typeof(rnx_vmp_prepare_contiguous_tmp_bytes) RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F;
typedef typeof(rnx_vmp_prepare_dblptr) RNX_VMP_PREPARE_DBLPTR_F;
typedef typeof(rnx_vmp_prepare_row) RNX_VMP_PREPARE_ROW_F;
typedef typeof(rnx_vmp_prepare_tmp_bytes) RNX_VMP_PREPARE_TMP_BYTES_F;
typedef typeof(rnx_vmp_apply_tmp_a) RNX_VMP_APPLY_TMP_A_F;
typedef typeof(rnx_vmp_apply_tmp_a_tmp_bytes) RNX_VMP_APPLY_TMP_A_TMP_BYTES_F;
typedef typeof(rnx_vmp_apply_dft_to_dft) RNX_VMP_APPLY_DFT_TO_DFT_F;
Expand Down Expand Up @@ -76,7 +78,9 @@ struct rnx_module_vtable_t {
RNX_SVP_APPLY_F* rnx_svp_apply;
BYTES_OF_RNX_VMP_PMAT_F* bytes_of_rnx_vmp_pmat;
RNX_VMP_PREPARE_CONTIGUOUS_F* rnx_vmp_prepare_contiguous;
RNX_VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* rnx_vmp_prepare_contiguous_tmp_bytes;
RNX_VMP_PREPARE_DBLPTR_F* rnx_vmp_prepare_dblptr;
RNX_VMP_PREPARE_ROW_F* rnx_vmp_prepare_row;
RNX_VMP_PREPARE_TMP_BYTES_F* rnx_vmp_prepare_tmp_bytes;
RNX_VMP_APPLY_TMP_A_F* rnx_vmp_apply_tmp_a;
RNX_VMP_APPLY_TMP_A_TMP_BYTES_F* rnx_vmp_apply_tmp_a_tmp_bytes;
RNX_VMP_APPLY_DFT_TO_DFT_F* rnx_vmp_apply_dft_to_dft;
Expand Down
28 changes: 26 additions & 2 deletions spqlios/arithmetic/vec_rnx_arithmetic_private.h
Original file line number Diff line number Diff line change
Expand Up @@ -183,8 +183,32 @@ EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
const double* mat, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module);
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module);
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double** mat, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double** mat, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* mat, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
);
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module);
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module);

EXPORT void fft64_rnx_vmp_apply_tmp_a_ref( //
const MOD_RNX* module, // N
Expand Down
58 changes: 58 additions & 0 deletions spqlios/arithmetic/vec_rnx_vmp_avx.c
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,64 @@ EXPORT void fft64_rnx_vmp_prepare_contiguous_avx( //
}
}

/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
EXPORT void fft64_rnx_vmp_prepare_dblptr_avx( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double** mat, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
fft64_rnx_vmp_prepare_row_avx(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
}
}

/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
EXPORT void fft64_rnx_vmp_prepare_row_avx( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
// there is an edge case if nn < 8
const uint64_t nn = module->n;
const uint64_t m = module->m;

double* const dtmp = (double*)tmp_space;
double* const output_mat = (double*)pmat;
double* start_addr = (double*)pmat;
uint64_t offset = nrows * ncols * 8;

if (nn >= 8) {
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
rnx_divide_by_m_avx(nn, m, dtmp, row + col_i * nn);
reim_fft(module->precomp.fft64.p_fft, dtmp);

if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
// special case: last column out of an odd column number
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
+ row_i * 8;
} else {
// general case: columns go by pair
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
+ row_i * 2 * 8 // third: row index
+ (col_i % 2) * 8;
}

for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
// extract blk from tmp and save it
reim4_extract_1blk_from_reim_avx(m, blk_i, start_addr + blk_i * offset, dtmp);
}
}
} else {
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
double* res = output_mat + (col_i * nrows + row_i) * nn;
rnx_divide_by_m_avx(nn, m, res, row + col_i * nn);
reim_fft(module->precomp.fft64.p_fft, res);
}
}
}

/** @brief minimal size of the tmp_space */
EXPORT void fft64_rnx_vmp_apply_dft_to_dft_avx( //
const MOD_RNX* module, // N
Expand Down
66 changes: 62 additions & 4 deletions spqlios/arithmetic/vec_rnx_vmp_ref.c
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,66 @@ EXPORT void fft64_rnx_vmp_prepare_contiguous_ref( //
}
}

/** @brief prepares a vmp matrix (mat[row]+col*N points to the item) */
EXPORT void fft64_rnx_vmp_prepare_dblptr_ref( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double** mat, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
for (uint64_t row_i = 0; row_i < nrows; row_i++) {
fft64_rnx_vmp_prepare_row_ref(module, pmat, mat[row_i], row_i, nrows, ncols, tmp_space);
}
}

/** @brief prepares the ith-row of a vmp matrix with nrows and ncols */
EXPORT void fft64_rnx_vmp_prepare_row_ref( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* row, uint64_t row_i, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
// there is an edge case if nn < 8
const uint64_t nn = module->n;
const uint64_t m = module->m;

double* const dtmp = (double*)tmp_space;
double* const output_mat = (double*)pmat;
double* start_addr = (double*)pmat;
uint64_t offset = nrows * ncols * 8;

if (nn >= 8) {
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
rnx_divide_by_m_ref(nn, m, dtmp, row + col_i * nn);
reim_fft(module->precomp.fft64.p_fft, dtmp);

if (col_i == (ncols - 1) && (ncols % 2 == 1)) {
// special case: last column out of an odd column number
start_addr = output_mat + col_i * nrows * 8 // col == ncols-1
+ row_i * 8;
} else {
// general case: columns go by pair
start_addr = output_mat + (col_i / 2) * (2 * nrows) * 8 // second: col pair index
+ row_i * 2 * 8 // third: row index
+ (col_i % 2) * 8;
}

for (uint64_t blk_i = 0; blk_i < m / 4; blk_i++) {
// extract blk from tmp and save it
reim4_extract_1blk_from_reim_ref(m, blk_i, start_addr + blk_i * offset, dtmp);
}
}
} else {
for (uint64_t col_i = 0; col_i < ncols; col_i++) {
double* res = output_mat + (col_i * nrows + row_i) * nn;
rnx_divide_by_m_ref(nn, m, res, row + col_i * nn);
reim_fft(module->precomp.fft64.p_fft, res);
}
}
}

/** @brief number of scratch bytes necessary to prepare a matrix */
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref(const MOD_RNX* module) {
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_ref(const MOD_RNX* module) {
const uint64_t nn = module->n;
return nn * sizeof(int64_t);
}
Expand Down Expand Up @@ -220,10 +278,10 @@ EXPORT uint64_t fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref( //

/** @brief number of scratch bytes necessary to prepare a matrix */
#ifdef __APPLE__
#pragma weak fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref
#pragma weak fft64_rnx_vmp_prepare_tmp_bytes_avx = fft64_rnx_vmp_prepare_tmp_bytes_ref
#else
EXPORT uint64_t fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx(const MOD_RNX* module)
__attribute((alias("fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref")));
EXPORT uint64_t fft64_rnx_vmp_prepare_tmp_bytes_avx(const MOD_RNX* module)
__attribute((alias("fft64_rnx_vmp_prepare_tmp_bytes_ref")));
#endif

/** @brief minimal size of the tmp_space */
Expand Down
11 changes: 5 additions & 6 deletions spqlios/arithmetic/vec_znx.c
Original file line number Diff line number Diff line change
Expand Up @@ -249,27 +249,26 @@ EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module //
return nn * sizeof(int64_t);
}


// alias have to be defined in this unit: do not move
#ifdef __APPLE__
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
const MODULE* module // N
) {
) {
return vec_znx_normalize_base2k_tmp_bytes_ref(module);
}
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
const MODULE* module // N
const MODULE* module // N
) {
return vec_znx_normalize_base2k_tmp_bytes_ref(module);
}
#else
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
const MODULE* module // N
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
const MODULE* module // N
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));

EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
const MODULE* module // N
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
#endif

/** @brief sets res = 0 */
Expand Down
Loading