@@ -197,7 +197,7 @@ def _integrate(self, model, t_eval, inputs_dict=None):
197
197
# Initialize solution
198
198
solution = pybamm .Solution (
199
199
np .array ([t ]), y0 , model , inputs_dict ,
200
- sensitivities = explicit_sensitivities
200
+ sensitivities = False ,
201
201
)
202
202
solution .solve_time = 0
203
203
solution .integration_time = 0
@@ -240,7 +240,8 @@ def _integrate(self, model, t_eval, inputs_dict=None):
240
240
# halve the step size and try again.
241
241
try :
242
242
current_step_sol = self ._run_integrator (
243
- model , y0 , inputs_dict , inputs , t_window , use_grid = use_grid
243
+ model , y0 , inputs_dict , inputs , t_window , use_grid = use_grid ,
244
+ extract_sensitivities_in_solution = False ,
244
245
)
245
246
solved = True
246
247
except pybamm .SolverError :
@@ -273,6 +274,20 @@ def _integrate(self, model, t_eval, inputs_dict=None):
273
274
t = t_window [- 1 ]
274
275
# update y0
275
276
y0 = solution .all_ys [- 1 ][:, - 1 ]
277
+
278
+ # now we extract sensitivities from the solution
279
+ if (explicit_sensitivities ):
280
+ # save original ys[0] and replace with separated soln
281
+ # TODO: This is a dodgy hack, perhaps re-init the solution object?
282
+ solution ._all_ys_and_sens = [solution ._all_ys [0 ][:]]
283
+ solution ._all_ys [0 ], solution ._sensitivities = \
284
+ solution ._extract_explicit_sensitivities (
285
+ solution .all_models [0 ],
286
+ solution .all_ys [0 ],
287
+ solution .all_ts [0 ],
288
+ solution .all_inputs [0 ],
289
+ )
290
+
276
291
return solution
277
292
278
293
def _solve_for_event (self , coarse_solution , init_event_signs ):
@@ -598,12 +613,20 @@ def create_integrator(self, model, inputs, t_eval=None, use_event_switch=False):
598
613
599
614
return integrator
600
615
601
- def _run_integrator (self , model , y0 , inputs_dict , inputs , t_eval , use_grid = True ):
616
+ def _run_integrator (self , model , y0 , inputs_dict ,
617
+ inputs , t_eval , use_grid = True ,
618
+ extract_sensitivities_in_solution = None ,
619
+ ):
602
620
pybamm .logger .debug ("Running CasADi integrator" )
603
621
604
622
# are we solving explicit forward equations?
605
623
explicit_sensitivities = bool (self .calculate_sensitivites )
606
624
625
+ # by default we extract sensitivities in the solution if we
626
+ # are calculating the sensitivities
627
+ if extract_sensitivities_in_solution is None :
628
+ extract_sensitivities_in_solution = explicit_sensitivities
629
+
607
630
if use_grid is True :
608
631
t_eval_shifted = t_eval - t_eval [0 ]
609
632
t_eval_shifted_rounded = np .round (t_eval_shifted , decimals = 12 ).tobytes ()
@@ -614,8 +637,9 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
614
637
len_rhs = model .concatenated_rhs .size
615
638
616
639
# Check y0 to see if it includes sensitivities
617
- if model .len_rhs_and_alg != y0 .shape [0 ]:
618
- len_rhs = len_rhs * (inputs .shape [0 ] + 1 )
640
+ if explicit_sensitivities :
641
+ num_parameters = model .len_rhs_sens // model .len_rhs
642
+ len_rhs = len_rhs * (num_parameters + 1 )
619
643
620
644
y0_diff = y0 [:len_rhs ]
621
645
y0_alg = y0 [len_rhs :]
@@ -634,7 +658,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
634
658
y_sol = casadi .vertcat (casadi_sol ["xf" ], casadi_sol ["zf" ])
635
659
sol = pybamm .Solution (
636
660
t_eval , y_sol , model , inputs_dict ,
637
- sensitivities = explicit_sensitivities
661
+ sensitivities = extract_sensitivities_in_solution
638
662
)
639
663
sol .integration_time = integration_time
640
664
return sol
@@ -665,7 +689,7 @@ def _run_integrator(self, model, y0, inputs_dict, inputs, t_eval, use_grid=True)
665
689
666
690
sol = pybamm .Solution (
667
691
t_eval , y_sol , model , inputs_dict ,
668
- sensitivities = explicit_sensitivities
692
+ sensitivities = extract_sensitivities_in_solution
669
693
)
670
694
sol .integration_time = integration_time
671
695
return sol
0 commit comments