@@ -43,7 +43,7 @@ def pop(self, size):
43
43
else :
44
44
done = False
45
45
for x in self .by_logsize [logsize + 1 :]:
46
- for block_size , addresses in x .items ():
46
+ for block_size , addresses in sorted ( x .items () ):
47
47
if len (addresses ) > 0 :
48
48
done = True
49
49
break
@@ -60,16 +60,92 @@ def pop(self, size):
60
60
self .by_address [addr + size ] = diff
61
61
return addr
62
62
63
+ class AllocRange :
64
+ def __init__ (self , base = 0 ):
65
+ self .base = base
66
+ self .top = base
67
+ self .limit = base
68
+ self .grow = True
69
+ self .pool = defaultdict (set )
70
+
71
+ def alloc (self , size ):
72
+ if self .pool [size ]:
73
+ return self .pool [size ].pop ()
74
+ elif self .grow or self .top + size <= self .limit :
75
+ res = self .top
76
+ self .top += size
77
+ self .limit = max (self .limit , self .top )
78
+ if res >= REG_MAX :
79
+ raise RegisterOverflowError ()
80
+ return res
81
+
82
+ def free (self , base , size ):
83
+ assert self .base <= base < self .top
84
+ self .pool [size ].add (base )
85
+
86
+ def stop_growing (self ):
87
+ self .grow = False
88
+
89
+ def consolidate (self ):
90
+ regs = []
91
+ for size , pool in self .pool .items ():
92
+ for base in pool :
93
+ regs .append ((base , size ))
94
+ for base , size in reversed (sorted (regs )):
95
+ if base + size == self .top :
96
+ self .top -= size
97
+ self .pool [size ].remove (base )
98
+ regs .pop ()
99
+ else :
100
+ if program .Program .prog .verbose :
101
+ print ('cannot free %d register blocks '
102
+ 'by a gap of %d at %d' %
103
+ (len (regs ), self .top - size - base , base ))
104
+ break
105
+
106
+ class AllocPool :
107
+ def __init__ (self ):
108
+ self .ranges = defaultdict (lambda : [AllocRange ()])
109
+ self .by_base = {}
110
+
111
+ def alloc (self , reg_type , size ):
112
+ for r in self .ranges [reg_type ]:
113
+ res = r .alloc (size )
114
+ if res is not None :
115
+ self .by_base [reg_type , res ] = r
116
+ return res
117
+
118
+ def free (self , reg ):
119
+ r = self .by_base .pop ((reg .reg_type , reg .i ))
120
+ r .free (reg .i , reg .size )
121
+
122
+ def new_ranges (self , min_usage ):
123
+ for t , n in min_usage .items ():
124
+ r = self .ranges [t ][- 1 ]
125
+ assert (n >= r .limit )
126
+ if r .limit < n :
127
+ r .stop_growing ()
128
+ self .ranges [t ].append (AllocRange (n ))
129
+
130
+ def consolidate (self ):
131
+ for r in self .ranges .values ():
132
+ for rr in r :
133
+ rr .consolidate ()
134
+
135
+ def n_fragments (self ):
136
+ return max (len (r ) for r in self .ranges )
137
+
63
138
class StraightlineAllocator :
64
139
"""Allocate variables in a straightline program using n registers.
65
140
It is based on the precondition that every register is only defined once."""
66
141
def __init__ (self , n , program ):
67
142
self .alloc = dict_by_id ()
68
- self .usage = Compiler . program . RegType . create_dict (lambda : 0 )
143
+ self .max_usage = defaultdict (lambda : 0 )
69
144
self .defined = dict_by_id ()
70
145
self .dealloc = set_by_id ()
71
- self . n = n
146
+ assert ( n == REG_MAX )
72
147
self .program = program
148
+ self .old_pool = None
73
149
74
150
def alloc_reg (self , reg , free ):
75
151
base = reg .vectorbase
@@ -79,14 +155,7 @@ def alloc_reg(self, reg, free):
79
155
80
156
reg_type = reg .reg_type
81
157
size = base .size
82
- if free [reg_type , size ]:
83
- res = free [reg_type , size ].pop ()
84
- else :
85
- if self .usage [reg_type ] < self .n :
86
- res = self .usage [reg_type ]
87
- self .usage [reg_type ] += size
88
- else :
89
- raise RegisterOverflowError ()
158
+ res = free .alloc (reg_type , size )
90
159
self .alloc [base ] = res
91
160
92
161
base .i = self .alloc [base ]
@@ -126,7 +195,7 @@ def dealloc_reg(self, reg, inst, free):
126
195
for x in itertools .chain (dup .duplicates , base .duplicates ):
127
196
to_check .add (x )
128
197
129
- free [ reg . reg_type , base . size ]. append ( self . alloc [ base ] )
198
+ free . free ( base )
130
199
if inst .is_vec () and base .vector :
131
200
self .defined [base ] = inst
132
201
for i in base .vector :
@@ -135,6 +204,7 @@ def dealloc_reg(self, reg, inst, free):
135
204
self .defined [reg ] = inst
136
205
137
206
def process (self , program , alloc_pool ):
207
+ self .update_usage (alloc_pool )
138
208
for k ,i in enumerate (reversed (program )):
139
209
unused_regs = []
140
210
for j in i .get_def ():
@@ -161,12 +231,26 @@ def process(self, program, alloc_pool):
161
231
if k % 1000000 == 0 and k > 0 :
162
232
print ("Allocated registers for %d instructions at" % k , time .asctime ())
163
233
234
+ self .update_max_usage (alloc_pool )
235
+ alloc_pool .consolidate ()
236
+
164
237
# print "Successfully allocated registers"
165
238
# print "modp usage: %d clear, %d secret" % \
166
239
# (self.usage[Compiler.program.RegType.ClearModp], self.usage[Compiler.program.RegType.SecretModp])
167
240
# print "GF2N usage: %d clear, %d secret" % \
168
241
# (self.usage[Compiler.program.RegType.ClearGF2N], self.usage[Compiler.program.RegType.SecretGF2N])
169
- return self .usage
242
+ return self .max_usage
243
+
244
+ def update_max_usage (self , alloc_pool ):
245
+ for t , r in alloc_pool .ranges .items ():
246
+ self .max_usage [t ] = max (self .max_usage [t ], r [- 1 ].limit )
247
+
248
+ def update_usage (self , alloc_pool ):
249
+ if self .old_pool :
250
+ self .update_max_usage (self .old_pool )
251
+ if id (self .old_pool ) != id (alloc_pool ):
252
+ alloc_pool .new_ranges (self .max_usage )
253
+ self .old_pool = alloc_pool
170
254
171
255
def finalize (self , options ):
172
256
for reg in self .alloc :
@@ -178,6 +262,21 @@ def finalize(self, options):
178
262
'\t \t ' ))
179
263
if options .stop :
180
264
sys .exit (1 )
265
+ if self .program .verbose :
266
+ def p (sizes ):
267
+ total = defaultdict (lambda : 0 )
268
+ for (t , size ) in sorted (sizes ):
269
+ n = sizes [t , size ]
270
+ total [t ] += size * n
271
+ print ('%s:%d*%d' % (t , size , n ), end = ' ' )
272
+ print ()
273
+ print ('Total:' , dict (total ))
274
+
275
+ sizes = defaultdict (lambda : 0 )
276
+ for reg in self .alloc :
277
+ x = reg .reg_type , reg .size
278
+ print ('Used registers: ' , end = '' )
279
+ p (sizes )
181
280
182
281
def determine_scope (block , options ):
183
282
last_def = defaultdict_by_id (lambda : - 1 )
0 commit comments