Skip to content

Commit 1f03394

Browse files
committed
#858 time derivative of state vector gives state vector dot, raise errors if taking time derivative of *Dot classes
1 parent 130bbb9 commit 1f03394

File tree

2 files changed

+30
-3
lines changed

2 files changed

+30
-3
lines changed

pybamm/expression_tree/state_vector.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,12 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
256256
return out
257257

258258
def _jac(self, variable):
259-
if isinstance(variable, pybamm.StateVector):
259+
if variable.id == pybamm.t.id:
260+
return StateVectorDot(*self._y_slices, name=self.name + "'",
261+
domain=self.domain,
262+
auxiliary_domains=self.auxiliary_domains,
263+
evaluation_array=self.evaluation_array)
264+
elif isinstance(variable, pybamm.StateVector):
260265
return self._jac_same_vector(variable)
261266
elif isinstance(variable, pybamm.StateVectorDot):
262267
return self._jac_diff_vector(variable)
@@ -312,6 +317,10 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
312317
return out
313318

314319
def _jac(self, variable):
320+
if variable.id == pybamm.t.id:
321+
raise pybamm.ModelError(
322+
"cannot take second time derivative of a state vector"
323+
)
315324
if isinstance(variable, pybamm.StateVectorDot):
316325
return self._jac_same_vector(variable)
317326
elif isinstance(variable, pybamm.StateVector):

pybamm/expression_tree/variable.py

+20-2
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ def _evaluate_for_shape(self):
4848
)
4949

5050
def _jac(self, variable):
51-
if variable == self:
51+
if variable.id == self.id:
5252
return pybamm.Scalar(1)
53-
elif variable == pybamm.t:
53+
elif variable.id == pybamm.t.id:
5454
return pybamm.VariableDot(self.name+"'",
5555
domain=self.domain,
5656
auxiliary_domains=self.auxiliary_domains)
@@ -99,6 +99,14 @@ def get_variable(self):
9999
domain=self._domain,
100100
auxiliary_domains=self._auxiliary_domains)
101101

102+
def _jac(self, variable):
103+
if variable.id == self.id:
104+
return pybamm.Scalar(1)
105+
elif variable.id == pybamm.t.id:
106+
raise pybamm.ModelError("cannot take second time derivative of a Variable")
107+
else:
108+
return pybamm.Scalar(0)
109+
102110

103111
class ExternalVariable(Variable):
104112
"""A node in the expression tree representing an external variable variable
@@ -161,3 +169,13 @@ def _base_evaluate(self, t=None, y=None, y_dot=None, u=None):
161169
# raise more informative error if can't find name in dict
162170
except KeyError:
163171
raise KeyError("External variable '{}' not found".format(self.name))
172+
173+
def _jac(self, variable):
174+
if variable.id == self.id:
175+
return pybamm.Scalar(1)
176+
elif variable.id == pybamm.t.id:
177+
raise pybamm.ModelError("cannot take time derivative of an external variable")
178+
else:
179+
return pybamm.Scalar(0)
180+
181+

0 commit comments

Comments
 (0)