Skip to content

Commit 7310527

Browse files
authoredMay 24, 2024··
Merge pull request #14 from tfhe/ng/vec_znx_api
vec_znx_arithmetic api def
2 parents d76a952 + a04fa51 commit 7310527

File tree

3 files changed

+834
-0
lines changed

3 files changed

+834
-0
lines changed
 

‎spqlios/CMakeLists.txt

+2
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ set(SRCS_F128
8787
# H header files containing the public API (these headers are installed)
8888
set(HEADERSPUBLIC
8989
commons.h
90+
arithmetic/vec_znx_arithmetic.h
9091
cplx/cplx_fft.h
9192
reim/reim_fft.h
9293
q120/q120_common.h
@@ -107,6 +108,7 @@ set(HEADERSPRIVATE
107108
reim/reim_fft_private.h
108109
q120/q120_arithmetic_private.h
109110
q120/q120_ntt_private.h
111+
arithmetic/vec_znx_arithmetic.h
110112
)
111113

112114
set(SPQLIOSSOURCES
+344
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,344 @@
1+
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_H
2+
#define SPQLIOS_VEC_ZNX_ARITHMETIC_H
3+
4+
#include <stdint.h>
5+
6+
#include "../commons.h"
7+
#include "../reim/reim_fft.h"
8+
9+
/**
10+
* We support the following module families:
11+
* - FFT64:
12+
* all the polynomials should fit at all times over 52 bits.
13+
* for FHE implementations, the recommended limb-sizes are
14+
* between K=10 and 20, which is good for low multiplicative depths.
15+
* - NTT120:
16+
* all the polynomials should fit at all times over 119 bits.
17+
* for FHE implementations, the recommended limb-sizes are
18+
* between K=20 and 40, which is good for large multiplicative depths.
19+
*/
20+
typedef enum module_type_t { FFT64, NTT120 } MODULE_TYPE;
21+
22+
/** @brief opaque structure that describr the modules (ZnX,TnX) and the hardware */
23+
typedef struct module_info_t MODULE;
24+
/** @brief opaque type that represents a prepared matrix */
25+
typedef struct vmp_pmat_t VMP_PMAT;
26+
/** @brief opaque type that represents a vector of znx in DFT space */
27+
typedef struct vec_znx_dft_t VEC_ZNX_DFT;
28+
/** @brief opaque type that represents a vector of znx in large coeffs space */
29+
typedef struct vec_znx_bigcoeff_t VEC_ZNX_BIG;
30+
/** @brief opaque type that represents a prepared scalar vector product */
31+
typedef struct svp_ppol_t SVP_PPOL;
32+
/** @brief opaque type that represents a prepared left convolution vector product */
33+
typedef struct cnv_pvec_l_t CNV_PVEC_L;
34+
/** @brief opaque type that represents a prepared right convolution vector product */
35+
typedef struct cnv_pvec_r_t CNV_PVEC_R;
36+
37+
/** @brief allocates a prepared matrix (release with free) */
38+
EXPORT VMP_PMAT* vmp_pmat_alloc(const MODULE* module, // N
39+
uint64_t nrows, uint64_t ncols // dimensions
40+
);
41+
42+
/** @brief allocates a vec_znx in DFT space (release with free) */
43+
EXPORT VEC_ZNX_DFT* vec_znx_dft_alloc(const MODULE* module, // N
44+
uint64_t size);
45+
46+
/** @brief allocates a vec_znx_big (release with free) */
47+
EXPORT VEC_ZNX_BIG* vec_znx_big_alloc(const MODULE* module, // N
48+
uint64_t size);
49+
50+
/** @brief allocates a prepared vector (release with free) */
51+
EXPORT SVP_PPOL* svp_ppol_alloc(const MODULE* module); // N
52+
53+
/** @brief free something (vec_znx, pvmp, pcnv...) was allocated
54+
* It just calls free. It is required to expose it for foreign
55+
* languages bindings that do cannot call libc directly
56+
*/
57+
EXPORT void std_free(void* address);
58+
59+
/**
60+
* @brief obtain a module info for ring dimension N
61+
* the module-info knows about:
62+
* - the dimension N (or the complex dimension m=N/2)
63+
* - any moduleuted fft or ntt items
64+
* - the hardware (avx, arm64, x86, ...)
65+
*/
66+
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mode);
67+
EXPORT void delete_module_info(MODULE* module_info);
68+
EXPORT uint64_t module_get_n(const MODULE* module);
69+
70+
/** @brief sets res = 0 */
71+
EXPORT void vec_znx_zero(const MODULE* module, // N
72+
int64_t* res, uint64_t res_size, uint64_t res_sl // res
73+
);
74+
75+
/** @brief sets res = a */
76+
EXPORT void vec_znx_copy(const MODULE* module, // N
77+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
78+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
79+
);
80+
81+
/** @brief sets res = a */
82+
EXPORT void vec_znx_negate(const MODULE* module, // N
83+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
84+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
85+
);
86+
87+
/** @brief sets res = a + b */
88+
EXPORT void vec_znx_add(const MODULE* module, // N
89+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
90+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
91+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
92+
);
93+
94+
/** @brief sets res = a - b */
95+
EXPORT void vec_znx_sub(const MODULE* module, // N
96+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
97+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
98+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
99+
);
100+
101+
/** @brief sets res = k-normalize-reduce(a) */
102+
EXPORT void vec_znx_normalize_base2k(const MODULE* module, // N
103+
uint64_t log2_base2k, // output base 2^K
104+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
105+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
106+
uint8_t* tmp_space // scratch space (size >= N)
107+
);
108+
109+
/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */
110+
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, // N
111+
uint64_t res_size, // res size
112+
uint64_t inp_size // inp size
113+
);
114+
115+
/** @brief sets res = a . X^p */
116+
EXPORT void vec_znx_rotate(const MODULE* module, // N
117+
const int64_t p, // rotation value
118+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
119+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
120+
);
121+
122+
/** @brief sets res = a(X^p) */
123+
EXPORT void vec_znx_automorphism(const MODULE* module, // N
124+
const int64_t p, // X-X^p
125+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
126+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
127+
);
128+
129+
/** @brief prepares a vmp matrix (contiguous row-major version) */
130+
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
131+
VMP_PMAT* pmat, // output
132+
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
133+
uint8_t* tmp_space // scratch space
134+
);
135+
136+
/** @brief prepares a vmp matrix (mat[row*ncols+col] points to the item) */
137+
EXPORT void vmp_prepare_dblptr(const MODULE* module, // N
138+
VMP_PMAT* pmat, // output
139+
const int64_t** mat, uint64_t nrows, uint64_t ncols, // a
140+
uint8_t* tmp_space // scratch space
141+
);
142+
143+
/** @brief sets res = 0 */
144+
EXPORT void vec_dft_zero(const MODULE* module, // N
145+
VEC_ZNX_DFT* res, uint64_t res_size // res
146+
);
147+
148+
/** @brief sets res = a+b */
149+
EXPORT void vec_dft_add(const MODULE* module, // N
150+
VEC_ZNX_DFT* res, uint64_t res_size, // res
151+
const VEC_ZNX_DFT* a, uint64_t a_size, // a
152+
const VEC_ZNX_DFT* b, uint64_t b_size // b
153+
);
154+
155+
/** @brief sets res = a-b */
156+
EXPORT void vec_dft_sub(const MODULE* module, // N
157+
VEC_ZNX_DFT* res, uint64_t res_size, // res
158+
const VEC_ZNX_DFT* a, uint64_t a_size, // a
159+
const VEC_ZNX_DFT* b, uint64_t b_size // b
160+
);
161+
162+
/** @brief sets res = DFT(a) */
163+
EXPORT void vec_znx_dft(const MODULE* module, // N
164+
VEC_ZNX_DFT* res, uint64_t res_size, // res
165+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
166+
);
167+
168+
/** @brief sets res = iDFT(a_dft) -- output in big coeffs space */
169+
EXPORT void vec_znx_idft(const MODULE* module, // N
170+
VEC_ZNX_BIG* res, uint64_t res_size, // res
171+
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
172+
uint8_t* tmp // scratch space
173+
);
174+
175+
/** @brief tmp bytes required for vec_znx_idft */
176+
EXPORT uint64_t vec_znx_idft_tmp_bytes(const MODULE* module);
177+
178+
/**
179+
* @brief sets res = iDFT(a_dft) -- output in big coeffs space
180+
*
181+
* @note a_dft is overwritten
182+
*/
183+
EXPORT void vec_znx_idft_tmp_a(const MODULE* module, // N
184+
VEC_ZNX_BIG* res, uint64_t res_size, // res
185+
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
186+
);
187+
188+
/** @brief sets res = a+b */
189+
EXPORT void vec_znx_big_add(const MODULE* module, // N
190+
VEC_ZNX_BIG* res, uint64_t res_size, // res
191+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
192+
const VEC_ZNX_BIG* b, uint64_t b_size // b
193+
);
194+
/** @brief sets res = a+b */
195+
EXPORT void vec_znx_big_add_small(const MODULE* module, // N
196+
VEC_ZNX_BIG* res, uint64_t res_size, // res
197+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
198+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
199+
);
200+
EXPORT void vec_znx_big_add_small2(const MODULE* module, // N
201+
VEC_ZNX_BIG* res, uint64_t res_size, // res
202+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
203+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
204+
);
205+
206+
/** @brief sets res = a-b */
207+
EXPORT void vec_znx_big_sub(const MODULE* module, // N
208+
VEC_ZNX_BIG* res, uint64_t res_size, // res
209+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
210+
const VEC_ZNX_BIG* b, uint64_t b_size // b
211+
);
212+
EXPORT void vec_znx_big_sub_small_b(const MODULE* module, // N
213+
VEC_ZNX_BIG* res, uint64_t res_size, // res
214+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
215+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
216+
);
217+
EXPORT void vec_znx_big_sub_small_a(const MODULE* module, // N
218+
VEC_ZNX_BIG* res, uint64_t res_size, // res
219+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
220+
const VEC_ZNX_BIG* b, uint64_t b_size // b
221+
);
222+
EXPORT void vec_znx_big_sub_small2(const MODULE* module, // N
223+
VEC_ZNX_BIG* res, uint64_t res_size, // res
224+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
225+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
226+
);
227+
228+
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
229+
EXPORT void vec_znx_big_normalize_base2k(const MODULE* module, // N
230+
uint64_t log2_base2k, // base-2^k
231+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
232+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
233+
uint8_t* tmp_space // temp space
234+
);
235+
236+
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
237+
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, // N
238+
uint64_t res_size, // res size
239+
uint64_t inp_size // inp size
240+
);
241+
242+
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
243+
EXPORT void fft64_svp_apply_dft(const MODULE* module, // N
244+
const VEC_ZNX_DFT* res, uint64_t res_size, // output
245+
const SVP_PPOL* ppol, // prepared pol
246+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
247+
);
248+
249+
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
250+
EXPORT void vec_znx_big_range_normalize_base2k( //
251+
const MODULE* module, // N
252+
uint64_t log2_base2k, // base-2^k
253+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
254+
const VEC_ZNX_BIG* a, uint64_t a_range_begin, uint64_t a_range_xend, uint64_t a_range_step, // range
255+
uint8_t* tmp_space // temp space
256+
);
257+
258+
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
259+
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
260+
const MODULE* module, // N
261+
uint64_t res_size, // res size
262+
uint64_t inp_size // inp size
263+
);
264+
265+
/** @brief sets res = a . X^p */
266+
EXPORT void vec_znx_big_rotate(const MODULE* module, // N
267+
int64_t p, // rotation value
268+
VEC_ZNX_BIG* res, uint64_t res_size, // res
269+
const VEC_ZNX_BIG* a, uint64_t a_size // a
270+
);
271+
272+
/** @brief sets res = a(X^p) */
273+
EXPORT void vec_znx_big_automorphism(const MODULE* module, // N
274+
int64_t p, // X-X^p
275+
VEC_ZNX_BIG* res, uint64_t res_size, // res
276+
const VEC_ZNX_BIG* a, uint64_t a_size // a
277+
);
278+
279+
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
280+
EXPORT void svp_apply_dft(const MODULE* module, // N
281+
const VEC_ZNX_DFT* res, uint64_t res_size, // output
282+
const SVP_PPOL* ppol, // prepared pol
283+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
284+
);
285+
286+
/** @brief prepares a svp polynomial */
287+
EXPORT void svp_prepare(const MODULE* module, // N
288+
SVP_PPOL* ppol, // output
289+
const int64_t* pol // a
290+
);
291+
292+
/** @brief res = a * b : small integer polynomial product */
293+
EXPORT void znx_small_single_product(const MODULE* module, // N
294+
int64_t* res, // output
295+
const int64_t* a, // a
296+
const int64_t* b, // b
297+
uint8_t* tmp);
298+
299+
/** @brief tmp bytes required for znx_small_single_product */
300+
EXPORT uint64_t znx_small_single_product_tmp_bytes(const MODULE* module);
301+
302+
/** @brief prepares a vmp matrix (contiguous row-major version) */
303+
EXPORT void vmp_prepare_contiguous(const MODULE* module, // N
304+
VMP_PMAT* pmat, // output
305+
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
306+
uint8_t* tmp_space // scratch space
307+
);
308+
309+
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
310+
EXPORT uint64_t vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N
311+
uint64_t nrows, uint64_t ncols);
312+
313+
/** @brief applies a vmp product (result in DFT space) */
314+
EXPORT void vmp_apply_dft(const MODULE* module, // N
315+
VEC_ZNX_DFT* res, uint64_t res_size, // res
316+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
317+
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
318+
uint8_t* tmp_space // scratch space
319+
);
320+
321+
/** @brief minimal size of the tmp_space */
322+
EXPORT uint64_t vmp_apply_dft_tmp_bytes(const MODULE* module, // N
323+
uint64_t res_size, // res
324+
uint64_t a_size, // a
325+
uint64_t nrows, uint64_t ncols // prep matrix
326+
);
327+
328+
/** @brief minimal size of the tmp_space */
329+
EXPORT void vmp_apply_dft_to_dft(const MODULE* module, // N
330+
VEC_ZNX_DFT* res, const uint64_t res_size, // res
331+
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
332+
const VMP_PMAT* pmat, const uint64_t nrows,
333+
const uint64_t ncols, // prep matrix
334+
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
335+
);
336+
;
337+
338+
/** @brief minimal size of the tmp_space */
339+
EXPORT uint64_t vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N
340+
uint64_t res_size, // res
341+
uint64_t a_size, // a
342+
uint64_t nrows, uint64_t ncols // prep matrix
343+
);
344+
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_H
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,488 @@
1+
#ifndef SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
2+
#define SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H
3+
4+
#include "../commons_private.h"
5+
#include "../q120/q120_ntt.h"
6+
#include "vec_znx_arithmetic.h"
7+
8+
/**
9+
* Layouts families:
10+
*
11+
* fft64:
12+
* K: <= 20, N: <= 65536, ell: <= 200
13+
* vec<ZnX> normalized: represented by int64
14+
* vec<ZnX> large: represented by int64 (expect <=52 bits)
15+
* vec<ZnX> DFT: represented by double (reim_fft space)
16+
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special reim4_fft space
17+
*
18+
* ntt120:
19+
* K: <= 50, N: <= 65536, ell: <= 80
20+
* vec<ZnX> normalized: represented by int64
21+
* vec<ZnX> large: represented by int128 (expect <=120 bits)
22+
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
23+
* On AVX2 inftastructure, PMAT, LCNV, RCNV use a special ntt120 space
24+
*
25+
* ntt104:
26+
* K: <= 40, N: <= 65536, ell: <= 80
27+
* vec<ZnX> normalized: represented by int64
28+
* vec<ZnX> large: represented by int128 (expect <=120 bits)
29+
* vec<ZnX> DFT: represented by int64x4 (ntt120 space)
30+
* On AVX512 inftastructure, PMAT, LCNV, RCNV use a special ntt104 space
31+
*/
32+
33+
struct fft64_module_info_t {
34+
// pre-computation for reim_fft
35+
REIM_FFT_PRECOMP* p_fft;
36+
// pre-computation for mul_fft
37+
REIM_FFTVEC_MUL_PRECOMP* mul_fft;
38+
// pre-computation for reim_from_znx6
39+
REIM_FROM_ZNX64_PRECOMP* p_conv;
40+
// pre-computation for reim_tp_znx6
41+
REIM_TO_ZNX64_PRECOMP* p_reim_to_znx;
42+
// pre-computation for reim_fft
43+
REIM_IFFT_PRECOMP* p_ifft;
44+
// pre-computation for reim_fftvec_addmul
45+
REIM_FFTVEC_ADDMUL_PRECOMP* p_addmul;
46+
};
47+
48+
struct q120_module_info_t {
49+
// pre-computation for q120b to q120b ntt
50+
q120_ntt_precomp* p_ntt;
51+
// pre-computation for q120b to q120b intt
52+
q120_ntt_precomp* p_intt;
53+
};
54+
55+
// TODO add function types here
56+
typedef typeof(vmp_pmat_alloc) VMP_PMAT_ALLOC_F;
57+
typedef typeof(vec_znx_dft_alloc) VEC_ZNX_DFT_ALLOC_F;
58+
typedef typeof(vec_znx_big_alloc) VEC_ZNX_BIG_ALLOC_F;
59+
typedef typeof(svp_ppol_alloc) SVP_PPOL_ALLOC_F;
60+
typedef typeof(vec_znx_zero) VEC_ZNX_ZERO_F;
61+
typedef typeof(vec_znx_copy) VEC_ZNX_COPY_F;
62+
typedef typeof(vec_znx_negate) VEC_ZNX_NEGATE_F;
63+
typedef typeof(vec_znx_add) VEC_ZNX_ADD_F;
64+
typedef typeof(vec_znx_dft) VEC_ZNX_DFT_F;
65+
typedef typeof(vec_znx_idft) VEC_ZNX_IDFT_F;
66+
typedef typeof(vec_znx_idft_tmp_bytes) VEC_ZNX_IDFT_TMP_BYTES_F;
67+
typedef typeof(vec_znx_idft_tmp_a) VEC_ZNX_IDFT_TMP_A_F;
68+
typedef typeof(vec_znx_sub) VEC_ZNX_SUB_F;
69+
typedef typeof(vec_znx_rotate) VEC_ZNX_ROTATE_F;
70+
typedef typeof(vec_znx_automorphism) VEC_ZNX_AUTOMORPHISM_F;
71+
typedef typeof(vec_znx_normalize_base2k) VEC_ZNX_NORMALIZE_BASE2K_F;
72+
typedef typeof(vec_znx_normalize_base2k_tmp_bytes) VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F;
73+
typedef typeof(vec_znx_big_normalize_base2k) VEC_ZNX_BIG_NORMALIZE_BASE2K_F;
74+
typedef typeof(vec_znx_big_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F;
75+
typedef typeof(vec_znx_big_range_normalize_base2k) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F;
76+
typedef typeof(vec_znx_big_range_normalize_base2k_tmp_bytes) VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F;
77+
typedef typeof(vec_znx_big_add) VEC_ZNX_BIG_ADD_F;
78+
typedef typeof(vec_znx_big_add_small) VEC_ZNX_BIG_ADD_SMALL_F;
79+
typedef typeof(vec_znx_big_add_small2) VEC_ZNX_BIG_ADD_SMALL2_F;
80+
typedef typeof(vec_znx_big_sub) VEC_ZNX_BIG_SUB_F;
81+
typedef typeof(vec_znx_big_sub_small_a) VEC_ZNX_BIG_SUB_SMALL_A_F;
82+
typedef typeof(vec_znx_big_sub_small_b) VEC_ZNX_BIG_SUB_SMALL_B_F;
83+
typedef typeof(vec_znx_big_sub_small2) VEC_ZNX_BIG_SUB_SMALL2_F;
84+
typedef typeof(vec_znx_big_rotate) VEC_ZNX_BIG_ROTATE_F;
85+
typedef typeof(vec_znx_big_automorphism) VEC_ZNX_BIG_AUTOMORPHISM_F;
86+
typedef typeof(svp_prepare) SVP_PREPARE;
87+
typedef typeof(svp_apply_dft) SVP_APPLY_DFT_F;
88+
typedef typeof(znx_small_single_product) ZNX_SMALL_SINGLE_PRODUCT_F;
89+
typedef typeof(znx_small_single_product_tmp_bytes) ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F;
90+
typedef typeof(vmp_prepare_contiguous) VMP_PREPARE_CONTIGUOUS_F;
91+
typedef typeof(vmp_prepare_contiguous_tmp_bytes) VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F;
92+
typedef typeof(vmp_apply_dft) VMP_APPLY_DFT_F;
93+
typedef typeof(vmp_apply_dft_tmp_bytes) VMP_APPLY_DFT_TMP_BYTES_F;
94+
typedef typeof(vmp_apply_dft_to_dft) VMP_APPLY_DFT_TO_DFT_F;
95+
typedef typeof(vmp_apply_dft_to_dft_tmp_bytes) VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F;
96+
97+
struct module_virtual_functions_t {
98+
// TODO add functions here
99+
VMP_PMAT_ALLOC_F* vmp_pmat_alloc;
100+
VEC_ZNX_DFT_ALLOC_F* vec_znx_dft_alloc;
101+
VEC_ZNX_BIG_ALLOC_F* vec_znx_big_alloc;
102+
SVP_PPOL_ALLOC_F* svp_ppol_alloc;
103+
VEC_ZNX_ZERO_F* vec_znx_zero;
104+
VEC_ZNX_COPY_F* vec_znx_copy;
105+
VEC_ZNX_NEGATE_F* vec_znx_negate;
106+
VEC_ZNX_ADD_F* vec_znx_add;
107+
VEC_ZNX_DFT_F* vec_znx_dft;
108+
VEC_ZNX_IDFT_F* vec_znx_idft;
109+
VEC_ZNX_IDFT_TMP_BYTES_F* vec_znx_idft_tmp_bytes;
110+
VEC_ZNX_IDFT_TMP_A_F* vec_znx_idft_tmp_a;
111+
VEC_ZNX_SUB_F* vec_znx_sub;
112+
VEC_ZNX_ROTATE_F* vec_znx_rotate;
113+
VEC_ZNX_AUTOMORPHISM_F* vec_znx_automorphism;
114+
VEC_ZNX_NORMALIZE_BASE2K_F* vec_znx_normalize_base2k;
115+
VEC_ZNX_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_normalize_base2k_tmp_bytes;
116+
VEC_ZNX_BIG_NORMALIZE_BASE2K_F* vec_znx_big_normalize_base2k;
117+
VEC_ZNX_BIG_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_normalize_base2k_tmp_bytes;
118+
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_F* vec_znx_big_range_normalize_base2k;
119+
VEC_ZNX_BIG_RANGE_NORMALIZE_BASE2K_TMP_BYTES_F* vec_znx_big_range_normalize_base2k_tmp_bytes;
120+
VEC_ZNX_BIG_ADD_F* vec_znx_big_add;
121+
VEC_ZNX_BIG_ADD_SMALL_F* vec_znx_big_add_small;
122+
VEC_ZNX_BIG_ADD_SMALL2_F* vec_znx_big_add_small2;
123+
VEC_ZNX_BIG_SUB_F* vec_znx_big_sub;
124+
VEC_ZNX_BIG_SUB_SMALL_A_F* vec_znx_big_sub_small_a;
125+
VEC_ZNX_BIG_SUB_SMALL_B_F* vec_znx_big_sub_small_b;
126+
VEC_ZNX_BIG_SUB_SMALL2_F* vec_znx_big_sub_small2;
127+
VEC_ZNX_BIG_ROTATE_F* vec_znx_big_rotate;
128+
VEC_ZNX_BIG_AUTOMORPHISM_F* vec_znx_big_automorphism;
129+
SVP_PREPARE* svp_prepare;
130+
SVP_APPLY_DFT_F* svp_apply_dft;
131+
ZNX_SMALL_SINGLE_PRODUCT_F* znx_small_single_product;
132+
ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F* znx_small_single_product_tmp_bytes;
133+
VMP_PREPARE_CONTIGUOUS_F* vmp_prepare_contiguous;
134+
VMP_PREPARE_CONTIGUOUS_TMP_BYTES_F* vmp_prepare_contiguous_tmp_bytes;
135+
VMP_APPLY_DFT_F* vmp_apply_dft;
136+
VMP_APPLY_DFT_TMP_BYTES_F* vmp_apply_dft_tmp_bytes;
137+
VMP_APPLY_DFT_TO_DFT_F* vmp_apply_dft_to_dft;
138+
VMP_APPLY_DFT_TO_DFT_TMP_BYTES_F* vmp_apply_dft_to_dft_tmp_bytes;
139+
};
140+
141+
union backend_module_info_t {
142+
struct fft64_module_info_t fft64;
143+
struct q120_module_info_t q120;
144+
};
145+
146+
struct module_info_t {
147+
// generic parameters
148+
MODULE_TYPE module_type;
149+
uint64_t nn;
150+
uint64_t m;
151+
// backend_dependent functions
152+
union backend_module_info_t mod;
153+
// virtual functions
154+
struct module_virtual_functions_t func;
155+
};
156+
157+
EXPORT VMP_PMAT* fft64_vmp_pmat_alloc(const MODULE* module, // N
158+
uint64_t nrows, uint64_t ncols // dimensions
159+
);
160+
161+
EXPORT VEC_ZNX_DFT* fft64_vec_znx_dft_alloc(const MODULE* module, // N
162+
uint64_t size);
163+
164+
EXPORT VEC_ZNX_BIG* fft64_vec_znx_big_alloc(const MODULE* module, // N
165+
uint64_t size);
166+
167+
EXPORT SVP_PPOL* fft64_svp_ppol_alloc(const MODULE* module); // N
168+
169+
EXPORT void vec_znx_zero_ref(const MODULE* module, // N
170+
int64_t* res, uint64_t res_size, uint64_t res_sl // res
171+
);
172+
173+
EXPORT void vec_znx_copy_ref(const MODULE* precomp, // N
174+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
175+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
176+
);
177+
178+
EXPORT void vec_znx_negate_ref(const MODULE* module, // N
179+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
180+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
181+
);
182+
183+
EXPORT void vec_znx_negate_avx(const MODULE* module, // N
184+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
185+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
186+
);
187+
188+
EXPORT void vec_znx_add_ref(const MODULE* module, // N
189+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
190+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
191+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
192+
);
193+
EXPORT void vec_znx_add_avx(const MODULE* module, // N
194+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
195+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
196+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
197+
);
198+
199+
EXPORT void vec_znx_sub_ref(const MODULE* precomp, // N
200+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
201+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
202+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
203+
);
204+
205+
EXPORT void vec_znx_sub_avx(const MODULE* module, // N
206+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
207+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
208+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
209+
);
210+
211+
EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module, // N
212+
uint64_t log2_base2k, // output base 2^K
213+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
214+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // inp
215+
uint8_t* tmp_space // scratch space
216+
);
217+
218+
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, // N
219+
uint64_t res_size, // res size
220+
uint64_t inp_size // inp size
221+
);
222+
223+
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
224+
const int64_t p, // rotation value
225+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
226+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
227+
);
228+
229+
EXPORT void vec_znx_automorphism_ref(const MODULE* module, // N
230+
const int64_t p, // X->X^p
231+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
232+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
233+
);
234+
235+
EXPORT void vmp_prepare_ref(const MODULE* precomp, // N
236+
VMP_PMAT* pmat, // output
237+
const int64_t* mat, uint64_t nrows, uint64_t ncols // a
238+
);
239+
240+
EXPORT void vmp_apply_dft_ref(const MODULE* precomp, // N
241+
VEC_ZNX_DFT* res, uint64_t res_size, // res
242+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
243+
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols // prep matrix
244+
);
245+
246+
EXPORT void vec_dft_zero_ref(const MODULE* precomp, // N
247+
VEC_ZNX_DFT* res, uint64_t res_size // res
248+
);
249+
250+
EXPORT void vec_dft_add_ref(const MODULE* precomp, // N
251+
VEC_ZNX_DFT* res, uint64_t res_size, // res
252+
const VEC_ZNX_DFT* a, uint64_t a_size, // a
253+
const VEC_ZNX_DFT* b, uint64_t b_size // b
254+
);
255+
256+
EXPORT void vec_dft_sub_ref(const MODULE* precomp, // N
257+
VEC_ZNX_DFT* res, uint64_t res_size, // res
258+
const VEC_ZNX_DFT* a, uint64_t a_size, // a
259+
const VEC_ZNX_DFT* b, uint64_t b_size // b
260+
);
261+
262+
EXPORT void vec_dft_ref(const MODULE* precomp, // N
263+
VEC_ZNX_DFT* res, uint64_t res_size, // res
264+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
265+
);
266+
267+
EXPORT void vec_idft_ref(const MODULE* precomp, // N
268+
VEC_ZNX_BIG* res, uint64_t res_size, // res
269+
const VEC_ZNX_DFT* a_dft, uint64_t a_size);
270+
271+
EXPORT void vec_znx_big_normalize_ref(const MODULE* precomp, // N
272+
uint64_t k, // base-2^k
273+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
274+
const VEC_ZNX_BIG* a, uint64_t a_size // a
275+
);
276+
277+
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
278+
EXPORT void fft64_svp_apply_dft_ref(const MODULE* module, // N
279+
const VEC_ZNX_DFT* res, uint64_t res_size, // output
280+
const SVP_PPOL* ppol, // prepared pol
281+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
282+
);
283+
284+
/** @brief sets res = k-normalize(a) -- output in int64 coeffs space */
285+
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // N
286+
uint64_t k, // base-2^k
287+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
288+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
289+
uint8_t* tmp_space // temp space
290+
);
291+
292+
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
293+
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, // N
294+
uint64_t res_size, // res size
295+
uint64_t inp_size // inp size
296+
297+
);
298+
299+
/** @brief sets res = k-normalize(a.subrange) -- output in int64 coeffs space */
300+
EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module, // N
301+
uint64_t log2_base2k, // base-2^k
302+
int64_t* res, uint64_t res_size, uint64_t res_sl, // res
303+
const VEC_ZNX_BIG* a, uint64_t a_range_begin, // a
304+
uint64_t a_range_xend, uint64_t a_range_step, // range
305+
uint8_t* tmp_space // temp space
306+
);
307+
308+
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
309+
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module, // N
310+
uint64_t res_size, // res size
311+
uint64_t inp_size // inp size
312+
);
313+
314+
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N
315+
VEC_ZNX_DFT* res, uint64_t res_size, // res
316+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
317+
);
318+
319+
EXPORT void fft64_vec_znx_idft(const MODULE* module, // N
320+
VEC_ZNX_BIG* res, uint64_t res_size, // res
321+
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
322+
uint8_t* tmp // scratch space
323+
);
324+
325+
EXPORT uint64_t fft64_vec_znx_idft_tmp_bytes(const MODULE* module);
326+
327+
EXPORT void fft64_vec_znx_idft_tmp_a(const MODULE* module, // N
328+
VEC_ZNX_BIG* res, uint64_t res_size, // res
329+
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
330+
);
331+
332+
EXPORT void ntt120_vec_znx_dft_avx(const MODULE* module, // N
333+
VEC_ZNX_DFT* res, uint64_t res_size, // res
334+
const int64_t* a, uint64_t a_size, uint64_t a_sl // a
335+
);
336+
337+
/** */
338+
EXPORT void ntt120_vec_znx_idft_avx(const MODULE* module, // N
339+
VEC_ZNX_BIG* res, uint64_t res_size, // res
340+
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
341+
uint8_t* tmp // scratch space
342+
);
343+
344+
EXPORT uint64_t ntt120_vec_znx_idft_tmp_bytes_avx(const MODULE* module);
345+
346+
EXPORT void ntt120_vec_znx_idft_tmp_a_avx(const MODULE* module, // N
347+
VEC_ZNX_BIG* res, uint64_t res_size, // res
348+
VEC_ZNX_DFT* a_dft, uint64_t a_size // a is overwritten
349+
);
350+
351+
// big additions/subtractions
352+
353+
/** @brief sets res = a+b */
354+
EXPORT void fft64_vec_znx_big_add(const MODULE* module, // N
355+
VEC_ZNX_BIG* res, uint64_t res_size, // res
356+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
357+
const VEC_ZNX_BIG* b, uint64_t b_size // b
358+
);
359+
/** @brief sets res = a+b */
360+
EXPORT void fft64_vec_znx_big_add_small(const MODULE* module, // N
361+
VEC_ZNX_BIG* res, uint64_t res_size, // res
362+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
363+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
364+
);
365+
EXPORT void fft64_vec_znx_big_add_small2(const MODULE* module, // N
366+
VEC_ZNX_BIG* res, uint64_t res_size, // res
367+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
368+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
369+
);
370+
371+
/** @brief sets res = a-b */
372+
EXPORT void fft64_vec_znx_big_sub(const MODULE* module, // N
373+
VEC_ZNX_BIG* res, uint64_t res_size, // res
374+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
375+
const VEC_ZNX_BIG* b, uint64_t b_size // b
376+
);
377+
EXPORT void fft64_vec_znx_big_sub_small_b(const MODULE* module, // N
378+
VEC_ZNX_BIG* res, uint64_t res_size, // res
379+
const VEC_ZNX_BIG* a, uint64_t a_size, // a
380+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
381+
);
382+
EXPORT void fft64_vec_znx_big_sub_small_a(const MODULE* module, // N
383+
VEC_ZNX_BIG* res, uint64_t res_size, // res
384+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
385+
const VEC_ZNX_BIG* b, uint64_t b_size // b
386+
);
387+
EXPORT void fft64_vec_znx_big_sub_small2(const MODULE* module, // N
388+
VEC_ZNX_BIG* res, uint64_t res_size, // res
389+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
390+
const int64_t* b, uint64_t b_size, uint64_t b_sl // b
391+
);
392+
393+
/** @brief sets res = a . X^p */
394+
EXPORT void fft64_vec_znx_big_rotate(const MODULE* module, // N
395+
int64_t p, // rotation value
396+
VEC_ZNX_BIG* res, uint64_t res_size, // res
397+
const VEC_ZNX_BIG* a, uint64_t a_size // a
398+
);
399+
400+
/** @brief sets res = a(X^p) */
401+
EXPORT void fft64_vec_znx_big_automorphism(const MODULE* module, // N
402+
int64_t p, // X-X^p
403+
VEC_ZNX_BIG* res, uint64_t res_size, // res
404+
const VEC_ZNX_BIG* a, uint64_t a_size // a
405+
);
406+
407+
/** @brief prepares a svp polynomial */
408+
EXPORT void fft64_svp_prepare_ref(const MODULE* module, // N
409+
SVP_PPOL* ppol, // output
410+
const int64_t* pol // a
411+
);
412+
413+
/** @brief res = a * b : small integer polynomial product */
414+
EXPORT void fft64_znx_small_single_product(const MODULE* module, // N
415+
int64_t* res, // output
416+
const int64_t* a, // a
417+
const int64_t* b, // b
418+
uint8_t* tmp);
419+
420+
/** @brief tmp bytes required for znx_small_single_product */
421+
EXPORT uint64_t fft64_znx_small_single_product_tmp_bytes(const MODULE* module);
422+
423+
/** @brief prepares a vmp matrix (contiguous row-major version) */
424+
EXPORT void fft64_vmp_prepare_contiguous_ref(const MODULE* module, // N
425+
VMP_PMAT* pmat, // output
426+
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
427+
uint8_t* tmp_space // scratch space
428+
);
429+
430+
/** @brief prepares a vmp matrix (contiguous row-major version) */
431+
EXPORT void fft64_vmp_prepare_contiguous_avx(const MODULE* module, // N
432+
VMP_PMAT* pmat, // output
433+
const int64_t* mat, uint64_t nrows, uint64_t ncols, // a
434+
uint8_t* tmp_space // scratch space
435+
);
436+
437+
/** @brief minimal scratch space byte-size required for the vmp_prepare function */
438+
EXPORT uint64_t fft64_vmp_prepare_contiguous_tmp_bytes(const MODULE* module, // N
439+
uint64_t nrows, uint64_t ncols);
440+
441+
/** @brief applies a vmp product (result in DFT space) */
442+
EXPORT void fft64_vmp_apply_dft_ref(const MODULE* module, // N
443+
VEC_ZNX_DFT* res, uint64_t res_size, // res
444+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
445+
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
446+
uint8_t* tmp_space // scratch space
447+
);
448+
449+
/** @brief applies a vmp product (result in DFT space) */
450+
EXPORT void fft64_vmp_apply_dft_avx(const MODULE* module, // N
451+
VEC_ZNX_DFT* res, uint64_t res_size, // res
452+
const int64_t* a, uint64_t a_size, uint64_t a_sl, // a
453+
const VMP_PMAT* pmat, uint64_t nrows, uint64_t ncols, // prep matrix
454+
uint8_t* tmp_space // scratch space
455+
);
456+
457+
/** @brief this inner function could be very handy */
458+
EXPORT void fft64_vmp_apply_dft_to_dft_ref(const MODULE* module, // N
459+
VEC_ZNX_DFT* res, const uint64_t res_size, // res
460+
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
461+
const VMP_PMAT* pmat, const uint64_t nrows,
462+
const uint64_t ncols, // prep matrix
463+
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
464+
);
465+
466+
/** @brief this inner function could be very handy */
467+
EXPORT void fft64_vmp_apply_dft_to_dft_avx(const MODULE* module, // N
468+
VEC_ZNX_DFT* res, const uint64_t res_size, // res
469+
const VEC_ZNX_DFT* a_dft, uint64_t a_size, // a
470+
const VMP_PMAT* pmat, const uint64_t nrows,
471+
const uint64_t ncols, // prep matrix
472+
uint8_t* tmp_space // scratch space (a_size*sizeof(reim4) bytes)
473+
);
474+
475+
/** @brief minimal size of the tmp_space */
476+
EXPORT uint64_t fft64_vmp_apply_dft_tmp_bytes(const MODULE* module, // N
477+
uint64_t res_size, // res
478+
uint64_t a_size, // a
479+
uint64_t nrows, uint64_t ncols // prep matrix
480+
);
481+
482+
/** @brief minimal size of the tmp_space */
483+
EXPORT uint64_t fft64_vmp_apply_dft_to_dft_tmp_bytes(const MODULE* module, // N
484+
uint64_t res_size, // res
485+
uint64_t a_size, // a
486+
uint64_t nrows, uint64_t ncols // prep matrix
487+
);
488+
#endif // SPQLIOS_VEC_ZNX_ARITHMETIC_PRIVATE_H

0 commit comments

Comments
 (0)
Please sign in to comment.