Skip to content

Commit cfd982d

Browse files
authored
Merge pull request #1008 from pybamm-team/issue-992-plot-arrays
Issue 992 plot arrays
2 parents 49aaff0 + 449e6f5 commit cfd982d

File tree

21 files changed

+282
-30
lines changed

21 files changed

+282
-30
lines changed

.flake8

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ exclude=
1212
share,
1313
pyvenv.cfg,
1414
third-party,
15-
sundials-5.0.0,
15+
KLU_module_deps,
1616
ignore=
1717
# False positive for white space before ':' on list slice
1818
# black should format these correctly

CHANGELOG.md

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
## Features
44

55
- Added `BackwardIndefiniteIntegral` symbol ([#1014](https://github.com/pybamm-team/PyBaMM/pull/1014))
6+
- Added `plot` and `plot2D` to enable easy plotting of `pybamm.Array` objects ([#1008](https://github.com/pybamm-team/PyBaMM/pull/1008))
67
- Added SEI film resistance as an option ([#994](https://github.com/pybamm-team/PyBaMM/pull/994))
78
- Added tab, edge, and surface cooling ([#965](https://github.com/pybamm-team/PyBaMM/pull/965))
89
- Added functionality to solver to automatically discretise a 0D model ([#947](https://github.com/pybamm-team/PyBaMM/pull/947))

docs/index.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ API documentation
6666
source/solvers/index
6767
source/experiments/index
6868
source/simulation
69-
source/quick_plot
69+
source/plotting/index
7070
source/util
7171
source/citations
7272
source/parameters_cli

docs/source/expression_tree/array.rst

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Array
2+
=====
3+
4+
.. autoclass:: pybamm.Array
5+
:members:
6+
7+
.. autofunction:: pybamm.linspace
8+
9+
.. autofunction:: pybamm.meshgrid

docs/source/expression_tree/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Expression Tree
88
variable
99
independent_variable
1010
scalar
11+
array
1112
matrix
1213
vector
1314
state_vector

docs/source/plotting/dynamic_plot.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Dynamic Plot
2+
============
3+
4+
.. autofunction:: pybamm.dynamic_plot

docs/source/plotting/index.rst

+9
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Plotting
2+
========
3+
4+
.. toctree::
5+
6+
quick_plot
7+
dynamic_plot
8+
plot
9+
plot_2D

docs/source/plotting/plot.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Plot
2+
====
3+
4+
.. autofunction:: pybamm.plot

docs/source/plotting/plot_2D.rst

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
Plot 2D
2+
=======
3+
4+
.. autofunction:: pybamm.plot2D
File renamed without changes.

pybamm/__init__.py

+9-3
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def version(formatted=False):
6868
from .expression_tree.symbol import *
6969
from .expression_tree.binary_operators import *
7070
from .expression_tree.concatenations import *
71-
from .expression_tree.array import Array
71+
from .expression_tree.array import Array, linspace, meshgrid
7272
from .expression_tree.matrix import Matrix
7373
from .expression_tree.unary_operators import *
7474
from .expression_tree.functions import *
@@ -221,10 +221,16 @@ def version(formatted=False):
221221
from . import experiments
222222

223223
#
224-
# other
224+
# Plotting
225225
#
226-
from .quick_plot import QuickPlot, dynamic_plot, ax_min, ax_max
226+
from .plotting.quick_plot import QuickPlot
227+
from .plotting.plot import plot
228+
from .plotting.plot2D import plot2D
229+
from .plotting.dynamic_plot import dynamic_plot
227230

231+
#
232+
# Simulation
233+
#
228234
from .simulation import Simulation, load_sim, is_notebook
229235

230236
#

pybamm/expression_tree/array.py

+21
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,24 @@ def new_copy(self):
101101
def _base_evaluate(self, t=None, y=None, y_dot=None, inputs=None):
102102
""" See :meth:`pybamm.Symbol._base_evaluate()`. """
103103
return self._entries
104+
105+
106+
def linspace(start, stop, num=50, **kwargs):
107+
"""
108+
Creates a linearly spaced array by calling `numpy.linspace` with keyword
109+
arguments 'kwargs'. For a list of 'kwargs' see the
110+
`numpy linspace documentation <https://tinyurl.com/yc4ne47x>`_
111+
"""
112+
return pybamm.Array(np.linspace(start, stop, num, **kwargs))
113+
114+
115+
def meshgrid(x, y, **kwargs):
116+
"""
117+
Return coordinate matrices as from coordinate vectors by calling
118+
`numpy.meshgrid` with keyword arguments 'kwargs'. For a list of 'kwargs'
119+
see the `numpy meshgrid documentation <https://tinyurl.com/y8azewrj>`_
120+
"""
121+
[X, Y] = np.meshgrid(x.entries, y.entries)
122+
X = pybamm.Array(X)
123+
Y = pybamm.Array(Y)
124+
return X, Y

pybamm/plotting/__init__.py

Whitespace-only changes.

pybamm/plotting/dynamic_plot.py

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#
2+
# Method for creating a dynamic plot
3+
#
4+
import pybamm
5+
6+
7+
def dynamic_plot(*args, **kwargs):
8+
"""
9+
Creates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword
10+
arguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.
11+
The key-word argument 'testing' is passed to the 'dynamic_plot' method, not the
12+
`QuickPlot` class.
13+
14+
Returns
15+
-------
16+
plot : :class:`pybamm.QuickPlot`
17+
The 'QuickPlot' object that was created
18+
"""
19+
kwargs_for_class = {k: v for k, v in kwargs.items() if k != "testing"}
20+
plot = pybamm.QuickPlot(*args, **kwargs_for_class)
21+
plot.dynamic_plot(kwargs.get("testing", False))
22+
return plot

pybamm/plotting/plot.py

+44
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
#
2+
# Method for creating a 1D plot of pybamm arrays
3+
#
4+
import pybamm
5+
from .quick_plot import ax_min, ax_max
6+
7+
8+
def plot(x, y, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
9+
"""
10+
Generate a simple 1D plot. Calls `matplotlib.pyplot.plot` with keyword
11+
arguments 'kwargs'. For a list of 'kwargs' see the
12+
`matplotlib plot documentation <https://tinyurl.com/ycblw9bx>`_
13+
14+
Parameters
15+
----------
16+
x : :class:`pybamm.Array`
17+
The array to plot on the x axis
18+
y : :class:`pybamm.Array`
19+
The array to plot on the y axis
20+
xlabel : str, optional
21+
The label for the x axis
22+
ylabel : str, optional
23+
The label for the y axis
24+
testing : bool, optional
25+
Whether to actually make the plot (turned off for unit tests)
26+
27+
"""
28+
import matplotlib.pyplot as plt
29+
30+
if not isinstance(x, pybamm.Array):
31+
raise TypeError("x must be 'pybamm.Array'")
32+
if not isinstance(y, pybamm.Array):
33+
raise TypeError("y must be 'pybamm.Array'")
34+
35+
plt.plot(x.entries, y.entries, **kwargs)
36+
plt.ylim([ax_min(y.entries), ax_max(y.entries)])
37+
plt.xlabel(xlabel)
38+
plt.ylabel(ylabel)
39+
plt.title(title)
40+
41+
if not testing: # pragma: no cover
42+
plt.show()
43+
44+
return

pybamm/plotting/plot2D.py

+66
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
#
2+
# Method for creating a filled contour plot of pybamm arrays
3+
#
4+
import pybamm
5+
from .quick_plot import ax_min, ax_max
6+
7+
8+
def plot2D(x, y, z, xlabel=None, ylabel=None, title=None, testing=False, **kwargs):
9+
"""
10+
Generate a simple 2D plot. Calls `matplotlib.pyplot.contourf` with keyword
11+
arguments 'kwargs'. For a list of 'kwargs' see the
12+
`matplotlib contourf documentation <https://tinyurl.com/y8mnadtn>`_
13+
14+
Parameters
15+
----------
16+
x : :class:`pybamm.Array`
17+
The array to plot on the x axis. Can be of shape (M, N) or (N, 1)
18+
y : :class:`pybamm.Array`
19+
The array to plot on the y axis. Can be of shape (M, N) or (M, 1)
20+
z : :class:`pybamm.Array`
21+
The array to plot on the z axis. Is of shape (M, N)
22+
xlabel : str, optional
23+
The label for the x axis
24+
ylabel : str, optional
25+
The label for the y axis
26+
title : str, optional
27+
The title for the plot
28+
testing : bool, optional
29+
Whether to actually make the plot (turned off for unit tests)
30+
31+
"""
32+
import matplotlib.pyplot as plt
33+
34+
if not isinstance(x, pybamm.Array):
35+
raise TypeError("x must be 'pybamm.Array'")
36+
if not isinstance(y, pybamm.Array):
37+
raise TypeError("y must be 'pybamm.Array'")
38+
if not isinstance(z, pybamm.Array):
39+
raise TypeError("z must be 'pybamm.Array'")
40+
41+
# Get correct entries of x and y depending on shape
42+
if x.shape == y.shape == z.shape:
43+
x_entries = x.entries
44+
y_entries = y.entries
45+
else:
46+
x_entries = x.entries[:, 0]
47+
y_entries = y.entries[:, 0]
48+
49+
plt.contourf(
50+
x_entries,
51+
y_entries,
52+
z.entries,
53+
vmin=ax_min(z.entries),
54+
vmax=ax_max(z.entries),
55+
cmap="coolwarm",
56+
**kwargs
57+
)
58+
plt.xlabel(xlabel)
59+
plt.ylabel(ylabel)
60+
plt.title(title)
61+
plt.colorbar()
62+
63+
if not testing: # pragma: no cover
64+
plt.show()
65+
66+
return

pybamm/quick_plot.py pybamm/plotting/quick_plot.py

-18
Original file line numberDiff line numberDiff line change
@@ -44,24 +44,6 @@ def split_long_string(title, max_words=4):
4444
return first_line + "\n" + second_line
4545

4646

47-
def dynamic_plot(*args, **kwargs):
48-
"""
49-
Creates a :class:`pybamm.QuickPlot` object (with arguments 'args' and keyword
50-
arguments 'kwargs') and then calls :meth:`pybamm.QuickPlot.dynamic_plot`.
51-
The key-word argument 'testing' is passed to the 'dynamic_plot' method, not the
52-
`QuickPlot' class.
53-
54-
Returns
55-
-------
56-
plot : :class:`pybamm.QuickPlot`
57-
The 'QuickPlot' object that was created
58-
"""
59-
kwargs_for_class = {k: v for k, v in kwargs.items() if k != "testing"}
60-
plot = pybamm.QuickPlot(*args, **kwargs_for_class)
61-
plot.dynamic_plot(kwargs.get("testing", False))
62-
return plot
63-
64-
6547
class QuickPlot(object):
6648
"""
6749
Generates a quick plot of a subset of key outputs of the model so that the model

tests/integration/test_models/test_full_battery_models/test_lithium_ion/test_compare_outputs.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ def test_compare_outputs_thermal(self):
9999
solutions = []
100100
t_eval = np.linspace(0, 3600, 100)
101101
for model in models:
102-
solution = pybamm.CasadiSolver(dt_max=0.01).solve(model, t_eval)
102+
solution = pybamm.CasadiSolver().solve(model, t_eval)
103103
solutions.append(solution)
104104

105105
# compare outputs
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#
2+
# Tests for the Array class
3+
#
4+
import pybamm
5+
import numpy as np
6+
7+
import unittest
8+
9+
10+
class TestArray(unittest.TestCase):
11+
def test_name(self):
12+
arr = pybamm.Array(np.array([1, 2, 3]))
13+
self.assertEqual(arr.name, "Array of shape (3, 1)")
14+
15+
def test_linspace(self):
16+
x = np.linspace(0, 1, 100)[:, np.newaxis]
17+
y = pybamm.linspace(0, 1, 100)
18+
np.testing.assert_array_equal(x, y.entries)
19+
20+
def test_meshgrid(self):
21+
a = np.linspace(0, 5)
22+
b = np.linspace(0, 3)
23+
A, B = np.meshgrid(a, b)
24+
c = pybamm.linspace(0, 5)
25+
d = pybamm.linspace(0, 3)
26+
C, D = pybamm.meshgrid(c, d)
27+
np.testing.assert_array_equal(A, C.entries)
28+
np.testing.assert_array_equal(B, D.entries)
29+
30+
31+
if __name__ == "__main__":
32+
print("Add -v for more debug output")
33+
import sys
34+
35+
if "-v" in sys.argv:
36+
debug = True
37+
pybamm.settings.debug_mode = True
38+
unittest.main()

tests/unit/test_expression_tree/test_matrix.py

-6
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,6 @@ def test_matrix_operations(self):
3232
)
3333

3434

35-
class TestArray(unittest.TestCase):
36-
def test_name(self):
37-
arr = pybamm.Array(np.array([1, 2, 3]))
38-
self.assertEqual(arr.name, "Array of shape (3, 1)")
39-
40-
4135
if __name__ == "__main__":
4236
print("Add -v for more debug output")
4337
import sys

tests/unit/test_plot.py

+47
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import pybamm
2+
import unittest
3+
import numpy as np
4+
5+
6+
class TestPlot(unittest.TestCase):
7+
def test_plot(self):
8+
x = pybamm.Array(np.array([0, 3, 10]))
9+
y = pybamm.Array(np.array([6, 16, 78]))
10+
pybamm.plot(x, y, xlabel="x", ylabel="y", title="title", testing=True)
11+
12+
def test_plot_fail(self):
13+
x = pybamm.Array(np.array([0]))
14+
with self.assertRaisesRegex(TypeError, "x must be 'pybamm.Array'"):
15+
pybamm.plot("bad", x)
16+
with self.assertRaisesRegex(TypeError, "y must be 'pybamm.Array'"):
17+
pybamm.plot(x, "bad")
18+
19+
def test_plot2D(self):
20+
x = pybamm.Array(np.array([0, 3, 10]))
21+
y = pybamm.Array(np.array([6, 16, 78]))
22+
X, Y = pybamm.meshgrid(x, y)
23+
24+
# plot with array directly
25+
pybamm.plot2D(x, y, Y, xlabel="x", ylabel="y", title="title", testing=True)
26+
27+
# plot with meshgrid
28+
pybamm.plot2D(X, Y, Y, xlabel="x", ylabel="y", title="title", testing=True)
29+
30+
def test_plot2D_fail(self):
31+
x = pybamm.Array(np.array([0]))
32+
with self.assertRaisesRegex(TypeError, "x must be 'pybamm.Array'"):
33+
pybamm.plot2D("bad", x, x)
34+
with self.assertRaisesRegex(TypeError, "y must be 'pybamm.Array'"):
35+
pybamm.plot2D(x, "bad", x)
36+
with self.assertRaisesRegex(TypeError, "z must be 'pybamm.Array'"):
37+
pybamm.plot2D(x, x, "bad")
38+
39+
40+
if __name__ == "__main__":
41+
print("Add -v for more debug output")
42+
import sys
43+
44+
if "-v" in sys.argv:
45+
debug = True
46+
pybamm.settings.debug_mode = True
47+
unittest.main()

0 commit comments

Comments
 (0)