@@ -325,7 +325,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
325
325
def Pow2 (a , l , kappa ):
326
326
m = int (ceil (log (l , 2 )))
327
327
t = BitDec (a , m , m , kappa )
328
- x = [types .sint () for i in range (m )]
328
+ return Pow2_from_bits (t )
329
+
330
+ def Pow2_from_bits (bits ):
331
+ m = len (bits )
332
+ t = list (bits )
329
333
pow2k = [types .cint () for i in range (m )]
330
334
for i in range (m ):
331
335
pow2k [i ] = two_power (2 ** i )
@@ -353,13 +357,20 @@ def B2U_from_Pow2(pow2a, l, kappa):
353
357
#print ' '.join(str(b.value) for b in y)
354
358
return [1 - y [i ] for i in range (l )]
355
359
356
- def Trunc (a , l , m , kappa , compute_modulo = False ):
360
+ def Trunc (a , l , m , kappa , compute_modulo = False , signed = False ):
357
361
""" Oblivious truncation by secret m """
362
+ if util .is_constant (m ) and not compute_modulo :
363
+ # cheaper
364
+ res = type (a )(size = a .size )
365
+ comparison .Trunc (res , a , l , m , kappa , signed = signed )
366
+ return res
358
367
if l == 1 :
359
368
if compute_modulo :
360
369
return a * m , 1 + m
361
370
else :
362
371
return a * (1 - m )
372
+ if program .Program .prog .options .ring and not compute_modulo :
373
+ return TruncInRing (a , l , Pow2 (m , l , kappa ))
363
374
r = [types .sint () for i in range (l )]
364
375
r_dprime = types .sint (0 )
365
376
r_prime = types .sint (0 )
@@ -370,8 +381,6 @@ def Trunc(a, l, m, kappa, compute_modulo=False):
370
381
x , pow2m = B2U (m , l , kappa )
371
382
#assert(pow2m.value == 2**m.value)
372
383
#assert(sum(b.value for b in x) == m.value)
373
- if program .Program .prog .options .ring and not compute_modulo :
374
- return TruncInRing (a , l , pow2m )
375
384
for i in range (l ):
376
385
bit (r [i ])
377
386
t1 = two_power (i ) * r [i ]
@@ -495,17 +504,28 @@ def TruncPrRing(a, k, m, signed=True):
495
504
return comparison .TruncLeakyInRing (a , k , m , signed = signed )
496
505
else :
497
506
from types import sint
498
- # extra bit to mask overflow
499
- r_bits = [sint .get_random_bit () for i in range (k + 1 )]
500
- n_shift = n_ring - len (r_bits )
501
- tmp = a + sint .bit_compose (r_bits )
502
- masked = (tmp << n_shift ).reveal ()
503
- shifted = (masked << 1 >> (n_shift + m + 1 ))
504
- overflow = r_bits [- 1 ].bit_xor (masked >> (n_ring - 1 ))
505
- res = shifted - sint .bit_compose (r_bits [m :k ]) + (overflow << (k - m ))
507
+ if signed :
508
+ a += (1 << (k - 1 ))
509
+ if program .Program .prog .use_trunc_pr :
510
+ res = sint ()
511
+ trunc_pr (res , a , k , m )
512
+ else :
513
+ # extra bit to mask overflow
514
+ r_bits = [sint .get_random_bit () for i in range (k + 1 )]
515
+ n_shift = n_ring - len (r_bits )
516
+ tmp = a + sint .bit_compose (r_bits )
517
+ masked = (tmp << n_shift ).reveal ()
518
+ shifted = (masked << 1 >> (n_shift + m + 1 ))
519
+ overflow = r_bits [- 1 ].bit_xor (masked >> (n_ring - 1 ))
520
+ res = shifted - sint .bit_compose (r_bits [m :k ]) + \
521
+ (overflow << (k - m ))
522
+ if signed :
523
+ res -= (1 << (k - m - 1 ))
506
524
return res
507
525
508
526
def TruncPrField (a , k , m , kappa = None ):
527
+ if m == 0 :
528
+ return a
509
529
if kappa is None :
510
530
kappa = 40
511
531
@@ -527,19 +547,24 @@ def SDiv(a, b, l, kappa, round_nearest=False):
527
547
w = types .cint (int (2.9142 * two_power (l ))) - 2 * b
528
548
x = alpha - b * w
529
549
y = a * w
530
- y = y .round (2 * l + 1 , l , kappa , round_nearest )
550
+ y = y .round (2 * l + 1 , l , kappa , round_nearest , signed = False )
531
551
x2 = types .sint ()
532
552
comparison .Mod2m (x2 , x , 2 * l + 1 , l , kappa , False )
533
553
x1 = comparison .TruncZeroes (x - x2 , 2 * l + 1 , l , True )
534
554
for i in range (theta - 1 ):
535
- y = y * (x1 + two_power (l )) + (y * x2 ).round (2 * l , l , kappa , round_nearest )
536
- y = y .round (2 * l + 1 , l + 1 , kappa , round_nearest )
537
- x = x1 * x2 + (x2 ** 2 ).round (2 * l + 1 , l + 1 , kappa , round_nearest )
538
- x = x1 * x1 + x .round (2 * l + 1 , l - 1 , kappa , round_nearest )
555
+ y = y * (x1 + two_power (l )) + (y * x2 ).round (2 * l , l , kappa ,
556
+ round_nearest ,
557
+ signed = False )
558
+ y = y .round (2 * l + 1 , l + 1 , kappa , round_nearest , signed = False )
559
+ x = x1 * x2 + (x2 ** 2 ).round (2 * l + 1 , l + 1 , kappa , round_nearest ,
560
+ signed = False )
561
+ x = x1 * x1 + x .round (2 * l + 1 , l - 1 , kappa , round_nearest ,
562
+ signed = False )
539
563
x2 = types .sint ()
540
564
comparison .Mod2m (x2 , x , 2 * l , l , kappa , False )
541
565
x1 = comparison .TruncZeroes (x - x2 , 2 * l + 1 , l , True )
542
- y = y * (x1 + two_power (l )) + (y * x2 ).round (2 * l , l , kappa , round_nearest )
566
+ y = y * (x1 + two_power (l )) + (y * x2 ).round (2 * l , l , kappa ,
567
+ round_nearest , signed = False )
543
568
y = y .round (2 * l + 1 , l - 1 , kappa , round_nearest )
544
569
return y
545
570
0 commit comments