Skip to content

Commit 802cfd0

Browse files
#664 convert scalars and operations to casadi
1 parent 9171302 commit 802cfd0

23 files changed

+182
-4
lines changed

.requirements-docs.txt

+1
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ pandas>=0.23
55
anytree>=2.4.3
66
autograd>=1.2
77
scikit-fem>=0.2.0
8+
casadi>=3.5.0
89
guzzle-sphinx-theme
910
sphinx>=1.5

docs/source/expression_tree/index.rst

+1-2
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ Expression Tree
1515
unary_operator
1616
concatenations
1717
broadcasts
18-
simplify
1918
functions
2019
interpolant
21-
evaluate
20+
operations/index
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Convert to CasADi
2+
=================
3+
4+
.. autoclass:: pybamm.CasadiConverter
5+
:members:
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
Operations on expression trees
2+
==============================
3+
4+
Classes and functions that operate on the expression tree
5+
6+
.. toctree::
7+
8+
simplify
9+
evaluate
10+
convert_to_casadi

docs/source/solvers/casadi_solver.rst

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
Casadi Solver
2+
=============
3+
4+
.. autoclass:: pybamm.CasadiSolver
5+
:members:

docs/source/solvers/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ Solvers
77
base_solvers
88
scipy_solver
99
scikits_solvers
10+
casadi_solver
1011
solution

