13
13
from Compiler .types import vectorized_classmethod
14
14
from Compiler .program import Tape , Program
15
15
from Compiler .exceptions import *
16
- from Compiler import util , oram , floatingpoint , library
16
+ from Compiler import util , oram , floatingpoint , library , comparison
17
17
from Compiler import instructions_base
18
18
import Compiler .GC .instructions as inst
19
19
import operator
20
20
import math
21
21
import itertools
22
22
from functools import reduce
23
23
24
+ class _binary :
25
+ def reveal_to (self , * args , ** kwargs ):
26
+ raise CompilerError (
27
+ '%s does not support revealing to indivual players' % type (self ))
28
+
24
29
class bits (Tape .Register , _structure , _bit ):
25
30
n = 40
26
31
unit = 64
@@ -149,6 +154,12 @@ def set_length(self, n):
149
154
self .n = n
150
155
def set_size (self , size ):
151
156
pass
157
+ def load_int (self , value ):
158
+ n_limbs = math .ceil (self .n / self .unit )
159
+ for i in range (n_limbs ):
160
+ self .conv_regint (min (self .unit , self .n - i * self .unit ),
161
+ self [i ], regint (value % 2 ** self .unit ))
162
+ value >>= self .unit
152
163
def load_other (self , other ):
153
164
if isinstance (other , cint ):
154
165
assert (self .n == other .size )
@@ -236,12 +247,14 @@ def _new_by_number(self, i, size=1):
236
247
return res
237
248
def if_else (self , x , y ):
238
249
"""
239
- Vectorized oblivious selection::
250
+ Bit-wise oblivious selection::
240
251
241
252
sb32 = sbits.get_type(32)
242
253
print_ln('%s', sb32(3).if_else(sb32(5), sb32(2)).reveal())
243
254
244
- This will output 1.
255
+ This will output 1 because it selects the two least
256
+ significant bits from 5 and the rest of the bits from 2.
257
+
245
258
"""
246
259
return result_conv (x , y )(self & (x ^ y ) ^ y )
247
260
def zero_if_not (self , condition ):
@@ -268,6 +281,9 @@ def copy_from_part(self, source, base, size):
268
281
self .bit_compose (source .bit_decompose ()[base :base + size ]))
269
282
def vector_size (self ):
270
283
return self .n
284
+ @staticmethod
285
+ def size_for_mem ():
286
+ return 1
271
287
272
288
class cbits (bits ):
273
289
""" Clear bits register. Helper type with limited functionality. """
@@ -302,13 +318,6 @@ def conv(cls, other):
302
318
else :
303
319
return super (cbits , cls ).conv (other )
304
320
types = {}
305
- def load_int (self , value ):
306
- n_limbs = math .ceil (self .n / self .unit )
307
- tmp = regint (size = n_limbs )
308
- for i in range (n_limbs ):
309
- tmp [i ].load_int (value % 2 ** self .unit )
310
- value >>= self .unit
311
- self .load_other (tmp )
312
321
def store_in_dynamic_mem (self , address ):
313
322
inst .stmsdci (self , cbits .conv (address ))
314
323
def clear_op (self , other , c_inst , ci_inst , op ):
@@ -502,11 +511,7 @@ def load_int(self, value):
502
511
if self .n <= 32 :
503
512
inst .ldbits (self , self .n , value )
504
513
else :
505
- size = math .ceil (self .n / self .unit )
506
- tmp = regint (size = size )
507
- for i in range (size ):
508
- tmp [i ].load_int ((value >> (i * 64 )) % 2 ** 64 )
509
- self .load_other (tmp )
514
+ bits .load_int (self , value )
510
515
def load_other (self , other ):
511
516
if isinstance (other , cbits ) and self .n == other .n :
512
517
inst .convcbit2s (self .n , self , other )
@@ -675,7 +680,7 @@ def bit_adder(*args, **kwargs):
675
680
def ripple_carry_adder (* args , ** kwargs ):
676
681
return sbitint .ripple_carry_adder (* args , ** kwargs )
677
682
678
- class sbitvec (_vec , _bit ):
683
+ class sbitvec (_vec , _bit , _binary ):
679
684
""" Vector of registers of secret bits, effectively a matrix of secret bits.
680
685
This facilitates parallel arithmetic operations in binary circuits.
681
686
Container types are not supported, use :py:obj:`sbitvec.get_type` for that.
@@ -732,15 +737,16 @@ def get_type(cls, n):
732
737
:py:obj:`v` and the columns by calling :py:obj:`elements`.
733
738
"""
734
739
class sbitvecn (cls , _structure ):
735
- @staticmethod
736
- def malloc (size , creator_tape = None ):
737
- return sbit .malloc (size * n , creator_tape = creator_tape )
740
+ @classmethod
741
+ def malloc (cls , size , creator_tape = None ):
742
+ return sbit .malloc (
743
+ size * cls .mem_size (), creator_tape = creator_tape )
738
744
@staticmethod
739
745
def n_elements ():
740
746
return 1
741
747
@staticmethod
742
748
def mem_size ():
743
- return n
749
+ return sbits . get_type ( n ). mem_size ()
744
750
@classmethod
745
751
def get_input_from (cls , player , size = 1 , f = 0 ):
746
752
""" Secret input from :py:obj:`player`. The input is decomposed
@@ -780,38 +786,28 @@ def __init__(self, other=None, size=None):
780
786
self .v = sbits .get_type (n )(other ).bit_decompose ()
781
787
assert len (self .v ) == n
782
788
assert size is None or size == self .v [0 ].n
783
- @vectorized_classmethod
784
- def load_mem (cls , address ):
785
- size = instructions_base .get_global_vector_size ()
786
- if size not in (None , 1 ):
787
- assert isinstance (address , int ) or len (address ) == 1
788
- sb = sbits .get_type (size )
789
- return cls .from_vec (sb .bit_compose (
790
- sbit .load_mem (address + i + j * n ) for j in range (size ))
791
- for i in range (n ))
792
- if not isinstance (address , int ):
793
- v = [sbit .load_mem (x , size = n ).v [0 ] for x in address ]
794
- return cls (v )
789
+ @classmethod
790
+ def load_mem (cls , address , size = None ):
791
+ if isinstance (address , int ) or len (address ) == 1 :
792
+ address = [address + i for i in range (size or 1 )]
795
793
else :
796
- return cls .from_vec (sbit .load_mem (address + i )
797
- for i in range (n ))
794
+ assert size == None
795
+ return cls (
796
+ [sbits .get_type (n ).load_mem (x ) for x in address ])
798
797
def store_in_mem (self , address ):
799
798
size = 1
800
799
for x in self .v :
801
800
if not util .is_constant (x ):
802
801
size = max (size , x .n )
803
- v = [sbits .get_type (size ).conv (x ) for x in self .v ]
804
- if not isinstance (address , int ) and len (address ) != 1 :
805
- v = self .elements ()
806
- assert len (v ) == len (address )
807
- for x , y in zip (v , address ):
808
- for i , xx in enumerate (x .bit_decompose (n )):
809
- xx .store_in_mem (y + i )
802
+ if isinstance (address , int ):
803
+ address = range (address , address + size )
804
+ elif len (address ) == 1 :
805
+ address = [address + i * self .mem_size ()
806
+ for i in range (size )]
810
807
else :
811
- assert isinstance (address , int ) or len (address ) == 1
812
- for i in range (n ):
813
- for j , x in enumerate (v [i ].bit_decompose ()):
814
- x .store_in_mem (address + i + j * n )
808
+ assert size == len (address )
809
+ for x , dest in zip (self .elements (), address ):
810
+ x .store_in_mem (dest )
815
811
@classmethod
816
812
def two_power (cls , nn , size = 1 ):
817
813
return cls .from_vec (
@@ -864,7 +860,7 @@ def __init__(self, elements=None, length=None, input_length=None):
864
860
assert isinstance (elements , sint )
865
861
if Program .prog .use_split ():
866
862
x = elements .split_to_two_summands (length )
867
- v = sbitint .carry_lookahead_adder (x [0 ], x [1 ], fewer_inv = True )
863
+ v = sbitint .bit_adder (x [0 ], x [1 ])
868
864
else :
869
865
prog = Program .prog
870
866
if not prog .options .ring :
@@ -877,6 +873,7 @@ def __init__(self, elements=None, length=None, input_length=None):
877
873
length , prog .security )
878
874
prog .use_edabit (backup )
879
875
return
876
+ comparison .require_ring_size (length , 'A2B conversion' )
880
877
l = int (Program .prog .options .ring )
881
878
r , r_bits = sint .get_edabit (length , size = elements .size )
882
879
c = ((elements - r ) << (l - length )).reveal ()
@@ -885,6 +882,8 @@ def __init__(self, elements=None, length=None, input_length=None):
885
882
x = sbitintvec .from_vec (r_bits ) + sbitintvec .from_vec (cb )
886
883
v = x .v
887
884
self .v = v [:length ]
885
+ elif isinstance (elements , sbitvec ):
886
+ self .v = elements .v
888
887
elif elements is not None and not (util .is_constant (elements ) and \
889
888
elements == 0 ):
890
889
self .v = sbits .trans (elements )
@@ -1347,13 +1346,19 @@ def elements(self):
1347
1346
def __add__ (self , other ):
1348
1347
if util .is_zero (other ):
1349
1348
return self
1350
- a , b = self .expand (other )
1349
+ try :
1350
+ a , b = self .expand (other )
1351
+ except :
1352
+ return NotImplemented
1351
1353
v = sbitint .bit_adder (a , b )
1352
1354
return self .get_type (len (v )).from_vec (v )
1353
1355
__radd__ = __add__
1354
1356
__sub__ = _bitint .__sub__
1355
1357
def __rsub__ (self , other ):
1356
- a , b = self .expand (other )
1358
+ try :
1359
+ a , b = self .expand (other )
1360
+ except :
1361
+ return NotImplemented
1357
1362
return self .from_vec (b ) - self .from_vec (a )
1358
1363
def __mul__ (self , other ):
1359
1364
if isinstance (other , sbits ):
@@ -1447,7 +1452,7 @@ def output(self):
1447
1452
inst .print_float_plainb (v , cbits .get_type (32 )(- self .f ), cbits (0 ),
1448
1453
cbits (0 ), cbits (0 ))
1449
1454
1450
- class sbitfix (_fix ):
1455
+ class sbitfix (_fix , _binary ):
1451
1456
""" Secret signed fixed-point number in one binary register.
1452
1457
Use :py:obj:`set_precision()` to change the precision.
1453
1458
@@ -1515,7 +1520,7 @@ class cls(_fix):
1515
1520
cls .set_precision (f , k )
1516
1521
return cls ._new (cls .int_type (other ), k , f )
1517
1522
1518
- class sbitfixvec (_fix , _vec ):
1523
+ class sbitfixvec (_fix , _vec , _binary ):
1519
1524
""" Vector of fixed-point numbers for parallel binary computation.
1520
1525
1521
1526
Use :py:obj:`set_precision()` to change the precision.
0 commit comments