forked from data61/MP-SPDZ
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.py
115 lines (101 loc) · 3.55 KB
/
client.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import socket, ssl
import struct
import time
class Client:
def __init__(self, hostnames, port_base, my_client_id):
ctx = ssl.SSLContext()
name = 'C%d' % my_client_id
prefix = 'Player-Data/%s' % name
ctx.load_cert_chain(certfile=prefix + '.pem', keyfile=prefix + '.key')
ctx.load_verify_locations(capath='Player-Data')
self.sockets = []
for i, hostname in enumerate(hostnames):
for j in range(10000):
try:
plain_socket = socket.create_connection(
(hostname, port_base + i))
break
except ConnectionRefusedError:
if j < 60:
time.sleep(1)
else:
raise
octetStream(b'%d' % my_client_id).Send(plain_socket)
self.sockets.append(ctx.wrap_socket(plain_socket,
server_hostname='P%d' % i))
self.specification = octetStream()
self.specification.Receive(self.sockets[0])
def receive_triples(self, T, n):
triples = [[0, 0, 0] for i in range(n)]
os = octetStream()
for socket in self.sockets:
os.Receive(socket)
for triple in triples:
for i in range(3):
t = T()
t.unpack(os)
triple[i] += t
res = []
for triple in triples:
prod = triple[0] * triple[1]
if prod != triple[2]:
raise Exception(
'invalid triple, diff %s' % hex(prod.v - triple[2].v))
return triples
def send_private_inputs(self, values):
T = type(values[0])
triples = self.receive_triples(T, len(values))
os = octetStream()
assert len(values) == len(triples)
for value, triple in zip(values, triples):
(value + triple[0]).pack(os)
for socket in self.sockets:
os.Send(socket)
def receive_outputs(self, T, n):
triples = self.receive_triples(T, n)
return [triple[0] for triple in triples]
class octetStream:
def __init__(self, value=None):
self.buf = b''
self.ptr = 0
if value is not None:
self.buf += value
def reset_write_head(self):
self.buf = b''
self.ptr = 0
def Send(self, socket):
socket.sendall(struct.pack('<i', len(self.buf)))
socket.sendall(self.buf)
def Receive(self, socket):
length = struct.unpack('<I', socket.recv(4))[0]
self.buf = b''
while len(self.buf) < length:
self.buf += socket.recv(length - len(self.buf))
self.ptr = 0
def store(self, value):
self.buf += struct.pack('<i', value)
def get_int(self, length):
buf = self.consume(length)
if length == 4:
return struct.unpack('<i', buf)[0]
elif length == 8:
return struct.unpack('<q', buf)[0]
raise ValueError()
def get_bigint(self):
sign = self.consume(1)[0]
assert(sign in (0, 1))
length = self.get_int(4)
if length:
res = 0
buf = self.consume(length)
for i, b in enumerate(reversed(buf)):
res += b << (i * 8)
if sign:
res *= -1
return res
else:
return 0
def consume(self, length):
self.ptr += length
assert self.ptr <= len(self.buf)
return self.buf[self.ptr - length:self.ptr]