Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue 784 simplify solver #800

Merged
merged 34 commits into from
Feb 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
991e08b
#784 start reformatting solvers
valentinsulzer Jan 15, 2020
6105fb1
Merge branch 'issue-760-better-processed-variable' into issue-784-sim…
valentinsulzer Jan 15, 2020
5d3446a
#784 more work on base solver
valentinsulzer Jan 16, 2020
e0efb3f
#774 removing duplicates from base solver
valentinsulzer Jan 17, 2020
f280e99
Merge branch 'master' into issue-784-simplify-solver
valentinsulzer Jan 17, 2020
5fc7244
#784 set up works for odes and daes
valentinsulzer Jan 18, 2020
0d9fb1c
#784 remove compute_solution
valentinsulzer Jan 19, 2020
a48ec76
#784 fix some casadi tests
valentinsulzer Jan 20, 2020
932c17a
#784 reformat scikits solvers
valentinsulzer Jan 20, 2020
8bd265e
#784 fixing syntax
valentinsulzer Jan 20, 2020
8368bb6
#784 make integral a private method
valentinsulzer Jan 21, 2020
1335b58
#784 fixing tests
valentinsulzer Jan 21, 2020
718f9ff
#784 fix notebooks (some solver fixes still required)
valentinsulzer Jan 21, 2020
2cc373c
#784 add examples to tests without dependencies
valentinsulzer Jan 21, 2020
1f73792
#784 reformat external variables discretisation
valentinsulzer Jan 22, 2020
ecc21ae
#784 reformat inputs and external for solvers
valentinsulzer Jan 22, 2020
838c3f6
#784 fixed casadi solvers, except stepping
valentinsulzer Jan 22, 2020
f90ffe5
Merge branch 'master' into issue-784-simplify-solver
valentinsulzer Jan 23, 2020
83b8b11
#784 fix external variable concatenations
valentinsulzer Jan 24, 2020
903f63f
#784 reformatting step
valentinsulzer Jan 24, 2020
2954ed6
#784 merge 793
valentinsulzer Jan 24, 2020
66963f2
#784 fix stepping
valentinsulzer Jan 26, 2020
4a552dc
#784 tests and falke8
valentinsulzer Jan 26, 2020
5832658
#784 fixing more tests
valentinsulzer Jan 27, 2020
5bdc5a6
Merge branch 'master' into issue-784-simplify-solver
valentinsulzer Jan 28, 2020
070d46f
Merge branch 'master' into issue-784-simplify-solver
valentinsulzer Jan 28, 2020
9e76142
Merge branch 'master' into issue-784-simplify-solver
valentinsulzer Jan 28, 2020
0ef6eac
#784 fix scikit tests
valentinsulzer Jan 28, 2020
cc1a87a
#784 coverage
valentinsulzer Jan 30, 2020
927ce97
#784 more coverage and changelog
valentinsulzer Jan 31, 2020
998ef58
#784 make sure inputs used for evaluating events
valentinsulzer Feb 1, 2020
1090c99
Merge branch 'issue-784-simplify-solver' of github.com:pybamm-team/Py…
valentinsulzer Feb 1, 2020
0b5ab30
#709 fix some more examples
valentinsulzer Feb 1, 2020
b1e40e5
#784 fix casadi steps with input
valentinsulzer Feb 4, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ matrix:
env:
- PYBAMM_UNIT=true
- PYBAMM_SCIKITS_ODES=true
# Unit testing on Python3.7 on Ubuntu without scikit odes
- PYBAMM_KLU=true
# Unit and example testing on Python3.7 on Ubuntu without optional dependencies
- python: "3.7"
addons:
apt:
Expand All @@ -95,6 +96,7 @@ matrix:
- libsuitesparse-dev
env:
- PYBAMM_UNIT=true
- PYBAMM_EXAMPLES=true
if: type != cron
# Cover, docs and style checking, latest Python version only
- python: "3.7"
Expand Down
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@

## Optimizations

