|
| 1 | +use std::hash::poseidon2_permutation; |
| 2 | +use std::option::Option; |
| 3 | + |
| 4 | +use dep::protocol_types::point::Point; |
| 5 | + |
| 6 | +global TWO_POW_128: Field = 0x100000000000000000000000000000000; |
| 7 | + |
| 8 | +/// Poseidon2 Encryption. |
| 9 | +/// |
| 10 | +/// ~160 constraints to encrypt 8 fields. Use this hash if you favour proving speed over long-term privacy for your users. |
| 11 | +/// |
| 12 | +/// WARNING: Poseidon2 as an _encryption scheme_ isn't considered as secure as more battle-tested encryption schemes, e.g. AES128. |
| 13 | +/// This is because: |
| 14 | +/// - it's relatively new; |
| 15 | +/// - it isn't used much in the wild, so there's less incentive for hackers or bounty hunters to try to break it; |
| 16 | +/// - it doesn't provide post-quantum privacy. |
| 17 | +/// |
| 18 | +/// If you want to protect your users' privacy decades into the future, it might be prudent to choose |
| 19 | +/// a more 'traditional' encryption scheme. |
| 20 | +/// If your app is "lower stakes", and your users will only care about their privacy in the near future or immediate future, then |
| 21 | +/// this encryption scheme might be for you! |
| 22 | +/// |
| 23 | +/// See the paper: https://drive.google.com/file/d/1EVrP3DzoGbmzkRmYnyEDcIQcXVU7GlOd/view |
| 24 | +/// |
| 25 | +/// Note: The return length is: L padded to the next multiple of 3, plus 1 for a message auth code of s[1]. |
| 26 | +/// |
| 27 | +/// @param nonce is only needed if your use case needs to protect against replay attacks. |
| 28 | +pub fn poseidon2_encrypt<let L: u32>( |
| 29 | + msg: [Field; L], |
| 30 | + shared_secret: Point, |
| 31 | + nonce: Field, |
| 32 | +) -> [Field; ((L + 2) / 3) * 3 + 1] { |
| 33 | + // TODO: assert(nonce < 2^128), assert(L < 2^120); |
| 34 | + let mut s = [0, shared_secret.x, shared_secret.y, nonce + (L as Field) * TWO_POW_128]; |
| 35 | + |
| 36 | + // We wish to compute NUM_MISSING_ELEMENTS, which is how many elements we must add as padding so |
| 37 | + // that the message length becomes a multiple of 3. |
| 38 | + let CEIL = (L + 3 - 1) / 3; // ceil(L / 3) |
| 39 | + let L_UPPER_BOUND = CEIL * 3; |
| 40 | + let NUM_MISSING_ELEMENTS = L_UPPER_BOUND - L; |
| 41 | + |
| 42 | + // The Noir compiler doesn't like using the above-defined constants as array lengths, |
| 43 | + // so these declarations are pretty verbose: |
| 44 | + let mut m = [0 as Field; ((L + 3 - 1) / 3) * 3]; // [Field; L_UPPER_BOUND] |
| 45 | + let mut c = [0 as Field; ((L + 3 - 1) / 3) * 3 + 1]; // [Field; L_UPPER_BOUND + 1] |
| 46 | + |
| 47 | + for i in 0..L { |
| 48 | + m[i] = msg[i]; |
| 49 | + } |
| 50 | + // Pad with 0's: |
| 51 | + for i in 0..NUM_MISSING_ELEMENTS { |
| 52 | + m[L + i] = 0; |
| 53 | + } |
| 54 | + |
| 55 | + for i in 0..CEIL { |
| 56 | + s = poseidon2_permutation(s, 4); |
| 57 | + |
| 58 | + // Absorb 3 elements of the message: |
| 59 | + let j = 3 * i; |
| 60 | + s[1] = s[1] + m[j]; |
| 61 | + s[2] = s[2] + m[j + 1]; |
| 62 | + s[3] = s[3] + m[j + 2]; |
| 63 | + |
| 64 | + // Release 3 elements of ciphertext: |
| 65 | + c[j] = s[1]; |
| 66 | + c[j + 1] = s[2]; |
| 67 | + c[j + 2] = s[3]; |
| 68 | + } |
| 69 | + |
| 70 | + // Iterate Poseidon2 on the state, one last time: |
| 71 | + s = poseidon2_permutation(s, 4); |
| 72 | + |
| 73 | + // Release the last ciphertext element: |
| 74 | + c[L_UPPER_BOUND] = s[1]; |
| 75 | + |
| 76 | + c |
| 77 | +} |
| 78 | + |
| 79 | +pub fn poseidon2_decrypt<let L: u32>( |
| 80 | + ciphertext: [Field; ((L + 3 - 1) / 3) * 3 + 1], |
| 81 | + shared_secret: Point, |
| 82 | + nonce: Field, |
| 83 | +) -> Option<[Field; L]> { |
| 84 | + let mut s = [0, shared_secret.x, shared_secret.y, nonce + (L as Field) * TWO_POW_128]; |
| 85 | + |
| 86 | + let CEIL = (L + 3 - 1) / 3; // ceil(L / 3) |
| 87 | + let L_UPPER_BOUND = CEIL * 3; |
| 88 | + let NUM_EXTRA_ELEMENTS = L_UPPER_BOUND - L; |
| 89 | + |
| 90 | + let mut m = [0 as Field; ((L + 3 - 1) / 3) * 3]; // [Field; L_UPPER_BOUND] |
| 91 | + let c = ciphertext; |
| 92 | + |
| 93 | + for i in 0..CEIL { |
| 94 | + s = poseidon2_permutation(s, 4); |
| 95 | + |
| 96 | + // Release 3 elements of message: |
| 97 | + let j = 3 * i; |
| 98 | + // QUESTION: the paper says to do what's commented-out, but actually, the thing that works is the uncommented code: |
| 99 | + // m[j] = s[1] + c[j]; |
| 100 | + // m[j + 1] = s[2] + c[j + 1]; |
| 101 | + // m[j + 2] = s[3] + c[j + 2]; |
| 102 | + m[j] = c[j] - s[1]; |
| 103 | + m[j + 1] = c[j + 1] - s[2]; |
| 104 | + m[j + 2] = c[j + 2] - s[3]; |
| 105 | + |
| 106 | + // Modify state: |
| 107 | + s[1] = c[j]; |
| 108 | + s[2] = c[j + 1]; |
| 109 | + s[3] = c[j + 2]; |
| 110 | + } |
| 111 | + |
| 112 | + // Iterate Poseidon2 on the state, one last time: |
| 113 | + s = poseidon2_permutation(s, 4); |
| 114 | + |
| 115 | + let mut msg: [Field; L] = [0; L]; |
| 116 | + for i in 0..L { |
| 117 | + msg[i] = m[i]; |
| 118 | + } |
| 119 | + |
| 120 | + let mut decryption_failed: bool = false; |
| 121 | + for i in 0..NUM_EXTRA_ELEMENTS { |
| 122 | + // If decryption is successful, and if the original plaintext was not a multiple of 3, |
| 123 | + // then there will be some lingering values (up to the next multiple of 3 of L) that |
| 124 | + // should be 0. If they are not 0, decryption has failed. |
| 125 | + if m[L + i] != 0 { |
| 126 | + decryption_failed = true; |
| 127 | + } |
| 128 | + } |
| 129 | + |
| 130 | + // Release the last ciphertext element: |
| 131 | + if c[L_UPPER_BOUND] != s[1] { |
| 132 | + // Decryption has failed if the message authentication code (the final |
| 133 | + // element of c) doesn't match s[1]. |
| 134 | + decryption_failed = true; |
| 135 | + } |
| 136 | + |
| 137 | + if decryption_failed { |
| 138 | + Option::none() |
| 139 | + } else { |
| 140 | + Option::some(msg) |
| 141 | + } |
| 142 | +} |
| 143 | + |
| 144 | +mod test { |
| 145 | + use super::{poseidon2_decrypt, poseidon2_encrypt, TWO_POW_128}; |
| 146 | + use std::{ |
| 147 | + embedded_curve_ops::{fixed_base_scalar_mul, multi_scalar_mul}, |
| 148 | + hash::from_field_unsafe as fr_to_fq_unsafe, |
| 149 | + }; |
| 150 | + |
| 151 | + // Helper function that allows us to test encryption, then decryption, for various sizes of message. |
| 152 | + fn encrypt_then_decrypt<let N: u32>(msg: [Field; N]) { |
| 153 | + // Alice encrypting to Bob: |
| 154 | + |
| 155 | + let bob_sk = 0x2345; // Obviously, Alice doesn't know this. |
| 156 | + let bob_pk = fixed_base_scalar_mul(fr_to_fq_unsafe(bob_sk)); |
| 157 | + |
| 158 | + let eph_sk = 0x5678; |
| 159 | + let eph_pk = fixed_base_scalar_mul(fr_to_fq_unsafe(eph_sk)); |
| 160 | + let shared_secret = multi_scalar_mul([bob_pk], [fr_to_fq_unsafe(eph_sk)]); |
| 161 | + |
| 162 | + let nonce = 3; // TODO. Can even be a timestamp. Why is this even needed? |
| 163 | + |
| 164 | + let ciphertext = poseidon2_encrypt(msg, shared_secret, nonce); |
| 165 | + |
| 166 | + // ****************** |
| 167 | + |
| 168 | + // Bob sees: [Epk, ciphertext, nonce]: |
| 169 | + |
| 170 | + let shared_secret = multi_scalar_mul([eph_pk], [fr_to_fq_unsafe(bob_sk)]); |
| 171 | + |
| 172 | + let result = poseidon2_decrypt(ciphertext, shared_secret, nonce); |
| 173 | + |
| 174 | + assert(result.is_some()); |
| 175 | + assert(msg == result.unwrap_unchecked()); |
| 176 | + } |
| 177 | + |
| 178 | + #[test] |
| 179 | + fn poseidon2_encryption() { |
| 180 | + encrypt_then_decrypt([1]); |
| 181 | + encrypt_then_decrypt([1, 2]); |
| 182 | + encrypt_then_decrypt([1, 2, 3]); |
| 183 | + encrypt_then_decrypt([1, 2, 3, 4]); |
| 184 | + encrypt_then_decrypt([1, 2, 3, 4, 5]); |
| 185 | + encrypt_then_decrypt([1, 2, 3, 4, 5, 6]); |
| 186 | + encrypt_then_decrypt([1, 2, 3, 4, 5, 6, 7]); |
| 187 | + encrypt_then_decrypt([1, 2, 3, 4, 5, 6, 7, 8]); |
| 188 | + encrypt_then_decrypt([1, 2, 3, 4, 5, 6, 7, 8, 9]); |
| 189 | + encrypt_then_decrypt([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]); |
| 190 | + } |
| 191 | + |
| 192 | + #[test] |
| 193 | + fn test_poseidon2_decryption_with_bad_secret_fails() { |
| 194 | + // Alice encrypting to Bob: |
| 195 | + |
| 196 | + let bob_sk = 0x2345; // Obviously, Alice doesn't know this. |
| 197 | + let bob_pk = fixed_base_scalar_mul(fr_to_fq_unsafe(bob_sk)); |
| 198 | + |
| 199 | + let eph_sk = 0x5678; |
| 200 | + let eph_pk = fixed_base_scalar_mul(fr_to_fq_unsafe(eph_sk)); |
| 201 | + let shared_secret = multi_scalar_mul([bob_pk], [fr_to_fq_unsafe(eph_sk)]); |
| 202 | + |
| 203 | + let msg = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]; |
| 204 | + |
| 205 | + let nonce = 3; |
| 206 | + |
| 207 | + let ciphertext = poseidon2_encrypt(msg, shared_secret, nonce); |
| 208 | + |
| 209 | + // ****************** |
| 210 | + |
| 211 | + // Bob sees: [Epk, ciphertext, nonce]: |
| 212 | + |
| 213 | + let mut shared_secret = multi_scalar_mul([eph_pk], [fr_to_fq_unsafe(bob_sk)]); |
| 214 | + // Let's intentionally corrupt the shared secret, so that decryption should fail |
| 215 | + shared_secret.x += 1; |
| 216 | + |
| 217 | + let result = poseidon2_decrypt(ciphertext, shared_secret, nonce); |
| 218 | + |
| 219 | + assert(result.is_none()); |
| 220 | + } |
| 221 | + |
| 222 | + // Helper function with encryption boilerplate |
| 223 | + fn encrypt_and_return_ct_length<let N: u32>(msg: [Field; N]) -> u32 { |
| 224 | + // Alice encrypting to Bob: |
| 225 | + |
| 226 | + let bob_sk = 0x2345; // Obviously, Alice doesn't know this. |
| 227 | + let bob_pk = fixed_base_scalar_mul(fr_to_fq_unsafe(bob_sk)); |
| 228 | + |
| 229 | + let eph_sk = 0x5678; |
| 230 | + let eph_pk = fixed_base_scalar_mul(fr_to_fq_unsafe(eph_sk)); |
| 231 | + let shared_secret = multi_scalar_mul([bob_pk], [fr_to_fq_unsafe(eph_sk)]); |
| 232 | + |
| 233 | + let nonce = 3; // TODO. Can even be a timestamp. Why is this even needed? |
| 234 | + |
| 235 | + let ciphertext = poseidon2_encrypt(msg, shared_secret, nonce); |
| 236 | + |
| 237 | + ciphertext.len() |
| 238 | + } |
| 239 | + |
| 240 | + #[test] |
| 241 | + fn test_ciphertext_lengths() { |
| 242 | + // Hard-coded expectations are computed by taking the input array |
| 243 | + // length, computing the next multiple of 3, then adding 1. |
| 244 | + assert(encrypt_and_return_ct_length([1]) == 4); |
| 245 | + assert(encrypt_and_return_ct_length([1, 2]) == 4); |
| 246 | + assert(encrypt_and_return_ct_length([1, 2, 3]) == 4); |
| 247 | + assert(encrypt_and_return_ct_length([1, 2, 3, 4]) == 7); |
| 248 | + assert(encrypt_and_return_ct_length([1, 2, 3, 4, 5]) == 7); |
| 249 | + assert(encrypt_and_return_ct_length([1, 2, 3, 4, 5, 6]) == 7); |
| 250 | + assert(encrypt_and_return_ct_length([1, 2, 3, 4, 5, 6, 7]) == 10); |
| 251 | + assert(encrypt_and_return_ct_length([1, 2, 3, 4, 5, 6, 7, 8]) == 10); |
| 252 | + assert(encrypt_and_return_ct_length([1, 2, 3, 4, 5, 6, 7, 8, 9]) == 10); |
| 253 | + } |
| 254 | + |
| 255 | + #[test] |
| 256 | + fn test_2_pow_128() { |
| 257 | + assert(2.pow_32(128) == TWO_POW_128); |
| 258 | + } |
| 259 | +} |
0 commit comments