Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rnx and zn apis #42

Merged
merged 11 commits into from
Aug 21, 2024
28 changes: 27 additions & 1 deletion spqlios/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,22 @@ set(SRCS_GENERIC
arithmetic/vec_znx_big.c
arithmetic/znx_small.c
arithmetic/module_api.c
arithmetic/zn_vmp_int8_ref.c
arithmetic/zn_vmp_int16_ref.c
arithmetic/zn_vmp_int32_ref.c
arithmetic/zn_vmp_ref.c
arithmetic/zn_api.c
arithmetic/zn_conversions_ref.c
arithmetic/zn_approxdecomp_ref.c
arithmetic/vec_rnx_api.c
arithmetic/vec_rnx_conversions_ref.c
arithmetic/vec_rnx_svp_ref.c
reim/reim_execute.c
cplx/cplx_execute.c
reim4/reim4_execute.c
arithmetic/vec_rnx_arithmetic.c
arithmetic/vec_rnx_approxdecomp_ref.c
arithmetic/vec_rnx_vmp_ref.c
)
# C or assembly source files compiled only on x86 targets
set(SRCS_X86
Expand Down Expand Up @@ -95,9 +108,16 @@ set(SRCS_AVX2
arithmetic/vec_znx_avx.c
coeffs/coeffs_arithmetic_avx.c
arithmetic/vec_znx_dft_avx2.c
arithmetic/zn_vmp_int8_avx.c
arithmetic/zn_vmp_int16_avx.c
arithmetic/zn_vmp_int32_avx.c
q120/q120_arithmetic_avx2.c
q120/q120_ntt_avx2.c
)
arithmetic/vec_rnx_arithmetic_avx.c
arithmetic/vec_rnx_approxdecomp_avx.c
arithmetic/vec_rnx_vmp_avx.c

)
set_source_files_properties(${SRCS_AVX2} PROPERTIES COMPILE_OPTIONS "-mbmi2;-mavx2")

