Skip to content

Commit 6cc3fcc

Browse files
committed
Maintenance.
1 parent c62ab2c commit 6cc3fcc

File tree

135 files changed

+1658
-1062
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

135 files changed

+1658
-1062
lines changed

.gitmodules

-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,6 @@
11
[submodule "SimpleOT"]
22
path = deps/SimpleOT
33
url = https://github.com/mkskeller/SimpleOT
4-
[submodule "mpir"]
5-
path = deps/mpir
6-
url = https://github.com/wbhart/mpir
74
[submodule "Programs/Circuits"]
85
path = Programs/Circuits
96
url = https://github.com/mkskeller/bristol-fashion

BMR/RealProgramParty.h

-2
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,6 @@ class RealProgramParty : public ProgramPartySpec<T>
4343

4444
bool one_shot;
4545

46-
size_t data_sent;
47-
4846
public:
4947
static RealProgramParty& s();
5048

BMR/RealProgramParty.hpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,6 @@ RealProgramParty<T>::RealProgramParty(int argc, const char** argv) :
154154
while (next != GC::DONE_BREAK);
155155

156156
MC->Check(*P);
157-
data_sent = P->total_comm().sent;
158157

159158
if (online_opts.verbose)
160159
P->total_comm().print();
@@ -216,7 +215,7 @@ RealProgramParty<T>::~RealProgramParty()
216215
delete prep;
217216
delete garble_inputter;
218217
delete garble_protocol;
219-
cout << "Data sent = " << data_sent * 1e-6 << " MB" << endl;
218+
garble_machine.print_comm(*this->P, this->P->total_comm());
220219
T::MAC_Check::teardown();
221220
}
222221

BMR/Register.h

+7-4
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ class BaseKeyVector
6262
#endif
6363
};
6464
#else
65-
class BaseKeyVector : public vector<Key>
65+
class BaseKeyVector : public CheckVector<Key>
6666
{
67+
typedef CheckVector<Key> super;
68+
6769
public:
68-
BaseKeyVector(int size = 0) : vector<Key>(size, Key(0)) {}
69-
void resize(int size) { vector<Key>::resize(size, Key(0)); }
70+
BaseKeyVector(int size = 0) : super(size, Key(0)) {}
71+
void resize(int size) { super::resize(size, Key(0)); }
7072
};
7173
#endif
7274

@@ -296,7 +298,8 @@ class ProgramRegister : public Phase, public Register
296298
static void andm(GC::Processor<U>&, const BaseInstruction&)
297299
{ throw runtime_error("andm not implemented"); }
298300

299-
static void run_tapes(const vector<int>&) { throw not_implemented(); }
301+
static void run_tapes(const vector<int>&)
302+
{ throw runtime_error("multi-threading not implemented"); }
300303

301304
// most BMR phases don't need actual input
302305
template<class T>

CHANGELOG.md

+15
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,20 @@
11
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
22

3+
## 0.3.6 (May 9, 2023)
4+
5+
- More extensive benchmarking outputs
6+
- Replace MPIR by GMP
7+
- Secure reading of edaBits from files
8+
- Semi-honest client communication
9+
- Back-propagation for average pooling
10+
- Parallelized convolution
11+
- Probabilistic truncation as in ABY3
12+
- More balanced communication in Shamir secret sharing
13+
- Avoid unnecessary communication in Dealer protocol
14+
- Linear solver using Cholesky decomposition
15+
- Accept .py files for compilation
16+
- Fixed security bug: proper accounting for random elements
17+
318
## 0.3.5 (Feb 16, 2023)
419

520
- Easier-to-use machine learning interface

CONFIG

+20-1
Original file line numberDiff line numberDiff line change
@@ -35,15 +35,32 @@ ARM := $(shell uname -m | grep x86; echo $$?)
3535
OS := $(shell uname -s)
3636
ifeq ($(MACHINE), x86_64)
3737
ifeq ($(OS), Linux)
38+
ifeq ($(shell cat /proc/cpuinfo | grep -q avx2; echo $$?), 0)
3839
AVX_OT = 1
3940
else
4041
AVX_OT = 0
4142
endif
4243
else
44+
AVX_OT = 0
45+
endif
46+
else
4347
ARCH =
4448
AVX_OT = 0
4549
endif
4650

