1
- from functools import partial
2
1
import operator as op
3
2
import numpy as onp
4
3
import collections
@@ -83,8 +82,10 @@ def fun_bind_inputs(y, t):
83
82
t0 = t_eval [0 ]
84
83
h0 = t_eval [1 ] - t0
85
84
86
- stepper = _bdf_init (fun_bind_inputs , jac_bind_inputs , mass , t0 , y0 , h0 , rtol , atol )
87
- i = 0
85
+ stepper , failed = _bdf_init (
86
+ fun_bind_inputs , jac_bind_inputs , mass , t0 , y0 , h0 , rtol , atol
87
+ )
88
+ i = failed * len (t_eval )
88
89
y_out = jnp .empty ((len (t_eval ), len (y0 )), dtype = y0 .dtype )
89
90
90
91
init_state = [stepper , t_eval , i , y_out ]
@@ -108,21 +109,24 @@ def for_body(j, y_out):
108
109
return [stepper , t_eval , index , y_out ]
109
110
110
111
stepper , t_eval , i , y_out = jax .lax .while_loop (cond_fun , body_fun ,
111
- init_state )
112
+ init_state )
112
113
113
114
return y_out
114
115
115
116
116
117
BDFInternalStates = [
117
- 't' , 'atol' , 'rtol' , 'M' , 'newton_tol' , 'order' , 'h' , 'n_equal_steps' , 'D' ,
118
- 'y0' , 'scale_y0' , 'kappa' , 'gamma' , 'alpha' , 'c' , 'error_const' , 'J' , 'LU' , 'U' ,
119
- 'psi' , 'n_function_evals' , 'n_jacobian_evals' , 'n_lu_decompositions' , 'n_steps' ]
118
+ 't' , 'atol' , 'rtol' , 'M' , 'newton_tol' , 'order' , 'h' , 'n_equal_steps' , 'D' ,
119
+ 'y0' , 'scale_y0' , 'kappa' , 'gamma' , 'alpha' , 'c' , 'error_const' , 'J' , 'LU' , 'U' ,
120
+ 'psi' , 'n_function_evals' , 'n_jacobian_evals' , 'n_lu_decompositions' , 'n_steps'
121
+ ]
120
122
BDFState = collections .namedtuple ('BDFState' , BDFInternalStates )
121
123
122
124
jax .tree_util .register_pytree_node (
123
- BDFState ,
124
- lambda xs : (tuple (xs ), None ),
125
- lambda _ , xs : BDFState (* xs ))
125
+ BDFState ,
126
+ lambda xs : (tuple (xs ), None ),
127
+ lambda _ , xs : BDFState (* xs )
128
+ )
129
+
126
130
127
131
def _bdf_init (fun , jac , mass , t0 , y0 , h0 , rtol , atol ):
128
132
"""
@@ -165,7 +169,9 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
165
169
state ['newton_tol' ] = jnp .max ((10 * EPS / rtol , jnp .min ((0.03 , rtol ** 0.5 ))))
166
170
167
171
scale_y0 = atol + rtol * jnp .abs (y0 )
168
- y0 = _select_initial_conditions (fun , mass , t0 , y0 , state ['newton_tol' ], scale_y0 )
172
+ y0 , not_converged = _select_initial_conditions (
173
+ fun , mass , t0 , y0 , state ['newton_tol' ], scale_y0
174
+ )
169
175
170
176
f0 = fun (y0 , t0 )
171
177
order = 1
@@ -207,7 +213,7 @@ def _bdf_init(fun, jac, mass, t0, y0, h0, rtol, atol):
207
213
tuple_state = BDFState (* [state [k ] for k in BDFInternalStates ])
208
214
y0 , scale_y0 = _predict (tuple_state , D )
209
215
psi = _update_psi (tuple_state , D )
210
- return tuple_state ._replace (y0 = y0 , scale_y0 = scale_y0 , psi = psi )
216
+ return tuple_state ._replace (y0 = y0 , scale_y0 = scale_y0 , psi = psi ), not_converged
211
217
212
218
213
219
def _compute_R (order , factor ):
@@ -239,7 +245,7 @@ def _select_initial_conditions(fun, M, t0, y0, tol, scale_y0):
239
245
# if all differentiable variables then return y0 (can use normal python if since M
240
246
# is static)
241
247
if not jnp .any (algebraic_variables ):
242
- return y0
248
+ return y0 , False
243
249
244
250
# calculate consistent initial conditions via a newton on -J_a @ delta = f_a This
245
251
# follows this reference:
@@ -256,7 +262,7 @@ def fun_a(y_a):
256
262
scale_y0_a = scale_y0 [algebraic_variables ]
257
263
258
264
d = jnp .zeros (y0_a .shape [0 ], dtype = y0 .dtype )
259
- y_a = jnp .array (y0_a )
265
+ y_a = jnp .array (y0_a , copy = True )
260
266
261
267
# calculate neg jacobian of fun_a
262
268
J_a = jax .jacfwd (fun_a )(y_a )
@@ -290,13 +296,12 @@ def while_body(while_state):
290
296
291
297
return [k + 1 , not_converged , dy_norm_old , d , y_a ]
292
298
293
-
294
299
k , not_converged , dy_norm_old , d , y_a = jax .lax .while_loop (while_cond ,
295
300
while_body ,
296
301
while_state )
297
302
y_tilde = jax .ops .index_update (y0 , algebraic_variables , y_a )
298
303
299
- return y_tilde
304
+ return y_tilde , not_converged
300
305
301
306
302
307
def _select_initial_step (atol , rtol , fun , t0 , y0 , f0 , h0 ):
@@ -399,9 +404,7 @@ def _update_step_size(state, factor):
399
404
- psi term
400
405
"""
401
406
order = state .order
402
- h = state .h
403
-
404
- h *= factor
407
+ h = state .h * factor
405
408
n_equal_steps = 0
406
409
c = h * state .alpha [order ]
407
410
@@ -432,6 +435,7 @@ def _update_step_size(state, factor):
432
435
n_lu_decompositions = n_lu_decompositions , h = h , c = c ,
433
436
D = D , psi = psi , y0 = y0 , scale_y0 = scale_y0 )
434
437
438
+
435
439
def _update_jacobian (state , jac ):
436
440
"""
437
441
we update the jacobian using J(t_{n+1}, y^0_{n+1})
@@ -481,7 +485,7 @@ def while_body(while_state):
481
485
pred = rate >= 1
482
486
pred += rate ** (NEWTON_MAXITER - k ) / (1 - rate ) * dy_norm > tol
483
487
pred *= dy_norm_old >= 0
484
- k += pred * (NEWTON_MAXITER - k )
488
+ k += pred * (NEWTON_MAXITER - k - 1 )
485
489
486
490
d += dy
487
491
y = y0 + d
@@ -495,11 +499,13 @@ def while_body(while_state):
495
499
496
500
return [k + 1 , not_converged , dy_norm_old , d , y , n_function_evals ]
497
501
498
- k , not_converged , dy_norm_old , d , y , n_function_evals = jax .lax .while_loop (while_cond ,
499
- while_body ,
500
- while_state )
502
+ k , not_converged , dy_norm_old , d , y , n_function_evals = \
503
+ jax .lax .while_loop (while_cond ,
504
+ while_body ,
505
+ while_state )
501
506
return not_converged , k , y , d , state ._replace (n_function_evals = n_function_evals )
502
507
508
+
503
509
def rms_norm (arg ):
504
510
return jnp .sqrt (jnp .mean (arg ** 2 ))
505
511
@@ -508,7 +514,7 @@ def _prepare_next_step(state, d):
508
514
D = _update_difference_for_next_step (state , d )
509
515
psi = _update_psi (state , D )
510
516
y0 , scale_y0 = _predict (state , D )
511
- return state ._replace (D = D ,psi = psi ,y0 = y0 ,scale_y0 = scale_y0 )
517
+ return state ._replace (D = D , psi = psi , y0 = y0 , scale_y0 = scale_y0 )
512
518
513
519
514
520
def _prepare_next_step_order_change (state , d , y , n_iter ):
@@ -543,7 +549,7 @@ def _prepare_next_step_order_change(state, d, y, n_iter):
543
549
# now we have the three factors for orders k-1, k and k+1, pick the maximum in
544
550
# order to maximise the resultant step size
545
551
max_index = jnp .argmax (factors )
546
- order = order + max_index - 1
552
+ order += max_index - 1
547
553
548
554
factor = jnp .min ((MAX_FACTOR , safety * factors [max_index ]))
549
555
@@ -578,16 +584,20 @@ def while_body(while_state):
578
584
# newton iteration did not converge, but jacobian has already been
579
585
# evaluated so reduce step size by 0.3 (as per [1]) and try again
580
586
state = tree_multimap (
581
- partial (jnp .where , not_converged * updated_jacobian ),
582
- _update_step_size (state , 0.3 ),
583
- state
587
+ partial (jnp .where , not_converged * updated_jacobian ),
588
+ _update_step_size (state , 0.3 ),
589
+ state
584
590
)
585
591
586
- # if not converged and jacobian not updated, then update the jacobian and try again
592
+ # if not converged and jacobian not updated, then update the jacobian and try
593
+ # again
587
594
(state , updated_jacobian ) = tree_multimap (
588
- partial (jnp .where , not_converged * (updated_jacobian == False )),
589
- (_update_jacobian (state , jac ), True ),
590
- (state , False )
595
+ partial (
596
+ jnp .where ,
597
+ not_converged * (updated_jacobian == False ) # noqa: E712
598
+ ),
599
+ (_update_jacobian (state , jac ), True ),
600
+ (state , False + updated_jacobian )
591
601
)
592
602
593
603
safety = 0.9 * (2 * NEWTON_MAXITER + 1 ) / (2 * NEWTON_MAXITER + n_iter )
@@ -606,17 +616,19 @@ def while_body(while_state):
606
616
error_norm ** (- 1 / (state .order + 1 ))))
607
617
608
618
(state , step_accepted ) = tree_multimap (
609
- partial (jnp .where , (not_converged == False ) * (error_norm > 1 )),
619
+ partial (
620
+ jnp .where ,
621
+ (not_converged == False ) * (error_norm > 1 ) # noqa: E712
622
+ ),
610
623
(_update_step_size (state , factor ), False ),
611
- (state , True )
624
+ (state , not_converged == False )
612
625
)
613
626
614
627
return [state , step_accepted , updated_jacobian , y , d , n_iter ]
615
628
616
629
state , step_accepted , updated_jacobian , y , d , n_iter = \
617
630
jax .lax .while_loop (while_cond , while_body , while_state )
618
631
619
-
620
632
# take the accepted step
621
633
n_steps = state .n_steps + 1
622
634
t = state .t + state .h
@@ -625,7 +637,6 @@ def while_body(while_state):
625
637
# (see page 83 of [2])
626
638
n_equal_steps = state .n_equal_steps + 1
627
639
628
-
629
640
state = state ._replace (n_equal_steps = n_equal_steps , t = t , n_steps = n_steps )
630
641
631
642
state = tree_multimap (
@@ -802,7 +813,7 @@ def flax_scan(f, init, xs, length=None): # pragma: no cover
802
813
return carry , onp .stack (ys )
803
814
804
815
805
- @partial (jax .jit , static_argnums = (0 , 1 , 2 , 3 ))
816
+ @jax . partial (jax .jit , static_argnums = (0 , 1 , 2 , 3 ))
806
817
def _bdf_odeint_wrapper (func , mass , rtol , atol , y0 , ts , * args ):
807
818
y0 , unravel = ravel_pytree (y0 )
808
819
if mass is None :
0 commit comments