Skip to content

Commit 65064c4

Browse files
authored
Merge pull request #31 from tfhe/ms/normalize_base2k_tmp_bytes_change
remove unnecessary argument for normalize_base2k tmp bytes funcs
2 parents ad1abe7 + c1de3cb commit 65064c4

6 files changed

+19
-47
lines changed

spqlios/arithmetic/vec_znx.c

+5-15
Original file line numberDiff line numberDiff line change
@@ -75,13 +75,9 @@ EXPORT void vec_znx_normalize_base2k(const MODULE* module,
7575
tmp_space);
7676
}
7777

78-
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, // N
79-
uint64_t res_size, // res size
80-
uint64_t inp_size // inp size
78+
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module // N
8179
) {
82-
return module->func.vec_znx_normalize_base2k_tmp_bytes(module, // N
83-
res_size, // res size
84-
inp_size // inp size
80+
return module->func.vec_znx_normalize_base2k_tmp_bytes(module // N
8581
);
8682
}
8783

@@ -247,26 +243,20 @@ EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module,
247243
}
248244
}
249245

250-
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, // N
251-
uint64_t res_size, // res size
252-
uint64_t inp_size // inp size
246+
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // N
253247
) {
254248
const uint64_t nn = module->nn;
255249
return nn * sizeof(int64_t);
256250
}
257251

258252
// alias have to be defined in this unit: do not move
259253
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes( //
260-
const MODULE* module, // N
261-
uint64_t res_size, // res size
262-
uint64_t inp_size // inp size
254+
const MODULE* module // N
263255
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
264256

265257
// alias have to be defined in this unit: do not move
266258
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes( //
267-
const MODULE* module, // N
268-
uint64_t res_size, // res size
269-
uint64_t inp_size // inp size
259+
const MODULE* module // N
270260
) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
271261

272262
EXPORT void std_free(void* addr) { free(addr); }

spqlios/arithmetic/vec_znx_arithmetic.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -107,9 +107,7 @@ EXPORT void vec_znx_normalize_base2k(const MODULE* module,
107107
);
108108

109109
/** @brief returns the minimal byte length of scratch space for vec_znx_normalize_base2k */
110-
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module, // N
111-
uint64_t res_size, // res size
112-
uint64_t inp_size // inp size
110+
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module // N
113111
);
114112

115113
/** @brief sets res = a . X^p */
@@ -234,9 +232,7 @@ EXPORT void vec_znx_big_normalize_base2k(const MODULE* module,
234232
);
235233

236234
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
237-
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, // N
238-
uint64_t res_size, // res size
239-
uint64_t inp_size // inp size
235+
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N
240236
);
241237

242238
/** @brief apply a svp product, result = ppol * a, presented in DFT space */
@@ -257,9 +253,7 @@ EXPORT void vec_znx_big_range_normalize_base2k(
257253

258254
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
259255
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
260-
const MODULE* module, // N
261-
uint64_t res_size, // res size
262-
uint64_t inp_size // inp size
256+
const MODULE* module // N
263257
);
264258

265259
/** @brief sets res = a . X^p */

spqlios/arithmetic/vec_znx_arithmetic_private.h

+3-9
Original file line numberDiff line numberDiff line change
@@ -215,9 +215,7 @@ EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module,
215215
uint8_t* tmp_space // scratch space
216216
);
217217

218-
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module, // N
219-
uint64_t res_size, // res size
220-
uint64_t inp_size // inp size
218+
EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module // N
221219
);
222220

223221
EXPORT void vec_znx_rotate_ref(const MODULE* module, // N
@@ -290,9 +288,7 @@ EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module,
290288
);
291289

292290
/** @brief returns the minimal byte length of scratch space for vec_znx_big_normalize_base2k */
293-
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, // N
294-
uint64_t res_size, // res size
295-
uint64_t inp_size // inp size
291+
EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N
296292

297293
);
298294

@@ -306,9 +302,7 @@ EXPORT void fft64_vec_znx_big_range_normalize_base2k(const MODULE* module,
306302
);
307303

308304
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
309-
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module, // N
310-
uint64_t res_size, // res size
311-
uint64_t inp_size // inp size
305+
EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(const MODULE* module // N
312306
);
313307

