Skip to content

Commit 0d3127c

Browse files
authored
Merge pull request #19 from tfhe/ng/utests
module-api + some unittests
2 parents 572b104 + 20bfc45 commit 0d3127c

File tree

4 files changed

+190
-2
lines changed

4 files changed

+190
-2
lines changed

spqlios/CMakeLists.txt

+1-2
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ set(SRCS_GENERIC
3131
arithmetic/scalar_vector_product.c
3232
arithmetic/vec_znx_big.c
3333
arithmetic/znx_small.c
34+
arithmetic/module_api.c
3435
)
3536
# C or assembly source files compiled only on x86 targets
3637
set(SRCS_X86
@@ -154,8 +155,6 @@ if (ENABLE_SPQLIOS_F128)
154155
endif (ENABLE_SPQLIOS_F128)
155156

156157
add_library(libspqlios-static STATIC ${SPQLIOSSOURCES})
157-
add_library(spqlios SHARED ${SPQLIOSSOURCES}
158-
arithmetic/vector_matrix_product.c)
159158
add_library(libspqlios SHARED ${SPQLIOSSOURCES})
160159
set_property(TARGET libspqlios-static PROPERTY POSITION_INDEPENDENT_CODE ON)
161160
set_property(TARGET libspqlios PROPERTY OUTPUT_NAME spqlios)

spqlios/arithmetic/module_api.c

+162
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
#include <string.h>
2+
3+
#include "vec_znx_arithmetic_private.h"
4+
5+
static void fill_generic_virtual_table(MODULE* module) {
6+
// TODO add default ref handler here
7+
module->func.vec_znx_zero = vec_znx_zero_ref;
8+
module->func.vec_znx_copy = vec_znx_copy_ref;
9+
module->func.vec_znx_negate = vec_znx_negate_ref;
10+
module->func.vec_znx_add = vec_znx_add_ref;
11+
module->func.vec_znx_sub = vec_znx_sub_ref;
12+
module->func.vec_znx_rotate = vec_znx_rotate_ref;
13+
module->func.vec_znx_automorphism = vec_znx_automorphism_ref;
14+
module->func.vec_znx_normalize_base2k = vec_znx_normalize_base2k_ref;
15+
module->func.vec_znx_normalize_base2k_tmp_bytes = vec_znx_normalize_base2k_tmp_bytes_ref;
16+
if (CPU_SUPPORTS("avx2")) {
17+
// TODO add avx handlers here
18+
module->func.vec_znx_negate = vec_znx_negate_avx;
19+
module->func.vec_znx_add = vec_znx_add_avx;
20+
module->func.vec_znx_sub = vec_znx_sub_avx;
21+
}
22+
}
23+
24+
static void fill_fft64_virtual_table(MODULE* module) {
25+
// TODO add default ref handler here
26+
// module->func.vec_znx_dft = ...;
27+
module->func.vmp_pmat_alloc = fft64_vmp_pmat_alloc;
28+
module->func.vec_znx_dft_alloc = fft64_vec_znx_dft_alloc;
29+
module->func.vec_znx_big_alloc = fft64_vec_znx_big_alloc;
30+
module->func.svp_ppol_alloc = fft64_svp_ppol_alloc;
31+
module->func.vec_znx_big_normalize_base2k = fft64_vec_znx_big_normalize_base2k;
32+
module->func.vec_znx_big_normalize_base2k_tmp_bytes = fft64_vec_znx_big_normalize_base2k_tmp_bytes;
33+
module->func.vec_znx_big_range_normalize_base2k = fft64_vec_znx_big_range_normalize_base2k;
34+
module->func.vec_znx_big_range_normalize_base2k_tmp_bytes = fft64_vec_znx_big_range_normalize_base2k_tmp_bytes;
35+
module->func.vec_znx_dft = fft64_vec_znx_dft;
36+
module->func.vec_znx_idft = fft64_vec_znx_idft;
37+
module->func.vec_znx_idft_tmp_bytes = fft64_vec_znx_idft_tmp_bytes;
38+
module->func.vec_znx_idft_tmp_a = fft64_vec_znx_idft_tmp_a;
39+
module->func.vec_znx_big_add = fft64_vec_znx_big_add;
40+
module->func.vec_znx_big_add_small = fft64_vec_znx_big_add_small;
41+
module->func.vec_znx_big_add_small2 = fft64_vec_znx_big_add_small2;
42+
module->func.vec_znx_big_sub = fft64_vec_znx_big_sub;
43+
module->func.vec_znx_big_sub_small_a = fft64_vec_znx_big_sub_small_a;
44+
module->func.vec_znx_big_sub_small_b = fft64_vec_znx_big_sub_small_b;
45+
module->func.vec_znx_big_sub_small2 = fft64_vec_znx_big_sub_small2;
46+
module->func.vec_znx_big_rotate = fft64_vec_znx_big_rotate;
47+
module->func.vec_znx_big_automorphism = fft64_vec_znx_big_automorphism;
48+
module->func.svp_prepare = fft64_svp_prepare_ref;
49+
module->func.svp_apply_dft = fft64_svp_apply_dft_ref;
50+
module->func.znx_small_single_product = fft64_znx_small_single_product;
51+
module->func.znx_small_single_product_tmp_bytes = fft64_znx_small_single_product_tmp_bytes;
52+
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_ref;
53+
module->func.vmp_prepare_contiguous_tmp_bytes = fft64_vmp_prepare_contiguous_tmp_bytes;
54+
module->func.vmp_apply_dft = fft64_vmp_apply_dft_ref;
55+
module->func.vmp_apply_dft_tmp_bytes = fft64_vmp_apply_dft_tmp_bytes;
56+
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_ref;
57+
module->func.vmp_apply_dft_to_dft_tmp_bytes = fft64_vmp_apply_dft_to_dft_tmp_bytes;
58+
if (CPU_SUPPORTS("avx2")) {
59+
// TODO add avx handlers here
60+
// TODO: enable when avx implementation is done
61+
module->func.vmp_prepare_contiguous = fft64_vmp_prepare_contiguous_avx;
62+
module->func.vmp_apply_dft = fft64_vmp_apply_dft_avx;
63+
module->func.vmp_apply_dft_to_dft = fft64_vmp_apply_dft_to_dft_avx;
64+
}
65+
}
66+
67+
static void fill_ntt120_virtual_table(MODULE* module) {
68+
// TODO add default ref handler here
69+
// module->func.vec_znx_dft = ...;
70+
if (CPU_SUPPORTS("avx2")) {
71+
// TODO add avx handlers here
72+
module->func.vec_znx_dft = ntt120_vec_znx_dft_avx;
73+
module->func.vec_znx_idft = ntt120_vec_znx_idft_avx;
74+
module->func.vec_znx_idft_tmp_bytes = ntt120_vec_znx_idft_tmp_bytes_avx;
75+
module->func.vec_znx_idft_tmp_a = ntt120_vec_znx_idft_tmp_a_avx;
76+
}
77+
}
78+
79+
static void fill_virtual_table(MODULE* module) {
80+
fill_generic_virtual_table(module);
81+
switch (module->module_type) {
82+
case FFT64:
83+
fill_fft64_virtual_table(module);
84+
break;
85+
case NTT120:
86+
fill_ntt120_virtual_table(module);
87+
break;
88+
default:
89+
NOT_SUPPORTED(); // invalid type
90+
}
91+
}
92+
93+
static void fill_fft64_precomp(MODULE* module) {
94+
// fill any necessary precomp stuff
95+
module->mod.fft64.p_conv = new_reim_from_znx64_precomp(module->m, 50);
96+
module->mod.fft64.p_fft = new_reim_fft_precomp(module->m, 0);
97+
module->mod.fft64.p_reim_to_znx = new_reim_to_znx64_precomp(module->m, module->m, 63);
98+
module->mod.fft64.p_ifft = new_reim_ifft_precomp(module->m, 0);
99+
module->mod.fft64.p_addmul = new_reim_fftvec_addmul_precomp(module->m);
100+
module->mod.fft64.mul_fft = new_reim_fftvec_mul_precomp(module->m);
101+
}
102+
static void fill_ntt120_precomp(MODULE* module) {
103+
// fill any necessary precomp stuff
104+
if (CPU_SUPPORTS("avx2")) {
105+
module->mod.q120.p_ntt = q120_new_ntt_bb_precomp(module->nn);
106+
module->mod.q120.p_intt = q120_new_intt_bb_precomp(module->nn);
107+
}
108+
}
109+
110+
static void fill_module_precomp(MODULE* module) {
111+
switch (module->module_type) {
112+
case FFT64:
113+
fill_fft64_precomp(module);
114+
break;
115+
case NTT120:
116+
fill_ntt120_precomp(module);
117+
break;
118+
default:
119+
NOT_SUPPORTED(); // invalid type
120+
}
121+
}
122+
123+
static void fill_module(MODULE* module, uint64_t nn, MODULE_TYPE mtype) {
124+
// init to zero to ensure that any non-initialized field bug is detected
125+
// by at least a "proper" segfault
126+
memset(module, 0, sizeof(MODULE));
127+
module->module_type = mtype;
128+
module->nn = nn;
129+
module->m = nn >> 1;
130+
fill_module_precomp(module);
131+
fill_virtual_table(module);
132+
}
133+
134+
EXPORT MODULE* new_module_info(uint64_t N, MODULE_TYPE mtype) {
135+
MODULE* m = (MODULE*)malloc(sizeof(MODULE));
136+
fill_module(m, N, mtype);
137+
return m;
138+
}
139+
140+
EXPORT void delete_module_info(MODULE* mod) {
141+
switch (mod->module_type) {
142+
case FFT64:
143+
free(mod->mod.fft64.p_conv);
144+
free(mod->mod.fft64.p_fft);
145+
free(mod->mod.fft64.p_ifft);
146+
free(mod->mod.fft64.p_reim_to_znx);
147+
free(mod->mod.fft64.mul_fft);
148+
free(mod->mod.fft64.p_addmul);
149+
break;
150+
case NTT120:
151+
if (CPU_SUPPORTS("avx2")) {
152+
q120_del_ntt_bb_precomp(mod->mod.q120.p_ntt);
153+
q120_del_intt_bb_precomp(mod->mod.q120.p_intt);
154+
}
155+
break;
156+
default:
157+
break;
158+
}
159+
free(mod);
160+
}
161+
162+
EXPORT uint64_t module_get_n(const MODULE* module) { return module->nn; }

test/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ add_executable(spqlios-test
5959
spqlios_reim_conversions_test.cpp
6060
spqlios_q120_ntt_test.cpp
6161
spqlios_q120_arithmetic_test.cpp
62+
spqlios_znx_small_test.cpp
6263
)
6364
target_link_libraries(spqlios-test spqlios-testlib libspqlios ${gtest_libs})
6465
target_include_directories(spqlios-test PRIVATE ${test_incs})

test/spqlios_znx_small_test.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
#include <gtest/gtest.h>
2+
3+
#include "../spqlios/arithmetic/vec_znx_arithmetic_private.h"
4+
#include "testlib/negacyclic_polynomial.h"
5+
6+
static void test_znx_small_single_product(ZNX_SMALL_SINGLE_PRODUCT_F product,
7+
ZNX_SMALL_SINGLE_PRODUCT_TMP_BYTES_F product_tmp_bytes) {
8+
for (const uint64_t nn : {2, 4, 8, 64}) {
9+
MODULE* module = new_module_info(nn, FFT64);
10+
znx_i64 a = znx_i64::random_log2bound(nn, 20);
11+
znx_i64 b = znx_i64::random_log2bound(nn, 20);
12+
znx_i64 expect = naive_product(a, b);
13+
znx_i64 actual(nn);
14+
std::vector<uint8_t> tmp(znx_small_single_product_tmp_bytes(module));
15+
fft64_znx_small_single_product(module, actual.data(), a.data(), b.data(), tmp.data());
16+
ASSERT_EQ(actual, expect) << actual.get_coeff(0) << " vs. " << expect.get_coeff(0);
17+
delete_module_info(module);
18+
}
19+
}
20+
21+
TEST(znx_small, fft64_znx_small_single_product) {
22+
test_znx_small_single_product(fft64_znx_small_single_product, fft64_znx_small_single_product_tmp_bytes);
23+
}
24+
TEST(znx_small, znx_small_single_product) {
25+
test_znx_small_single_product(znx_small_single_product, znx_small_single_product_tmp_bytes);
26+
}

0 commit comments

Comments
 (0)