Skip to content

Commit eae7586

Browse files
sage: Exit with non-zero status in case of failures
1 parent b54d843 commit eae7586

3 files changed

+29
-18
lines changed

sage/group_prover.sage

+6-7
Original file line numberDiff line numberDiff line change
@@ -299,22 +299,21 @@ def check_symbolic(R, assumeLaw, assumeAssert, assumeBranch, require):
299299

300300
if conflicts(R, assume):
301301
# This formula does not apply
302-
return None
302+
return (True, None)
303303

304304
describe = describe_extra(R, assumeLaw + assumeBranch, assumeAssert)
305+
if describe != "":
306+
describe = " (assuming " + describe + ")"
305307

306308
ok, msg = prove_zero(R, require.zero, assume)
307309
if not ok:
308-
return "FAIL, %s fails (assuming %s)" % (str(msg), describe)
310+
return (False, "FAIL, %s fails%s" % (str(msg), describe))
309311

310312
res, expl = prove_nonzero(R, require.nonzero, assume)
311313
if not res:
312-
return "FAIL, %s fails (assuming %s)" % (str(expl), describe)
314+
return (False, "FAIL, %s fails%s" % (str(expl), describe))
313315

314-
if describe != "":
315-
return "OK (assuming %s)" % describe
316-
else:
317-
return "OK"
316+
return (True, "OK%s" % describe)
318317

319318

320319
def concrete_verify(c):

sage/prove_group_implementations.sage

+13-10
Original file line numberDiff line numberDiff line change
@@ -292,15 +292,18 @@ def formula_secp256k1_gej_add_ge_old(branch, a, b):
292292
return (constraints(zero={b.Z - 1 : 'b.z=1', b.Infinity : 'b_finite'}), constraints(zero=zero, nonzero=nonzero), jacobianpoint(rx, ry, rz))
293293

294294
if __name__ == "__main__":
295-
check_symbolic_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var)
296-
check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var)
297-
check_symbolic_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var)
298-
check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge)
299-
check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old)
295+
success = True
296+
success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var)
297+
success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var)
298+
success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var)
299+
success = success & check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge)
300+
success = success & (not check_symbolic_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old))
300301

301302
if len(sys.argv) >= 2 and sys.argv[1] == "--exhaustive":
302-
check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var, 43)
303-
check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var, 43)
304-
check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var, 43)
305-
check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge, 43)
306-
check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old, 43)
303+
success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_var", 0, 7, 5, formula_secp256k1_gej_add_var, 43)
304+
success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_var", 0, 7, 5, formula_secp256k1_gej_add_ge_var, 43)
305+
success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_zinv_var", 0, 7, 5, formula_secp256k1_gej_add_zinv_var, 43)
306+
success = success & check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge", 0, 7, 16, formula_secp256k1_gej_add_ge, 43)
307+
success = success & (not check_exhaustive_jacobian_weierstrass("secp256k1_gej_add_ge_old [should fail]", 0, 7, 4, formula_secp256k1_gej_add_ge_old, 43))
308+
309+
sys.exit(int(not success))

sage/weierstrass_prover.sage

+10-1
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,7 @@ def check_exhaustive_jacobian_weierstrass(name, A, B, branches, formula, p):
184184
if r:
185185
points.append(point)
186186

187+
ret = True
187188
for za in range(1, p):
188189
for zb in range(1, p):
189190
for pa in points:
@@ -211,8 +212,11 @@ def check_exhaustive_jacobian_weierstrass(name, A, B, branches, formula, p):
211212
match = True
212213
r, e = concrete_verify(require)
213214
if not r:
215+
ret = False
214216
print(" failure in branch %i for (%s,%s,%s,%s) + (%s,%s,%s,%s) = (%s,%s,%s,%s): %s" % (branch, pA.X, pA.Y, pA.Z, pA.Infinity, pB.X, pB.Y, pB.Z, pB.Infinity, pC.X, pC.Y, pC.Z, pC.Infinity, e))
217+
215218
print()
219+
return ret
216220

217221

218222
def check_symbolic_function(R, assumeAssert, assumeBranch, f, A, B, pa, pb, pA, pB, pC):
@@ -244,6 +248,7 @@ def check_symbolic_jacobian_weierstrass(name, A, B, branches, formula):
244248

245249
print("Formula " + name + ":")
246250
count = 0
251+
ret = True
247252
for branch in range(branches):
248253
assumeFormula, assumeBranch, pC = formula(branch, pA, pB)
249254
pC.X = lift(pC.X)
@@ -252,7 +257,10 @@ def check_symbolic_jacobian_weierstrass(name, A, B, branches, formula):
252257
pC.Infinity = lift(pC.Infinity)
253258

254259
for key in laws_jacobian_weierstrass:
255-
res[key].append((check_symbolic_function(R, assumeFormula, assumeBranch, laws_jacobian_weierstrass[key], A, B, pa, pb, pA, pB, pC), branch))
260+
success, msg = check_symbolic_function(R, assumeFormula, assumeBranch, laws_jacobian_weierstrass[key], A, B, pa, pb, pA, pB, pC)
261+
if not success:
262+
ret = False
263+
res[key].append((msg, branch))
256264

257265
for key in res:
258266
print(" %s:" % key)
@@ -262,3 +270,4 @@ def check_symbolic_jacobian_weierstrass(name, A, B, branches, formula):
262270
print(" branch %i: %s" % (x[1], x[0]))
263271

264272
print()
273+
return ret

0 commit comments

Comments
 (0)