@@ -45,16 +45,17 @@ class JaxSolver(pybamm.BaseSolver):
45
45
for details.
46
46
"""
47
47
48
- def __init__ (self , method = 'RK45' , root_method = None ,
49
- rtol = 1e-6 , atol = 1e-6 , extra_options = None ):
48
+ def __init__ (
49
+ self , method = "RK45" , root_method = None , rtol = 1e-6 , atol = 1e-6 , extra_options = None
50
+ ):
50
51
# note: bdf solver itself calculates consistent initial conditions so can set
51
52
# root_method to none, allow user to override this behavior
52
53
super ().__init__ (method , rtol , atol , root_method = root_method )
53
- method_options = [' RK45' , ' BDF' ]
54
+ method_options = [" RK45" , " BDF" ]
54
55
if method not in method_options :
55
- raise ValueError (' method must be one of {}' .format (method_options ))
56
+ raise ValueError (" method must be one of {}" .format (method_options ))
56
57
self .ode_solver = False
57
- if method == ' RK45' :
58
+ if method == " RK45" :
58
59
self .ode_solver = True
59
60
self .extra_options = extra_options or {}
60
61
self .name = "JAX solver ({})" .format (method )
@@ -80,8 +81,9 @@ def get_solve(self, model, t_eval):
80
81
"""
81
82
if model not in self ._cached_solves :
82
83
if model not in self .models_set_up :
83
- raise RuntimeError ("Model is not set up for solving, run"
84
- "`solver.solve(model)` first" )
84
+ raise RuntimeError (
85
+ "Model is not set up for solving, run" "`solver.solve(model)` first"
86
+ )
85
87
86
88
self ._cached_solves [model ] = self .create_solve (model , t_eval )
87
89
@@ -106,32 +108,35 @@ def create_solve(self, model, t_eval):
106
108
107
109
"""
108
110
if model .convert_to_format != "jax" :
109
- raise RuntimeError ("Model must be converted to JAX to use this solver"
110
- " (i.e. `model.convert_to_format = 'jax')" )
111
+ raise RuntimeError (
112
+ "Model must be converted to JAX to use this solver"
113
+ " (i.e. `model.convert_to_format = 'jax')"
114
+ )
111
115
112
116
if model .terminate_events_eval :
113
- raise RuntimeError ("Terminate events not supported for this solver."
114
- " Model has the following events:"
115
- " {}.\n You can remove events using `model.events = []`."
116
- " It might be useful to first solve the model using a"
117
- " different solver to obtain the time of the event, then"
118
- " re-solve using no events and a fixed"
119
- " end-time" .format (model .events ))
117
+ raise RuntimeError (
118
+ "Terminate events not supported for this solver."
119
+ " Model has the following events:"
120
+ " {}.\n You can remove events using `model.events = []`."
121
+ " It might be useful to first solve the model using a"
122
+ " different solver to obtain the time of the event, then"
123
+ " re-solve using no events and a fixed"
124
+ " end-time" .format (model .events )
125
+ )
120
126
121
127
# Initial conditions, make sure they are an 0D array
122
128
y0 = jnp .array (model .y0 ).reshape (- 1 )
123
129
mass = None
124
- if self .method == ' BDF' :
130
+ if self .method == " BDF" :
125
131
mass = model .mass_matrix .entries .toarray ()
126
132
127
133
def rhs_ode (y , t , inputs ):
128
- return model .rhs_eval (t , y , inputs ),
134
+ return ( model .rhs_eval (t , y , inputs ),)
129
135
130
136
def rhs_dae (y , t , inputs ):
131
- return jnp .concatenate ([
132
- model .rhs_eval (t , y , inputs ),
133
- model .algebraic_eval (t , y , inputs ),
134
- ])
137
+ return jnp .concatenate (
138
+ [model .rhs_eval (t , y , inputs ), model .algebraic_eval (t , y , inputs )]
139
+ )
135
140
136
141
def solve_model_rk45 (inputs ):
137
142
y = odeint (
@@ -158,7 +163,7 @@ def solve_model_bdf(inputs):
158
163
)
159
164
return jnp .transpose (y )
160
165
161
- if self .method == ' RK45' :
166
+ if self .method == " RK45" :
162
167
return jax .jit (solve_model_rk45 )
163
168
else :
164
169
return jax .jit (solve_model_bdf )
@@ -194,5 +199,7 @@ def _integrate(self, model, t_eval, inputs=None):
194
199
termination = "final time"
195
200
t_event = None
196
201
y_event = onp .array (None )
197
- return pybamm .Solution (t_eval , y ,
198
- t_event , y_event , termination )
202
+ return pybamm .Solution (
203
+ t_eval , y , t_event , y_event , termination , model = model , inputs = inputs
204
+ )
205
+
0 commit comments