51+
ifeq ($(OS), Darwin)
52+
BREW_CFLAGS += -I/usr/local/opt/openssl/include -I`brew --prefix`/opt/openssl/include -I`brew --prefix`/include
53+
BREW_LDLIBS += -L/usr/local/opt/openssl/lib -L`brew --prefix`/lib -L`brew --prefix`/opt/openssl/lib
54+
endif
55+
56+
ifeq ($(OS), Linux)
57+
ifeq ($(ARM), 1)
58+
ifeq ($(shell cat /proc/cpuinfo | grep -q aes; echo $$?), 0)
59+
ARCH = -march=armv8.2-a+crypto
60+
endif
61+
endif
62+
endif
63+
4764
USE_KOS = 0
4865

4966
# allow to set compiler in CONFIG.mine
@@ -66,7 +83,8 @@ endif
6683
# Default for MAX_MOD_SZ is 10, which suffices for all Overdrive protocols
6784
# MOD = -DMAX_MOD_SZ=10 -DGFP_MOD_SZ=5
6885

69-
LDLIBS = -lmpirxx -lmpir -lsodium $(MY_LDLIBS)
86+
LDLIBS = -lgmpxx -lgmp -lsodium $(MY_LDLIBS)
87+
LDLIBS += $(BREW_LDLIBS)
7088
LDLIBS += -Wl,-rpath -Wl,$(CURDIR)/local/lib -L$(CURDIR)/local/lib
7189
LDLIBS += -lboost_system -lssl -lcrypto
7290

@@ -88,6 +106,7 @@ BOOST = -lboost_thread $(MY_BOOST)
88106
endif
89107

90108
CFLAGS += $(ARCH) $(MY_CFLAGS) $(GDEBUG) -Wextra -Wall $(OPTIM) -I$(ROOT) -I$(ROOT)/deps -pthread $(PROF) $(DEBUG) $(MOD) $(GF2N_LONG) $(PREP_DIR) $(SSL_DIR) $(SECURE) -std=c++11 -Werror
109+
CFLAGS += $(BREW_CFLAGS)
91110
CPPFLAGS = $(CFLAGS)
92111
LD = $(CXX)
93112

Compiler/GC/instructions.py

+2
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
class SecretBitsAF(base.RegisterArgFormat):
1919
reg_type = 'sb'
20+
name = 'sbit'
2021
class ClearBitsAF(base.RegisterArgFormat):
2122
reg_type = 'cb'
23+
name = 'cbit'
2224

2325
base.ArgFormats['sb'] = SecretBitsAF
2426
base.ArgFormats['sbw'] = SecretBitsAF

Compiler/allocator.py

+12-21
Original file line numberDiff line numberDiff line change
@@ -338,16 +338,19 @@ def add_edge(i, j):
338338
d[j] = d[i]
339339

340340
def read(reg, n):
341-
last_read[reg] = n
342341
for dup in reg.duplicates:
343-
if last_def[dup] != -1:
342+
if last_def[dup] not in (-1, n):
344343
add_edge(last_def[dup], n)
344+
last_read[reg] = n
345345

346346
def write(reg, n):
347-
last_def[reg] = n
348347
for dup in reg.duplicates:
349348
if last_read[dup] not in (-1, n):
350349
add_edge(last_read[dup], n)
350+
if id(dup) in [id(x) for x in block.instructions[n].get_used()] and \
351+
last_read[dup] not in (-1, n):
352+
add_edge(last_read[dup], n)
353+
last_def[reg] = n
351354

352355
def handle_mem_access(addr, reg_type, last_access_this_kind,
353356
last_access_other_kind):
@@ -434,19 +437,19 @@ def keep_text_order(inst, n):
434437
# if options.debug:
435438
# col = colordict[instr.__class__.__name__]
436439
# G.add_node(n, color=col, label=str(instr))
437-
for reg in inputs:
440+
for reg in outputs:
438441
if reg.vector and instr.is_vec():
439442
for i in reg.vector:
440-
read(i, n)
443+
write(i, n)
441444
else:
442-
read(reg, n)
445+
write(reg, n)
443446

