From 7fb55fbf027e590a140f890605b4dc0acecb2493 Mon Sep 17 00:00:00 2001
From: Sandra Guasch <sandra.guasch@sbt-ext.com>
Date: Fri, 24 May 2024 13:27:43 +0000
Subject: [PATCH 1/2] vec_znx operations

---
 spqlios/arithmetic/vec_znx.c           | 332 ++++++++++++++++++
 spqlios/arithmetic/vec_znx_avx.c       | 103 ++++++
 spqlios/coeffs/coeffs_arithmetic.c     | 461 +++++++++++++++++++++++++
 spqlios/coeffs/coeffs_arithmetic.h     |  73 ++++
 spqlios/coeffs/coeffs_arithmetic_avx.c |  84 +++++
 5 files changed, 1053 insertions(+)
 create mode 100644 spqlios/arithmetic/vec_znx.c
 create mode 100644 spqlios/arithmetic/vec_znx_avx.c
 create mode 100644 spqlios/coeffs/coeffs_arithmetic.c
 create mode 100644 spqlios/coeffs/coeffs_arithmetic.h
 create mode 100644 spqlios/coeffs/coeffs_arithmetic_avx.c

diff --git a/spqlios/arithmetic/vec_znx.c b/spqlios/arithmetic/vec_znx.c
new file mode 100644
index 0000000..af38265
--- /dev/null
+++ b/spqlios/arithmetic/vec_znx.c
@@ -0,0 +1,332 @@
+#include <assert.h>
+#include <math.h>
+#include <stdint.h>
+#include <stdlib.h>
+#include <string.h>
+
+#include "../coeffs/coeffs_arithmetic.h"
+#include "../q120/q120_arithmetic.h"
+#include "../q120/q120_ntt.h"
+#include "../reim/reim_fft_internal.h"
+#include "../reim4/reim4_arithmetic.h"
+#include "vec_znx_arithmetic.h"
+#include "vec_znx_arithmetic_private.h"
+
+// general function (virtual dispatch)
+
+EXPORT void vec_znx_add(const MODULE* module,                              // N
+                        int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                        const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                        const int64_t* b, uint64_t b_size, uint64_t b_sl   // b
+) {
+  module->func.vec_znx_add(module,                 // N
+                           res, res_size, res_sl,  // res
+                           a, a_size, a_sl,        // a
+                           b, b_size, b_sl         // b
+  );
+}
+
+EXPORT void vec_znx_sub(const MODULE* module,                              // N
+                        int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                        const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                        const int64_t* b, uint64_t b_size, uint64_t b_sl   // b
+) {
+  module->func.vec_znx_sub(module,                 // N
+                           res, res_size, res_sl,  // res
+                           a, a_size, a_sl,        // a
+                           b, b_size, b_sl         // b
+  );
+}
+
+EXPORT void vec_znx_rotate(const MODULE* module,                              // N
+                           const int64_t p,                                   // rotation value
+                           int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                           const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  module->func.vec_znx_rotate(module,                 // N
+                              p,                      // p
+                              res, res_size, res_sl,  // res
+                              a, a_size, a_sl         // a
+  );
+}
+
+EXPORT void vec_znx_automorphism(const MODULE* module,                              // N
+                                 const int64_t p,                                   // X->X^p
+                                 int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                                 const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  module->func.vec_znx_automorphism(module,                 // N
+                                    p,                      // p
+                                    res, res_size, res_sl,  // res
+                                    a, a_size, a_sl         // a
+  );
+}
+
+EXPORT void vec_znx_normalize_base2k(const MODULE* module,                              // N
+                                     uint64_t log2_base2k,                              // output base 2^K
+                                     int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                                     const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                                     uint8_t* tmp_space                                 // scratch space of size >= N
+) {
+  module->func.vec_znx_normalize_base2k(module,                 // N
+                                        log2_base2k,            // log2_base2k
+                                        res, res_size, res_sl,  // res
+                                        a, a_size, a_sl,        // a
+                                        tmp_space);
+}
+
+EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes(const MODULE* module,  // N
+                                                   uint64_t res_size,     // res size
+                                                   uint64_t inp_size      // inp size
+) {
+  return module->func.vec_znx_normalize_base2k_tmp_bytes(module,    // N
+                                                         res_size,  // res size
+                                                         inp_size   // inp size
+  );
+}
+
+// specialized function (ref)
+
+EXPORT void vec_znx_add_ref(const MODULE* module,                              // N
+                            int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                            const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                            const int64_t* b, uint64_t b_size, uint64_t b_sl   // b
+) {
+  const uint64_t nn = module->nn;
+  if (a_size <= b_size) {
+    const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
+    const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
+    // add up to the smallest dimension
+    for (uint64_t i = 0; i < sum_idx; ++i) {
+      znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then copy to the largest dimension
+    for (uint64_t i = sum_idx; i < copy_idx; ++i) {
+      znx_copy_i64_ref(nn, res + i * res_sl, b + i * b_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_ref(nn, res + i * res_sl);
+    }
+  } else {
+    const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
+    const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
+    // add up to the smallest dimension
+    for (uint64_t i = 0; i < sum_idx; ++i) {
+      znx_add_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then copy to the largest dimension
+    for (uint64_t i = sum_idx; i < copy_idx; ++i) {
+      znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_ref(nn, res + i * res_sl);
+    }
+  }
+}
+
+EXPORT void vec_znx_sub_ref(const MODULE* module,                              // N
+                            int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                            const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                            const int64_t* b, uint64_t b_size, uint64_t b_sl   // b
+) {
+  const uint64_t nn = module->nn;
+  if (a_size <= b_size) {
+    const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
+    const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
+    // subtract up to the smallest dimension
+    for (uint64_t i = 0; i < sub_idx; ++i) {
+      znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then negate to the largest dimension
+    for (uint64_t i = sub_idx; i < copy_idx; ++i) {
+      znx_negate_i64_ref(nn, res + i * res_sl, b + i * b_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_ref(nn, res + i * res_sl);
+    }
+  } else {
+    const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
+    const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
+    // subtract up to the smallest dimension
+    for (uint64_t i = 0; i < sub_idx; ++i) {
+      znx_sub_i64_ref(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then copy to the largest dimension
+    for (uint64_t i = sub_idx; i < copy_idx; ++i) {
+      znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_ref(nn, res + i * res_sl);
+    }
+  }
+}
+
+EXPORT void vec_znx_rotate_ref(const MODULE* module,                              // N
+                               const int64_t p,                                   // rotation value
+                               int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                               const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  const uint64_t nn = module->nn;
+
+  const uint64_t rot_end_idx = res_size < a_size ? res_size : a_size;
+  // rotate up to the smallest dimension
+  for (uint64_t i = 0; i < rot_end_idx; ++i) {
+    int64_t* res_ptr = res + i * res_sl;
+    const int64_t* a_ptr = a + i * a_sl;
+    if (res_ptr == a_ptr) {
+      znx_rotate_inplace_i64(nn, p, res_ptr);
+    } else {
+      znx_rotate_i64(nn, p, res_ptr, a_ptr);
+    }
+  }
+  // then extend with zeros
+  for (uint64_t i = rot_end_idx; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
+
+EXPORT void vec_znx_automorphism_ref(const MODULE* module,                              // N
+                                     const int64_t p,                                   // X->X^p
+                                     int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                                     const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  const uint64_t nn = module->nn;
+
+  const uint64_t auto_end_idx = res_size < a_size ? res_size : a_size;
+
+  for (uint64_t i = 0; i < auto_end_idx; ++i) {
+    int64_t* res_ptr = res + i * res_sl;
+    const int64_t* a_ptr = a + i * a_sl;
+    if (res_ptr == a_ptr) {
+      znx_automorphism_inplace_i64(nn, p, res_ptr);
+    } else {
+      znx_automorphism_i64(nn, p, res_ptr, a_ptr);
+    }
+  }
+  // then extend with zeros
+  for (uint64_t i = auto_end_idx; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
+
+EXPORT void vec_znx_normalize_base2k_ref(const MODULE* module,                              // N
+                                         uint64_t log2_base2k,                              // output base 2^K
+                                         int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                                         const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                                         uint8_t* tmp_space  // scratch space of size >= N
+) {
+  const uint64_t nn = module->nn;
+
+  // use MSB limb of res for carry propagation
+  int64_t* cout = (int64_t*)tmp_space;
+  int64_t* cin = 0x0;
+
+  // propagate carry until first limb of res
+  int64_t i = a_size - 1;
+  for (; i >= res_size; --i) {
+    znx_normalize(nn, log2_base2k, 0x0, cout, a + i * a_sl, cin);
+    cin = cout;
+  }
+
+  // propagate carry and normalize
+  for (; i >= 1; --i) {
+    znx_normalize(nn, log2_base2k, res + i * res_sl, cout, a + i * a_sl, cin);
+    cin = cout;
+  }
+
+  // normalize last limb
+  znx_normalize(nn, log2_base2k, res, 0x0, a, cin);
+
+  // extend result with zeros
+  for (uint64_t i = a_size; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
+
+EXPORT uint64_t vec_znx_normalize_base2k_tmp_bytes_ref(const MODULE* module,  // N
+                                                       uint64_t res_size,     // res size
+                                                       uint64_t inp_size      // inp size
+) {
+  const uint64_t nn = module->nn;
+  return nn * sizeof(int64_t);
+}
+
+// alias have to be defined in this unit: do not move
+EXPORT uint64_t fft64_vec_znx_big_normalize_base2k_tmp_bytes(  //
+    const MODULE* module,                                      // N
+    uint64_t res_size,                                         // res size
+    uint64_t inp_size                                          // inp size
+    ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
+
+// alias have to be defined in this unit: do not move
+EXPORT uint64_t fft64_vec_znx_big_range_normalize_base2k_tmp_bytes(  //
+    const MODULE* module,                                            // N
+    uint64_t res_size,                                               // res size
+    uint64_t inp_size                                                // inp size
+    ) __attribute((alias("vec_znx_normalize_base2k_tmp_bytes_ref")));
+
+EXPORT void std_free(void* addr) { free(addr); }
+
+/** @brief sets res = 0 */
+EXPORT void vec_znx_zero(const MODULE* module,                             // N
+                         int64_t* res, uint64_t res_size, uint64_t res_sl  // res
+) {
+  module->func.vec_znx_zero(module, res, res_size, res_sl);
+}
+
+/** @brief sets res = a */
+EXPORT void vec_znx_copy(const MODULE* module,                              // N
+                         int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                         const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  module->func.vec_znx_copy(module, res, res_size, res_sl, a, a_size, a_sl);
+}
+
+/** @brief sets res = a */
+EXPORT void vec_znx_negate(const MODULE* module,                              // N
+                           int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                           const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  module->func.vec_znx_negate(module, res, res_size, res_sl, a, a_size, a_sl);
+}
+
+EXPORT void vec_znx_zero_ref(const MODULE* module,                             // N
+                             int64_t* res, uint64_t res_size, uint64_t res_sl  // res
+) {
+  uint64_t nn = module->nn;
+  for (uint64_t i = 0; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
+
+EXPORT void vec_znx_copy_ref(const MODULE* module,                              // N
+                             int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                             const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  uint64_t nn = module->nn;
+  uint64_t smin = res_size < a_size ? res_size : a_size;
+  for (uint64_t i = 0; i < smin; ++i) {
+    znx_copy_i64_ref(nn, res + i * res_sl, a + i * a_sl);
+  }
+  for (uint64_t i = smin; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
+
+EXPORT void vec_znx_negate_ref(const MODULE* module,                              // N
+                               int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                               const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  uint64_t nn = module->nn;
+  uint64_t smin = res_size < a_size ? res_size : a_size;
+  for (uint64_t i = 0; i < smin; ++i) {
+    znx_negate_i64_ref(nn, res + i * res_sl, a + i * a_sl);
+  }
+  for (uint64_t i = smin; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
diff --git a/spqlios/arithmetic/vec_znx_avx.c b/spqlios/arithmetic/vec_znx_avx.c
new file mode 100644
index 0000000..100902d
--- /dev/null
+++ b/spqlios/arithmetic/vec_znx_avx.c
@@ -0,0 +1,103 @@
+#include <string.h>
+
+#include "../coeffs/coeffs_arithmetic.h"
+#include "../reim4/reim4_arithmetic.h"
+#include "vec_znx_arithmetic_private.h"
+
+// specialized function (ref)
+
+// Note: these functions do not have an avx variant.
+#define znx_copy_i64_avx znx_copy_i64_ref
+#define znx_zero_i64_avx znx_zero_i64_ref
+
+EXPORT void vec_znx_add_avx(const MODULE* module,                              // N
+                            int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                            const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                            const int64_t* b, uint64_t b_size, uint64_t b_sl   // b
+) {
+  const uint64_t nn = module->nn;
+  if (a_size <= b_size) {
+    const uint64_t sum_idx = res_size < a_size ? res_size : a_size;
+    const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
+    // add up to the smallest dimension
+    for (uint64_t i = 0; i < sum_idx; ++i) {
+      znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then copy to the largest dimension
+    for (uint64_t i = sum_idx; i < copy_idx; ++i) {
+      znx_copy_i64_avx(nn, res + i * res_sl, b + i * b_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_avx(nn, res + i * res_sl);
+    }
+  } else {
+    const uint64_t sum_idx = res_size < b_size ? res_size : b_size;
+    const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
+    // add up to the smallest dimension
+    for (uint64_t i = 0; i < sum_idx; ++i) {
+      znx_add_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then copy to the largest dimension
+    for (uint64_t i = sum_idx; i < copy_idx; ++i) {
+      znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_avx(nn, res + i * res_sl);
+    }
+  }
+}
+
+EXPORT void vec_znx_sub_avx(const MODULE* module,                              // N
+                            int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                            const int64_t* a, uint64_t a_size, uint64_t a_sl,  // a
+                            const int64_t* b, uint64_t b_size, uint64_t b_sl   // b
+) {
+  const uint64_t nn = module->nn;
+  if (a_size <= b_size) {
+    const uint64_t sub_idx = res_size < a_size ? res_size : a_size;
+    const uint64_t copy_idx = res_size < b_size ? res_size : b_size;
+    // subtract up to the smallest dimension
+    for (uint64_t i = 0; i < sub_idx; ++i) {
+      znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then negate to the largest dimension
+    for (uint64_t i = sub_idx; i < copy_idx; ++i) {
+      znx_negate_i64_avx(nn, res + i * res_sl, b + i * b_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_avx(nn, res + i * res_sl);
+    }
+  } else {
+    const uint64_t sub_idx = res_size < b_size ? res_size : b_size;
+    const uint64_t copy_idx = res_size < a_size ? res_size : a_size;
+    // subtract up to the smallest dimension
+    for (uint64_t i = 0; i < sub_idx; ++i) {
+      znx_sub_i64_avx(nn, res + i * res_sl, a + i * a_sl, b + i * b_sl);
+    }
+    // then copy to the largest dimension
+    for (uint64_t i = sub_idx; i < copy_idx; ++i) {
+      znx_copy_i64_avx(nn, res + i * res_sl, a + i * a_sl);
+    }
+    // then extend with zeros
+    for (uint64_t i = copy_idx; i < res_size; ++i) {
+      znx_zero_i64_avx(nn, res + i * res_sl);
+    }
+  }
+}
+
+EXPORT void vec_znx_negate_avx(const MODULE* module,                              // N
+                               int64_t* res, uint64_t res_size, uint64_t res_sl,  // res
+                               const int64_t* a, uint64_t a_size, uint64_t a_sl   // a
+) {
+  uint64_t nn = module->nn;
+  uint64_t smin = res_size < a_size ? res_size : a_size;
+  for (uint64_t i = 0; i < smin; ++i) {
+    znx_negate_i64_avx(nn, res + i * res_sl, a + i * a_sl);
+  }
+  for (uint64_t i = smin; i < res_size; ++i) {
+    znx_zero_i64_ref(nn, res + i * res_sl);
+  }
+}
diff --git a/spqlios/coeffs/coeffs_arithmetic.c b/spqlios/coeffs/coeffs_arithmetic.c
new file mode 100644
index 0000000..01d15db
--- /dev/null
+++ b/spqlios/coeffs/coeffs_arithmetic.c
@@ -0,0 +1,461 @@
+#include "coeffs_arithmetic.h"
+
+#include <memory.h>
+#include <assert.h>
+
+/** res = a + b */
+EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
+  for (uint64_t i = 0; i < nn; ++i) {
+    res[i] = a[i] + b[i];
+  }
+}
+/** res = a - b */
+EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
+  for (uint64_t i = 0; i < nn; ++i) {
+    res[i] = a[i] - b[i];
+  }
+}
+
+EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) {
+  for (uint64_t i = 0; i < nn; ++i) {
+    res[i] = -a[i];
+  }
+}
+EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a) { memcpy(res, a, nn * sizeof(int64_t)); }
+
+EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res) { memset(res, 0, nn * sizeof(int64_t)); }
+
+EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in) {
+  uint64_t a = (-p) & (2 * nn - 1);  // a= (-p) (pos)mod (2*nn)
+
+  if (a < nn) {  // rotate to the left
+    uint64_t nma = nn - a;
+    // rotate first half
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = in[j + a];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      res[j] = -in[j - nma];
+    }
+  } else {
+    a -= nn;
+    uint64_t nma = nn - a;
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = -in[j + a];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      // rotate first half
+      res[j] = in[j - nma];
+    }
+  }
+}
+
+EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
+  uint64_t a = (-p) & (2 * nn - 1);  // a= (-p) (pos)mod (2*nn)
+
+  if (a < nn) {  // rotate to the left
+    uint64_t nma = nn - a;
+    // rotate first half
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = in[j + a];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      res[j] = -in[j - nma];
+    }
+  } else {
+    a -= nn;
+    uint64_t nma = nn - a;
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = -in[j + a];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      // rotate first half
+      res[j] = in[j - nma];
+    }
+  }
+}
+
+EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in) {
+  uint64_t a = (-p) & (2 * nn - 1);  // a= (-p) (pos)mod (2*nn)
+  if (a < nn) {                      // rotate to the left
+    uint64_t nma = nn - a;
+    // rotate first half
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = in[j + a] - in[j];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      res[j] = -in[j - nma] - in[j];
+    }
+  } else {
+    a -= nn;
+    uint64_t nma = nn - a;
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = -in[j + a] - in[j];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      // rotate first half
+      res[j] = in[j - nma] - in[j];
+    }
+  }
+}
+
+EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
+  uint64_t a = (-p) & (2 * nn - 1);  // a= (-p) (pos)mod (2*nn)
+  if (a < nn) {                      // rotate to the left
+    uint64_t nma = nn - a;
+    // rotate first half
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = in[j + a] - in[j];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      res[j] = -in[j - nma] - in[j];
+    }
+  } else {
+    a -= nn;
+    uint64_t nma = nn - a;
+    for (uint64_t j = 0; j < nma; j++) {
+      res[j] = -in[j + a] - in[j];
+    }
+    for (uint64_t j = nma; j < nn; j++) {
+      // rotate first half
+      res[j] = in[j - nma] - in[j];
+    }
+  }
+}
+
+// 0 < p < 2nn
+EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in) {
+  res[0] = in[0];
+  uint64_t a = 0;
+  uint64_t _2mn = 2 * nn - 1;
+  for (uint64_t i = 1; i < nn; i++) {
+    a = (a + p) & _2mn;  // i*p mod 2n
+    if (a < nn) {
+      res[a] = in[i];  // res[ip mod 2n] = res[i]
+    } else {
+      res[a - nn] = -in[i];
+    }
+  }
+}
+
+EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in) {
+  res[0] = in[0];
+  uint64_t a = 0;
+  uint64_t _2mn = 2 * nn - 1;
+  for (uint64_t i = 1; i < nn; i++) {
+    a = (a + p) & _2mn;
+    if (a < nn) {
+      res[a] = in[i];  // res[ip mod 2n] = res[i]
+    } else {
+      res[a - nn] = -in[i];
+    }
+  }
+}
+
+EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res) {
+  const uint64_t _2mn = 2 * nn - 1;
+  const uint64_t _mn = nn - 1;
+  uint64_t nb_modif = 0;
+  uint64_t j_start = 0;
+  while (nb_modif < nn) {
+    // follow the cycle that start with j_start
+    uint64_t j = j_start;
+    double tmp1 = res[j];
+    do {
+      // find where the value should go, and with which sign
+      uint64_t new_j = (j + p) & _2mn;  // mod 2n to get the position and sign
+      uint64_t new_j_n = new_j & _mn;   // mod n to get just the position
+      // exchange this position with tmp1 (and take care of the sign)
+      double tmp2 = res[new_j_n];
+      res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
+      tmp1 = tmp2;
+      // move to the new location, and store the number of items modified
+      ++nb_modif;
+      j = new_j_n;
+    } while (j != j_start);
+    // move to the start of the next cycle:
+    // we need to find an index that has not been touched yet, and pick it as next j_start.
+    // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
+    ++j_start;
+  }
+}
+
+EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
+  const uint64_t _2mn = 2 * nn - 1;
+  const uint64_t _mn = nn - 1;
+  uint64_t nb_modif = 0;
+  uint64_t j_start = 0;
+  while (nb_modif < nn) {
+    // follow the cycle that start with j_start
+    uint64_t j = j_start;
+    int64_t tmp1 = res[j];
+    do {
+      // find where the value should go, and with which sign
+      uint64_t new_j = (j + p) & _2mn;  // mod 2n to get the position and sign
+      uint64_t new_j_n = new_j & _mn;   // mod n to get just the position
+      // exchange this position with tmp1 (and take care of the sign)
+      int64_t tmp2 = res[new_j_n];
+      res[new_j_n] = (new_j < nn) ? tmp1 : -tmp1;
+      tmp1 = tmp2;
+      // move to the new location, and store the number of items modified
+      ++nb_modif;
+      j = new_j_n;
+    } while (j != j_start);
+    // move to the start of the next cycle:
+    // we need to find an index that has not been touched yet, and pick it as next j_start.
+    // in practice, it is enough to do +1, because the group of rotations is cyclic and 1 is a generator.
+    ++j_start;
+  }
+}
+
+__always_inline int64_t get_base_k_digit(const int64_t x, const uint64_t base_k) {
+  return (x << (64 - base_k)) >> (64 - base_k);
+}
+
+__always_inline int64_t get_base_k_carry(const int64_t x, const int64_t digit, const uint64_t base_k) {
+  return (x - digit) >> base_k;
+}
+
+EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
+                          const int64_t* carry_in) {
+  assert(in);
+  if (out != 0) {
+    if (carry_in != 0x0 && carry_out != 0x0) {
+      // with carry in and carry out is computed
+      for (uint64_t i = 0; i < nn; ++i) {
+        const int64_t x = in[i];
+        const int64_t cin = carry_in[i];
+
+        int64_t digit = get_base_k_digit(x, base_k);
+        int64_t carry = get_base_k_carry(x, digit, base_k);
+        int64_t digit_plus_cin = digit + cin;
+        int64_t y = get_base_k_digit(digit_plus_cin, base_k);
+        int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
+
+        out[i] = y;
+        carry_out[i] = cout;
+      }
+    } else if (carry_in != 0) {
+      // with carry in and carry out is dropped
+      for (uint64_t i = 0; i < nn; ++i) {
+        const int64_t x = in[i];
+        const int64_t cin = carry_in[i];
+
+        int64_t digit = get_base_k_digit(x, base_k);
+        int64_t digit_plus_cin = digit + cin;
+        int64_t y = get_base_k_digit(digit_plus_cin, base_k);
+
+        out[i] = y;
+      }
+
+    } else if (carry_out != 0) {
+      // no carry in and carry out is computed
+      for (uint64_t i = 0; i < nn; ++i) {
+        const int64_t x = in[i];
+
+        int64_t y = get_base_k_digit(x, base_k);
+        int64_t cout = get_base_k_carry(x, y, base_k);
+
+        out[i] = y;
+        carry_out[i] = cout;
+      }
+
+    } else {
+      // no carry in and carry out is dropped
+      for (uint64_t i = 0; i < nn; ++i) {
+        out[i] = get_base_k_digit(in[i], base_k);
+      }
+    }
+  } else {
+    assert(carry_out);
+    if (carry_in != 0x0) {
+      // with carry in and carry out is computed
+      for (uint64_t i = 0; i < nn; ++i) {
+        const int64_t x = in[i];
+        const int64_t cin = carry_in[i];
+
+        int64_t digit = get_base_k_digit(x, base_k);
+        int64_t carry = get_base_k_carry(x, digit, base_k);
+        int64_t digit_plus_cin = digit + cin;
+        int64_t y = get_base_k_digit(digit_plus_cin, base_k);
+        int64_t cout = carry + get_base_k_carry(digit_plus_cin, y, base_k);
+
+        carry_out[i] = cout;
+      }
+    } else {
+      // no carry in and carry out is computed
+      for (uint64_t i = 0; i < nn; ++i) {
+        const int64_t x = in[i];
+
+        int64_t y = get_base_k_digit(x, base_k);
+        int64_t cout = get_base_k_carry(x, y, base_k);
+
+        carry_out[i] = cout;
+      }
+    }
+  }
+}
+
+void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res) {
+  const uint64_t _2mn = 2 * nn - 1;
+  const uint64_t _mn = nn - 1;
+  const uint64_t m = nn >> 1;
+  // reduce p mod 2n
+  p &= _2mn;
+  // uint64_t vp = p & _2mn;
+  /// uint64_t target_modifs = m >> 1;
+  // we proceed by increasing binary valuation
+  for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
+       binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
+    // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
+    // At the beginning of this loop we have:
+    //   vp = binval * p mod 2n
+    //   target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
+
+    // first, handle the orders 1 and 2.
+    // if p*binval == binval % 2n: we're done!
+    if (vp == binval) return;
+    // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
+    if (((vp + binval) & _2mn) == 0) {
+      for (uint64_t j = binval; j < m; j += binval) {
+        int64_t tmp = res[j];
+        res[j] = -res[nn - j];
+        res[nn - j] = -tmp;
+      }
+      res[m] = -res[m];
+      return;
+    }
+    // if p*binval == binval + n % 2n: negate the orbit and exit
+    if (((vp - binval) & _mn) == 0) {
+      for (uint64_t j = binval; j < nn; j += 2 * binval) {
+        res[j] = -res[j];
+      }
+      return;
+    }
+    // if p*binval == n - binval % 2n: mirror the orbit and continue!
+    if (((vp + binval) & _mn) == 0) {
+      for (uint64_t j = binval; j < m; j += 2 * binval) {
+        int64_t tmp = res[j];
+        res[j] = res[nn - j];
+        res[nn - j] = tmp;
+      }
+      continue;
+    }
+    // otherwise we will follow the orbit cycles,
+    // starting from binval and -binval in parallel
+    uint64_t j_start = binval;
+    uint64_t nb_modif = 0;
+    while (nb_modif < orb_size) {
+      // follow the cycle that start with j_start
+      uint64_t j = j_start;
+      int64_t tmp1 = res[j];
+      int64_t tmp2 = res[nn - j];
+      do {
+        // find where the value should go, and with which sign
+        uint64_t new_j = (j * p) & _2mn;  // mod 2n to get the position and sign
+        uint64_t new_j_n = new_j & _mn;   // mod n to get just the position
+        // exchange this position with tmp1 (and take care of the sign)
+        int64_t tmp1a = res[new_j_n];
+        int64_t tmp2a = res[nn - new_j_n];
+        if (new_j < nn) {
+          res[new_j_n] = tmp1;
+          res[nn - new_j_n] = tmp2;
+        } else {
+          res[new_j_n] = -tmp1;
+          res[nn - new_j_n] = -tmp2;
+        }
+        tmp1 = tmp1a;
+        tmp2 = tmp2a;
+        // move to the new location, and store the number of items modified
+        nb_modif += 2;
+        j = new_j_n;
+      } while (j != j_start);
+      // move to the start of the next cycle:
+      // we need to find an index that has not been touched yet, and pick it as next j_start.
+      // in practice, it is enough to do *5, because 5 is a generator.
+      j_start = (5 * j_start) & _mn;
+    }
+  }
+}
+
+void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res) {
+  const uint64_t _2mn = 2 * nn - 1;
+  const uint64_t _mn = nn - 1;
+  const uint64_t m = nn >> 1;
+  // reduce p mod 2n
+  p &= _2mn;
+  // uint64_t vp = p & _2mn;
+  /// uint64_t target_modifs = m >> 1;
+  // we proceed by increasing binary valuation
+  for (uint64_t binval = 1, vp = p & _2mn, orb_size = m; binval < nn;
+       binval <<= 1, vp = (vp << 1) & _2mn, orb_size >>= 1) {
+    // In this loop, we are going to treat the orbit of indexes = binval mod 2.binval.
+    // At the beginning of this loop we have:
+    //   vp = binval * p mod 2n
+    //   target_modif = m / binval (i.e. order of the orbit binval % 2.binval)
+
+    // first, handle the orders 1 and 2.
+    // if p*binval == binval % 2n: we're done!
+    if (vp == binval) return;
+    // if p*binval == -binval % 2n: nega-mirror the orbit and all the sub-orbits and exit!
+    if (((vp + binval) & _2mn) == 0) {
+      for (uint64_t j = binval; j < m; j += binval) {
+        double tmp = res[j];
+        res[j] = -res[nn - j];
+        res[nn - j] = -tmp;
+      }
+      res[m] = -res[m];
+      return;
+    }
+    // if p*binval == binval + n % 2n: negate the orbit and exit
+    if (((vp - binval) & _mn) == 0) {
+      for (uint64_t j = binval; j < nn; j += 2 * binval) {
+        res[j] = -res[j];
+      }
+      return;
+    }
+    // if p*binval == n - binval % 2n: mirror the orbit and continue!
+    if (((vp + binval) & _mn) == 0) {
+      for (uint64_t j = binval; j < m; j += 2 * binval) {
+        double tmp = res[j];
+        res[j] = res[nn - j];
+        res[nn - j] = tmp;
+      }
+      continue;
+    }
+    // otherwise we will follow the orbit cycles,
+    // starting from binval and -binval in parallel
+    uint64_t j_start = binval;
+    uint64_t nb_modif = 0;
+    while (nb_modif < orb_size) {
+      // follow the cycle that start with j_start
+      uint64_t j = j_start;
+      double tmp1 = res[j];
+      double tmp2 = res[nn - j];
+      do {
+        // find where the value should go, and with which sign
+        uint64_t new_j = (j * p) & _2mn;  // mod 2n to get the position and sign
+        uint64_t new_j_n = new_j & _mn;   // mod n to get just the position
+        // exchange this position with tmp1 (and take care of the sign)
+        double tmp1a = res[new_j_n];
+        double tmp2a = res[nn - new_j_n];
+        if (new_j < nn) {
+          res[new_j_n] = tmp1;
+          res[nn - new_j_n] = tmp2;
+        } else {
+          res[new_j_n] = -tmp1;
+          res[nn - new_j_n] = -tmp2;
+        }
+        tmp1 = tmp1a;
+        tmp2 = tmp2a;
+        // move to the new location, and store the number of items modified
+        nb_modif += 2;
+        j = new_j_n;
+      } while (j != j_start);
+      // move to the start of the next cycle:
+      // we need to find an index that has not been touched yet, and pick it as next j_start.
+      // in practice, it is enough to do *5, because 5 is a generator.
+      j_start = (5 * j_start) & _mn;
+    }
+  }
+}
diff --git a/spqlios/coeffs/coeffs_arithmetic.h b/spqlios/coeffs/coeffs_arithmetic.h
new file mode 100644
index 0000000..73a2b43
--- /dev/null
+++ b/spqlios/coeffs/coeffs_arithmetic.h
@@ -0,0 +1,73 @@
+#ifndef SPQLIOS_COEFFS_ARITHMETIC_H
+#define SPQLIOS_COEFFS_ARITHMETIC_H
+
+#include "../commons.h"
+
+/** res = a + b */
+EXPORT void znx_add_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
+EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
+/** res = a - b */
+EXPORT void znx_sub_i64_ref(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
+EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b);
+/** res = -a */
+EXPORT void znx_negate_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
+EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a);
+/** res = a */
+EXPORT void znx_copy_i64_ref(uint64_t nn, int64_t* res, const int64_t* a);
+/** res = 0 */
+EXPORT void znx_zero_i64_ref(uint64_t nn, int64_t* res);
+
+/**
+ * @param res = X^p *in mod X^nn +1
+ * @param nn the ring dimension
+ * @param p a power for the rotation -2nn <= p <= 2nn
+ * @param in is a rnx/znx vector of dimension nn
+ */
+EXPORT void rnx_rotate_f64(uint64_t nn, int64_t p, double* res, const double* in);
+EXPORT void znx_rotate_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
+EXPORT void rnx_rotate_inplace_f64(uint64_t nn, int64_t p, double* res);
+EXPORT void znx_rotate_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
+
+/**
+ * @brief res(X) = in(X^p)
+ * @param nn the ring dimension
+ * @param p is odd integer and must be between 0 < p < 2nn
+ * @param in is a rnx/znx vector of dimension nn
+ */
+EXPORT void rnx_automorphism_f64(uint64_t nn, int64_t p, double* res, const double* in);
+EXPORT void znx_automorphism_i64(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
+EXPORT void rnx_automorphism_inplace_f64(uint64_t nn, int64_t p, double* res);
+EXPORT void znx_automorphism_inplace_i64(uint64_t nn, int64_t p, int64_t* res);
+
+/**
+ * @brief res = (X^p-1).in
+ * @param nn the ring dimension
+ * @param p must be between -2nn <= p <= 2nn
+ * @param in is a rnx/znx vector of dimension nn
+ */
+EXPORT void rnx_mul_xp_minus_one(uint64_t nn, int64_t p, double* res, const double* in);
+EXPORT void znx_mul_xp_minus_one(uint64_t nn, int64_t p, int64_t* res, const int64_t* in);
+
+/**
+ * @brief      Normalize input plus carry mod-2^k. The following
+ *             equality holds @c {in + carry_in == out + carry_out . 2^k}.
+ *
+ *             @c in must be in [-2^62 .. 2^62]
+ *
+ *             @c out is in [ -2^(base_k-1), 2^(base_k-1) [.
+ *
+ *             @c carry_in and @carry_out have at most 64+1-k bits.
+ *
+ *             Null @c carry_in or @c carry_out are ignored.
+ *
+ * @param[in]  nn         the ring dimension
+ * @param[in]  base_k     the base k
+ * @param      out        output normalized znx
+ * @param      carry_out  output carry znx
+ * @param[in]  in         input znx
+ * @param[in]  carry_in   input carry znx
+ */
+EXPORT void znx_normalize(uint64_t nn, uint64_t base_k, int64_t* out, int64_t* carry_out, const int64_t* in,
+                          const int64_t* carry_in);
+
+#endif  // SPQLIOS_COEFFS_ARITHMETIC_H
diff --git a/spqlios/coeffs/coeffs_arithmetic_avx.c b/spqlios/coeffs/coeffs_arithmetic_avx.c
new file mode 100644
index 0000000..9fea143
--- /dev/null
+++ b/spqlios/coeffs/coeffs_arithmetic_avx.c
@@ -0,0 +1,84 @@
+#include <immintrin.h>
+
+#include "coeffs_arithmetic.h"
+
+// res = a + b. dimension n must be a power of 2
+EXPORT void znx_add_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
+  if (nn <= 2) {
+    if (nn == 1) {
+      res[0] = a[0] + b[0];
+    } else {
+      _mm_storeu_si128((__m128i*)res,                     //
+                       _mm_add_epi64(                     //
+                           _mm_loadu_si128((__m128i*)a),  //
+                           _mm_loadu_si128((__m128i*)b)));
+    }
+  } else {
+    const __m256i* aa = (__m256i*)a;
+    const __m256i* bb = (__m256i*)b;
+    __m256i* rr = (__m256i*)res;
+    __m256i* const rrend = (__m256i*)(res + nn);
+    do {
+      _mm256_storeu_si256(rr,                          //
+                          _mm256_add_epi64(            //
+                              _mm256_loadu_si256(aa),  //
+                              _mm256_loadu_si256(bb)));
+      ++rr;
+      ++aa;
+      ++bb;
+    } while (rr < rrend);
+  }
+}
+
+// res = a - b. dimension n must be a power of 2
+EXPORT void znx_sub_i64_avx(uint64_t nn, int64_t* res, const int64_t* a, const int64_t* b) {
+  if (nn <= 2) {
+    if (nn == 1) {
+      res[0] = a[0] - b[0];
+    } else {
+      _mm_storeu_si128((__m128i*)res,                     //
+                       _mm_sub_epi64(                     //
+                           _mm_loadu_si128((__m128i*)a),  //
+                           _mm_loadu_si128((__m128i*)b)));
+    }
+  } else {
+    const __m256i* aa = (__m256i*)a;
+    const __m256i* bb = (__m256i*)b;
+    __m256i* rr = (__m256i*)res;
+    __m256i* const rrend = (__m256i*)(res + nn);
+    do {
+      _mm256_storeu_si256(rr,                          //
+                          _mm256_sub_epi64(            //
+                              _mm256_loadu_si256(aa),  //
+                              _mm256_loadu_si256(bb)));
+      ++rr;
+      ++aa;
+      ++bb;
+    } while (rr < rrend);
+  }
+}
+
+EXPORT void znx_negate_i64_avx(uint64_t nn, int64_t* res, const int64_t* a) {
+  if (nn <= 2) {
+    if (nn == 1) {
+      res[0] = -a[0];
+    } else {
+      _mm_storeu_si128((__m128i*)res,           //
+                       _mm_sub_epi64(           //
+                           _mm_set1_epi64x(0),  //
+                           _mm_loadu_si128((__m128i*)a)));
+    }
+  } else {
+    const __m256i* aa = (__m256i*)a;
+    __m256i* rr = (__m256i*)res;
+    __m256i* const rrend = (__m256i*)(res + nn);
+    do {
+      _mm256_storeu_si256(rr,                         //
+                          _mm256_sub_epi64(           //
+                              _mm256_set1_epi64x(0),  //
+                              _mm256_loadu_si256(aa)));
+      ++rr;
+      ++aa;
+    } while (rr < rrend);
+  }
+}