- Simplified solver interface ([#800](https://github.com/pybamm-team/PyBaMM/pull/800))
- Added caching for shape evaluation, used during discretisation ([#780](https://github.com/pybamm-team/PyBaMM/pull/780))
- Added an option to skip model checks during discretisation, which could be slow for large models ([#739](https://github.com/pybamm-team/PyBaMM/pull/739))
- Use CasADi's automatic differentation algorithms by default when solving a model ([#714](https://github.com/pybamm-team/PyBaMM/pull/714))
Expand All @@ -44,6 +45,8 @@

## Bug fixes

- Fixed examples to run with basic pip installation ([#800](https://github.com/pybamm-team/PyBaMM/pull/800))
- Added events for CasADi solver when stepping ([#800](https://github.com/pybamm-team/PyBaMM/pull/800))
- Improved implementation of broadcasts ([#776](https://github.com/pybamm-team/PyBaMM/pull/776))
- Fixed a bug which meant that the Ohmic heating in the current collectors was incorrect if using the Finite Element Method ([#767](https://github.com/pybamm-team/PyBaMM/pull/767))
- Improved automatic broadcasting ([#747](https://github.com/pybamm-team/PyBaMM/pull/747))
Expand Down
3 changes: 3 additions & 0 deletions docs/source/expression_tree/variable.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,6 @@ Variable

.. autoclass:: pybamm.Variable
:members:

.. autoclass:: pybamm.ExternalVariable
:members:
6 changes: 0 additions & 6 deletions docs/source/solvers/base_solvers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,3 @@ Base Solvers

.. autoclass:: pybamm.BaseSolver
:members:

.. autoclass:: pybamm.OdeSolver
:members:

.. autoclass:: pybamm.DaeSolver
:members:
10 changes: 5 additions & 5 deletions docs/tutorials/add-solver.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,29 +27,29 @@ The role of solvers is to solve a model at a given set of time points, returning
Base solver classes vs specific solver classes
----------------------------------------------

There is one general base solver class, :class:`pybamm.BaseSolver`, and two specialised base classes, :class:`pybamm.OdeSolver` and :class:`pybamm.DaeSolver`. The general base class simply sets up some useful solver properties such as tolerances. The specialised base classes implement a method :meth:`self.solve()` that solves a model at a given set of time points.
There is one general base solver class, :class:`pybamm.BaseSolver`, which sets up some useful solver properties such as tolerances and implement a method :meth:`self.solve()` that solves a model at a given set of time points.

The ``solve`` method unpacks the model, simplifies it by removing extraneous operations, (optionally) creates or calls the mass matrix and/or jacobian, and passes the appropriate attributes to another method, called ``integrate``, which does the time-stepping. The role of specific solver classes is simply to implement this ``integrate`` method for an arbitrary set of derivative function, initial conditions etc.

The base DAE solver class also computes a consistent set of initial conditions for the algebraic equations, using ``model.concatenated_initial_conditions`` as an initial guess.
The base solver class also computes a consistent set of initial conditions for the algebraic equations, using ``model.concatenated_initial_conditions`` as an initial guess.

Implementing a new solver
-------------------------

To add a new solver (e.g. My Fast DAE Solver), first create a new file (``my_fast_dae_solver.py``) in ``pybamm/solvers/``,
with a single class that inherits from either :class:`pybamm.OdeSolver` or :class:`pybamm.DaeSolver`, depending on whether the new solver can solve DAE systems. For example:
with a single class that inherits from :class:`pybamm.BaseSolver`. For example:

.. code-block:: python

def MyFastDaeSolver(pybamm.DaeSolver):
def MyFastDaeSolver(pybamm.BaseSolver):

Also add the class to ``pybamm/__init__.py``:

.. code-block:: python

from .solvers.my_fast_dae_solver import MyFastDaeSolver

You can then start implementing the solver by adding the ``integrate`` function to the class (the interfaces are slightly different for an ODE Solver and a DAE Solver, see :meth:`pybamm.OdeSolver.integrate` vs :meth:`pybamm.DaeSolver.integrate`)
You can then start implementing the solver by adding the ``integrate`` function to the class.

For an example of an existing solver implementation, see the Scikits DAE solver
`API docs <https://pybamm.readthedocs.io/en/latest/source/solvers/scikits_solvers.html>`_
Expand Down
9 changes: 4 additions & 5 deletions examples/notebooks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,9 @@ See [here](https://pybamm.readthedocs.io/en/latest/tutorials/add-spatial-method.

### Solvers

The following solvers are implemented
- Scipy ODE solver
- [Scikits ODE solver](./solvers/scikits-ode-solver.ipynb)
- [Scikits DAE solver](./solvers/scikits-dae-solver.ipynb)
- CasADi DAE solver
The following notebooks show examples for generic ODE and DAE solvers. Several solvers are implemented in PyBaMM and we encourage users to try different ones to find the most appropriate one for their models.

- [ODE solver](./solvers/ode-solver.ipynb)
- [DAE solver](./solvers/dae-solver.ipynb)

See [here](https://pybamm.readthedocs.io/en/latest/tutorials/add-solver.html) for instructions on adding new solvers.
131 changes: 86 additions & 45 deletions examples/notebooks/solution-data-and-processed-variables.ipynb

Large diffs are not rendered by default.

310 changes: 310 additions & 0 deletions examples/notebooks/solvers/dae-solver.ipynb

Large diffs are not rendered by default.

291 changes: 291 additions & 0 deletions examples/notebooks/solvers/ode-solver.ipynb

Large diffs are not rendered by default.

357 changes: 0 additions & 357 deletions examples/notebooks/solvers/scikits-dae-solver.ipynb

This file was deleted.

277 changes: 0 additions & 277 deletions examples/notebooks/solvers/scikits-ode-solver.ipynb

This file was deleted.

8 changes: 1 addition & 7 deletions examples/scripts/SPMe_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,7 @@
step_solver = model.default_solver
step_solution = None
while time < end_time:
current_step_sol = step_solver.step(model, dt=dt, npts=10)
if not step_solution:
# create solution object on first step
step_solution = current_step_sol
else:
# append solution from the current step to step_solution
step_solution.append(current_step_sol)
step_solution = step_solver.step(step_solution, model, dt=dt, npts=10)
time += dt

# plot
Expand Down
25 changes: 22 additions & 3 deletions examples/scripts/compare-dae-solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,29 @@
t_eval = np.linspace(0, 0.25, 100)

casadi_sol = pybamm.CasadiSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
klu_sol = pybamm.IDAKLUSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
scikits_sol = pybamm.ScikitsDaeSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions = [casadi_sol]

if pybamm.have_idaklu():
klu_sol = pybamm.IDAKLUSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions.append(klu_sol)
else:
pybamm.logger.error(
"""
Cannot solve model with IDA KLU solver as solver is not installed.
Please consult installation instructions on GitHub.
"""
)
if pybamm.have_scikits_odes():
scikits_sol = pybamm.ScikitsDaeSolver(atol=1e-8, rtol=1e-8).solve(model, t_eval)
solutions.append(scikits_sol)
else:
pybamm.logger.error(
"""
Cannot solve model with Scikits DAE solver as solver is not installed.
Please consult installation instructions on GitHub.
"""
)

# plot
solutions = [casadi_sol, klu_sol, casadi_sol]
plot = pybamm.QuickPlot(solutions)
plot.dynamic_plot()
4 changes: 1 addition & 3 deletions pybamm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def version(formatted=False):
ones_like,
)
from .expression_tree.scalar import Scalar
from .expression_tree.variable import Variable
from .expression_tree.variable import Variable, ExternalVariable
from .expression_tree.independent_variable import (
IndependentVariable,
Time,
Expand Down Expand Up @@ -237,8 +237,6 @@ def version(formatted=False):
#
from .solvers.solution import Solution
from .solvers.base_solver import BaseSolver
from .solvers.ode_solver import OdeSolver
from .solvers.dae_solver import DaeSolver
from .solvers.algebraic_solver import AlgebraicSolver
from .solvers.casadi_solver import CasadiSolver
from .solvers.scikits_dae_solver import ScikitsDaeSolver
Expand Down
147 changes: 93 additions & 54 deletions pybamm/discretisations/discretisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def __init__(self, mesh=None, spatial_methods=None):
self.bcs = {}
self.y_slices = {}
self._discretised_symbols = {}
self.external_variables = []
self.external_variables = {}

@property
def mesh(self):
Expand Down Expand Up @@ -127,11 +127,7 @@ def process_model(self, model, inplace=True, check_model=True):

# Prepare discretisation
# set variables (we require the full variable not just id)
variables = (
list(model.rhs.keys())
+ list(model.algebraic.keys())
+ model.external_variables
)
variables = list(model.rhs.keys()) + list(model.algebraic.keys())

# Set the y split for variables
pybamm.logger.info("Set variable slices for {}".format(model.name))
Expand All @@ -140,6 +136,7 @@ def process_model(self, model, inplace=True, check_model=True):
# now add extrapolated external variables to the boundary conditions
# if required by the spatial method
self._preprocess_external_variables(model)
self.set_external_variables(model)

# set boundary conditions (only need key ids for boundary_conditions)
pybamm.logger.info("Discretise boundary conditions for {}".format(model.name))
Expand All @@ -158,28 +155,6 @@ def process_model(self, model, inplace=True, check_model=True):

model_disc.bcs = self.bcs

self.external_variables = model.external_variables
# find where external variables begin in state vector
# we always append external variables to the end, so
# it is sufficient to only know the starting location
start_vals = []
for var in self.external_variables:
if isinstance(var, pybamm.Concatenation):
for child in var.children:
start_vals += [self.y_slices[child.id][0].start]
elif isinstance(var, pybamm.Variable):
start_vals += [self.y_slices[var.id][0].start]

# attach properties of the state vector so that it
# can be divided correctly during the solving stage
model_disc.external_variables = model.external_variables
model_disc.y_length = self.y_length
model_disc.y_slices = self.y_slices
if start_vals:
model_disc.external_start = min(start_vals)
else:
model_disc.external_start = self.y_length

pybamm.logger.info("Discretise initial conditions for {}".format(model.name))
ics, concat_ics = self.process_initial_conditions(model)
model_disc.initial_conditions = ics
Expand Down Expand Up @@ -235,13 +210,8 @@ def set_variable_slices(self, variables):
end = 0
# Iterate through unpacked variables, adding appropriate slices to y_slices
for variable in variables:
# If domain is empty then variable has size 1
if variable.domain == []:
end += 1
y_slices[variable.id].append(slice(start, end))
start = end
# Otherwise, add up the size of all the domains in variable.domain
elif isinstance(variable, pybamm.Concatenation):
# Add up the size of all the domains in variable.domain
if isinstance(variable, pybamm.Concatenation):
children = variable.children
meshes = OrderedDict()
for child in children:
Expand All @@ -257,18 +227,27 @@ def set_variable_slices(self, variables):
y_slices[child.id].append(slice(start, end))
start = end
else:
for dom in variable.domain:
for submesh in self.spatial_methods[dom].mesh[dom]:
end += submesh.npts_for_broadcast
end += self._get_variable_size(variable)
y_slices[variable.id].append(slice(start, end))
start = end

self.y_slices = y_slices
self.y_length = end

# reset discretised_symbols
self._discretised_symbols = {}

def _get_variable_size(self, variable):
"Helper function to determine what size a variable should be"
# If domain is empty then variable has size 1
if variable.domain == []:
return 1
else:
size = 0
for dom in variable.domain:
for submesh in self.spatial_methods[dom].mesh[dom]:
size += submesh.npts_for_broadcast
return size

def _preprocess_external_variables(self, model):
"""
A method to preprocess external variables so that they are
Expand All @@ -291,6 +270,43 @@ def _preprocess_external_variables(self, model):

model.boundary_conditions.update(new_bcs)

def set_external_variables(self, model):
"""
Add external variables to the list of variables to account for, being careful
about concatenations
"""
for var in model.external_variables:
# Find the name in the model variables
# Look up dictionary key based on value
try:
idx = [x.id for x in model.variables.values()].index(var.id)
except ValueError:
raise ValueError(
"""
Variable {} must be in the model.variables dictionary to be set
as an external variable
""".format(
var
)
)
name = list(model.variables.keys())[idx]
if isinstance(var, pybamm.Variable):
# No need to keep track of the parent
self.external_variables[(name, None)] = var
elif isinstance(var, pybamm.Concatenation):
start = 0
end = 0
for child in var.children:
dom = child.domain[0]
if len(self.spatial_methods[dom].mesh[dom]) > 1:
raise NotImplementedError(
"Cannot create 2D external variable with concatenations"
)
end += self._get_variable_size(child)
# Keep a record of the parent
self.external_variables[(name, (var, start, end))] = child
start = end

def set_internal_boundary_conditions(self, model):
"""
A method to set the internal boundary conditions for the submodel.
Expand Down Expand Up @@ -683,9 +699,7 @@ def process_dict(self, var_eqn_dict):
# Broadcast if the equation evaluates to a number(e.g. Scalar)
if eqn.evaluates_to_number() and not isinstance(eqn_key, str):
eqn = pybamm.FullBroadcast(
eqn,
eqn_key.domain,
eqn_key.auxiliary_domains,
eqn, eqn_key.domain, eqn_key.auxiliary_domains
)

# note we are sending in the key.id here so we don't have to
Expand Down Expand Up @@ -839,11 +853,40 @@ def _process_symbol(self, symbol):
return symbol._function_new_copy(disc_children)

elif isinstance(symbol, pybamm.Variable):
return pybamm.StateVector(
*self.y_slices[symbol.id],
domain=symbol.domain,
auxiliary_domains=symbol.auxiliary_domains
)
# Check if variable is a standard variable or an external variable
if any(symbol.id == var.id for var in self.external_variables.values()):
# Look up dictionary key based on value
idx = [x.id for x in self.external_variables.values()].index(symbol.id)
name, parent_and_slice = list(self.external_variables.keys())[idx]
if parent_and_slice is None:
# Variable didn't come from a concatenation so we can just create a
# normal external variable using the symbol's name
return pybamm.ExternalVariable(
symbol.name,
size=self._get_variable_size(symbol),
domain=symbol.domain,
auxiliary_domains=symbol.auxiliary_domains,
)
else:
# We have to use a special name since the concatenation doesn't have
# a very informative name. Needs improving
parent, start, end = parent_and_slice
ext = pybamm.ExternalVariable(
name,
size=self._get_variable_size(parent),
domain=parent.domain,
auxiliary_domains=parent.auxiliary_domains,
)
out = ext[slice(start, end)]
out.domain = symbol.domain
return out

else:
return pybamm.StateVector(
*self.y_slices[symbol.id],
domain=symbol.domain,
auxiliary_domains=symbol.auxiliary_domains
)

elif isinstance(symbol, pybamm.SpatialVariable):
return spatial_method.spatial_variable(symbol)
Expand Down Expand Up @@ -917,8 +960,8 @@ def _concatenate_in_order(self, var_eqn_dict, check_complete=False, sparse=False
if check_complete:
# Check keys from the given var_eqn_dict against self.y_slices
ids = {v.id for v in unpacked_variables}
external_id = {v.id for v in self.external_variables}
for var in self.external_variables:
external_id = {v.id for v in self.external_variables.values()}
for var in self.external_variables.values():
child_ids = {child.id for child in var.children}
external_id = external_id.union(child_ids)
y_slices_with_external_removed = set(self.y_slices.keys()).difference(
Expand Down Expand Up @@ -1022,11 +1065,7 @@ def check_variables(self, model):
and np.all(var.right.entries == 1)
)

if (
different_shapes
and not_concatenation
and not_mult_by_one_vec
):
if different_shapes and not_concatenation and not_mult_by_one_vec:
raise pybamm.ModelError(
"""
variable and its eqn must have the same shape after discretisation
Expand Down
Loading