Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#1066 add numpy function sqrt, sin, cos and exp to convert_to_casadi #1067

Merged
merged 7 commits into from
Jun 23, 2020
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -10,6 +10,8 @@

## Bug fixes

- Allowed for pybamm functions exp, sin, cos, sqrt to be used in expression trees that
are converted to casadi format ([#1067](https://github.com/pybamm-team/PyBaMM/pull/1067)
- Fix a bug where variables that depend on y and z were transposed in `QuickPlot` ([#1055](https://github.com/pybamm-team/PyBaMM/pull/1055))

## Breaking changes
2 changes: 1 addition & 1 deletion examples/scripts/compare_lithium_ion.py
Original file line number Diff line number Diff line change
@@ -3,7 +3,7 @@
#
import pybamm

pybamm.set_logging_level("INFO")
# pybamm.set_logging_level("INFO")

# load models
models = [
2 changes: 1 addition & 1 deletion pybamm/expression_tree/array.py
Original file line number Diff line number Diff line change
@@ -74,7 +74,7 @@ def entries_string(self, value):
if issparse(entries):
self._entries_string = str(entries.__dict__)
else:
self._entries_string = entries.tostring()
self._entries_string = entries.tobytes()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
2 changes: 1 addition & 1 deletion pybamm/expression_tree/interpolant.py
Original file line number Diff line number Diff line change
@@ -86,7 +86,7 @@ def entries_string(self, value):
self._entries_string = value
else:
entries = self.data
self._entries_string = entries.tostring()
self._entries_string = entries.tobytes()

def set_id(self):
""" See :meth:`pybamm.Symbol.set_id()`. """
22 changes: 22 additions & 0 deletions pybamm/expression_tree/operations/convert_to_casadi.py
Original file line number Diff line number Diff line change
@@ -101,6 +101,28 @@ def _convert(self, symbol, t, y, y_dot, inputs):
return casadi.mmax(*converted_children)
elif symbol.function == np.abs:
return casadi.fabs(*converted_children)
elif symbol.function == np.sqrt:
return casadi.sqrt(*converted_children)
elif symbol.function == np.sin:
return casadi.sin(*converted_children)
elif symbol.function == np.arcsinh:
return casadi.arcsinh(*converted_children)
elif symbol.function == np.arccosh:
return casadi.arccosh(*converted_children)
elif symbol.function == np.tanh:
return casadi.tanh(*converted_children)
elif symbol.function == np.cosh:
return casadi.cosh(*converted_children)
elif symbol.function == np.sinh:
return casadi.sinh(*converted_children)
elif symbol.function == np.cos:
return casadi.cos(*converted_children)
elif symbol.function == np.exp:
return casadi.exp(*converted_children)
elif symbol.function == np.log:
return casadi.log(*converted_children)
elif symbol.function == np.sign:
return casadi.sign(*converted_children)
elif isinstance(symbol.function, (PchipInterpolator, CubicSpline)):
return casadi.interpolant("LUT", "bspline", [symbol.x], symbol.y)(
*converted_children
4 changes: 2 additions & 2 deletions pybamm/expression_tree/symbol.py
Original file line number Diff line number Diff line change
@@ -740,7 +740,7 @@ def shape(self):
# Default behaviour is to try to evaluate the object directly
# Try with some large y, to avoid having to unpack (slow)
try:
y = np.linspace(0.1, 0.9, int(1e4))
y = np.nan * np.ones((1000, 1))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")
# If that fails, fall back to calculating how big y should really be
except ValueError:
@@ -753,7 +753,7 @@ def shape(self):
len(x._evaluation_array) for x in state_vectors_in_node
)
# Pick a y that won't cause RuntimeWarnings
y = np.linspace(0.1, 0.9, min_y_size)
y = np.nan * np.ones((min_y_size, 1))
evaluated_self = self.evaluate(0, y, y, inputs="shape test")

# Return shape of evaluated object
Original file line number Diff line number Diff line change
@@ -30,11 +30,11 @@ def test_convert_scalar_symbols(self):
self.assertEqual(abs(c).to_casadi(), casadi.MX(1))

# function
def sin(x):
return np.sin(x)
def square_plus_one(x):
return x ** 2 + 1

f = pybamm.Function(sin, b)
self.assertEqual(f.to_casadi(), casadi.MX(np.sin(1)))
f = pybamm.Function(square_plus_one, b)
self.assertEqual(f.to_casadi(), 2)

def myfunction(x, y):
return x + y
@@ -95,6 +95,12 @@ def test_special_functions(self):
self.assert_casadi_equal(
pybamm.Function(np.abs, c).to_casadi(), casadi.MX(3), evalf=True
)
for np_fun in [np.sqrt, np.tanh, np.cosh, np.sinh,
np.exp, np.log, np.sign, np.sin, np.cos,
np.arccosh, np.arcsinh]:
self.assert_casadi_equal(
pybamm.Function(np_fun, c).to_casadi(), casadi.MX(np_fun(3)), evalf=True
)

def test_interpolation(self):
x = np.linspace(0, 1)[:, np.newaxis]