-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathcost.py
240 lines (199 loc) · 8.73 KB
/
cost.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
# Data for the cost of sieving from:
# Estimating quantum speedups for lattice sieves
# Martin R. Albrecht and Vlad Gheorghiu and Eamonn W. Postlethwaite and John M. Schanck
# data file "cost-estimate-list_decoding-classical.csv"
# The data used below is from the version of May 2020, available at
# https://eprint.iacr.org/eprint-bin/getfile.pl?entry=2019/1161&version=20200520:144757&file=1161.pdf
# This datafile can be extracted from the pdf via the linux tool `pdfdetach'.
# Another version of that datafile is more easily accessible at
# https://github.com/jschanck/eprint-2019-1161/blob/main/data/cost-estimate-list_decoding-classical.csv
# and differ from the one we used by less than a bit at dim 376.
from mpmath import mp
from math import ceil, floor, exp, pi
from math import log as ln
def log2(x):
return ln(x)/ln(2.)
agps20_gate_data = {
64 :42.5948446291284, 72 :44.8735917172503, 80 :47.4653141889341, 88 :50.0329479433691, 96 :52.5817667347844,
104 :55.1130237325179, 112 :57.6295421450947, 120 :60.133284108578, 128 :62.1470129451821, 136 :65.4744488064273,
144 :67.951405476229, 152 :70.0494944191399, 160 :72.50927387359, 168 :74.9619105412039, 176 :77.4100782579645,
184 :79.3495443657483, 192 :81.7856479853679, 200 :84.2178462414349, 208 :86.646452845262, 216 :89.0717383389617,
224 :91.4939375786565, 232 :93.9132560751063, 240 :96.3298751307529, 248 :98.7439563146036, 256 :101.155644837658,
264 :104.091650357302, 272 :106.500713866161, 280 :108.907671199501, 288 :111.312627066864, 296 :113.715679081585,
304 :116.11691871212, 312 :118.516432037545, 320 :120.914300351043, 328 :123.310600632063, 336 :125.705405925853,
344 :128.098785623819, 352 :130.490805751072, 360 :132.881529104042, 368 :135.271015458153, 376 :137.659321707881,
384 :140.046501985502, 392 :142.432607773479, 400 :144.817688009257, 408 :147.201789183958, 416 :149.584955436701,
424 :151.967228645918, 432 :154.348648518547, 440 :156.729252677678, 448 :159.109076748918, 456 :161.488154445581,
464 :163.866517652676, 472 :166.24419650959, 480 :168.621219491327, 488 :170.997613488119, 496 :173.373403883249,
504 :175.748614628914, 512 :178.123268319974, 520 :180.931640474467, 528 :183.305745118107, 536 :185.679338509895,
544 :188.052439374005, 552 :190.425065356218, 560 :192.797233085084, 568 :195.168958230518, 576 :197.540255559816,
584 :199.911138991095, 592 :202.281621644196, 600 :204.651715889082, 608 :207.02143339179, 616 :209.390785157985,
624 :211.759781574203, 632 :214.128432446848, 640 :216.496747039019, 648 :218.864734105257, 656 :221.232401924303,
664 :223.599758329925, 672 :225.96681073994, 680 :228.333566183483, 688 :230.700031326626, 696 :233.066212496418,
704 :235.43211570344, 712 :237.797746662944, 720 :240.163110814653, 728 :242.528213341298, 736 :244.893059185964,
744 :247.25765306831, 752 :249.621999499728, 760 :251.986102797502, 768 :254.349967098032, 776 :256.71359636917,
784 :259.076994421734, 792 :261.440164920231, 800 :263.803111392861, 808 :266.165837240825, 816 :268.528345816343,
824 :270.890640143248, 832 :273.252723321704, 840 :275.614598434176, 848 :277.976268306208, 856 :280.337735739304,
864 :282.699003457275, 872 :285.060074111424, 880 :287.420950285349, 888 :289.781634499399, 896 :292.142129214795,
904 :294.502436837451, 912 :296.862559721505, 920 :299.222500172584, 928 :301.582260450819, 936 :303.941842773632,
944 :306.301249318305, 952 :308.660482224348, 960 :311.019543595679, 968 :313.378435502636, 976 :315.737159983825,
984 :318.095719047813, 992 :320.454114674691, 1000:322.8123488175, 1008:325.170423403542,1016:327.52834033558,
1024:329.886101492934
}
# Function C from AGPS20 source code
def caps_vol(d, theta, integrate=False, prec=None):
"""
The probability that some v from the sphere has angle at most theta with some fixed u.
:param d: We consider spheres of dimension `d-1`
:param theta: angle in radians
:param: compute via explicit integration
:param: precision to use
EXAMPLE::
sage: C(80, pi/3)
mpf('1.0042233739846629e-6')
"""
prec = prec if prec else mp.prec
with mp.workprec(prec):
theta = mp.mpf(theta)
d = mp.mpf(d)
if integrate:
r = (
1
/ mp.sqrt(mp.pi)
* mp.gamma(d / 2)
/ mp.gamma((d - 1) / 2)
* mp.quad(lambda x: mp.sin(x) ** (d - 2), (0, theta), error=True)[0]
)
else:
r = mp.betainc((d - 1) / 2, 1 / 2.0, x2=mp.sin(theta) ** 2, regularized=True) / 2
return r
# Return log2 of the number of gates for FindAllPairs according to AGPS20
def agps20_gates(beta_prime):
k = beta_prime / 8
if k != round(k):
x = k - floor(k)
d1 = agps20_gates(8*floor(k))
d2 = agps20_gates(8*(floor(k) + 1))
return x * d2 + (1 - x) * d1
return agps20_gate_data[beta_prime]
# Return log2 of the number of vectors for sieving according to AGPS20
def agps20_vectors(beta_prime):
k = round(beta_prime)
N = 1./caps_vol(beta_prime, mp.pi/3.)
return log2(N)
# Progressivity Overhead Factor
C = 1./(1.- 2**(-.292))
def dims4free(beta):
return ceil(beta * ln(4./3.) / ln(beta/(2*pi*exp(1))))
#cost of bkz with progressive sieve
def theo_bkz_cost(n, beta,J=1):
if(beta <=10):
return (0,0)
beta_prime = floor(beta - dims4free(beta))
if(beta_prime < 64 or beta < beta_prime):
return (0,0)
elif(beta_prime > 1024):
return (float("inf"),float("inf"))
else:
gates = log2((1.*(n-beta)/J)*C*C) + agps20_gates(beta_prime)
bits = log2(8*beta_prime) + agps20_vectors(beta_prime)
return (gates, bits)
def theo_pump_cost(beta):
if(beta <=10):
return (0,0)
beta_prime = floor(beta - dims4free(beta))
if(beta_prime < 64 or beta < beta_prime):
return (0,0)
elif(beta_prime > 1024):
return (float("inf"),float("inf"))
else:
gates = log2(C*C) + agps20_gates(beta_prime)
bits = log2(8*beta_prime) + agps20_vectors(beta_prime)
return (gates, bits)
def pump_cost(d,beta,cost_model = 1):
if(cost_model == 1):
return theo_pump_cost(beta)
elif(cost_model == 2):
return log2(get_pump_time(beta,d)),theo_pump_cost(beta)[1]
def bkz_cost(d, beta,J=1,cost_model=1):
if(cost_model == 1):
return theo_bkz_cost(d, beta,J)
elif(cost_model == 2):
f = dims4free(beta)
return log2(get_pre_pnj_time(d,beta,f,J)),theo_bkz_cost(d, beta,J)[1]
def summary(n, beta):
beta_prime = floor(beta - dims4free(beta))
gates1, bits1 = bkz_cost(n, beta)
gates2, bits2 = pump_cost(beta)
print(gates1,gates2)
gates = log2(2**gates1+2**gates2)
bits = max(bits1,bits2)
return(n, beta, beta_prime, gates, bits)
###########################
#practical cost model
# threads = 32, gpus = 2, pnj-bkz cost
def get_k1_k2_pnj(beta,sieve):
if beta >=0 and beta <10:
k1 = 0
k2 = 0
elif beta>=10 and beta<=42 and sieve == False:
k1 = 0.03
k2 = 5.188
elif beta<=60 and sieve == False:
k1 = 0.19
k2 = -1.741
elif beta <= 97:
k1 = 0.056
k2 = 7.85
elif beta <= 118:
k1 = 0.215
k2 = -7.61
elif beta <= 128:
k1 = 0.314
k2 = - 19.24
else:
k1 = 0.368
k2 = -26.15
return k1,k2
# threads = 32, gpus = 2, test pump
def get_k1_k2_pump(beta):
if beta >=0 and beta <10:
k1 = 0
k2 = 0
elif beta>=10 and beta<=60:
k1 = 0.035657
k2 = -2.317327
elif beta <= 96:
k1 = 0.078794
k2 = -0.039742
elif beta <= 116:
k1 = 0.231927
k2 = -14.713430
elif beta <= 128:
k1 = 0.314
k2 = -24.21
else:
k1 = 0.368
k2 = -31.12
# else: #30 > 80
# k1 = 0.3642
# k2 = - 24.398
return k1,k2
#get pump time test in threads = 20
def get_pump_time(beta,d):
#make sure not use the enum cost
k1, k2 = get_k1_k2_pump(beta) # threads = 20
# k = (1/71.)*((1.33)**(beta/10.))
T_pump = round((2 **(k1*(beta)+k2)),4)
return T_pump # n_expected = beta -f , beta = d-llb
#get pnj-BKZ time test in threads = 20
def get_pre_pnj_time(d,beta,f,jump):
if beta <= 60:
sieve = False
else:
sieve = True
k1,k2 = get_k1_k2_pnj(beta,sieve)
c3, c4 = 0.018, -2.24
T_pnj = 2**(k1*(beta-f)+k2)
pre_pnj_time = T_pnj*(c3*d+c4)/jump
return round(pre_pnj_time,4)