314308
EXPORT void fft64_vec_znx_dft(const MODULE* module, // N

spqlios/arithmetic/vec_znx_big.c

+4-10
Original file line numberDiff line numberDiff line change
@@ -209,13 +209,9 @@ EXPORT void vec_znx_big_normalize_base2k(const MODULE* module,
209209
tmp_space);
210210
}
211211

212-
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module, // N
213-
uint64_t res_size, // res size
214-
uint64_t inp_size // inp size
212+
EXPORT uint64_t vec_znx_big_normalize_base2k_tmp_bytes(const MODULE* module // N
215213
) {
216-
return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module, // N
217-
res_size, // res size
218-
inp_size // inp size
214+
return module->func.vec_znx_big_normalize_base2k_tmp_bytes(module // N
219215
);
220216
}
221217

@@ -233,11 +229,9 @@ EXPORT void vec_znx_big_range_normalize_base2k(
233229

234230
/** @brief returns the minimal byte length of scratch space for vec_znx_big_range_normalize_base2k */
235231
EXPORT uint64_t vec_znx_big_range_normalize_base2k_tmp_bytes( //
236-
const MODULE* module, // N
237-
uint64_t res_size, // res size
238-
uint64_t inp_size // inp size
232+
const MODULE* module // N
239233
) {
240-
return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module, res_size, inp_size);
234+
return module->func.vec_znx_big_range_normalize_base2k_tmp_bytes(module);
241235
}
242236

243237
EXPORT void fft64_vec_znx_big_normalize_base2k(const MODULE* module, // N

test/spqlios_vec_znx_big_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ static void test_vec_znx_big_normalize(VEC_ZNX_BIG_NORMALIZE_BASE2K_F normalize,
152152
uint64_t r_sl = n + 3;
153153
def_rand_big(a, n, sa);
154154
znx_vec_i64_layout r(n, sr, r_sl);
155-
std::vector<uint8_t> tmp_space(normalize_tmp_bytes(module, sr, sa));
155+
std::vector<uint8_t> tmp_space(normalize_tmp_bytes(module));
156156
normalize(module, k, r.data(), sr, r_sl, a.data, sa, tmp_space.data());
157157
}
158158
}
@@ -189,7 +189,7 @@ static void test_vec_znx_big_range_normalize( //
189189
znx_vec_i64_layout r(n, sr, r_sl);
190190
znx_vec_i64_layout r2(n, sr, r_sl);
191191
// tmp_space is large-enough for both
192-
std::vector<uint8_t> tmp_space(normalize_tmp_bytes(module, sr, sa));
192+
std::vector<uint8_t> tmp_space(normalize_tmp_bytes(module));
193193
normalize(module, k, r.data(), sr, r_sl, a.data, a_start, a_end, a_step, tmp_space.data());
194194
fft64_vec_znx_big_normalize_base2k(module, k, r2.data(), sr, r_sl, aextr.data, range_size, tmp_space.data());
195195
for (uint64_t i = 0; i < sr; ++i) {

test/spqlios_vec_znx_test.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ void test_vec_znx_normalize_outplace(ACTUAL_FCN test_normalize, TMP_BYTES_FNC tm
472472
uint64_t b_sl = uniform_u64_bits(3) * 5 + n;
473473
znx_vec_i64_layout lb(n, sb, b_sl);
474474

475-
const uint64_t tmp_size = tmp_bytes(mod, sa, sa);
475+
const uint64_t tmp_size = tmp_bytes(mod);
476476
uint8_t* tmp = new uint8_t[tmp_size];
477477
test_normalize(mod, // N
478478
base_k, // base_k
@@ -517,7 +517,7 @@ void test_vec_znx_normalize_inplace(ACTUAL_FCN test_normalize, TMP_BYTES_FNC tmp
517517
}
518518
vec_poly_normalize(base_k, la_norm);
519519

520-
const uint64_t tmp_size = tmp_bytes(mod, sa, sa);
520+
const uint64_t tmp_size = tmp_bytes(mod);
521521
uint8_t* tmp = new uint8_t[tmp_size];
522522
test_normalize(mod, // N
523523
base_k, // base_k

0 commit comments

Comments
 (0)