Skip to content

Commit cd25c2e

Browse files
committed
Decision tree training.
1 parent 9033656 commit cd25c2e

File tree

187 files changed

+2356
-328
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

187 files changed

+2356
-328
lines changed

BMR/Register.h

+3
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,9 @@ class Phase
235235
template <class T>
236236
static void ands(T& processor, const vector<int>& args) { processor.ands(args); }
237237
template <class T>
238+
static void andrsvec(T& processor, const vector<int>& args)
239+
{ processor.andrsvec(args); }
240+
template <class T>
238241
static void xors(T& processor, const vector<int>& args) { processor.xors(args); }
239242
template <class T>
240243
static void inputb(T& processor, const vector<int>& args) { processor.input(args); }

CHANGELOG.md

+11
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,16 @@
11
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
22

3+
## 0.3.4 (Nov 9, 2022)
4+
5+
- Decision tree learning
6+
- Optimized oblivious shuffle in Rep3
7+
- Optimized daBit generation in Rep3 and semi-honest HE-based 2PC
8+
- Optimized element-vector AND in SemiBin
9+
- Optimized input protocol in Shamir-based protocols
10+
- Square-root ORAM (@Quitlox)
11+
- Improved ORAM in binary circuits
12+
- UTF-8 outputs
13+
314
## 0.3.3 (Aug 25, 2022)
415

516
- Use SoftSpokenOT to avoid unclear security of KOS OT extension candidate

CONFIG

+3-1
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,11 @@ endif
6767
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
6868

6969
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
70+
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
7071
LDLIBS += -lboost_system -lssl -lcrypto
7172

73+
CFLAGS += -I./local/include
74+
7275
ifeq ($(USE_NTL),1)
7376
CFLAGS += -DUSE_NTL
7477
LDLIBS := -lntl $(LDLIBS)
@@ -100,5 +103,4 @@ ifeq ($(USE_KOS),1)
100103
CFLAGS += -DUSE_KOS
101104
else
102105
CFLAGS += -std=c++17
103-
LDLIBS += -llibOTe -lcryptoTools
104106
endif

Compiler/GC/instructions.py

+49
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import Compiler.tools as tools
1414
import collections
1515
import itertools
16+
import math
1617

1718
class SecretBitsAF(base.RegisterArgFormat):
1819
reg_type = 'sb'
@@ -50,6 +51,7 @@ class ClearBitsAF(base.RegisterArgFormat):
5051
INPUTBVEC = 0x247,
5152
SPLIT = 0x248,
5253
CONVCBIT2S = 0x249,
54+
ANDRSVEC = 0x24a,
5355
XORCBI = 0x210,
5456
BITDECC = 0x211,
5557
NOTCB = 0x212,
@@ -155,6 +157,52 @@ class andrs(BinaryVectorInstruction):
155157

