Skip to content

Commit 3e31c5f

Browse files
committedJul 14, 2020
#1104 fix flake8 and some minor bugs
1 parent 0ca87c4 commit 3e31c5f

File tree

2 files changed

+49
-38
lines changed

2 files changed

+49
-38
lines changed
 

‎pybamm/solvers/jax_bdf_solver.py

+48-37
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from functools import partial
21
import operator as op
32
import numpy as onp
43
import collections
@@ -83,8 +82,10 @@ def fun_bind_inputs(y, t):
8382
t0 = t_eval[0]
8483
h0 = t_eval[1] - t0
8584

86-
stepper = _bdf_init(fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol)
87-
i = 0
85+
stepper, failed = _bdf_init(
86+
fun_bind_inputs, jac_bind_inputs, mass, t0, y0, h0, rtol, atol
87+
)
88+
i = failed * len(t_eval)
8889
y_out = jnp.empty((len(t_eval), len(y0)), dtype=y0.dtype)
8990

9091
init_state = [stepper, t_eval, i, y_out]
@@ -108,21 +109,24 @@ def for_body(j, y_out):
108109
return [stepper, t_eval, index, y_out]
109110

110111
stepper, t_eval, i, y_out = jax.lax.while_loop(cond_fun, body_fun,
111-
init_state)
112+
init_state)
112113

113114
return y_out
114115

115116

116117
BDFInternalStates = [
117-
't', 'atol', 'rtol', 'M', 'newton_tol', 'order', 'h', 'n_equal_steps', 'D',
118-
'y0', 'scale_y0', 'kappa', 'gamma', 'alpha', 'c', 'error_const', 'J', 'LU', 'U',
119-
'psi', 'n_function_evals', 'n_jacobian_evals', 'n_lu_decompositions', 'n_steps']
118+
't', 'atol', 'rtol', 'M', 'newton_tol', 'order', 'h', 'n_equal_steps', 'D',
119+
'y0', 'scale_y0', 'kappa', 'gamma', 'alpha', 'c', 'error_const', 'J', 'LU', 'U',
120+
'psi', 'n_function_evals', 'n_jacobian_evals', 'n_lu_decompositions', 'n_steps'
121+
]
120122
BDFState = collections.namedtuple('BDFState', BDFInternalStates)
121123

122124
jax.tree_util.register_pytree_node(
123-
BDFState,
124-
lambda xs: (tuple(xs), None),
125-
lambda _, xs: BDFState(*xs))
125+
BDFState,
126+
lambda xs: (tuple(xs), None),
127+
lambda _, xs: BDFState(*xs)
128+
)
129+
126130

127131
def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
128132
"""
@@ -165,7 +169,9 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
165169
state['newton_tol'] = jnp.max((10 * EPS / rtol, jnp.min((0.03, rtol ** 0.5))))
166170

167171
scale_y0 = atol + rtol * jnp.abs(y0)
168-
y0 = _select_initial_conditions(fun, mass, t0, y0, state['newton_tol'], scale_y0)
172+
y0, not_converged = _select_initial_conditions(
173+
fun, mass, t0, y0, state['newton_tol'], scale_y0
174+
)
169175

170176
f0 = fun(y0, t0)
171177
order = 1
@@ -207,7 +213,7 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
207213
tuple_state = BDFState(*[state[k] for k in BDFInternalStates])
208214
y0, scale_y0 = _predict(tuple_state, D)
209215
psi = _update_psi(tuple_state, D)
210-
return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi)
216+
return tuple_state._replace(y0=y0, scale_y0=scale_y0, psi=psi), not_converged
211217

212218

213219
def _compute_R(order, factor):
@@ -239,7 +245,7 @@ def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
239245
# if all differentiable variables then return y0 (can use normal python if since M
240246
# is static)
241247
if not jnp.any(algebraic_variables):
242-
return y0
248+
return y0, False
243249

244250
# calculate consistent initial conditions via a newton on -J_a @ delta = f_a This
245251
# follows this reference:
@@ -256,7 +262,7 @@ def fun_a(y_a):
256262
scale_y0_a = scale_y0[algebraic_variables]
257263

258264
d = jnp.zeros(y0_a.shape[0], dtype=y0.dtype)
259-
y_a = jnp.array(y0_a)
265+
y_a = jnp.array(y0_a, copy=True)
260266

261267
# calculate neg jacobian of fun_a
262268
J_a = jax.jacfwd(fun_a)(y_a)
@@ -290,13 +296,12 @@ def while_body(while_state):
290296

291297
return [k + 1, not_converged, dy_norm_old, d, y_a]
292298

293-
294299
k, not_converged, dy_norm_old, d, y_a = jax.lax.while_loop(while_cond,
295300
while_body,
296301
while_state)
297302
y_tilde = jax.ops.index_update(y0, algebraic_variables, y_a)
298303

299-
return y_tilde
304+
return y_tilde, not_converged
300305

301306

302307
def _select_initial_step(atol, rtol, fun, t0, y0, f0, h0):
@@ -399,9 +404,7 @@ def _update_step_size(state, factor):
399404
- psi term
400405
"""
401406
order = state.order
402-
h = state.h
403-
404-
h *= factor
407+
h = state.h * factor
405408
n_equal_steps = 0
406409
c = h * state.alpha[order]
407410

@@ -432,6 +435,7 @@ def _update_step_size(state, factor):
432435
n_lu_decompositions=n_lu_decompositions, h=h, c=c,
433436
D=D, psi=psi, y0=y0, scale_y0=scale_y0)
434437

