Skip to content

Commit 75ef91e

Browse files
bwesterbarmfazh
authored andcommitted
kyber: remove division by q in ciphertext compression
On some platforms, division by q leaks some information on the ciphertext by its timing. If a keypair is reused, and an attacker has access to a decapsulation oracle, this reveals information on the private key. This is known as "kyberslash2". Note that this does not affect to the typical ephemeral usage in TLS.
1 parent 899732a commit 75ef91e

File tree

2 files changed

+123
-10
lines changed

2 files changed

+123
-10
lines changed

pke/kyber/internal/common/poly.go

+18-10
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ func (p *Poly) CompressMessageTo(m []byte) {
166166

167167
// Set p to Decompress_q(m, 1).
168168
//
169-
// Assumes d is in {3, 4, 5, 10, 11}. p will be normalized.
169+
// Assumes d is in {4, 5, 10, 11}. p will be normalized.
170170
func (p *Poly) Decompress(m []byte, d int) {
171171
// Decompress_q(x, d) = ⌈(q/2ᵈ)x⌋
172172
// = ⌊(q/2ᵈ)x+½⌋
@@ -244,20 +244,28 @@ func (p *Poly) Decompress(m []byte, d int) {
244244

245245
// Writes Compress_q(p, d) to m.
246246
//
247-
// Assumes p is normalized and d is in {3, 4, 5, 10, 11}.
247+
// Assumes p is normalized and d is in {4, 5, 10, 11}.
248248
func (p *Poly) CompressTo(m []byte, d int) {
249249
// Compress_q(x, d) = ⌈(2ᵈ/q)x⌋ mod⁺ 2ᵈ
250250
// = ⌊(2ᵈ/q)x+½⌋ mod⁺ 2ᵈ
251251
// = ⌊((x << d) + q/2) / q⌋ mod⁺ 2ᵈ
252252
// = DIV((x << d) + q/2, q) & ((1<<d) - 1)
253+
//
254+
// We approximate DIV(x, q) by computing (x*a)>>e, where a/(2^e) ≈ 1/q.
255+
// For d in {10,11} we use 20,642,679/2^36, which computes division by x/q
256+
// correctly for 0 ≤ x < 41,522,616, which fits (q << 11) + q/2 comfortably.
257+
// For d in {4,5} we use 315/2^20, which doesn't compute division by x/q
258+
// correctly for all inputs, but it's close enough that the end result
259+
// of the compression is correct. The advantage is that we do not need
260+
// to use a 64-bit intermediate value.
253261
switch d {
254262
case 4:
255263
var t [8]uint16
256264
idx := 0
257265
for i := 0; i < N/8; i++ {
258266
for j := 0; j < 8; j++ {
259-
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
260-
uint32(Q)) & ((1 << 4) - 1)
267+
t[j] = uint16((((uint32(p[8*i+j])<<4)+uint32(Q)/2)*315)>>
268+
20) & ((1 << 4) - 1)
261269
}
262270
m[idx] = byte(t[0]) | byte(t[1]<<4)
263271
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
@@ -271,8 +279,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
271279
idx := 0
272280
for i := 0; i < N/8; i++ {
273281
for j := 0; j < 8; j++ {
274-
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
275-
uint32(Q)) & ((1 << 5) - 1)
282+
t[j] = uint16((((uint32(p[8*i+j])<<5)+uint32(Q)/2)*315)>>
283+
20) & ((1 << 5) - 1)
276284
}
277285
m[idx] = byte(t[0]) | byte(t[1]<<5)
278286
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
@@ -287,8 +295,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
287295
idx := 0
288296
for i := 0; i < N/4; i++ {
289297
for j := 0; j < 4; j++ {
290-
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
291-
uint32(Q)) & ((1 << 10) - 1)
298+
t[j] = uint16((uint64((uint32(p[4*i+j])<<10)+uint32(Q)/2)*
299+
20642679)>>36) & ((1 << 10) - 1)
292300
}
293301
m[idx] = byte(t[0])
294302
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
@@ -302,8 +310,8 @@ func (p *Poly) CompressTo(m []byte, d int) {
302310
idx := 0
303311
for i := 0; i < N/8; i++ {
304312
for j := 0; j < 8; j++ {
305-
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
306-
uint32(Q)) & ((1 << 11) - 1)
313+
t[j] = uint16((uint64((uint32(p[8*i+j])<<11)+uint32(Q)/2)*
314+
20642679)>>36) & ((1 << 11) - 1)
307315
}
308316
m[idx] = byte(t[0])
309317
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)

pke/kyber/internal/common/poly_test.go

+105
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package common
22

33
import (
4+
"bytes"
45
"crypto/rand"
56
"fmt"
67
"testing"
@@ -273,3 +274,107 @@ func TestNormalizeAgainstGeneric(t *testing.T) {
273274
}
274275
}
275276
}
277+
278+
func (p *Poly) OldCompressTo(m []byte, d int) {
279+
switch d {
280+
case 4:
281+
var t [8]uint16
282+
idx := 0
283+
for i := 0; i < N/8; i++ {
284+
for j := 0; j < 8; j++ {
285+
t[j] = uint16(((uint32(p[8*i+j])<<4)+uint32(Q)/2)/
286+
uint32(Q)) & ((1 << 4) - 1)
287+
}
288+
m[idx] = byte(t[0]) | byte(t[1]<<4)
289+
m[idx+1] = byte(t[2]) | byte(t[3]<<4)
290+
m[idx+2] = byte(t[4]) | byte(t[5]<<4)
291+
m[idx+3] = byte(t[6]) | byte(t[7]<<4)
292+
idx += 4
293+
}
294+
295+
case 5:
296+
var t [8]uint16
297+
idx := 0
298+
for i := 0; i < N/8; i++ {
299+
for j := 0; j < 8; j++ {
300+
t[j] = uint16(((uint32(p[8*i+j])<<5)+uint32(Q)/2)/
301+
uint32(Q)) & ((1 << 5) - 1)
302+
}
303+
m[idx] = byte(t[0]) | byte(t[1]<<5)
304+
m[idx+1] = byte(t[1]>>3) | byte(t[2]<<2) | byte(t[3]<<7)
305+
m[idx+2] = byte(t[3]>>1) | byte(t[4]<<4)
306+
m[idx+3] = byte(t[4]>>4) | byte(t[5]<<1) | byte(t[6]<<6)
307+
m[idx+4] = byte(t[6]>>2) | byte(t[7]<<3)
308+
idx += 5
309+
}
310+
311+
case 10:
312+
var t [4]uint16
313+
idx := 0
314+
for i := 0; i < N/4; i++ {
315+
for j := 0; j < 4; j++ {
316+
t[j] = uint16(((uint32(p[4*i+j])<<10)+uint32(Q)/2)/
317+
uint32(Q)) & ((1 << 10) - 1)
318+
}
319+
m[idx] = byte(t[0])
320+
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<2)
321+
m[idx+2] = byte(t[1]>>6) | byte(t[2]<<4)
322+
m[idx+3] = byte(t[2]>>4) | byte(t[3]<<6)
323+
m[idx+4] = byte(t[3] >> 2)
324+
idx += 5
325+
}
326+
case 11:
327+
var t [8]uint16
328+
idx := 0
329+
for i := 0; i < N/8; i++ {
330+
for j := 0; j < 8; j++ {
331+
t[j] = uint16(((uint32(p[8*i+j])<<11)+uint32(Q)/2)/
332+
uint32(Q)) & ((1 << 11) - 1)
333+
}
334+
m[idx] = byte(t[0])
335+
m[idx+1] = byte(t[0]>>8) | byte(t[1]<<3)
336+
m[idx+2] = byte(t[1]>>5) | byte(t[2]<<6)
337+
m[idx+3] = byte(t[2] >> 2)
338+
m[idx+4] = byte(t[2]>>10) | byte(t[3]<<1)
339+
m[idx+5] = byte(t[3]>>7) | byte(t[4]<<4)
340+
m[idx+6] = byte(t[4]>>4) | byte(t[5]<<7)
341+
m[idx+7] = byte(t[5] >> 1)
342+
m[idx+8] = byte(t[5]>>9) | byte(t[6]<<2)
343+
m[idx+9] = byte(t[6]>>6) | byte(t[7]<<5)
344+
m[idx+10] = byte(t[7] >> 3)
345+
idx += 11
346+
}
347+
default:
348+
panic("unsupported d")
349+
}
350+
}
351+
352+
func TestCompressFullInputFirstCoeff(t *testing.T) {
353+
for _, d := range []int{4, 5, 10, 11} {
354+
d := d
355+
t.Run(fmt.Sprintf("d=%d", d), func(t *testing.T) {
356+
var p, q Poly
357+
bound := (Q + (1 << uint(d))) >> uint(d+1)
358+
buf := make([]byte, (N*d-1)/8+1)
359+
buf2 := make([]byte, len(buf))
360+
for i := int16(0); i < Q; i++ {
361+
p[0] = i
362+
p.CompressTo(buf, d)
363+
p.OldCompressTo(buf2, d)
364+
if !bytes.Equal(buf, buf2) {
365+
t.Fatalf("%d", i)
366+
}
367+
q.Decompress(buf, d)
368+
diff := sModQ(p[0] - q[0])
369+
if diff < 0 {
370+
diff = -diff
371+
}
372+
if diff > bound {
373+
t.Logf("%v\n", buf)
374+
t.Fatalf("|%d - %d mod^± q| = %d > %d",
375+
p[0], q[0], diff, bound)
376+
}
377+
}
378+
})
379+
}
380+
}

0 commit comments

Comments
 (0)