Skip to content

Commit 7a5195d

Browse files
committed
Machine learning functionality, dishonest-majority binary secret sharing.
1 parent 5f0a7ad commit 7a5195d

File tree

203 files changed

+6255
-1484
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

203 files changed

+6255
-1484
lines changed

CHANGELOG.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
22

3+
## 0.1.2
4+
5+
- Machine learning capabilities used for [MobileNets inference](https://eprint.iacr.org/2019/131) and the iDASH submission
6+
- Binary computation for dishonest majority using secret sharing
7+
- Mathematical functions from [SCALE-MAMBA](https://github.com/KULeuven-COSIC/SCALE-MAMBA)
8+
- Fixed security bug: CowGear would reuse triples.
9+
310
## 0.1.1 (Aug 6, 2019)
411

512
- ECDSA

Compiler/allocator.py

+15-11
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,12 @@ def determine_scope(block, options):
101101
used_from_scope = set()
102102

103103
def find_in_scope(reg, scope):
104-
if scope is None:
105-
return False
106-
elif reg in scope.defined_registers:
107-
return True
108-
else:
109-
return find_in_scope(reg, scope.scope)
104+
while True:
105+
if scope is None:
106+
return False
107+
elif reg in scope.defined_registers:
108+
return True
109+
scope = scope.scope
110110

111111
def read(reg, n):
112112
if last_def[reg] == -1:
@@ -386,7 +386,7 @@ def dependency_graph(self, merge_classes):
386386
last_print_str = None
387387
last = defaultdict(lambda: defaultdict(lambda: None))
388388
last_open = deque()
389-
last_text_input = None
389+
last_text_input = [None, None]
390390

391391
depths = [0] * len(block.instructions)
392392
self.depths = depths
@@ -474,10 +474,14 @@ def keep_order(instr, n, t, arg_index=None):
474474

475475
# will be merged
476476
if isinstance(instr, TextInputInstruction):
477-
if last_text_input is not None and \
478-
type(block.instructions[last_text_input]) is not type(instr):
479-
add_edge(last_text_input, n)
480-
last_text_input = n
477+
if last_text_input[0] is not None:
478+
if instr.merge_id() != \
479+
block.instructions[last_text_input[0]].merge_id():
480+
add_edge(last_text_input[0], n)
481+
last_text_input[1] = last_text_input[0]
482+
elif last_text_input[1] is not None:
483+
add_edge(last_text_input[1], n)
484+
last_text_input[0] = n
481485

482486
if isinstance(instr, merge_classes):
483487
open_nodes.add(n)

Compiler/comparison.py

+20-4
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,12 @@ def LTZ(s, a, k, kappa):
8080
Trunc(t, a, k, k - 1, kappa, True)
8181
subsfi(s, t, 0)
8282

83+
def LessThanZero(a, k, kappa):
84+
import types
85+
res = types.sint()
86+
LTZ(res, a, k, kappa)
87+
return res
88+
8389
def Trunc(d, a, k, m, kappa, signed):
8490
"""
8591
d = a >> m
@@ -153,6 +159,8 @@ def TruncRoundNearest(a, k, m, kappa, signed=False):
153159
k: bit length of a
154160
m: compile-time integer
155161
"""
162+
if m == 0:
163+
return a
156164
if k == int(program.options.ring):
157165
# cannot work with bit length k+1
158166
tmp = TruncRing(None, a, k, m - 1, signed)
@@ -359,7 +367,7 @@ def CarryOutAux(d, a, kappa):
359367
movs(d, a[0][1])
360368

361369
# carry out with carry-in bit c
362-
def CarryOut(res, a, b, c, kappa):
370+
def CarryOut(res, a, b, c=0, kappa=None):
363371
"""
364372
res = last carry bit in addition of a and b
365373
@@ -368,21 +376,29 @@ def CarryOut(res, a, b, c, kappa):
368376
c: initial carry-in bit
369377
"""
370378
k = len(a)
379+
import types
371380
d = [program.curr_block.new_reg('s') for i in range(k)]
372-
t = [[program.curr_block.new_reg('s') for i in range(k)] for i in range(4)]
381+
t = [[types.sint() for i in range(k)] for i in range(4)]
373382
s = [program.curr_block.new_reg('s') for i in range(3)]
374383
for i in range(k):
375384
mulm(t[0][i], b[i], a[i])
376385
mulsi(t[1][i], t[0][i], 2)
377386
addm(t[2][i], b[i], a[i])
378387
subs(t[3][i], t[2][i], t[1][i])
379388
d[i] = [t[3][i], t[0][i]]
380-
mulsi(s[0], d[-1][0], c)
381-
adds(s[1], d[-1][1], s[0])
389+
s[0] = d[-1][0] * c
390+
s[1] = d[-1][1] + s[0]
382391
d[-1][1] = s[1]
383392

384393
CarryOutAux(res, d[::-1], kappa)
385394

395+
def CarryOutLE(a, b, c=0):
396+
""" Little-endian version """
397+
import types
398+
res = types.sint()
399+
CarryOut(res, a[::-1], b[::-1], c)
400+
return res
401+
386402
def BitLTL(res, a, b, kappa):
387403
"""
388404
res = a <? b (logarithmic rounds version)

Compiler/config.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,10 @@
4747
'bittriple': 0.00004828818388140422,
4848
'bitgf2ntriple': 0.00020716801325875284,
4949
'PreMulC': 2 * 0.00020716801325875284,
50-
})
50+
}),
51+
'all': { 'round': 0,
52+
'inv': 0,
53+
}
5154
}
5255

5356

Compiler/floatingpoint.py

+43-18
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,11 @@ def BitDecField(a, k, m, kappa, bits_to_compute=None):
325325
def Pow2(a, l, kappa):
326326
m = int(ceil(log(l, 2)))
327327
t = BitDec(a, m, m, kappa)
328-
x = [types.sint() for i in range(m)]
328+
return Pow2_from_bits(t)
329+
330+
def Pow2_from_bits(bits):
331+
m = len(bits)
332+
t = list(bits)
329333
pow2k = [types.cint() for i in range(m)]
330334
for i in range(m):
331335
pow2k[i] = two_power(2**i)
@@ -353,13 +357,20 @@ def B2U_from_Pow2(pow2a, l, kappa):
353357
#print ' '.join(str(b.value) for b in y)
354358
return [1 - y[i] for i in range(l)]
355359

356-
def Trunc(a, l, m, kappa, compute_modulo=False):
360+
def Trunc(a, l, m, kappa, compute_modulo=False, signed=False):
357361
""" Oblivious truncation by secret m """
362+
if util.is_constant(m) and not compute_modulo:
363+
# cheaper
364+
res = type(a)(size=a.size)
365+
comparison.Trunc(res, a, l, m, kappa, signed=signed)
366+
return res
358367
if l == 1:
359368
if compute_modulo:
360369
return a * m, 1 + m
361370
else:
362371
return a * (1 - m)
372+
if program.Program.prog.options.ring and not compute_modulo:
373+
return TruncInRing(a, l, Pow2(m, l, kappa))
363374
r = [types.sint() for i in range(l)]
364375
r_dprime = types.sint(0)
365376
r_prime = types.sint(0)
@@ -370,8 +381,6 @@ def Trunc(a, l, m, kappa, compute_modulo=False):
370381
x, pow2m = B2U(m, l, kappa)
371382
#assert(pow2m.value == 2**m.value)
372383
#assert(sum(b.value for b in x) == m.value)
373-
if program.Program.prog.options.ring and not compute_modulo:
374-
return TruncInRing(a, l, pow2m)
375384
for i in range(l):
376385
bit(r[i])
377386
t1 = two_power(i) * r[i]
@@ -495,17 +504,28 @@ def TruncPrRing(a, k, m, signed=True):
495504
return comparison.TruncLeakyInRing(a, k, m, signed=signed)
496505
else:
497506
from types import sint
498-
# extra bit to mask overflow
499-
r_bits = [sint.get_random_bit() for i in range(k + 1)]
500-
n_shift = n_ring - len(r_bits)
501-
tmp = a + sint.bit_compose(r_bits)
502-
masked = (tmp << n_shift).reveal()
503-
shifted = (masked << 1 >> (n_shift + m + 1))
504-
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
505-
res = shifted - sint.bit_compose(r_bits[m:k]) + (overflow << (k - m))
507+
if signed:
508+
a += (1 << (k - 1))
509+
if program.Program.prog.use_trunc_pr:
510+
res = sint()
511+
trunc_pr(res, a, k, m)
512+
else:
513+
# extra bit to mask overflow
514+
r_bits = [sint.get_random_bit() for i in range(k + 1)]
515+
n_shift = n_ring - len(r_bits)
516+
tmp = a + sint.bit_compose(r_bits)
517+
masked = (tmp << n_shift).reveal()
518+
shifted = (masked << 1 >> (n_shift + m + 1))
519+
overflow = r_bits[-1].bit_xor(masked >> (n_ring - 1))
520+
res = shifted - sint.bit_compose(r_bits[m:k]) + \
521+
(overflow << (k - m))
522+
if signed:
523+
res -= (1 << (k - m - 1))
506524
return res
507525

508526
def TruncPrField(a, k, m, kappa=None):
527+
if m == 0:
528+
return a
509529
if kappa is None:
510530
kappa = 40
511531

@@ -527,19 +547,24 @@ def SDiv(a, b, l, kappa, round_nearest=False):
527547
w = types.cint(int(2.9142 * two_power(l))) - 2 * b
528548
x = alpha - b * w
529549
y = a * w
530-
y = y.round(2 * l + 1, l, kappa, round_nearest)
550+
y = y.round(2 * l + 1, l, kappa, round_nearest, signed=False)
531551
x2 = types.sint()
532552
comparison.Mod2m(x2, x, 2 * l + 1, l, kappa, False)
533553
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
534554
for i in range(theta-1):
535-
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
536-
y = y.round(2 * l + 1, l + 1, kappa, round_nearest)
537-
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest)
538-
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest)
555+
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
556+
round_nearest,
557+
signed=False)
558+
y = y.round(2 * l + 1, l + 1, kappa, round_nearest, signed=False)
559+
x = x1 * x2 + (x2**2).round(2 * l + 1, l + 1, kappa, round_nearest,
560+
signed=False)
561+
x = x1 * x1 + x.round(2 * l + 1, l - 1, kappa, round_nearest,
562+
signed=False)
539563
x2 = types.sint()
540564
comparison.Mod2m(x2, x, 2 * l, l, kappa, False)
541565
x1 = comparison.TruncZeroes(x - x2, 2 * l + 1, l, True)
542-
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa, round_nearest)
566+
y = y * (x1 + two_power(l)) + (y * x2).round(2 * l, l, kappa,
567+
round_nearest, signed=False)
543568
y = y.round(2 * l + 1, l - 1, kappa, round_nearest)
544569
return y
545570

Compiler/instructions.py

+65
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,55 @@ def add_usage(self, req_node):
894894
req_node.increment((self.field_type, 'input', player), \
895895
4 * self.get_size())
896896

897+
@base.vectorize
898+
class inputmixed(base.TextInputInstruction):
899+
__slots__ = []
900+
code = base.opcodes['INPUTMIXED']
901+
field_type = 'modp'
902+
# the following has to match TYPE: (N_DEST, N_PARAM)
903+
types = {
904+
0: (1, 0),
905+
1: (1, 1),
906+
2: (4, 1)
907+
}
908+
type_ids = {
909+
'int': 0,
910+
'fix': 1,
911+
'float': 2
912+
}
913+
914+
def __init__(self, name, *args):
915+
try:
916+
type_id = self.type_ids[name]
917+
except:
918+
pass
919+
super(inputmixed_class, self).__init__(type_id, *args)
920+
921+
@property
922+
def arg_format(self):
923+
for i in self.bases():
924+
t = self.args[i]
925+
yield 'int'
926+
for j in range(self.types[t][0]):
927+
yield 'sw'
928+
for j in range(self.types[t][1]):
929+
yield 'int'
930+
yield 'p'
931+
932+
def bases(self):
933+
i = 0
934+
while i < len(self.args):
935+
yield i
936+
i += sum(self.types[self.args[i]]) + 2
937+
938+
def add_usage(self, req_node):
939+
for i in self.bases():
940+
t = self.args[i]
941+
player = self.args[i + sum(self.types[t]) + 1]
942+
n_dest = self.types[t][0]
943+
req_node.increment((self.field_type, 'input', player), \
944+
n_dest * self.get_size())
945+
897946
@base.gf2n
898947
class startinput(base.RawInputInstruction):
899948
r""" Receive inputs from player $p$. """
@@ -957,6 +1006,11 @@ class print_reg_plain(base.IOInstruction):
9571006
code = base.opcodes['PRINTREGPLAIN']
9581007
arg_format = ['c']
9591008

1009+
class cond_print_plain(base.IOInstruction):
1010+
r""" Conditionally print the value of a register. """
1011+
code = base.opcodes['CONDPRINTPLAIN']
1012+
arg_format = ['c', 'c']
1013+
9601014
class print_int(base.IOInstruction):
9611015
r""" Print only the value of register \verb|ci| to stdout. """
9621016
__slots__ = []
@@ -1383,6 +1437,9 @@ def get_repeat(self):
13831437

13841438
def merge_id(self):
13851439
# can merge different sizes
1440+
# but not if large
1441+
if self.get_size() > 100:
1442+
return type(self), self.get_size()
13861443
return type(self)
13871444

13881445
# def expand(self):
@@ -1468,6 +1525,14 @@ def get_used(self):
14681525
for reg in self.args[i + 2:i + self.args[i]]:
14691526
yield reg
14701527

1528+
@base.vectorize
1529+
class trunc_pr(base.VarArgsInstruction):
1530+
""" Probalistic truncation for semi-honest computation """
1531+
""" with honest majority """
1532+
__slots__ = []
1533+
code = base.opcodes['TRUNC_PR']
1534+
arg_format = tools.cycle(['sw','s','int','int'])
1535+
14711536
###
14721537
### CISC-style instructions
14731538
###

Compiler/instructions_base.py

+3-15
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@
8989
MULS = 0xA6,
9090
MULRS = 0xA7,
9191
DOTPRODS = 0xA8,
92+
TRUNC_PR = 0xA9,
9293
# Data access
9394
TRIPLE = 0x50,
9495
BIT = 0x51,
@@ -102,6 +103,7 @@
102103
INPUT = 0x60,
103104
INPUTFIX = 0xF0,
104105
INPUTFLOAT = 0xF1,
106+
INPUTMIXED = 0xF2,
105107
STARTINPUT = 0x61,
106108
STOPINPUT = 0x62,
107109
READSOCKETC = 0x63,
@@ -168,6 +170,7 @@
168170
READFILESHARE = 0xBE,
169171
CONDPRINTSTR = 0xBF,
170172
PRINTFLOATPREC = 0xE0,
173+
CONDPRINTPLAIN = 0xE1,
171174
GBITDEC = 0x184,
172175
GBITCOM = 0x185,
173176
# Secure socket
@@ -767,21 +770,6 @@ def check_args(self):
767770
### Jumps etc
768771
###
769772

770-
class dummywrite(Instruction):
771-
""" Dummy instruction to create source node in the dependency graph,
772-
preventing read-before-write warnings. """
773-
__slots__ = []
774-
775-
def __init__(self, *args, **kwargs):
776-
self.arg_format = [arg.reg_type + 'w' for arg in args]
777-
super(dummywrite, self).__init__(*args, **kwargs)
778-
779-
def execute(self):
780-
pass
781-
782-
def get_encoding(self):
783-
return []
784-
785773
class JumpInstruction(Instruction):
786774
__slots__ = ['jump_arg']
787775

0 commit comments

Comments
 (0)