Skip to content

Commit 5153c63

Browse files
committed
More accessible machine learning functionality.
1 parent 7266f3b commit 5153c63

File tree

119 files changed

+3857
-969
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

119 files changed

+3857
-969
lines changed

BMR/Register.h

+2
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,8 @@ class ProgramRegister : public Phase, public Register
296296
static void andm(GC::Processor<U>&, const BaseInstruction&)
297297
{ throw runtime_error("andm not implemented"); }
298298

299+
static void run_tapes(const vector<int>&) { throw not_implemented(); }
300+
299301
// most BMR phases don't need actual input
300302
template<class T>
301303
static T get_input(GC::Processor<T>& processor, const InputArgs& args)

CHANGELOG.md

+12
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,17 @@
11
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
22

3+
## 0.3.5 (Feb 16, 2023)
4+
5+
- Easier-to-use machine learning interface
6+
- Integrated compilation-execution facility
7+
- Import/export sequential models and parameters from/to PyTorch
8+
- Binary-format input files
9+
- Less aggressive round optimization for faster compilation by default
10+
- Multithreading with client interface
11+
- Functionality to protect order of specific memory accesses
12+
- Oblivious transfer works again on older (pre-2011) x86 CPUs
13+
- clang is used by default
14+
315
## 0.3.4 (Nov 9, 2022)
416

517
- Decision tree learning

CONFIG

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ endif
4747
USE_KOS = 0
4848

4949
# allow to set compiler in CONFIG.mine
50-
CXX = g++
50+
CXX = clang++
5151

5252
# use CONFIG.mine to overwrite DIR settings
5353
-include CONFIG.mine

Compiler/GC/types.py

+21-14
Original file line numberDiff line numberDiff line change
@@ -711,23 +711,31 @@ def n_elements():
711711
def mem_size():
712712
return n
713713
@classmethod
714-
def get_input_from(cls, player):
714+
def get_input_from(cls, player, size=1, f=0):
715715
""" Secret input from :py:obj:`player`. The input is decomposed
716716
into bits.
717717
718718
:param: player (int)
719719
"""
720+
v = [0] * n
720721
sbits._check_input_player(player)
721-
res = cls.from_vec(sbit() for i in range(n))
722-
inst.inputbvec(n + 3, 0, player, *res.v)
723-
return res
722+
instructions_base.check_vector_size(size)
723+
for i in range(size):
724+
vv = [sbit() for i in range(n)]
725+
inst.inputbvec(n + 3, f, player, *vv)
726+
for j in range(n):
727+
tmp = vv[j] << i
728+
v[j] = tmp ^ v[j]
729+
sbits._check_input_player(player)
730+
return cls.from_vec(v)
724731
get_raw_input_from = get_input_from
725732
@classmethod
726733
def from_vec(cls, vector):
727734
res = cls()
728735
res.v = _complement_two_extend(list(vector), n)[:n]
729736
return res
730737
def __init__(self, other=None, size=None):
738+
instructions_base.check_vector_size(size)
731739
if other is not None:
732740
if util.is_constant(other):
733741
t = sbits.get_type(size or 1)
@@ -1148,6 +1156,9 @@ class sbitint(_bitint, _number, sbits, _sbitintbase):
11481156
mul: 15
11491157
lt: 0
11501158
1159+
This class is retained for compatibility, but development now
1160+
focuses on :py:class:`sbitintvec`.
1161+
11511162
"""
11521163
n_bits = None
11531164
bin_type = None
@@ -1347,9 +1358,12 @@ def output(self):
13471358
cbits(0), cbits(0))
13481359

13491360
class sbitfix(_fix):
1350-
""" Secret signed integer in one binary register.
1361+
""" Secret signed fixed-point number in one binary register.
13511362
Use :py:obj:`set_precision()` to change the precision.
13521363
1364+
This class is retained for compatibility, but development now
1365+
focuses on :py:class:`sbitfixvec`.
1366+
13531367
Example::
13541368
13551369
print_ln('add: %s', (sbitfix(0.5) + sbitfix(0.3)).reveal())
@@ -1453,15 +1467,8 @@ def get_input_from(cls, player, size=1):
14531467
14541468
:param: player (int)
14551469
"""
1456-
v = [0] * sbitfix.k
1457-
sbits._check_input_player(player)
1458-
for i in range(size):
1459-
vv = [sbit() for i in range(sbitfix.k)]
1460-
inst.inputbvec(len(v) + 3, sbitfix.f, player, *vv)
1461-
for j in range(sbitfix.k):
1462-
tmp = vv[j] << i
1463-
v[j] = tmp ^ v[j]
1464-
return cls._new(cls.int_type.from_vec(v))
1470+
return cls._new(cls.int_type.get_input_from(player, size=size,
1471+
f=sbitfix.f))
14651472
def __init__(self, value=None, *args, **kwargs):
14661473
if isinstance(value, (list, tuple)):
14671474
self.v = self.int_type.from_vec(sbitvec([x.v for x in value]))

