Skip to content

Commit a31e209

Browse files
authored
Merge pull request #1472 from pybamm-team/issue-1465-evaluate-function-parameters
#1465 allow evaluate parameter to return arrays
2 parents b4735e1 + cb04f46 commit a31e209

File tree

6 files changed

+43
-33
lines changed

6 files changed

+43
-33
lines changed

CHANGELOG.md

+2
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
## Features
44

5+
- `plot` and `plot2D` now take and return a matplotlib Axis to allow for easier customization ([#1472](https://github.com/pybamm-team/PyBaMM/pull/1472))
6+
- `ParameterValues.evaluate` can now return arrays to allow function parameters to be easily evaluated ([#1472](https://github.com/pybamm-team/PyBaMM/pull/1472))
57
- Added Batch Study class ([#1455](https://github.com/pybamm-team/PyBaMM/pull/1455))
68
- Added `ConcatenationVariable`, which is automatically created when variables are concatenated ([#1453](https://github.com/pybamm-team/PyBaMM/pull/1453))
79
- Added "fast with events" mode for the CasADi solver, which solves a model and finds events more efficiently than "safe" mode. As of PR #1450 this feature is still being tested and "safe" mode remains the default ([#1450](https://github.com/pybamm-team/PyBaMM/pull/1450))

pybamm/parameters/parameter_values.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -776,14 +776,14 @@ def evaluate(self, symbol):
776776
777777
Returns
778778
-------
779-
number of array
779+
number or array
780780
The evaluated symbol
781781
"""
782782
processed_symbol = self.process_symbol(symbol)
783-
if processed_symbol.evaluates_to_constant_number():
783+
if processed_symbol.is_constant():
784784
return processed_symbol.evaluate()
785785
else:
786-
raise ValueError("symbol must evaluate to a constant scalar")
786+
raise ValueError("symbol must evaluate to a constant scalar or array")
787787

788788
def _ipython_key_completions_(self):
789789
return list(self._dict_items.keys())

pybamm/plotting/plot.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .quick_plot import ax_min, ax_max
66

77

8-
def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
8+
def plot(x, y, ax=None, testing=False, **kwargs):
99
"""
1010
Generate a simple 1D plot. Calls `matplotlib.pyplot.plot` with keyword
1111
arguments 'kwargs'. For a list of 'kwargs' see the
@@ -17,10 +17,8 @@ def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
1717
The array to plot on the x axis
1818
y : :class:`pybamm.Array`
1919
The array to plot on the y axis
20-
xlabel : str, optional
21-
The label for the x axis
22-
ylabel : str, optional
23-
The label for the y axis
20+
ax : matplotlib Axis, optional
21+
The axis on which to put the plot. If None, a new figure and axis is created.
2422
testing : bool, optional
2523
Whether to actually make the plot (turned off for unit tests)
2624
kwargs
@@ -34,13 +32,15 @@ def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
3432
if not isinstance(y, pybamm.Array):
3533
raise TypeError("y must be 'pybamm.Array'")
3634

37-
plt.plot(x.entries, y.entries, **kwargs)
38-
plt.ylim([ax_min(y.entries), ax_max(y.entries)])
39-
plt.xlabel(xlabel)
40-
plt.ylabel(ylabel)
41-
plt.title(title)
35+
if ax is not None:
36+
testing = True
37+
else:
38+
_, ax = plt.subplots()
39+
40+
ax.plot(x.entries, y.entries, **kwargs)
41+
ax.set_ylim([ax_min(y.entries), ax_max(y.entries)])
4242

4343
if not testing: # pragma: no cover
4444
plt.show()
4545

46-
return
46+
return ax

pybamm/plotting/plot2D.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from .quick_plot import ax_min, ax_max
66

77

8-
def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
8+
def plot2D(x, y, z, ax=None, testing=False, **kwargs):
99
"""
1010
Generate a simple 2D plot. Calls `matplotlib.pyplot.contourf` with keyword
1111
arguments 'kwargs'. For a list of 'kwargs' see the
@@ -19,12 +19,8 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg
1919
The array to plot on the y axis. Can be of shape (M, N) or (M, 1)
2020
z : :class:`pybamm.Array`
2121
The array to plot on the z axis. Is of shape (M, N)
22-
xlabel : str, optional
23-
The label for the x axis
24-
ylabel : str, optional
25-
The label for the y axis
26-
title : str, optional
27-
The title for the plot
22+
ax : matplotlib Axis, optional
23+
The axis on which to put the plot. If None, a new figure and axis is created.
2824
testing : bool, optional
2925
Whether to actually make the plot (turned off for unit tests)
3026
@@ -38,6 +34,11 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg
3834
if not isinstance(z, pybamm.Array):
3935
raise TypeError("z must be 'pybamm.Array'")
4036

37+
if ax is not None:
38+
testing = True
39+
else:
40+
_, ax = plt.subplots()
41+
4142
# Get correct entries of x and y depending on shape
4243
if x.shape == y.shape == z.shape:
4344
x_entries = x.entries
@@ -46,20 +47,17 @@ def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwarg
4647
x_entries = x.entries[:, 0]
4748
y_entries = y.entries[:, 0]
4849

49-
plt.contourf(
50+
plot = ax.contourf(
5051
x_entries,
5152
y_entries,
5253
z.entries,
5354
vmin=ax_min(z.entries),
5455
vmax=ax_max(z.entries),
5556
**kwargs
5657
)
57-
plt.xlabel(xlabel)
58-
plt.ylabel(ylabel)
59-
plt.title(title)
60-
plt.colorbar()
58+
plt.colorbar(plot, ax=ax)
6159

6260
if not testing: # pragma: no cover
6361
plt.show()
6462

65-
return
63+
return ax

tests/unit/test_parameters/test_parameter_values.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -756,13 +756,14 @@ def test_evaluate(self):
756756
c = pybamm.Parameter("c")
757757
self.assertEqual(parameter_values.evaluate(a), 1)
758758
self.assertEqual(parameter_values.evaluate(a + (b * c)), 7)
759+
d = pybamm.Parameter("a") + pybamm.Parameter("b") * pybamm.Array([4, 5])
760+
np.testing.assert_array_equal(
761+
parameter_values.evaluate(d), np.array([9, 11])[:, np.newaxis]
762+
)
759763

760764
y = pybamm.StateVector(slice(0, 1))
761765
with self.assertRaises(ValueError):
762766
parameter_values.evaluate(y)
763-
array = pybamm.Array(np.array([1, 2, 3]))
764-
with self.assertRaises(ValueError):
765-
parameter_values.evaluate(array)
766767

767768
def test_export_csv(self):
768769
def some_function(self):

tests/unit/test_plotting/test_plot.py

+12-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,18 @@
11
import pybamm
22
import unittest
33
import numpy as np
4+
import matplotlib.pyplot as plt
45

56

67
class TestPlot(unittest.TestCase):
78
def test_plot(self):
89
x = pybamm.Array(np.array([0, 3, 10]))
910
y = pybamm.Array(np.array([6, 16, 78]))
10-
pybamm.plot(x, y, xlabel="x", ylabel="y", title="title", testing=True)
11+
pybamm.plot(x, y, testing=True)
12+
13+
_, ax = plt.subplots()
14+
ax_out = pybamm.plot(x, y, ax=ax, testing=True)
15+
self.assertEqual(ax_out, ax)
1116

1217
def test_plot_fail(self):
1318
x = pybamm.Array(np.array([0]))
@@ -22,10 +27,14 @@ def test_plot2D(self):
2227
X, Y = pybamm.meshgrid(x, y)
2328

2429
# plot with array directly
25-
pybamm.plot2D(x, y, Y, xlabel="x", ylabel="y", title="title", testing=True)
30+
pybamm.plot2D(x, y, Y, testing=True)
2631

2732
# plot with meshgrid
28-
pybamm.plot2D(X, Y, Y, xlabel="x", ylabel="y", title="title", testing=True)
33+
pybamm.plot2D(X, Y, Y, testing=True)
34+
35+
_, ax = plt.subplots()
36+
ax_out = pybamm.plot2D(X, Y, Y, ax=ax, testing=True)
37+
self.assertEqual(ax_out, ax)
2938

3039
def test_plot2D_fail(self):
3140
x = pybamm.Array(np.array([0]))

0 commit comments

Comments
 (0)