156158
def add_usage(self, req_node):
157159
req_node.increment(('bit', 'triple'), sum(self.args[::4]))
160+
req_node.increment(('bit', 'mixed'),
161+
sum(int(math.ceil(x / 64)) for x in self.args[::4]))
162+
163+
class andrsvec(base.VarArgsInstruction, base.Mergeable,
164+
base.DynFormatInstruction):
165+
""" Constant-vector AND of secret bit registers (vectorized version).
166+
167+
:param: total number of arguments to follow (int)
168+
:param: number of arguments to follow for one operation /
169+
operation vector size plus three (int)
170+
:param: vector size (int)
171+
:param: result vector (sbit)
172+
:param: (repeat)...
173+
:param: constant operand (sbits)
174+
:param: vector operand
175+
:param: (repeat)...
176+
:param: (repeat from number of arguments to follow for one operation)...
177+
178+
"""
179+
code = opcodes['ANDRSVEC']
180+
181+
def __init__(self, *args, **kwargs):
182+
super(andrsvec, self).__init__(*args, **kwargs)
183+
for i, n in self.bases(iter(self.args)):
184+
size = self.args[i + 1]
185+
for x in self.args[i + 2:i + n]:
186+
assert x.n == size
187+
188+
@classmethod
189+
def dynamic_arg_format(cls, args):
190+
yield 'int'
191+
for i, n in cls.bases(args):
192+
yield 'int'
193+
n_args = (n - 3) // 2
194+
assert n_args > 0
195+
for j in range(n_args):
196+
yield 'sbw'
197+
for j in range(n_args + 1):
198+
yield 'sb'
199+
yield 'int'
200+
201+
def add_usage(self, req_node):
202+
for i, n in self.bases(iter(self.args)):
203+
size = self.args[i + 1]
204+
req_node.increment(('bit', 'triple'), size * (n - 3) // 2)
205+
req_node.increment(('bit', 'mixed'), size)
158206

159207
class ands(BinaryVectorInstruction):
160208
""" Bitwise AND of secret bit register vector.
@@ -605,6 +653,7 @@ def dynamic_arg_format(cls, args):
605653
for i, n in cls.bases(args):
606654
yield 'int'
607655
yield 'p'
656+
assert n > 3
608657
for j in range(n - 3):
609658
yield 'sbw'
610659
yield 'int'

Compiler/GC/types.py

+49-10
Original file line numberDiff line numberDiff line change
@@ -652,7 +652,7 @@ class sbitvec(_vec, _bit):
652652
You can access the rows by member :py:obj:`v` and the columns by calling
653653
:py:obj:`elements`.
654654
655-
There are three ways to create an instance:
655+
There are four ways to create an instance:
656656
657657
1. By transposition::
658658
@@ -685,6 +685,11 @@ class sbitvec(_vec, _bit):
685685
This should output::
686686
687687
[1, 0, 1]
688+
689+
4. Private input::
690+
691+
x = sbitvec.get_type(32).get_input_from(player)
692+
688693
"""
689694
bit_extend = staticmethod(lambda v, n: v[:n] + [0] * (n - len(v)))
690695
is_clear = False
@@ -904,6 +909,34 @@ def half_adder(self, other):
904909
def __mul__(self, other):
905910
if isinstance(other, int):
906911
return self.from_vec(x * other for x in self.v)
912+
if isinstance(other, sbitvec):
913+
if len(other.v) == 1:
914+
other = other.v[0]
915+
elif len(self.v) == 1:
916+
self, other = other, self.v[0]
917+
else:
918+
raise CompilerError('no operand of lenght 1: %d/%d',
919+
(len(self.v), len(other.v)))
920+
if not isinstance(other, sbits):
921+
return NotImplemented
922+
ops = []
923+
for x in self.v:
924+
if not util.is_zero(x):
925+
assert x.n == other.n
926+
ops.append(x)
927+
if ops:
928+
prods = [sbits.get_type(other.n)() for i in ops]
929+
inst.andrsvec(3 + 2 * len(ops), other.n, *prods, other, *ops)
930+
res = []
931+
i = 0
932+
for x in self.v:
933+
if util.is_zero(x):
934+
res.append(0)
935+
else:
936+
res.append(prods[i])
937+
i += 1
938+
return sbitvec.from_vec(res)
939+
__rmul__ = __mul__
907940
def __add__(self, other):
908941
return self.from_vec(x + y for x, y in zip(self.v, other))
909942
def bit_and(self, other):
@@ -945,6 +978,13 @@ def expand(self, other, expand=True):
945978
else:
946979
res.append([x.expand(m) if (expand and isinstance(x, bits)) else x for x in y.v])
947980
return res
981+
def demux(self):
982+
if len(self) == 1:
983+
return sbitvec.from_vec([self.v[0].bit_not(), self.v[0]])
984+
a = sbitvec.from_vec(self.v[:len(self) // 2]).demux()
985+
b = sbitvec.from_vec(self.v[len(self) // 2:]).demux()
986+
prod = [a * bb for bb in b.v]
987+
return sbitvec.from_vec(reduce(operator.add, (x.v for x in prod)))
948988

949989
class bit(object):
950990
n = 1
@@ -1243,20 +1283,19 @@ def __mul__(self, other):
12431283
return other * self.v[0]
12441284
elif isinstance(other, sbitfixvec):
12451285
return NotImplemented
1246-
_, other_bits = self.expand(other, False)
1286+
my_bits, other_bits = self.expand(other, False)
1287+
matrix = []
12471288
m = float('inf')
1248-
for x in itertools.chain(self.v, other_bits):
1289+
for x in itertools.chain(my_bits, other_bits):
12491290
try:
12501291
m = min(m, x.n)
12511292
except:
12521293
pass
1253-
if m == 1:
1254-
op = operator.mul
1255-
else:
1256-
op = operator.and_
1257-
matrix = []
12581294
for i, b in enumerate(other_bits):
1259-
matrix.append([op(x, b) for x in self.v[:len(self.v)-i]])
1295+
if m == 1:
1296+
matrix.append([x * b for x in my_bits[:len(self.v)-i]])
1297+
else:
1298+
matrix.append((sbitvec.from_vec(my_bits[:len(self.v)-i]) * b).v)
12601299
v = sbitint.wallace_tree_from_matrix(matrix)
12611300
return self.from_vec(v[:len(self.v)])
12621301
__rmul__ = __mul__
@@ -1366,7 +1405,7 @@ class cls(_fix):
13661405
cls.set_precision(f, k)
13671406
return cls._new(cls.int_type(other), k, f)
13681407

1369-
class sbitfixvec(_fix):
1408+
class sbitfixvec(_fix, _vec):
13701409
""" Vector of fixed-point numbers for parallel binary computation.
13711410
13721411
Use :py:obj:`set_precision()` to change the precision.

Compiler/allocator.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,7 @@ def longest_paths_merge(self):
261261
instructions = self.instructions
262262
merge_nodes = self.open_nodes
263263
depths = self.depths
264+
self.req_num = defaultdict(lambda: 0)
264265
if not merge_nodes:
265266
return 0
266267

@@ -281,6 +282,7 @@ def longest_paths_merge(self):
281282
print('Merging %d %s in round %d/%d' % \
282283
(len(merge), t.__name__, i, len(merges)))
283284
self.do_merge(merge)
285+
self.req_num[t.__name__, 'round'] += 1
284286

285287
preorder = None
286288

@@ -530,7 +532,9 @@ def eliminate_dead_code(self):
530532
can_eliminate_defs = True
531533
for reg in inst.get_def():
532534
for dup in reg.duplicates:
533-
if not dup.can_eliminate:
535+
if not (dup.can_eliminate and reduce(
536+
operator.and_,
537+
(x.can_eliminate for x in dup.vector), True)):
534538
can_eliminate_defs = False
535539
break
536540
# remove if instruction has result that isn't used

Compiler/circuit.py

-2
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,6 @@ def sha3_256(x):
137137
0x4a43f8804b0ad882fa493be44dff80f562d661a05647c15166d71ebff8c6ffa7
138138
0xf0d7aa0ab2d92d580bb080e17cbb52627932ba37f085d3931270d31c39357067
139139
140-
Note that :py:obj:`sint` to :py:obj:`sbitvec` conversion is only
141-
implemented for computation modulo a power of two.
142140
"""
143141

144142
global Keccak_f

Compiler/circuit_oram.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11

2-
from Compiler.path_oram import *
2+
from Compiler.oram import *
3+
from Compiler.path_oram import PathORAM, XOR
34
from Compiler.util import bit_compose
45

56
def first_diff(a_bits, b_bits):

Compiler/compilerLib.py

+7
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,13 @@ def build_option_parser(self):
125125
default=defaults.binary,
126126
help="bit length of sint in binary circuit (default: 0 for arithmetic)",
127127
)
128+
parser.add_option(
129+
"-G",
130+
"--garbled-circuit",
131+
dest="garbled",
132+
action="store_true",
133+
help="compile for binary circuits only (default: false)",
134+
)
128135
parser.add_option(
129136
"-F",
130137
"--field",

0 commit comments

Comments
 (0)