Skip to content

Commit 130bbb9

Browse files
committed
#858 discretisation of VariableDot results in StateVectorDot
1 parent e2b4ba7 commit 130bbb9

File tree

3 files changed

+27
-1
lines changed

3 files changed

+27
-1
lines changed

pybamm/discretisations/discretisation.py

+7
Original file line numberDiff line numberDiff line change
@@ -855,6 +855,13 @@ def _process_symbol(self, symbol):
855855
disc_children = [self.process_symbol(child) for child in symbol.children]
856856
return symbol._function_new_copy(disc_children)
857857

858+
elif isinstance(symbol, pybamm.VariableDot):
859+
return pybamm.StateVectorDot(
860+
*self.y_slices[symbol.get_variable().id],
861+
domain=symbol.domain,
862+
auxiliary_domains=symbol.auxiliary_domains
863+
)
864+
858865
elif isinstance(symbol, pybamm.Variable):
859866
# Check if variable is a standard variable or an external variable
860867
if any(symbol.id == var.id for var in self.external_variables.values()):

pybamm/expression_tree/variable.py

+12
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,18 @@ class VariableDot(Variable):
8787
def __init__(self, name, domain=None, auxiliary_domains=None):
8888
super().__init__(name, domain=domain, auxiliary_domains=auxiliary_domains)
8989

90+
def get_variable(self):
91+
"""
92+
return a :class:`.Variable` corresponding to this VariableDot
93+
94+
Note: Variable._jac adds a dash to the name of the corresponding VariableDot, so
95+
we remove this here
96+
97+
"""
98+
return Variable(self.name[:-1],
99+
domain=self._domain,
100+
auxiliary_domains=self._auxiliary_domains)
101+
90102

91103
class ExternalVariable(Variable):
92104
"""A node in the expression tree representing an external variable variable

tests/unit/test_discretisations/test_discretisation.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,14 @@ def test_process_symbol_base(self):
329329
var_disc = disc.process_symbol(var)
330330
self.assertIsInstance(var_disc, pybamm.StateVector)
331331
self.assertEqual(var_disc.y_slices[0], disc.y_slices[var.id][0])
332+
333+
# variable dot
334+
var_dot = pybamm.VariableDot("var'")
335+
var_vec_dot = pybamm.VariableDot("var vec'", domain=["negative electrode"])
336+
var_dot_disc = disc.process_symbol(var_dot)
337+
self.assertIsInstance(var_dot_disc, pybamm.StateVectorDot)
338+
self.assertEqual(var_dot_disc.y_slices[0], disc.y_slices[var.id][0])
339+
332340
# scalar
333341
scal = pybamm.Scalar(5)
334342
scal_disc = disc.process_symbol(scal)
@@ -1086,7 +1094,6 @@ def test_mass_matirx_inverse(self):
10861094
model.mass_matrix_inv.entries.toarray(), mass_inv.toarray()
10871095
)
10881096

1089-
10901097
if __name__ == "__main__":
10911098
print("Add -v for more debug output")
10921099
import sys

0 commit comments

Comments
 (0)