# C source files on float128 via libquadmath on x86 targets targets
Expand All @@ -110,6 +130,8 @@ set(SRCS_F128
set(HEADERSPUBLIC
commons.h
arithmetic/vec_znx_arithmetic.h
arithmetic/vec_rnx_arithmetic.h
arithmetic/zn_arithmetic.h
cplx/cplx_fft.h
reim/reim_fft.h
q120/q120_common.h
Expand All @@ -131,6 +153,10 @@ set(HEADERSPRIVATE
q120/q120_arithmetic_private.h
q120/q120_ntt_private.h
arithmetic/vec_znx_arithmetic.h
arithmetic/vec_rnx_arithmetic_private.h
arithmetic/vec_rnx_arithmetic_plugin.h
arithmetic/zn_arithmetic_private.h
arithmetic/zn_arithmetic_plugin.h
coeffs/coeffs_arithmetic.h
reim/reim_fft_core_template.h
)
Expand Down
318 changes: 318 additions & 0 deletions spqlios/arithmetic/vec_rnx_api.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,318 @@
#include <string.h>

#include "vec_rnx_arithmetic_private.h"

void fft64_init_rnx_module_precomp(MOD_RNX* module) {
// Add here initialization of items that are in the precomp
const uint64_t m = module->m;
module->precomp.fft64.p_fft = new_reim_fft_precomp(m, 0);
module->precomp.fft64.p_ifft = new_reim_ifft_precomp(m, 0);
module->precomp.fft64.p_fftvec_mul = new_reim_fftvec_mul_precomp(m);
module->precomp.fft64.p_fftvec_addmul = new_reim_fftvec_addmul_precomp(m);
}

void fft64_finalize_rnx_module_precomp(MOD_RNX* module) {
// Add here deleters for items that are in the precomp
delete_reim_fft_precomp(module->precomp.fft64.p_fft);
delete_reim_ifft_precomp(module->precomp.fft64.p_ifft);
delete_reim_fftvec_mul_precomp(module->precomp.fft64.p_fftvec_mul);
delete_reim_fftvec_addmul_precomp(module->precomp.fft64.p_fftvec_addmul);
}

void fft64_init_rnx_module_vtable(MOD_RNX* module) {
// Add function pointers here
module->vtable.vec_rnx_add = vec_rnx_add_ref;
module->vtable.vec_rnx_zero = vec_rnx_zero_ref;
module->vtable.vec_rnx_copy = vec_rnx_copy_ref;
module->vtable.vec_rnx_negate = vec_rnx_negate_ref;
module->vtable.vec_rnx_sub = vec_rnx_sub_ref;
module->vtable.vec_rnx_rotate = vec_rnx_rotate_ref;
module->vtable.vec_rnx_automorphism = vec_rnx_automorphism_ref;
module->vtable.vec_rnx_mul_xp_minus_one = vec_rnx_mul_xp_minus_one_ref;
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_ref;
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_ref;
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_ref;
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_ref;
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_ref;
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_ref;
module->vtable.bytes_of_rnx_vmp_pmat = fft64_bytes_of_rnx_vmp_pmat;
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_ref;
module->vtable.vec_rnx_to_znx32 = vec_rnx_to_znx32_ref;
module->vtable.vec_rnx_from_znx32 = vec_rnx_from_znx32_ref;
module->vtable.vec_rnx_to_tnx32 = vec_rnx_to_tnx32_ref;
module->vtable.vec_rnx_from_tnx32 = vec_rnx_from_tnx32_ref;
module->vtable.vec_rnx_to_tnxdbl = vec_rnx_to_tnxdbl_ref;
module->vtable.bytes_of_rnx_svp_ppol = fft64_bytes_of_rnx_svp_ppol;
module->vtable.rnx_svp_prepare = fft64_rnx_svp_prepare_ref;
module->vtable.rnx_svp_apply = fft64_rnx_svp_apply_ref;

// Add optimized function pointers here
if (CPU_SUPPORTS("avx")) {
module->vtable.vec_rnx_add = vec_rnx_add_avx;
module->vtable.vec_rnx_sub = vec_rnx_sub_avx;
module->vtable.vec_rnx_negate = vec_rnx_negate_avx;
module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes = fft64_rnx_vmp_apply_dft_to_dft_tmp_bytes_avx;
module->vtable.rnx_vmp_apply_dft_to_dft = fft64_rnx_vmp_apply_dft_to_dft_avx;
module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes = fft64_rnx_vmp_apply_tmp_a_tmp_bytes_avx;
module->vtable.rnx_vmp_apply_tmp_a = fft64_rnx_vmp_apply_tmp_a_avx;
module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes = fft64_rnx_vmp_prepare_contiguous_tmp_bytes_avx;
module->vtable.rnx_vmp_prepare_contiguous = fft64_rnx_vmp_prepare_contiguous_avx;
module->vtable.rnx_approxdecomp_from_tnxdbl = rnx_approxdecomp_from_tnxdbl_avx;
}
}

void init_rnx_module_info(MOD_RNX* module, //
uint64_t n, RNX_MODULE_TYPE mtype) {
memset(module, 0, sizeof(MOD_RNX));
module->n = n;
module->m = n >> 1;
module->mtype = mtype;
switch (mtype) {
case FFT64:
fft64_init_rnx_module_precomp(module);
fft64_init_rnx_module_vtable(module);
break;
default:
NOT_SUPPORTED(); // unknown mtype
}
}

void finalize_rnx_module_info(MOD_RNX* module) {
if (module->custom) module->custom_deleter(module->custom);
switch (module->mtype) {
case FFT64:
fft64_finalize_rnx_module_precomp(module);
// fft64_finalize_rnx_module_vtable(module); // nothing to finalize
break;
default:
NOT_SUPPORTED(); // unknown mtype
}
}

EXPORT MOD_RNX* new_rnx_module_info(uint64_t nn, RNX_MODULE_TYPE mtype) {
MOD_RNX* res = (MOD_RNX*)malloc(sizeof(MOD_RNX));
init_rnx_module_info(res, nn, mtype);
return res;
}

EXPORT void delete_rnx_module_info(MOD_RNX* module_info) {
finalize_rnx_module_info(module_info);
free(module_info);
}

EXPORT uint64_t rnx_module_get_n(const MOD_RNX* module) { return module->n; }

/** @brief allocates a prepared matrix (release with delete_rnx_vmp_pmat) */
EXPORT RNX_VMP_PMAT* new_rnx_vmp_pmat(const MOD_RNX* module, // N
uint64_t nrows, uint64_t ncols) { // dimensions
return (RNX_VMP_PMAT*)spqlios_alloc(bytes_of_rnx_vmp_pmat(module, nrows, ncols));
}
EXPORT void delete_rnx_vmp_pmat(RNX_VMP_PMAT* ptr) { spqlios_free(ptr); }

//////////////// wrappers //////////////////

/** @brief sets res = a + b */
EXPORT void vec_rnx_add( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl, // a
const double* b, uint64_t b_size, uint64_t b_sl // b
) {
module->vtable.vec_rnx_add(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
}

/** @brief sets res = 0 */
EXPORT void vec_rnx_zero( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl // res
) {
module->vtable.vec_rnx_zero(module, res, res_size, res_sl);
}

/** @brief sets res = a */
EXPORT void vec_rnx_copy( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
}

/** @brief sets res = -a */
EXPORT void vec_rnx_negate( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
}

/** @brief sets res = a - b */
EXPORT void vec_rnx_sub( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl, // a
const double* b, uint64_t b_size, uint64_t b_sl // b
) {
module->vtable.vec_rnx_sub(module, res, res_size, res_sl, a, a_size, a_sl, b, b_size, b_sl);
}

/** @brief sets res = a . X^p */
EXPORT void vec_rnx_rotate( //
const MOD_RNX* module, // N
const int64_t p, // rotation value
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_rotate(module, p, res, res_size, res_sl, a, a_size, a_sl);
}

/** @brief sets res = a(X^p) */
EXPORT void vec_rnx_automorphism( //
const MOD_RNX* module, // N
int64_t p, // X -> X^p
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_automorphism(module, p, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_mul_xp_minus_one( //
const MOD_RNX* module, // N
const int64_t p, // rotation value
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_mul_xp_minus_one(module, p, res, res_size, res_sl, a, a_size, a_sl);
}
/** @brief number of bytes in a RNX_VMP_PMAT (for manual allocation) */
EXPORT uint64_t bytes_of_rnx_vmp_pmat(const MOD_RNX* module, // N
uint64_t nrows, uint64_t ncols) { // dimensions
return module->vtable.bytes_of_rnx_vmp_pmat(module, nrows, ncols);
}

/** @brief prepares a vmp matrix (contiguous row-major version) */
EXPORT void rnx_vmp_prepare_contiguous( //
const MOD_RNX* module, // N
RNX_VMP_PMAT* pmat, // output
const double* a, uint64_t nrows, uint64_t ncols, // a
uint8_t* tmp_space // scratch space
) {
module->vtable.rnx_vmp_prepare_contiguous(module, pmat, a, nrows, ncols, tmp_space);
}

/** @brief number of scratch bytes necessary to prepare a matrix */
EXPORT uint64_t rnx_vmp_prepare_contiguous_tmp_bytes(const MOD_RNX* module) {
return module->vtable.rnx_vmp_prepare_contiguous_tmp_bytes(module);
}

/** @brief applies a vmp product res = a x pmat */
EXPORT void rnx_vmp_apply_tmp_a( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
double* tmpa, uint64_t a_size, uint64_t a_sl, // a (will be overwritten)
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
uint8_t* tmp_space // scratch space
) {
module->vtable.rnx_vmp_apply_tmp_a(module, res, res_size, res_sl, tmpa, a_size, a_sl, pmat, nrows, ncols, tmp_space);
}

EXPORT uint64_t rnx_vmp_apply_tmp_a_tmp_bytes( //
const MOD_RNX* module, // N
uint64_t res_size, // res size
uint64_t a_size, // a size
uint64_t nrows, uint64_t ncols // prep matrix dims
) {
return module->vtable.rnx_vmp_apply_tmp_a_tmp_bytes(module, res_size, a_size, nrows, ncols);
}

/** @brief minimal size of the tmp_space */
EXPORT void rnx_vmp_apply_dft_to_dft( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a_dft, uint64_t a_size, uint64_t a_sl, // a
const RNX_VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
) {
module->vtable.rnx_vmp_apply_dft_to_dft(module, res, res_size, res_sl, a_dft, a_size, a_sl, pmat, nrows, ncols,
tmp_space);
}

/** @brief minimal size of the tmp_space */
EXPORT uint64_t rnx_vmp_apply_dft_to_dft_tmp_bytes( //
const MOD_RNX* module, // N
uint64_t res_size, // res
uint64_t a_size, // a
uint64_t nrows, uint64_t ncols // prep matrix
) {
return module->vtable.rnx_vmp_apply_dft_to_dft_tmp_bytes(module, res_size, a_size, nrows, ncols);
}

EXPORT uint64_t bytes_of_rnx_svp_ppol(const MOD_RNX* module) { return module->vtable.bytes_of_rnx_svp_ppol(module); }

EXPORT void rnx_svp_prepare(const MOD_RNX* module, // N
RNX_SVP_PPOL* ppol, // output
const double* pol // a
) {
module->vtable.rnx_svp_prepare(module, ppol, pol);
}

EXPORT void rnx_svp_apply( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // output
const RNX_SVP_PPOL* ppol, // prepared pol
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.rnx_svp_apply(module, // N
res, res_size, res_sl, // output
ppol, // prepared pol
a, a_size, a_sl);
}

EXPORT void rnx_approxdecomp_from_tnxdbl( //
const MOD_RNX* module, // N
const TNXDBL_APPROXDECOMP_GADGET* gadget, // output base 2^K
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a) { // a
module->vtable.rnx_approxdecomp_from_tnxdbl(module, gadget, res, res_size, res_sl, a);
}

EXPORT void vec_rnx_to_znx32( //
const MOD_RNX* module, // N
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_to_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_from_znx32( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_from_znx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_to_tnx32( //
const MOD_RNX* module, // N
int32_t* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_to_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_from_tnx32( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const int32_t* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_from_tnx32(module, res, res_size, res_sl, a, a_size, a_sl);
}

EXPORT void vec_rnx_to_tnxdbl( //
const MOD_RNX* module, // N
double* res, uint64_t res_size, uint64_t res_sl, // res
const double* a, uint64_t a_size, uint64_t a_sl // a
) {
module->vtable.vec_rnx_to_tnxdbl(module, res, res_size, res_sl, a, a_size, a_sl);
}
Loading
Loading