Skip to content

Commit 2cb99ad

Browse files
committed
#1477 fix bug in jax evaluate
1 parent df0ff95 commit 2cb99ad

File tree

1 file changed

+8
-2
lines changed

1 file changed

+8
-2
lines changed

pybamm/expression_tree/operations/evaluate_python.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -685,7 +685,10 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
685685
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
686686
result = result.reshape(result.shape[0], -1)
687687

688-
return result
688+
if known_evals is not None:
689+
return result, known_evals
690+
else:
691+
return result
689692

690693

691694
class EvaluatorJaxSensitivities:
@@ -704,4 +707,7 @@ def evaluate(self, t=None, y=None, y_dot=None, inputs=None, known_evals=None):
704707
# execute code
705708
result = self._jac_evaluate(*self._constants, t, y, y_dot, inputs, known_evals)
706709

707-
return result
710+
if known_evals is not None:
711+
return result, known_evals
712+
else:
713+
return result

0 commit comments

Comments
 (0)