438+
435439
def _update_jacobian(state, jac):
436440
"""
437441
we update the jacobian using J(t_{n+1}, y^0_{n+1})
@@ -481,7 +485,7 @@ def while_body(while_state):
481485
pred = rate >= 1
482486
pred += rate ** (NEWTON_MAXITER - k) / (1 - rate) * dy_norm > tol
483487
pred *= dy_norm_old >= 0
484-
k += pred * (NEWTON_MAXITER - k)
488+
k += pred * (NEWTON_MAXITER - k - 1)
485489

486490
d += dy
487491
y = y0 + d
@@ -495,11 +499,13 @@ def while_body(while_state):
495499

496500
return [k + 1, not_converged, dy_norm_old, d, y, n_function_evals]
497501

498-
k, not_converged, dy_norm_old, d, y, n_function_evals = jax.lax.while_loop(while_cond,
499-
while_body,
500-
while_state)
502+
k, not_converged, dy_norm_old, d, y, n_function_evals = \
503+
jax.lax.while_loop(while_cond,
504+
while_body,
505+
while_state)
501506
return not_converged, k, y, d, state._replace(n_function_evals=n_function_evals)
502507

508+
503509
def rms_norm(arg):
504510
return jnp.sqrt(jnp.mean(arg**2))
505511

@@ -508,7 +514,7 @@ def _prepare_next_step(state, d):
508514
D = _update_difference_for_next_step(state, d)
509515
psi = _update_psi(state, D)
510516
y0, scale_y0 = _predict(state, D)
511-
return state._replace(D=D,psi=psi,y0=y0,scale_y0=scale_y0)
517+
return state._replace(D=D, psi=psi, y0=y0, scale_y0=scale_y0)
512518

513519

514520
def _prepare_next_step_order_change(state, d, y, n_iter):
@@ -543,7 +549,7 @@ def _prepare_next_step_order_change(state, d, y, n_iter):
543549
# now we have the three factors for orders k-1, k and k+1, pick the maximum in
544550
# order to maximise the resultant step size
545551
max_index = jnp.argmax(factors)
546-
order = order + max_index - 1
552+
order += max_index - 1
547553

548554
factor = jnp.min((MAX_FACTOR, safety * factors[max_index]))
549555

@@ -578,16 +584,20 @@ def while_body(while_state):
578584
# newton iteration did not converge, but jacobian has already been
579585
# evaluated so reduce step size by 0.3 (as per [1]) and try again
580586
state = tree_multimap(
581-
partial(jnp.where, not_converged * updated_jacobian),
582-
_update_step_size(state, 0.3),
583-
state
587+
partial(jnp.where, not_converged * updated_jacobian),
588+
_update_step_size(state, 0.3),
589+
state
584590
)
585591

586-
# if not converged and jacobian not updated, then update the jacobian and try again
592+
# if not converged and jacobian not updated, then update the jacobian and try
593+
# again
587594
(state, updated_jacobian) = tree_multimap(
588-
partial(jnp.where, not_converged * (updated_jacobian == False)),
589-
(_update_jacobian(state, jac), True),
590-
(state, False)
595+
partial(
596+
jnp.where,
597+
not_converged * (updated_jacobian == False) # noqa: E712
598+
),
599+
(_update_jacobian(state, jac), True),
600+
(state, False + updated_jacobian)
591601
)
592602

593603
safety = 0.9 * (2 * NEWTON_MAXITER + 1) / (2 * NEWTON_MAXITER + n_iter)
@@ -606,17 +616,19 @@ def while_body(while_state):
606616
error_norm ** (-1 / (state.order + 1))))
607617

608618
(state, step_accepted) = tree_multimap(
609-
partial(jnp.where, (not_converged == False) * (error_norm > 1)),
619+
partial(
620+
jnp.where,
621+
(not_converged == False) * (error_norm > 1) # noqa: E712
622+
),
610623
(_update_step_size(state, factor), False),
611-
(state, True)
624+
(state, not_converged == False)
612625
)
613626

614627
return [state, step_accepted, updated_jacobian, y, d, n_iter]
615628

616629
state, step_accepted, updated_jacobian, y, d, n_iter = \
617630
jax.lax.while_loop(while_cond, while_body, while_state)
618631

619-
620632
# take the accepted step
621633
n_steps = state.n_steps + 1
622634
t = state.t + state.h
@@ -625,7 +637,6 @@ def while_body(while_state):
625637
# (see page 83 of [2])
626638
n_equal_steps = state.n_equal_steps + 1
627639

628-
629640
state = state._replace(n_equal_steps=n_equal_steps, t=t, n_steps=n_steps)
630641

631642
state = tree_multimap(
@@ -802,7 +813,7 @@ def flax_scan(f, init, xs, length=None): # pragma: no cover
802813
return carry, onp.stack(ys)
803814

804815

805-
@partial(jax.jit, static_argnums=(0, 1, 2, 3))
816+
@jax.partial(jax.jit, static_argnums=(0, 1, 2, 3))
806817
def _bdf_odeint_wrapper(func, mass, rtol, atol, y0, ts, *args):
807818
y0, unravel = ravel_pytree(y0)
808819
if mass is None:

‎tests/unit/test_solvers/test_jax_solver.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ def test_model_solver(self):
3838
t_first_solve = time.perf_counter() - t0
3939
np.testing.assert_array_equal(solution.t, t_eval)
4040
np.testing.assert_allclose(solution.y[0], np.exp(0.1 * solution.t),
41-
rtol=1e-7, atol=1e-7)
41+
rtol=1e-6, atol=1e-6)
4242

4343
# Test time
4444
self.assertEqual(

0 commit comments

Comments
 (0)
Please sign in to comment.