From dc6790dab57ce0b9beb559b07ae09e816d9d164d Mon Sep 17 00:00:00 2001
From: Sandra Guasch <sandra.guasch@sbt-ext.com>
Date: Fri, 24 May 2024 13:34:55 +0000
Subject: [PATCH 2/2] update makefile

---
 spqlios/CMakeLists.txt | 5 +++++
 1 file changed, 5 insertions(+)

diff --git a/spqlios/CMakeLists.txt b/spqlios/CMakeLists.txt
index 797e3f2..50fbb13 100644
--- a/spqlios/CMakeLists.txt
+++ b/spqlios/CMakeLists.txt
@@ -4,6 +4,8 @@ enable_language(ASM)
 set(SRCS_GENERIC
         commons.c 
         commons_private.c
+        coeffs/coeffs_arithmetic.c
+        arithmetic/vec_znx.c
         arithmetic/vec_znx_dft.c
         cplx/cplx_common.c
         cplx/cplx_conversions.c
@@ -74,6 +76,8 @@ set_source_files_properties(${SRCS_AVX512} PROPERTIES COMPILE_OPTIONS "-mfma;-ma
 
 # C or assembly source files compiled only on x86: avx2 + bmi targets
 set(SRCS_AVX2
+        arithmetic/vec_znx_avx.c
+        coeffs/coeffs_arithmetic_avx.c
         arithmetic/vec_znx_dft_avx2.c
         q120/q120_arithmetic_avx2.c
         q120/q120_ntt_avx2.c
@@ -111,6 +115,7 @@ set(HEADERSPRIVATE
         q120/q120_arithmetic_private.h
         q120/q120_ntt_private.h
         arithmetic/vec_znx_arithmetic.h
+        coeffs/coeffs_arithmetic.h
         )
 
 set(SPQLIOSSOURCES