444-
for reg in outputs:
447+
for reg in inputs:
445448
if reg.vector and instr.is_vec():
446449
for i in reg.vector:
447-
write(i, n)
450+
read(i, n)
448451
else:
449-
write(reg, n)
452+
read(reg, n)
450453

451454
# will be merged
452455
if isinstance(instr, TextInputInstruction):
@@ -556,18 +559,6 @@ def eliminate(i):
556559
if unused_result:
557560
eliminate(i)
558561
count += 1
559-
# remove unnecessary stack instructions
560-
# left by optimization with budget
561-
if isinstance(inst, popint_class) and \
562-
(not G.degree(i) or (G.degree(i) == 1 and
563-
isinstance(instructions[list(G[i])[0]], StackInstruction))) \
564-
and \
565-
inst.args[0].can_eliminate and \
566-
len(G.pred[i]) == 1 and \
567-
isinstance(instructions[list(G.pred[i])[0]], pushint_class):
568-
eliminate(list(G.pred[i])[0])
569-
eliminate(i)
570-
count += 2
571562
if count > 0 and self.block.parent.program.verbose:
572563
print('Eliminated %d dead instructions, among which %d opens: %s' \
573564
% (count, open_count, dict(stats)))

Compiler/comparison.py

+3
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ def set_variant(options):
5050
do_precomp = False
5151
elif variant is not None:
5252
raise CompilerError('Unknown comparison variant: %s' % variant)
53+
if const_rounds and instructions_base.program.options.binary:
54+
raise CompilerError(
55+
'Comparison variant choice incompatible with binary circuits')
5356

5457
def ld2i(c, n):
5558
""" Load immediate 2^n into clear GF(p) register c """

Compiler/compilerLib.py

+8-7
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def __init__(self, custom_args=None, usage=None, execute=False):
2222
self.custom_args = custom_args
2323
self.build_option_parser()
2424
self.VARS = {}
25+
self.root = os.path.dirname(__file__) + '/..'
2526

2627
def build_option_parser(self):
2728
parser = OptionParser(usage=self.usage)
@@ -269,7 +270,7 @@ def build_program(self, name=None):
269270
self.prog = Program(self.args, self.options, name=name)
270271
if self.execute:
271272
if self.options.execute in \
272-
("emulate", "ring", "rep-field", "semi2k"):
273+
("emulate", "ring", "rep-field"):
273274
self.prog.use_trunc_pr = True
274275
if self.options.execute in ("ring",):
275276
self.prog.use_split(3)
@@ -405,7 +406,7 @@ def compile_file(self):
405406
infile = open(self.prog.infile)
406407

407408
# make compiler modules directly accessible
408-
sys.path.insert(0, "Compiler")
409+
sys.path.insert(0, "%s/Compiler" % self.root)
409410
# create the tapes
410411
exec(compile(infile.read(), infile.name, "exec"), self.VARS)
411412

@@ -477,15 +478,15 @@ def executable_from_protocol(protocol):
477478

478479
def local_execution(self, args=[]):
479480
executable = self.executable_from_protocol(self.options.execute)
480-
if not os.path.exists(executable):
481+
if not os.path.exists("%s/%s" % (self.root, executable)):
481482
print("Creating binary for virtual machine...")
482483
try:
483-
subprocess.run(["make", executable], check=True)
484+
subprocess.run(["make", executable], check=True, cwd=self.root)
484485
except:
485486
raise CompilerError(
486487
"Cannot produce %s. " % executable + \
487488
"Note that compilation requires a few GB of RAM.")
488-
vm = 'Scripts/%s.sh' % self.options.execute
489+
vm = "%s/Scripts/%s.sh" % (self.root, self.options.execute)
489490
os.execl(vm, vm, self.prog.name, *args)
490491

491492
def remote_execution(self, args=[]):
@@ -496,7 +497,7 @@ def remote_execution(self, args=[]):
496497
from fabric import Connection
497498
import subprocess
498499
print("Creating static binary for virtual machine...")
499-
subprocess.run(["make", "static/%s" % vm], check=True)
500+
subprocess.run(["make", "static/%s" % vm], check=True, cwd=self.root)
500501

