Skip to content

Commit 91321ff

Browse files
committed
Functionality to call high-level code from C++.
1 parent e7554cc commit 91321ff

File tree

245 files changed

+3868
-1132
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

245 files changed

+3868
-1132
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ callgrind.out.*
4949
Programs/Bytecode/*
5050
Programs/Schedules/*
5151
Programs/Public-Input/*
52+
Programs/Functions
5253
*.com
5354
*.class
5455
*.dll

.gitmodules

+3
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,6 @@
1313
[submodule "deps/SimplestOT_C"]
1414
path = deps/SimplestOT_C
1515
url = https://github.com/mkskeller/SimplestOT_C
16+
[submodule "deps/sse2neon"]
17+
path = deps/sse2neon
18+
url = https://github.com/DLTcollab/sse2neon

CHANGELOG.md

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
The changelog explains changes pulled through from the private development repository. Bug fixes and small enhancements are committed between releases and not documented here.
22

3+
## 0.4.0 (November 21, 2024)
4+
5+
- Functionality to call high-level code from C++
6+
- Matrix triples from file for all appropriate protocols
7+
- Exit with message on errors instead of uncaught exceptions
8+
- Reduce memory usage for binary memory
9+
- Optimized cint-regint conversion in Dealer protocol
10+
- Fixed security bug: missing MAC check in probabilistic truncation
11+
312
## 0.3.9 (July 9, 2024)
413

514
- Inference with non-sequential PyTorch networks

CONFIG

+2
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,8 @@ CXX = clang++
7171
# use CONFIG.mine to overwrite DIR settings
7272
-include CONFIG.mine
7373

74+
AVX_SIMPLEOT := $(AVX_OT)
75+
7476
ifeq ($(USE_GF2N_LONG),1)
7577
GF2N_LONG = -DUSE_GF2N_LONG
7678
endif

Compiler/GC/instructions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -540,8 +540,8 @@ class split(base.Instruction):
540540
541541
:param: number of arguments to follow (number of bits times number of additive shares plus one)
542542
:param: source (sint)
543-
:param: first share of least significant bit
544-
:param: second share of least significant bit
543+
:param: first share of least significant bit (sbit)
544+
:param: second share of least significant bit (sbit)
545545
:param: (remaining share of least significant bit)...
546546
:param: (repeat from first share for bit one step higher)...
547547
"""

Compiler/GC/types.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -737,6 +737,7 @@ def get_type(cls, n):
737737
:py:obj:`v` and the columns by calling :py:obj:`elements`.
738738
"""
739739
class sbitvecn(cls, _structure):
740+
n_bits = n
740741
@staticmethod
741742
def get_type(n):
742743
return cls.get_type(n)
@@ -757,17 +758,19 @@ def get_input_from(cls, player, size=1, f=0):
757758
758759
:param: player (int)
759760
"""
760-
v = [0] * n
761761
sbits._check_input_player(player)
762762
instructions_base.check_vector_size(size)
763-
for i in range(size):
764-
vv = [sbit() for i in range(n)]
765-
inst.inputbvec(n + 3, f, player, *vv)
766-
for j in range(n):
767-
tmp = vv[j] << i
768-
v[j] = tmp ^ v[j]
769-
sbits._check_input_player(player)
770-
return cls.from_vec(v)
763+
if size == 1:
764+
res = cls.from_vec(sbit() for i in range(n))
765+
inst.inputbvec(n + 3, f, player, *res.v)
766+
return res
767+
else:
768+
elements = []
769+
for i in range(size):
770+
v = sbits.get_type(n)()
771+
inst.inputb(player, n, f, v)
772+
elements.append(v)
773+
return cls(elements)
771774
get_raw_input_from = get_input_from
772775
@classmethod
773776
def from_vec(cls, vector):

Compiler/allocator.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,8 @@ def alloc_reg(self, reg, free):
178178
dup = dup.vectorbase
179179
self.alloc[dup] = self.alloc[base]
180180
dup.i = self.alloc[base]
181+
if not dup.dup_count:
182+
dup.dup_count = len(base.duplicates)
181183

182184
def dealloc_reg(self, reg, inst, free):
183185
if reg.vector:
@@ -275,8 +277,9 @@ def finalize(self, options):
275277
for reg in self.alloc:
276278
for x in reg.get_all():
277279
if x not in self.dealloc and reg not in self.dealloc \
278-
and len(x.duplicates) == 0:
279-
print('Warning: read before write at register', x)
280+
and len(x.duplicates) == x.dup_count:
281+
print('Warning: read before write at register %s/%x' %
282+
(x, id(x)))
280283
print('\tregister trace: %s' % format_trace(x.caller,
281284
'\t\t'))
282285
if options.stop:
@@ -750,6 +753,8 @@ def eliminate(i):
750753
G.remove_node(i)
751754
merge_nodes.discard(i)
752755
stats[type(instructions[i]).__name__] += 1
756+
for reg in instructions[i].get_def():
757+
self.block.parent.program.base_addresses.pop(reg)
753758
instructions[i] = None
754759
if unused_result:
755760
eliminate(i)

Compiler/compilerLib.py

+10
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,18 @@
1313

1414

1515
class Compiler:
16+
singleton = None
17+
1618
def __init__(self, custom_args=None, usage=None, execute=False,
1719
split_args=False):
20+
if Compiler.singleton:
21+
raise CompilerError(
22+
"Cannot have more than one compiler instance. "
23+
"It's not possible to run direct compilation programs with "
24+
"compile.py or compile-run.py.")
25+
else:
26+
Compiler.singleton = self
27+
1828
if usage:
1929
self.usage = usage
2030
else:

Compiler/dijkstra.py

+49-11
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
""" This module implements `Dijkstra's algorithm
2+
<https://en.wikipedia.org/wiki/Dijkstra%27s_algorithm>`_ based on
3+
oblivious RAM. """
4+
5+
16
from Compiler.oram import *
27

38
from Compiler.program import Program
@@ -222,7 +227,21 @@ def dump(self, msg=''):
222227
print_ln()
223228
print_ln()
224229

225-
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None):
230+
def dijkstra(source, edges, e_index, oram_type, n_loops=None, int_type=None,
231+
debug=False):
232+
""" Securely compute Dijstra's algorithm on a secret graph. See
233+
:download:`../Programs/Source/dijkstra_example.mpc` for an
234+
explanation of the required inputs.
235+
236+
:param source: source node (secret or clear-text integer)
237+
:param edges: ORAM representation of edges
238+
:param e_index: ORAM representation of vertices
239+
:param oram_type: ORAM type to use internally (default:
240+
:py:func:`~Compiler.oram.OptimalORAM`)
241+
:param n_loops: when to stop (default: number of edges)
242+
:param int_type: secret integer type (default: sint)
243+
244+
"""
226245
vert_loops = n_loops * e_index.size // edges.size \
227246
if n_loops else -1
228247
dist = oram_type(e_index.size, entry_size=(32,log2(e_index.size)), \
@@ -267,27 +286,46 @@ def f(i):
267286
dist.access(v, (basic_type(alt), u), is_shorter)
268287
#previous.access(v, u, is_shorter)
269288
Q.update(v, basic_type(alt), is_shorter)
270-
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s', \
271-
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(), \
272-
not_visited.reveal())
289+
if debug:
290+
print_ln('u: %s, v: %s, alt: %s, dv: %s, first visit: %s, '
291+
'shorter: %s, running: %s, queue size: %s, last edge: %s',
292+
u.reveal(), v.reveal(), alt.reveal(), dv[0].reveal(),
293+
not_visited.reveal(), is_shorter.reveal(),
294+
running.reveal(), Q.size.reveal(), last_edge.reveal())
273295
return dist
274296

275297
def convert_graph(G):
298+
""" Convert a `NetworkX directed graph
299+
<https://networkx.org/documentation/stable/reference/classes/digraph.html>`_
300+
to the cleartext representation of what :py:func:`dijkstra` expects. """
301+
G = G.copy()
302+
for u in G:
303+
for v in G[u]:
304+
G[u][v].setdefault('weight', 1)
276305
edges = [None] * (2 * G.size())
277306
e_index = [None] * (len(G))
278307
i = 0
279-
for v in G:
308+
for v in sorted(G):
280309
e_index[v] = i
281-
for u in G[v]:
310+
for u in sorted(G[v]):
282311
edges[i] = [u, G[v][u]['weight'], 0]
283312
i += 1
313+
if not G[v]:
314+
edges[i] = [v, 0, 0]
315+
i += 1
284316
edges[i-1][-1] = 1
285-
return edges, e_index
317+
return list(filter(lambda x: x, edges)), e_index
286318

287-
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None, int_type=sint):
288-
for u in G:
289-
for v in G[u]:
290-
G[u][v].setdefault('weight', 1)
319+
def test_dijkstra(G, source, oram_type=ORAM, n_loops=None,
320+
int_type=sint):
321+
""" Securely compute Dijstra's algorithm on a cleartext graph.
322+
323+
:param G: directed graph with NetworkX interface
324+
:param source: source node (secret or clear-text integer)
325+
:param n_loops: when to stop (default: number of edges)
326+
:param int_type: secret integer type (default: sint)
327+
328+
"""
291329
edges_list, e_index_list = convert_graph(G)
292330
edges = oram_type(len(edges_list), \
293331
entry_size=(log2(len(G)), log2(len(G)), 1), \

Compiler/instructions.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ class stop(base.Instruction):
399399
arg_format = ['i']
400400

401401
class use(base.Instruction):
402-
""" Offline data usage. Necessary to avoid reusage while using
402+
r""" Offline data usage. Necessary to avoid reusage while using
403403
preprocessing from files. Also used to multithreading for expensive
404404
preprocessing.
405405
@@ -419,7 +419,7 @@ def get_usage(cls, args):
419419
args[2].i}
420420

421421
class use_inp(base.Instruction):
422-
""" Input usage. Necessary to avoid reusage while using
422+
r""" Input usage. Necessary to avoid reusage while using
423423
preprocessing from files.
424424
425425
:param: domain (0: integer, 1: :math:`\mathrm{GF}(2^n)`, 2: bit)
@@ -1738,7 +1738,7 @@ class print_reg_plains(base.IOInstruction):
17381738
arg_format = ['s']
17391739

17401740
class cond_print_plain(base.IOInstruction):
1741-
""" Conditionally output clear register (with precision).
1741+
r""" Conditionally output clear register (with precision).
17421742
Outputs :math:`x \cdot 2^p` where :math:`p` is the precision.
17431743
17441744
:param: condition (cint, no output if zero)
@@ -1989,7 +1989,7 @@ class closeclientconnection(base.IOInstruction):
19891989
code = base.opcodes['CLOSECLIENTCONNECTION']
19901990
arg_format = ['ci']
19911991

1992-
class writesharestofile(base.IOInstruction):
1992+
class writesharestofile(base.VectorInstruction, base.IOInstruction):
19931993
""" Write shares to ``Persistence/Transactions-P<playerno>.data``
19941994
(appending at the end).
19951995
@@ -2002,11 +2002,12 @@ class writesharestofile(base.IOInstruction):
20022002
__slots__ = []
20032003
code = base.opcodes['WRITEFILESHARE']
20042004
arg_format = tools.chain(['ci'], itertools.repeat('s'))
2005+
vector_index = 1
20052006

20062007
def has_var_args(self):
20072008
return True
20082009

2009-
class readsharesfromfile(base.IOInstruction):
2010+
class readsharesfromfile(base.VectorInstruction, base.IOInstruction):
20102011
""" Read shares from ``Persistence/Transactions-P<playerno>.data``.
20112012
20122013
:param: number of arguments to follow / number of shares plus two (int)
@@ -2018,6 +2019,7 @@ class readsharesfromfile(base.IOInstruction):
20182019
__slots__ = []
20192020
code = base.opcodes['READFILESHARE']
20202021
arg_format = tools.chain(['ci', 'ciw'], itertools.repeat('sw'))
2022+
vector_index = 2
20212023

20222024
def has_var_args(self):
20232025
return True
@@ -2341,7 +2343,7 @@ class convint(base.Instruction):
23412343

23422344
@base.vectorize
23432345
class convmodp(base.Instruction):
2344-
""" Convert clear integer register (vector) to clear register
2346+
r""" Convert clear integer register (vector) to clear register
23452347
(vector). If the bit length is zero, the unsigned conversion is
23462348
used, otherwise signed conversion is used. This makes a difference
23472349
when computing modulo a prime :math:`p`. Signed conversion of
@@ -2814,13 +2816,11 @@ class check(base.Instruction):
28142816
@base.gf2n
28152817
@base.vectorize
28162818
class sqrs(base.CISC):
2817-
""" Secret squaring $s_i = s_j \cdot s_j$. """
2819+
r""" Secret squaring $s_i = s_j \cdot s_j$. """
28182820
__slots__ = []
28192821
arg_format = ['sw', 's']
28202822

28212823
def expand(self):
2822-
if program.options.ring:
2823-
return muls(self.args[0], self.args[1], self.args[1])
28242824
s = [program.curr_block.new_reg('s') for i in range(6)]
28252825
c = [program.curr_block.new_reg('c') for i in range(2)]
28262826
square(s[0], s[1])

Compiler/instructions_base.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -1200,9 +1200,11 @@ def has_var_args(self):
12001200
class VectorInstruction(Instruction):
12011201
__slots__ = []
12021202
is_vec = lambda self: True
1203+
vector_index = 0
12031204

12041205
def get_code(self):
1205-
return super(VectorInstruction, self).get_code(len(self.args[0]))
1206+
return super(VectorInstruction, self).get_code(
1207+
len(self.args[self.vector_index]))
12061208

12071209
class Ciscable(Instruction):
12081210
def copy(self, size, subs):

0 commit comments

Comments
 (0)