pybamm/__init__.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -150,18 +150,21 @@ def version(formatted=False):
150150
UndefinedOperationError,
151151
GeometryError,
152152
)
153-
from .expression_tree.simplify import (
153+
154+
# Operations
155+
from .expression_tree.operations.simplify import (
154156
Simplification,
155157
simplify_if_constant,
156158
simplify_addition_subtraction,
157159
simplify_multiplication_division,
158160
)
159-
from .expression_tree.evaluate import (
161+
from .expression_tree.operations.evaluate import (
160162
find_symbols,
161163
id_to_python_variable,
162164
to_python,
163165
EvaluatorPython,
164166
)
167+
from .expression_tree.operations.convert_to_casadi import CasadiConverter
165168

166169
#
167170
# Model classes

pybamm/expression_tree/operations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
#
2+
# Convert a PyBaMM expression tree to a CasADi expression tree
3+
#
4+
import pybamm
5+
import casadi
6+
7+
8+
class CasadiConverter(object):
9+
def __init__(self, casadi_symbols=None):
10+
self._casadi_symbols = casadi_symbols or {}
11+
12+
def convert(self, symbol):
13+
"""
14+
This function recurses down the tree, applying any simplifications defined in
15+
classes derived from pybamm.Symbol. E.g. any expression multiplied by a
16+
pybamm.Scalar(0) will be simplified to a pybamm.Scalar(0).
17+
If a symbol has already been simplified, the stored value is returned.
18+
19+
Parameters
20+
----------
21+
symbol : :class:`pybamm.Symbol`
22+
The symbol to convert
23+
24+
Returns
25+
-------
26+
CasADi symbol
27+
The convert symbol
28+
"""
29+
30+
try:
31+
return self._casadi_symbols[symbol.id]
32+
except KeyError:
33+
casadi_symbol = self._convert(symbol)
34+
self._casadi_symbols[symbol.id] = casadi_symbol
35+
36+
return casadi_symbol
37+
38+
def _convert(self, symbol):
39+
""" See :meth:`Simplification.convert()`. """
40+
if isinstance(symbol, pybamm.Scalar):
41+
return casadi.SX(symbol.evaluate())
42+
43+
if isinstance(symbol, pybamm.BinaryOperator):
44+
left, right = symbol.children
45+
# process children
46+
converted_left = self.convert(left)
47+
converted_right = self.convert(right)
48+
# _binary_evaluate defined in derived classes for specific rules
49+
return symbol._binary_evaluate(converted_left, converted_right)
50+
51+
elif isinstance(symbol, pybamm.UnaryOperator):
52+
converted_child = self.convert(symbol.child)
53+
if isinstance(symbol, pybamm.AbsoluteValue):
54+
return casadi.fabs(converted_child)
55+
return symbol._unary_evaluate(converted_child)
56+
57+
elif isinstance(symbol, pybamm.Function):
58+
converted_children = [None] * len(symbol.children)
59+
for i, child in enumerate(symbol.children):
60+
converted_children[i] = self.convert(child)
61+
return symbol._function_evaluate(converted_children)
62+
63+
elif isinstance(symbol, pybamm.Concatenation):
64+
converted_children = [self.convert(child) for child in symbol.children]
65+
return symbol._concatenation_evaluate(converted_children)
66+
67+
else:
68+
raise TypeError(
69+
"""
70+
Cannot convert symbol of type '{}' to CasADi. Symbols must all be
71+
'linear algebra' at this stage.
72+
""".format(
73+
type(symbol)
74+
)
75+
)

pybamm/expression_tree/symbol.py

+7
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,13 @@ def simplify(self, simplified_symbols=None):
584584
""" Simplify the expression tree. See :class:`pybamm.Simplification`. """
585585
return pybamm.Simplification(simplified_symbols).simplify(self)
586586

587+
def to_casadi(self, casadi_symbols=None):
588+
"""
589+
Convert the expression tree to a CasADi expression tree.
590+
See :class:`pybamm.CasadiConverter`.
591+
"""
592+
return pybamm.CasadiConverter(casadi_symbols).convert(self)
593+
587594
def new_copy(self):
588595
"""
589596
Make a new copy of a symbol, to avoid Tree corruption errors while bypassing

pybamm/solvers/casadi_solver.py

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
#
2+
# Wrap CasADi
3+
#
4+
import pybamm
5+
import casadi
6+
7+
8+
class CasadiSolver(pybamm.DaeSolver):
9+
pass

setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ def load_version():
4545
"anytree>=2.4.3",
4646
"autograd>=1.2",
4747
"scikit-fem>=0.2.0",
48+
"casadi>=3.5.0",
4849
# Note: Matplotlib is loaded for debug plots, but to ensure pybamm runs
4950
# on systems without an attached display, it should never be imported
5051
# outside of plot() methods.

tests/unit/test_expression_tree/test_operations/__init__.py

Whitespace-only changes.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
#
2+
# Test for the Simplify class
3+
#
4+
import casadi
5+
import math
6+
import numpy as np
7+
import pybamm
8+
import unittest
9+
from tests import get_discretisation_for_testing
10+
11+
12+
class TestCasadiConverter(unittest.TestCase):
13+
def test_convert_scalar_symbols(self):
14+
a = pybamm.Scalar(0)
15+
b = pybamm.Scalar(1)
16+
c = pybamm.Scalar(-1)
17+
d = pybamm.Scalar(2)
18+
19+
self.assertEqual(a.to_casadi(), casadi.SX(0))
20+
self.assertEqual(d.to_casadi(), casadi.SX(2))
21+
22+
# negate
23+
self.assertEqual((-b).to_casadi(), casadi.SX(-1))
24+
# absolute value
25+
self.assertEqual(abs(c).to_casadi(), casadi.SX(1))
26+
27+
# function
28+
def sin(x):
29+
return np.sin(x)
30+
31+
f = pybamm.Function(sin, b)
32+
self.assertEqual((f).to_casadi(), casadi.SX(np.sin(1)))
33+
34+
def myfunction(x, y):
35+
return x + y
36+
37+
f = pybamm.Function(myfunction, b, d)
38+
self.assertEqual((f).to_casadi(), casadi.SX(3))
39+
40+
# addition
41+
self.assertEqual((a + b).to_casadi(), casadi.SX(1))
42+
# subtraction
43+
self.assertEqual((c - d).to_casadi(), casadi.SX(-3))
44+
# multiplication
45+
self.assertEqual((c * d).to_casadi(), casadi.SX(-2))
46+
# power
47+
self.assertEqual((c ** d).to_casadi(), casadi.SX(1))
48+
# division
49+
self.assertEqual((b / d).to_casadi(), casadi.SX(1 / 2))
50+
51+
def test_convert_array_symbols(self):
52+
pass
53+
54+
55+
if __name__ == "__main__":
56+
print("Add -v for more debug output")
57+
import sys
58+
59+
if "-v" in sys.argv:
60+
debug = True
61+
pybamm.settings.debug_mode = True
62+
unittest.main()

0 commit comments

Comments
 (0)