501502
# transfer files
502503
import glob
@@ -519,7 +520,7 @@ def run(i):
519520
"mkdir -p %s/{Player-Data,Programs/{Bytecode,Schedules}} " % \
520521
dest)
521522
# executable
522-
connection.put("static/%s" % vm, dest)
523+
connection.put("%s/static/%s" % (self.root, vm), dest)
523524
# program
524525
dest += "/"
525526
connection.put("Programs/Schedules/%s.sch" % self.prog.name,

Compiler/floatingpoint.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def BitDecRingRaw(a, k, m):
289289
def BitDecRing(a, k, m):
290290
bits = BitDecRingRaw(a, k, m)
291291
# reversing to reduce number of rounds
292-
return [types.sint.conv(bit) for bit in reversed(bits)][::-1]
292+
return [types.sintbit.conv(bit) for bit in reversed(bits)][::-1]
293293

294294
def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
295295
instructions_base.set_global_vector_size(a.size)
@@ -306,7 +306,7 @@ def BitDecFieldRaw(a, k, m, kappa, bits_to_compute=None):
306306

307307
def BitDecField(a, k, m, kappa, bits_to_compute=None):
308308
res = BitDecFieldRaw(a, k, m, kappa, bits_to_compute)
309-
return [types.sint.conv(bit) for bit in res]
309+
return [types.sintbit.conv(bit) for bit in res]
310310

311311

312312
@instructions_base.ret_cisc

Compiler/instructions.py

+23-8
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,17 @@ class reqbl(base.Instruction):
356356
code = base.opcodes['REQBL']
357357
arg_format = ['int']
358358

359+
class active(base.Instruction):
360+
""" Indicate whether program is compatible with malicious-security
361+
protocols.
362+
363+
:param: 0 for no, 1 for yes
364+
"""
365+
code = base.opcodes['ACTIVE']
366+
arg_format = ['int']
367+
359368
class time(base.IOInstruction):
369+
360370
""" Output time since start of computation. """
361371
code = base.opcodes['TIME']
362372
arg_format = []
@@ -2418,9 +2428,10 @@ def add_usage(self, req_node):
24182428
super(matmulsm, self).add_usage(req_node)
24192429
req_node.increment(('matmul', tuple(self.args[3:6])), 1)
24202430

2421-
class conv2ds(base.DataInstruction):
2431+
class conv2ds(base.DataInstruction, base.VarArgsInstruction, base.Mergeable):
24222432
""" Secret 2D convolution.
24232433
2434+
:param: number of arguments to follow (int)
24242435
:param: result (sint vector in row-first order)
24252436
:param: inputs (sint vector in row-first order)
24262437
:param: weights (sint vector in row-first order)
@@ -2436,10 +2447,12 @@ class conv2ds(base.DataInstruction):
24362447
:param: padding height (int)
24372448
:param: padding width (int)
24382449
:param: batch size (int)
2450+
:param: repeat from result...
2451+
24392452
"""
24402453
code = base.opcodes['CONV2DS']
2441-
arg_format = ['sw','s','s','int','int','int','int','int','int','int','int',
2442-
'int','int','int','int']
2454+
arg_format = itertools.cycle(['sw','s','s','int','int','int','int','int',
2455+
'int','int','int','int','int','int','int'])
24432456
data_type = 'triple'
24442457
is_vec = lambda self: True
24452458

@@ -2450,14 +2463,16 @@ def __init__(self, *args, **kwargs):
24502463
assert args[2].size == args[7] * args[8] * args[11]
24512464

24522465
def get_repeat(self):
2453-
return self.args[3] * self.args[4] * self.args[7] * self.args[8] * \
2454-
self.args[11] * self.args[14]
2466+
args = self.args
2467+
return sum(args[i+3] * args[i+4] * args[i+7] * args[i+8] * \
2468+
args[i+11] * args[i+14] for i in range(0, len(args), 15))
24552469

24562470
def add_usage(self, req_node):
24572471
super(conv2ds, self).add_usage(req_node)
2458-
args = self.args
2459-
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
2460-
args[14] * args[3] * args[4])), 1)
2472+
for i in range(0, len(self.args), 15):
2473+
args = self.args[i:i + 15]
2474+
req_node.increment(('matmul', (1, args[7] * args[8] * args[11],
2475+
args[14] * args[3] * args[4])), 1)
24612476

24622477
@base.vectorize
24632478
class trunc_pr(base.VarArgsInstruction):

0 commit comments

Comments
 (0)