Compiler/allocator.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,6 @@ def dependency_graph(self, merge_classes):
315315
last_def = defaultdict_by_id(lambda: -1)
316316
last_mem_write = []
317317
last_mem_read = []
318-
warned_about_mem = []
319318
last_mem_write_of = defaultdict(list)
320319
last_mem_read_of = defaultdict(list)
321320
last_print_str = None
@@ -364,20 +363,22 @@ def mem_access(n, instr, last_access_this_kind, last_access_other_kind):
364363
addr_i = addr + i
365364
handle_mem_access(addr_i, reg_type, last_access_this_kind,
366365
last_access_other_kind)
367-
if block.warn_about_mem and not warned_about_mem and \
368-
(instr.get_size() > 100):
366+
if block.warn_about_mem and \
367+
not block.parent.warned_about_mem and \
368+
(instr.get_size() > 100) and not instr._protect:
369369
print('WARNING: Order of memory instructions ' \
370370
'not preserved due to long vector, errors possible')
371-
warned_about_mem.append(True)
371+
block.parent.warned_about_mem = True
372372
else:
373373
handle_mem_access(addr, reg_type, last_access_this_kind,
374374
last_access_other_kind)
375-
if block.warn_about_mem and not warned_about_mem and \
376-
not isinstance(instr, DirectMemoryInstruction):
375+
if block.warn_about_mem and \
376+
not block.parent.warned_about_mem and \
377+
not isinstance(instr, DirectMemoryInstruction) and \
378+
not instr._protect:
377379
print('WARNING: Order of memory instructions ' \
378380
'not preserved, errors possible')
379-
# hack
380-
warned_about_mem.append(True)
381+
block.parent.warned_about_mem = True
381382

382383
def strict_mem_access(n, last_this_kind, last_other_kind):
383384
if last_other_kind and last_this_kind and \
@@ -473,14 +474,14 @@ def keep_text_order(inst, n):
473474
depths[n] = depth
474475

475476
if isinstance(instr, ReadMemoryInstruction):
476-
if options.preserve_mem_order:
477+
if options.preserve_mem_order or instr._protect:
477478
strict_mem_access(n, last_mem_read, last_mem_write)
478-
else:
479+
elif not options.preserve_mem_order:
479480
mem_access(n, instr, last_mem_read_of, last_mem_write_of)
480481
elif isinstance(instr, WriteMemoryInstruction):
481-
if options.preserve_mem_order:
482+
if options.preserve_mem_order or instr._protect:
482483
strict_mem_access(n, last_mem_write, last_mem_read)
483-
else:
484+
elif not options.preserve_mem_order:
484485
mem_access(n, instr, last_mem_write_of, last_mem_read_of)
485486
elif isinstance(instr, matmulsm):
486487
if options.preserve_mem_order:
@@ -495,7 +496,7 @@ def keep_text_order(inst, n):
495496
add_edge(last_print_str, n)
496497
last_print_str = n
497498
elif isinstance(instr, PublicFileIOInstruction):
498-
keep_order(instr, n, instr.__class__)
499+
keep_order(instr, n, PublicFileIOInstruction)
499500
elif isinstance(instr, prep_class):
500501
keep_order(instr, n, instr.args[0])
501502
elif isinstance(instr, StackInstruction):
@@ -586,7 +587,7 @@ class RegintOptimizer:
586587
def __init__(self):
587588
self.cache = util.dict_by_id()
588589

589-
def run(self, instructions):
590+
def run(self, instructions, program):
590591
for i, inst in enumerate(instructions):
591592
if isinstance(inst, ldint_class):
592593
self.cache[inst.args[0]] = inst.args[1]
@@ -601,6 +602,7 @@ def run(self, instructions):
601602
elif isinstance(inst, IndirectMemoryInstruction):
602603
if inst.args[1] in self.cache:
603604
instructions[i] = inst.get_direct(self.cache[inst.args[1]])
605+
instructions[i]._protect = inst._protect
604606
elif type(inst) == convint_class:
605607
if inst.args[1] in self.cache:
606608
res = self.cache[inst.args[1]]
@@ -614,4 +616,13 @@ def run(self, instructions):
614616
if op == 0:
615617
instructions[i] = ldsi(inst.args[0], 0,
616618
add_to_prog=False)
619+
elif isinstance(inst, (crash, cond_print_str, cond_print_plain)):
620+
if inst.args[0] in self.cache:
621+
cond = self.cache[inst.args[0]]
622+
if not cond:
623+
instructions[i] = None
624+
pre = len(instructions)
617625
instructions[:] = list(filter(lambda x: x is not None, instructions))
626+
post = len(instructions)
627+
if pre != post and program.options.verbose:
628+
print('regint optimizer removed %d instructions' % (pre - post))

0 commit comments

Comments
 (0)