diff --git a/.github/workflows/test_on_push.yml b/.github/workflows/test_on_push.yml
index 0c0562fd01..962c4b6ad6 100644
--- a/.github/workflows/test_on_push.yml
+++ b/.github/workflows/test_on_push.yml
@@ -83,10 +83,10 @@ jobs:
       run: tox -e examples
         
     - name: Install and run coverage
-      if: success() && (matrix.os == 'ubuntu-latest' && matrix.python-version == 3.7)
+      if: success() && (matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9)
       run: tox -e coverage
 
     - name: Upload coverage report
-      if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.7
+      if: matrix.os == 'ubuntu-latest' && matrix.python-version == 3.9
       uses: codecov/codecov-action@v1
 
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 996e4b4232..e53477673c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -2,6 +2,7 @@
 
 ## Features
 
+-   `Solution` objects can now be created by stepping *different* models ([#1408](https://github.com/pybamm-team/PyBaMM/pull/1408))
 -   Added support for Python 3.9 and dropped support for Python 3.6. Python 3.6 may still work but is now untested ([#1370](https://github.com/pybamm-team/PyBaMM/pull/1370))
 -   Added the electrolyte overpotential and Ohmic losses for full conductivity, including surface form ([#1350](https://github.com/pybamm-team/PyBaMM/pull/1350))
 -   Added functionality to `Citations` to print formatted citations ([#1340](https://github.com/pybamm-team/PyBaMM/pull/1340))
@@ -22,6 +23,7 @@
 
 ## Optimizations
 
+-   Improved the way an `Experiment` is simulated to reduce solve time (at the cost of slightly higher set-up time) ([#1408](https://github.com/pybamm-team/PyBaMM/pull/1408))
 -   Add script and workflow to automatically update parameter_sets.py docstrings ([#1371](https://github.com/pybamm-team/PyBaMM/pull/1371))
 -   Add URLs checker in workflows ([#1347](https://github.com/pybamm-team/PyBaMM/pull/1347))
 -   The `Solution` class now only creates the concatenated `y` when the user asks for it. This is an optimization step as the concatenation can be slow, especially with larger experiments ([#1331](https://github.com/pybamm-team/PyBaMM/pull/1331))
diff --git a/examples/scripts/DFN.py b/examples/scripts/DFN.py
index 2e01f1df93..01d6e9a0e8 100644
--- a/examples/scripts/DFN.py
+++ b/examples/scripts/DFN.py
@@ -15,8 +15,8 @@
 
 # load parameter values and process model and geometry
 param = model.default_parameter_values
-param.process_model(model)
 param.process_geometry(geometry)
+param.process_model(model)
 
 # set mesh
 var = pybamm.standard_spatial_vars
diff --git a/examples/scripts/experimental_protocols/cccv.py b/examples/scripts/experimental_protocols/cccv.py
index 2006695296..2db64c055f 100644
--- a/examples/scripts/experimental_protocols/cccv.py
+++ b/examples/scripts/experimental_protocols/cccv.py
@@ -15,7 +15,7 @@
             "Rest for 1 hour",
         ),
     ]
-    * 3
+    * 3,
 )
 model = pybamm.lithium_ion.DFN()
 sim = pybamm.Simulation(model, experiment=experiment, solver=pybamm.CasadiSolver())
diff --git a/examples/scripts/experimental_protocols/gitt.py b/examples/scripts/experimental_protocols/gitt.py
index 5a13a2a1b2..2b1634cf2c 100644
--- a/examples/scripts/experimental_protocols/gitt.py
+++ b/examples/scripts/experimental_protocols/gitt.py
@@ -5,7 +5,7 @@
 
 pybamm.set_logging_level("INFO")
 experiment = pybamm.Experiment(
-    [("Discharge at C/20 for 1 hour", "Rest for 1 hour")] * 20
+    [("Discharge at C/20 for 1 hour", "Rest for 1 hour")] * 20,
 )
 model = pybamm.lithium_ion.DFN()
 sim = pybamm.Simulation(model, experiment=experiment, solver=pybamm.CasadiSolver())
diff --git a/pybamm/experiments/experiment.py b/pybamm/experiments/experiment.py
index b525f4d524..97071045c4 100644
--- a/pybamm/experiments/experiment.py
+++ b/pybamm/experiments/experiment.py
@@ -40,10 +40,18 @@ class Experiment:
     period : string, optional
         Period (1/frequency) at which to record outputs. Default is 1 minute. Can be
         overwritten by individual operating conditions.
-
+    use_simulation_setup_type : str
+        Whether to use the "new" (default) or "old" simulation set-up type. "new" is
+        faster at simulating individual steps but has higher set-up overhead
     """
 
-    def __init__(self, operating_conditions, parameters=None, period="1 minute"):
+    def __init__(
+        self,
+        operating_conditions,
+        parameters=None,
+        period="1 minute",
+        use_simulation_setup_type="new",
+    ):
         self.period = self.convert_time_to_seconds(period.split())
         operating_conditions_cycles = []
         for cycle in operating_conditions:
@@ -84,6 +92,8 @@ def __init__(self, operating_conditions, parameters=None, period="1 minute"):
         else:
             raise TypeError("experimental parameters should be a dictionary")
 
+        self.use_simulation_setup_type = use_simulation_setup_type
+
     def __str__(self):
         return str(self.operating_conditions_strings)
 
diff --git a/pybamm/models/base_model.py b/pybamm/models/base_model.py
index 842f024d2a..24b6757ae3 100644
--- a/pybamm/models/base_model.py
+++ b/pybamm/models/base_model.py
@@ -93,7 +93,7 @@ def __init__(self, name="Unnamed model"):
         self._algebraic = {}
         self._initial_conditions = {}
         self._boundary_conditions = {}
-        self._variables = {}
+        self._variables = pybamm.FuzzyDict({})
         self._events = []
         self._concatenated_rhs = None
         self._concatenated_algebraic = None
@@ -382,13 +382,25 @@ def set_initial_conditions_from(self, solution, inplace=True):
         else:
             model = self.new_copy()
 
+        if isinstance(solution, pybamm.Solution):
+            solution = solution.last_state
+        else:
+            solution = pybamm.FuzzyDict(solution)
         for var, equation in model.initial_conditions.items():
             if isinstance(var, pybamm.Variable):
-                final_state = solution[var.name]
+                try:
+                    final_state = solution[var.name]
+                except KeyError as e:
+                    raise pybamm.ModelError(
+                        "To update a model from a solution, each variable in "
+                        "model.initial_conditions must appear in the solution with "
+                        "the same key as the variable name. In the solution provided, "
+                        f"{e.args[0]}"
+                    )
                 if isinstance(solution, pybamm.Solution):
                     final_state = final_state.data
                 if final_state.ndim == 1:
-                    final_state_eval = np.array([final_state[-1]])
+                    final_state_eval = final_state[-1:]
                 elif final_state.ndim == 2:
                     final_state_eval = final_state[:, -1]
                 elif final_state.ndim == 3:
@@ -399,7 +411,15 @@ def set_initial_conditions_from(self, solution, inplace=True):
             elif isinstance(var, pybamm.Concatenation):
                 children = []
                 for child in var.orphans:
-                    final_state = solution[child.name]
+                    try:
+                        final_state = solution[child.name]
+                    except KeyError as e:
+                        raise pybamm.ModelError(
+                            "To update a model from a solution, each variable in "
+                            "model.initial_conditions must appear in the solution with "
+                            "the same key as the variable name. In the solution "
+                            f"provided, {e.args[0]}"
+                        )
                     if isinstance(solution, pybamm.Solution):
                         final_state = final_state.data
                     if final_state.ndim == 2:
diff --git a/pybamm/models/standard_variables.py b/pybamm/models/standard_variables.py
index 44c86e1002..a9312fbd02 100644
--- a/pybamm/models/standard_variables.py
+++ b/pybamm/models/standard_variables.py
@@ -265,18 +265,18 @@
 
 # SEI variables
 L_inner_av = pybamm.Variable(
-    "X-averaged inner SEI thickness", domain="current collector"
+    "X-averaged inner negative electrode SEI thickness", domain="current collector"
 )
 L_inner = pybamm.Variable(
-    "Inner SEI thickness",
+    "Inner negative electrode SEI thickness",
     domain=["negative electrode"],
     auxiliary_domains={"secondary": "current collector"},
 )
 L_outer_av = pybamm.Variable(
-    "X-averaged outer SEI thickness", domain="current collector"
+    "X-averaged outer negative electrode SEI thickness", domain="current collector"
 )
 L_outer = pybamm.Variable(
-    "Outer SEI thickness",
+    "Outer negative electrode SEI thickness",
     domain=["negative electrode"],
     auxiliary_domains={"secondary": "current collector"},
 )
diff --git a/pybamm/models/submodels/interface/kinetics/base_kinetics.py b/pybamm/models/submodels/interface/kinetics/base_kinetics.py
index 52f710e68e..21869579cb 100644
--- a/pybamm/models/submodels/interface/kinetics/base_kinetics.py
+++ b/pybamm/models/submodels/interface/kinetics/base_kinetics.py
@@ -37,7 +37,7 @@ def get_fundamental_variables(self):
             j = pybamm.Variable(
                 "Total "
                 + self.domain.lower()
-                + " electrode interfacial current density",
+                + " electrode interfacial current density variable",
                 domain=self.domain.lower() + " electrode",
                 auxiliary_domains={"secondary": "current collector"},
             )
diff --git a/pybamm/plotting/quick_plot.py b/pybamm/plotting/quick_plot.py
index 14ba189863..c765fe8ff3 100644
--- a/pybamm/plotting/quick_plot.py
+++ b/pybamm/plotting/quick_plot.py
@@ -7,7 +7,8 @@
 
 
 class LoopList(list):
-    """A list which loops over itself when accessing an index so that it never runs out
+    """
+    A list which loops over itself when accessing an index so that it never runs out
     """
 
     def __getitem__(self, i):
@@ -114,7 +115,7 @@ def __init__(
                 # attribute
                 solutions[idx] = sol.solution
 
-        models = [solution.model for solution in solutions]
+        models = [solution.all_models[0] for solution in solutions]
 
         # Set labels
         if labels is None:
diff --git a/pybamm/simulation.py b/pybamm/simulation.py
index ab6b15d28b..0ecd62cece 100644
--- a/pybamm/simulation.py
+++ b/pybamm/simulation.py
@@ -41,6 +41,19 @@ def constant_current_constant_voltage_constant_power(variables):
     )
 
 
+def constant_voltage(variables, V_applied):
+    V = variables["Terminal voltage [V]"]
+    n_cells = pybamm.Parameter("Number of cells connected in series to make a battery")
+    return V - V_applied / n_cells
+
+
+def constant_power(variables, P_applied):
+    I = variables["Current [A]"]
+    V = variables["Terminal voltage [V]"]
+    n_cells = pybamm.Parameter("Number of cells connected in series to make a battery")
+    return V * I - P_applied / n_cells
+
+
 class Simulation:
     """A Simulation class for easy building and running of PyBaMM simulations.
 
@@ -145,61 +158,6 @@ def set_up_experiment(self, model, experiment):
         """
         self.operating_mode = "with experiment"
 
-        # Create a new model where the current density is now a variable
-        # To do so, we replace all instances of the current density in the
-        # model with a current density variable, which is obtained from the
-        # FunctionControl submodel
-        # create the FunctionControl submodel and extract variables
-        external_circuit_variables = pybamm.external_circuit.FunctionControl(
-            model.param, None
-        ).get_fundamental_variables()
-
-        # Perform the replacement
-        symbol_replacement_map = {
-            model.variables[name]: variable
-            for name, variable in external_circuit_variables.items()
-        }
-        replacer = pybamm.SymbolReplacer(symbol_replacement_map)
-        new_model = replacer.process_model(model, inplace=False)
-
-        # Update the algebraic equation and initial conditions for FunctionControl
-        # This creates an algebraic equation for the current to allow current, voltage,
-        # or power control, together with the appropriate guess for the
-        # initial condition.
-        # External circuit submodels are always equations on the current
-        # The external circuit function should fix either the current, or the voltage,
-        # or a combination (e.g. I*V for power control)
-        i_cell = new_model.variables["Total current density"]
-        new_model.initial_conditions[i_cell] = new_model.param.current_with_time
-        new_model.algebraic[i_cell] = constant_current_constant_voltage_constant_power(
-            new_model.variables
-        )
-
-        # add current and voltage events to the model
-        # current events both negative and positive to catch specification
-        new_model.events.extend(
-            [
-                pybamm.Event(
-                    "Current cut-off (positive) [A] [experiment]",
-                    new_model.variables["Current [A]"]
-                    - abs(pybamm.InputParameter("Current cut-off [A]")),
-                ),
-                pybamm.Event(
-                    "Current cut-off (negative) [A] [experiment]",
-                    new_model.variables["Current [A]"]
-                    + abs(pybamm.InputParameter("Current cut-off [A]")),
-                ),
-                pybamm.Event(
-                    "Voltage cut-off [V] [experiment]",
-                    new_model.variables["Terminal voltage [V]"]
-                    - pybamm.InputParameter("Voltage cut-off [V]")
-                    / model.param.n_cells,
-                ),
-            ]
-        )
-        self._unprocessed_model = new_model
-        self.model = new_model
-
         if not isinstance(experiment, pybamm.Experiment):
             raise TypeError("experiment must be a pybamm `Experiment` instance")
 
@@ -284,6 +242,210 @@ def set_up_experiment(self, model, experiment):
                 dt = 7 * 24 * 3600
             self._experiment_times.append(dt)
 
+        # Set up model for experiment
+        if experiment.use_simulation_setup_type == "old":
+            self.set_up_model_for_experiment_old(model)
+        elif experiment.use_simulation_setup_type == "new":
+            self.set_up_model_for_experiment_new(model)
+
+    def set_up_model_for_experiment_old(self, model):
+        """
+        Set up self.model to be able to run the experiment (old version).
+        In this version, a single model is created which can then be called with
+        different inputs for current-control, voltage-control, or power-control.
+
+        This reduces set-up time since only one model needs to be processed, but
+        increases simulation time since the model formulation is inefficient
+        """
+        # Create a new model where the current density is now a variable
+        # To do so, we replace all instances of the current density in the
+        # model with a current density variable, which is obtained from the
+        # FunctionControl submodel
+        # create the FunctionControl submodel and extract variables
+        external_circuit_variables = pybamm.external_circuit.FunctionControl(
+            model.param, None
+        ).get_fundamental_variables()
+
+        # Perform the replacement
+        symbol_replacement_map = {
+            model.variables[name]: variable
+            for name, variable in external_circuit_variables.items()
+        }
+        replacer = pybamm.SymbolReplacer(symbol_replacement_map)
+        new_model = replacer.process_model(model, inplace=False)
+
+        # Update the algebraic equation and initial conditions for FunctionControl
+        # This creates an algebraic equation for the current to allow current, voltage,
+        # or power control, together with the appropriate guess for the
+        # initial condition.
+        # External circuit submodels are always equations on the current
+        # The external circuit function should fix either the current, or the voltage,
+        # or a combination (e.g. I*V for power control)
+        i_cell = new_model.variables["Total current density"]
+        new_model.initial_conditions[i_cell] = new_model.param.current_with_time
+        new_model.algebraic[i_cell] = constant_current_constant_voltage_constant_power(
+            new_model.variables
+        )
+
+        # Remove upper and lower voltage cut-offs that are *not* part of the experiment
+        new_model.events = [
+            event
+            for event in model.events
+            if event.name not in ["Minimum voltage", "Maximum voltage"]
+        ]
+        # add current and voltage events to the model
+        # current events both negative and positive to catch specification
+        new_model.events.extend(
+            [
+                pybamm.Event(
+                    "Current cut-off (positive) [A] [experiment]",
+                    new_model.variables["Current [A]"]
+                    - abs(pybamm.InputParameter("Current cut-off [A]")),
+                ),
+                pybamm.Event(
+                    "Current cut-off (negative) [A] [experiment]",
+                    new_model.variables["Current [A]"]
+                    + abs(pybamm.InputParameter("Current cut-off [A]")),
+                ),
+                pybamm.Event(
+                    "Voltage cut-off [V] [experiment]",
+                    new_model.variables["Terminal voltage [V]"]
+                    - pybamm.InputParameter("Voltage cut-off [V]")
+                    / model.param.n_cells,
+                ),
+            ]
+        )
+
+        self.model = new_model
+
+        self.op_conds_to_model_and_param = {
+            op_cond[:2]: (new_model, self.parameter_values)
+            for op_cond in set(self.experiment.operating_conditions)
+        }
+        self.op_conds_to_built_models = None
+
+    def set_up_model_for_experiment_new(self, model):
+        """
+        Set up self.model to be able to run the experiment (new version).
+        In this version, a new model is created for each step.
+
+        This increases set-up time since several models to be processed, but
+        reduces simulation time since the model formulation is efficient.
+        """
+        self.op_conds_to_model_and_param = {}
+        self.op_conds_to_built_models = None
+        for op_cond, op_inputs in zip(
+            self.experiment.operating_conditions, self._experiment_inputs
+        ):
+            # Create model for this operating condition if it has not already been seen
+            # before
+            if op_cond[:2] not in self.op_conds_to_model_and_param:
+                if op_inputs["Current switch"] == 1:
+                    # Current control
+                    # Make a new copy of the model (we will update events later))
+                    new_model = model.new_copy()
+                else:
+                    # Voltage or power control
+                    # Create a new model where the current density is now a variable
+                    # To do so, we replace all instances of the current density in the
+                    # model with a current density variable, which is obtained from the
+                    # FunctionControl submodel
+                    # create the FunctionControl submodel and extract variables
+                    external_circuit_variables = (
+                        pybamm.external_circuit.FunctionControl(
+                            model.param, None
+                        ).get_fundamental_variables()
+                    )
+
+                    # Perform the replacement
+                    symbol_replacement_map = {
+                        model.variables[name]: variable
+                        for name, variable in external_circuit_variables.items()
+                    }
+                    replacer = pybamm.SymbolReplacer(symbol_replacement_map)
+                    new_model = replacer.process_model(model, inplace=False)
+
+                    # Update the algebraic equation and initial conditions for
+                    # FunctionControl
+                    # This creates an algebraic equation for the current to allow
+                    # current, voltage, or power control, together with the appropriate
+                    # guess for the initial condition.
+                    # External circuit submodels are always equations on the current
+                    # The external circuit function should fix either the current, or
+                    # the voltage, or a combination (e.g. I*V for power control)
+                    i_cell = new_model.variables["Total current density"]
+                    new_model.initial_conditions[
+                        i_cell
+                    ] = new_model.param.current_with_time
+
+                    # add current events to the model
+                    # current events both negative and positive to catch specification
+                    new_model.events.extend(
+                        [
+                            pybamm.Event(
+                                "Current cut-off (positive) [A] [experiment]",
+                                new_model.variables["Current [A]"]
+                                - abs(pybamm.InputParameter("Current cut-off [A]")),
+                            ),
+                            pybamm.Event(
+                                "Current cut-off (negative) [A] [experiment]",
+                                new_model.variables["Current [A]"]
+                                + abs(pybamm.InputParameter("Current cut-off [A]")),
+                            ),
+                        ]
+                    )
+                    if op_inputs["Voltage switch"] == 1:
+                        new_model.algebraic[i_cell] = constant_voltage(
+                            new_model.variables,
+                            pybamm.Parameter("Voltage function [V]"),
+                        )
+                    elif op_inputs["Power switch"] == 1:
+                        new_model.algebraic[i_cell] = constant_power(
+                            new_model.variables,
+                            pybamm.Parameter("Power function [W]"),
+                        )
+
+                # add voltage events to the model
+                if op_inputs["Power switch"] == 1 or op_inputs["Current switch"] == 1:
+                    new_model.events.append(
+                        pybamm.Event(
+                            "Voltage cut-off [V] [experiment]",
+                            new_model.variables["Terminal voltage [V]"]
+                            - op_inputs["Voltage cut-off [V]"] / model.param.n_cells,
+                        )
+                    )
+
+                # Remove upper and lower voltage cut-offs that are *not* part of the
+                # experiment
+                new_model.events = [
+                    event
+                    for event in new_model.events
+                    if event.name not in ["Minimum voltage", "Maximum voltage"]
+                ]
+
+                # Update parameter values
+                new_parameter_values = self.parameter_values.copy()
+                if op_inputs["Current switch"] == 1:
+                    new_parameter_values.update(
+                        {"Current function [A]": op_inputs["Current input [A]"]}
+                    )
+                elif op_inputs["Voltage switch"] == 1:
+                    new_parameter_values.update(
+                        {"Voltage function [V]": op_inputs["Voltage input [V]"]},
+                        check_already_exists=False,
+                    )
+                elif op_inputs["Power switch"] == 1:
+                    new_parameter_values.update(
+                        {"Power function [W]": op_inputs["Power input [W]"]},
+                        check_already_exists=False,
+                    )
+
+                self.op_conds_to_model_and_param[op_cond[:2]] = (
+                    new_model,
+                    new_parameter_values,
+                )
+        self.model = model
+
     def set_parameters(self):
         """
         A method to set the parameters in the model and the associated geometry.
@@ -330,7 +492,51 @@ def build(self, check_model=True):
                 self._model_with_set_params, inplace=False, check_model=check_model
             )
 
-    def solve(self, t_eval=None, solver=None, check_model=True, **kwargs):
+    def build_for_experiment(self, check_model=True):
+        """
+        Similar to :meth:`Simulation.build`, but for the case of simulating an
+        experiment, where there may be several models to build
+        """
+        if self.op_conds_to_built_models:
+            return None
+        else:
+            # Can process geometry with default parameter values (only electrical
+            # parameters change between parameter values)
+            self._parameter_values.process_geometry(self._geometry)
+            # Only needs to set up mesh and discretisation once
+            self._mesh = pybamm.Mesh(self._geometry, self._submesh_types, self._var_pts)
+            self._disc = pybamm.Discretisation(self._mesh, self._spatial_methods)
+            # Process all the different models
+            self.op_conds_to_built_models = {}
+            processed_models = {}
+            for op_cond, (
+                unbuilt_model,
+                parameter_values,
+            ) in self.op_conds_to_model_and_param.items():
+                if unbuilt_model in processed_models:
+                    built_model = processed_models[unbuilt_model]
+                else:
+                    # It's ok to modify the models in-place as they are not accessible
+                    # from outside the simulation
+                    model_with_set_params = parameter_values.process_model(
+                        unbuilt_model, inplace=True
+                    )
+                    built_model = self._disc.process_model(
+                        model_with_set_params, inplace=True, check_model=check_model
+                    )
+                    processed_models[unbuilt_model] = built_model
+
+                self.op_conds_to_built_models[op_cond] = built_model
+
+    def solve(
+        self,
+        t_eval=None,
+        solver=None,
+        check_model=True,
+        save_at_cycles=None,
+        starting_solution=None,
+        **kwargs,
+    ):
         """
         A method to solve the model. This method will automatically build
         and set the model parameters if not already done so.
@@ -353,22 +559,37 @@ def solve(self, t_eval=None, solver=None, check_model=True, **kwargs):
             If None and the parameter "Current function [A]" is read from data
             (i.e. drive cycle simulation) the model will be solved at the times
             provided in the data.
-        solver : :class:`pybamm.BaseSolver`
-            The solver to use to solve the model.
+        solver : :class:`pybamm.BaseSolver`, optional
+            The solver to use to solve the model. If None, Simulation.solver is used
         check_model : bool, optional
             If True, model checks are performed after discretisation (see
             :meth:`pybamm.Discretisation.process_model`). Default is True.
+        save_at_cycles : int or list of ints, optional
+            Which cycles to save the full sub-solutions for. If None, all cycles are
+            saved. If int, every multiple of save_at_cycles is saved. If list, every
+            cycle in the list is saved.
+        starting_solution : :class:`pybamm.Solution`
+            The solution to start stepping from. If None (default), then self._solution
+            is used. Must be None if not using an experiment.
         **kwargs
             Additional key-word arguments passed to `solver.solve`.
             See :meth:`pybamm.BaseSolver.solve`.
         """
         # Setup
-        self.build(check_model=check_model)
         if solver is None:
             solver = self.solver
 
         if self.operating_mode in ["without experiment", "drive cycle"]:
-
+            self.build(check_model=check_model)
+            if save_at_cycles is not None:
+                raise ValueError(
+                    "'save_at_cycles' option can only be used if simulating an "
+                    "Experiment "
+                )
+            if starting_solution is not None:
+                raise ValueError(
+                    "starting_solution can only be provided if simulating an Experiment"
+                )
             if self.operating_mode == "without experiment":
                 if t_eval is None:
                     raise pybamm.SolverError(
@@ -430,80 +651,75 @@ def solve(self, t_eval=None, solver=None, check_model=True, **kwargs):
             self._solution = solver.solve(self.built_model, t_eval, **kwargs)
 
         elif self.operating_mode == "with experiment":
+            self.build_for_experiment(check_model=check_model)
             if t_eval is not None:
                 pybamm.logger.warning(
                     "Ignoring t_eval as solution times are specified by the experiment"
                 )
             # Re-initialize solution, e.g. for solving multiple times with different
             # inputs without having to build the simulation again
-            self._solution = None
-            previous_num_subsolutions = 0
+            self._solution = starting_solution
             # Step through all experimental conditions
             inputs = kwargs.get("inputs", {})
             pybamm.logger.info("Start running experiment")
             timer = pybamm.Timer()
 
-            all_cycle_solutions = []
+            if starting_solution is None:
+                starting_solution_cycles = []
+            else:
+                starting_solution_cycles = starting_solution.cycles
+
+            cycle_offset = len(starting_solution_cycles)
+            all_cycle_solutions = starting_solution_cycles
+            current_solution = starting_solution
 
             idx = 0
             num_cycles = len(self.experiment.cycle_lengths)
             feasible = True  # simulation will stop if experiment is infeasible
-            for cycle_num, cycle_length in enumerate(self.experiment.cycle_lengths):
-                pybamm.logger.info(
-                    f"Cycle {cycle_num+1}/{num_cycles} ({timer.time()} elapsed) "
-                    + "-" * 20
+            for cycle_num, cycle_length in enumerate(
+                self.experiment.cycle_lengths, start=1
+            ):
+                pybamm.logger.notice(
+                    f"Cycle {cycle_num+cycle_offset}/{num_cycles+cycle_offset} "
+                    f"({timer.time()} elapsed) " + "-" * 20
                 )
                 steps = []
                 cycle_solution = None
-                for step_num in range(cycle_length):
+
+                for step_num in range(1, cycle_length + 1):
                     exp_inputs = self._experiment_inputs[idx]
                     dt = self._experiment_times[idx]
+                    op_conds_str = self.experiment.operating_conditions_strings[idx]
+                    op_conds_elec = self.experiment.operating_conditions[idx][:2]
+                    model = self.op_conds_to_built_models[op_conds_elec]
                     # Use 1-indexing for printing cycle number as it is more
                     # human-intuitive
-                    pybamm.logger.info(
-                        f"Cycle {cycle_num+1}/{num_cycles}, "
-                        f"step {step_num+1}/{cycle_length}: "
-                        f"{self.experiment.operating_conditions_strings[idx]}"
+                    pybamm.logger.notice(
+                        f"Cycle {cycle_num+cycle_offset}/{num_cycles+cycle_offset}, "
+                        f"step {step_num}/{cycle_length}: {op_conds_str}"
                     )
                     inputs.update(exp_inputs)
                     kwargs["inputs"] = inputs
                     # Make sure we take at least 2 timesteps
                     npts = max(int(round(dt / exp_inputs["period"])) + 1, 2)
-                    self.step(dt, solver=solver, npts=npts, **kwargs)
-
-                    # Extract the new parts of the solution
-                    # to construct the entire "step"
-                    sol = self.solution
-                    new_num_subsolutions = len(sol.sub_solutions)
-                    diff_num_subsolutions = (
-                        new_num_subsolutions - previous_num_subsolutions
-                    )
-                    previous_num_subsolutions = new_num_subsolutions
-
-                    step_solution = pybamm.Solution(
-                        sol.all_ts[-diff_num_subsolutions:],
-                        sol.all_ys[-diff_num_subsolutions:],
-                        sol.model,
-                        sol.all_inputs[-diff_num_subsolutions:],
-                        sol.t_event,
-                        sol.y_event,
-                        sol.termination,
+                    step_solution = solver.step(
+                        current_solution,
+                        model,
+                        dt,
+                        npts=npts,
+                        save=False,
+                        **kwargs,
                     )
-                    step_solution.solve_time = 0
-                    step_solution.integration_time = 0
                     steps.append(step_solution)
+                    current_solution = step_solution
 
-                    # Construct cycle solutions (a list of solutions corresponding to
-                    # cycles) from sub_solutions
-                    if step_num == 0:
-                        cycle_solution = step_solution
-                    else:
-                        cycle_solution = cycle_solution + step_solution
+                    cycle_solution = cycle_solution + step_solution
 
                     # Only allow events specified by experiment
                     if not (
-                        self._solution.termination == "final time"
-                        or "[experiment]" in self._solution.termination
+                        cycle_solution is None
+                        or cycle_solution.termination == "final time"
+                        or "[experiment]" in cycle_solution.termination
                     ):
                         feasible = False
                         break
@@ -515,23 +731,25 @@ def solve(self, t_eval=None, solver=None, check_model=True, **kwargs):
                 if feasible is False:
                     pybamm.logger.warning(
                         "\n\n\tExperiment is infeasible: '{}' ".format(
-                            self._solution.termination
+                            cycle_solution.termination
                         )
                         + "was triggered during '{}'. ".format(
                             self.experiment.operating_conditions_strings[idx]
                         )
                         + "The returned solution only contains the first "
-                        "{} cycles. ".format(cycle_num)
+                        "{} cycles. ".format(cycle_num - 1 + cycle_offset)
                         + "Try reducing the current, shortening the time interval, "
                         "or reducing the period.\n\n"
                     )
                     break
 
                 # At the final step of the inner loop we save the cycle
+                self._solution = self.solution + cycle_solution
                 cycle_solution.steps = steps
                 all_cycle_solutions.append(cycle_solution)
 
-            self.solution.cycles = all_cycle_solutions
+            if self.solution is not None:
+                self.solution.cycles = all_cycle_solutions
 
             pybamm.logger.notice(
                 "Finish experiment simulation, took {}".format(timer.time())
@@ -539,7 +757,9 @@ def solve(self, t_eval=None, solver=None, check_model=True, **kwargs):
 
         return self.solution
 
-    def step(self, dt, solver=None, npts=2, save=True, **kwargs):
+    def step(
+        self, dt, solver=None, npts=2, save=True, starting_solution=None, **kwargs
+    ):
         """
         A method to step the model forward one timestep. This method will
         automatically build and set the model parameters if not already done so.
@@ -555,17 +775,24 @@ def step(self, dt, solver=None, npts=2, save=True, **kwargs):
             the step dt. Default is 2 (returns the solution at t0 and t0 + dt).
         save : bool
             Turn on to store the solution of all previous timesteps
+        starting_solution : :class:`pybamm.Solution`
+            The solution to start stepping from. If None (default), then self._solution
+            is used
         **kwargs
             Additional key-word arguments passed to `solver.solve`.
             See :meth:`pybamm.BaseSolver.step`.
         """
-        self.build()
+        if self.operating_mode in ["without experiment", "drive cycle"]:
+            self.build()
 
         if solver is None:
             solver = self.solver
 
+        if starting_solution is None:
+            starting_solution = self._solution
+
         self._solution = solver.step(
-            self._solution, self.built_model, dt, npts=npts, save=save, **kwargs
+            starting_solution, self.built_model, dt, npts=npts, save=save, **kwargs
         )
 
         return self.solution
diff --git a/pybamm/solvers/base_solver.py b/pybamm/solvers/base_solver.py
index 030f569ac6..77069c78ce 100644
--- a/pybamm/solvers/base_solver.py
+++ b/pybamm/solvers/base_solver.py
@@ -932,17 +932,34 @@ def step(
                         "steps!".format(domain)
                     )
 
-        # Run set up on first step
         if old_solution is None:
+            # Run set up on first step
             pybamm.logger.verbose(
                 "Start stepping {} with {}".format(model.name, self.name)
             )
             self.set_up(model, ext_and_inputs)
+            self.models_set_up.update(
+                {model: {"initial conditions": model.concatenated_initial_conditions}}
+            )
             t = 0.0
-        else:
-            # initialize with old solution
+        elif model not in self.models_set_up:
+            # Run set up if the model has changed
+            self.set_up(model, ext_and_inputs)
+            self.models_set_up.update(
+                {model: {"initial conditions": model.concatenated_initial_conditions}}
+            )
+
+        if old_solution is not None:
             t = old_solution.all_ts[-1][-1]
-            model.y0 = old_solution.all_ys[-1][:, -1]
+            if old_solution.all_models[-1] == model:
+                # initialize with old solution
+                model.y0 = old_solution.all_ys[-1][:, -1]
+            else:
+                model.y0 = (
+                    model.set_initial_conditions_from(old_solution)
+                    .concatenated_initial_conditions.evaluate(0, inputs=ext_and_inputs)
+                    .flatten()
+                )
         set_up_time = timer.time()
 
         # (Re-)calculate consistent initial conditions
@@ -995,7 +1012,7 @@ def step(
         )
 
         # Return solution
-        if save is False or old_solution is None:
+        if save is False:
             return solution
         else:
             return old_solution + solution
@@ -1044,7 +1061,7 @@ def get_termination_reason(self, solution, events):
                 event_sol = pybamm.Solution(
                     solution.t_event,
                     solution.y_event,
-                    solution.model,
+                    solution.all_models[-1],
                     solution.all_inputs[-1],
                     solution.t_event,
                     solution.y_event,
diff --git a/pybamm/solvers/casadi_solver.py b/pybamm/solvers/casadi_solver.py
index a2903cb484..7e20a66c41 100644
--- a/pybamm/solvers/casadi_solver.py
+++ b/pybamm/solvers/casadi_solver.py
@@ -336,11 +336,8 @@ def event_fun(t):
                     # assign temporary solve time
                     current_step_sol.solve_time = np.nan
 
-                    if solution is None:
-                        solution = current_step_sol
-                    else:
-                        # append solution from the current step to solution
-                        solution = solution + current_step_sol
+                    # append solution from the current step to solution
+                    solution = solution + current_step_sol
                     solution.termination = "event"
                     solution.t_event = np.array([t_event])
                     solution.y_event = y_event[:, np.newaxis]
@@ -349,11 +346,8 @@ def event_fun(t):
                 else:
                     # assign temporary solve time
                     current_step_sol.solve_time = np.nan
-                    if solution is None:
-                        solution = current_step_sol
-                    else:
-                        # append solution from the current step to solution
-                        solution = solution + current_step_sol
+                    # append solution from the current step to solution
+                    solution = solution + current_step_sol
                     # update time
                     t = t_window[-1]
                     # update y0
diff --git a/pybamm/solvers/processed_variable.py b/pybamm/solvers/processed_variable.py
index ddc6b408bf..97a5a9f0f0 100644
--- a/pybamm/solvers/processed_variable.py
+++ b/pybamm/solvers/processed_variable.py
@@ -34,13 +34,16 @@ class ProcessedVariable(object):
 
     Parameters
     ----------
-    base_variable : :class:`pybamm.Symbol`
-        A base variable with a method `evaluate(t,y)` that returns the value of that
-        variable. Note that this can be any kind of node in the expression tree, not
+    base_variables : list of :class:`pybamm.Symbol`
+        A list of base variables with a method `evaluate(t,y)`, each entry of which
+        returns the value of that variable for that particular sub-solution.
+        A Solution can be comprised of sub-solutions which are the solutions of
+        different models.
+        Note that this can be any kind of node in the expression tree, not
         just a :class:`pybamm.Variable`.
         When evaluated, returns an array of size (m,n)
-    base_variable_casadi : :class:`casadi.Function`
-        A casadi function. When evaluated, returns the same thing as
+    base_variable_casadis : list of :class:`casadi.Function`
+        A list of casadi functions. When evaluated, returns the same thing as
         `base_Variable.evaluate` (but more efficiently).
     solution : :class:`pybamm.Solution`
         The solution object to be used to create the processed variables
@@ -49,17 +52,17 @@ class ProcessedVariable(object):
         Default is True.
     """
 
-    def __init__(self, base_variable, base_variable_casadi, solution, warn=True):
-        self.base_variable = base_variable
-        self.base_variable_casadi = base_variable_casadi
+    def __init__(self, base_variables, base_variables_casadi, solution, warn=True):
+        self.base_variables = base_variables
+        self.base_variables_casadi = base_variables_casadi
 
         self.all_ts = solution.all_ts
         self.all_ys = solution.all_ys
         self.all_inputs_casadi = solution.all_inputs_casadi
 
-        self.mesh = base_variable.mesh
-        self.domain = base_variable.domain
-        self.auxiliary_domains = base_variable.auxiliary_domains
+        self.mesh = base_variables[0].mesh
+        self.domain = base_variables[0].domain
+        self.auxiliary_domains = base_variables[0].auxiliary_domains
         self.warn = warn
 
         # Set timescale
@@ -67,11 +70,10 @@ def __init__(self, base_variable, base_variable_casadi, solution, warn=True):
         self.t_pts = solution.t * self.timescale
 
         # Store length scales
-        if solution.model:
-            self.length_scales = solution.length_scales_eval
+        self.length_scales = solution.length_scales_eval
 
         # Evaluate base variable at initial time
-        self.base_eval = self.base_variable_casadi(
+        self.base_eval = self.base_variables_casadi[0](
             self.all_ts[0][0], self.all_ys[0][:, 0], self.all_inputs_casadi[0]
         ).full()
 
@@ -101,7 +103,7 @@ def __init__(self, base_variable, base_variable_casadi, solution, warn=True):
                     # Try some shapes that could make the variable a 2D variable
                     first_dim_nodes = self.mesh.nodes
                     first_dim_edges = self.mesh.edges
-                    second_dim_pts = self.base_variable.secondary_mesh.nodes
+                    second_dim_pts = self.base_variables[0].secondary_mesh.nodes
                     if self.base_eval.size // len(second_dim_pts) in [
                         len(first_dim_nodes),
                         len(first_dim_edges),
@@ -110,7 +112,7 @@ def __init__(self, base_variable, base_variable_casadi, solution, warn=True):
                     else:
                         # Raise error for 3D variable
                         raise NotImplementedError(
-                            "Shape not recognized for {} ".format(base_variable)
+                            "Shape not recognized for {} ".format(base_variables[0])
                             + "(note processing of 3D variables is not yet implemented)"
                         )
 
@@ -119,11 +121,13 @@ def initialise_0D(self):
         entries = np.empty(len(self.t_pts))
         idx = 0
         # Evaluate the base_variable index-by-index
-        for ts, ys, inputs in zip(self.all_ts, self.all_ys, self.all_inputs_casadi):
+        for ts, ys, inputs, base_var_casadi in zip(
+            self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
+        ):
             for inner_idx, t in enumerate(ts):
                 t = ts[inner_idx]
                 y = ys[:, inner_idx]
-                entries[idx] = self.base_variable_casadi(t, y, inputs).full()[0, 0]
+                entries[idx] = base_var_casadi(t, y, inputs).full()[0, 0]
                 idx += 1
 
         # set up interpolation
@@ -152,11 +156,13 @@ def initialise_1D(self, fixed_t=False):
 
         # Evaluate the base_variable index-by-index
         idx = 0
-        for ts, ys, inputs in zip(self.all_ts, self.all_ys, self.all_inputs_casadi):
+        for ts, ys, inputs, base_var_casadi in zip(
+            self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
+        ):
             for inner_idx, t in enumerate(ts):
                 t = ts[inner_idx]
                 y = ys[:, inner_idx]
-                entries[:, idx] = self.base_variable_casadi(t, y, inputs).full()[:, 0]
+                entries[:, idx] = base_var_casadi(t, y, inputs).full()[:, 0]
                 idx += 1
 
         # Get node and edge values
@@ -243,8 +249,8 @@ def initialise_2D(self):
         """
         first_dim_nodes = self.mesh.nodes
         first_dim_edges = self.mesh.edges
-        second_dim_nodes = self.base_variable.secondary_mesh.nodes
-        second_dim_edges = self.base_variable.secondary_mesh.edges
+        second_dim_nodes = self.base_variables[0].secondary_mesh.nodes
+        second_dim_edges = self.base_variables[0].secondary_mesh.edges
         if self.base_eval.size // len(second_dim_nodes) == len(first_dim_nodes):
             first_dim_pts = first_dim_nodes
         elif self.base_eval.size // len(second_dim_nodes) == len(first_dim_edges):
@@ -257,12 +263,14 @@ def initialise_2D(self):
 
         # Evaluate the base_variable index-by-index
         idx = 0
-        for ts, ys, inputs in zip(self.all_ts, self.all_ys, self.all_inputs_casadi):
+        for ts, ys, inputs, base_var_casadi in zip(
+            self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
+        ):
             for inner_idx, t in enumerate(ts):
                 t = ts[inner_idx]
                 y = ys[:, inner_idx]
                 entries[:, :, idx] = np.reshape(
-                    self.base_variable_casadi(t, y, inputs).full(),
+                    base_var_casadi(t, y, inputs).full(),
                     [first_dim_size, second_dim_size],
                     order="F",
                 )
@@ -401,12 +409,14 @@ def initialise_2D_scikit_fem(self):
 
         # Evaluate the base_variable index-by-index
         idx = 0
-        for ts, ys, inputs in zip(self.all_ts, self.all_ys, self.all_inputs_casadi):
+        for ts, ys, inputs, base_var_casadi in zip(
+            self.all_ts, self.all_ys, self.all_inputs_casadi, self.base_variables_casadi
+        ):
             for inner_idx, t in enumerate(ts):
                 t = ts[inner_idx]
                 y = ys[:, inner_idx]
                 entries[:, :, idx] = np.reshape(
-                    self.base_variable_casadi(t, y, inputs).full(),
+                    base_var_casadi(t, y, inputs).full(),
                     [len_y, len_z],
                     order="F",
                 )
@@ -462,11 +472,11 @@ def __call__(self, t=None, x=None, r=None, y=None, z=None, warn=True):
         if t is None:
             if len(self.t_pts) == 1:
                 t = self.t_pts
-            elif self.base_variable.is_constant():
+            elif len(self.base_variables) == 1 and self.base_variables[0].is_constant():
                 t = self.t_pts[0]
             else:
                 raise ValueError(
-                    "t cannot be None for variable {}".format(self.base_variable)
+                    "t cannot be None for variable {}".format(self.base_variables)
                 )
 
         # Call interpolant of correct spatial dimension
diff --git a/pybamm/solvers/solution.py b/pybamm/solvers/solution.py
index d295f79b1d..2495c3ea04 100644
--- a/pybamm/solvers/solution.py
+++ b/pybamm/solvers/solution.py
@@ -26,8 +26,10 @@ class Solution(object):
         vector of solutions at time t[i].
         A list of ys can be provided instead to initialize a solution with
         sub-solutions.
-    model : :class:`pybamm.BaseModel`
-        The model that was used to calculate the solution
+    all_models : :class:`pybamm.BaseModel`
+        The model that was used to calculate the solution.
+        A list of models can be provided instead to initialize a solution with
+        sub-solutions that have been calculated using those models.
     all_inputs : dict (or list of these)
         The inputs that were used to calculate the solution
         A list of inputs can be provided instead to initialize a solution with
@@ -46,7 +48,7 @@ def __init__(
         self,
         all_ts,
         all_ys,
-        model,
+        all_models,
         all_inputs,
         t_event=None,
         y_event=None,
@@ -56,8 +58,11 @@ def __init__(
             all_ts = [all_ts]
         if not isinstance(all_ys, list):
             all_ys = [all_ys]
-        self.all_ts = all_ts
-        self.all_ys = all_ys
+        if not isinstance(all_models, list):
+            all_models = [all_models]
+        self._all_ts = all_ts
+        self._all_ys = all_ys
+        self._all_models = all_models
 
         self._t_event = t_event
         self._y_event = y_event
@@ -74,21 +79,18 @@ def __init__(
             isinstance(v, casadi.MX) for v in all_inputs[0].values()
         )
 
-        # Set up model
-        self._model = model
-
         # Copy the timescale_eval and lengthscale_evals if they exist
-        if hasattr(model, "timescale_eval"):
-            self.timescale_eval = model.timescale_eval
+        if hasattr(all_models[0], "timescale_eval"):
+            self.timescale_eval = all_models[0].timescale_eval
         else:
-            self.timescale_eval = model.timescale.evaluate()
-        # self.timescale_eval = model.timescale_eval
-        if hasattr(model, "length_scales_eval"):
-            self.length_scales_eval = model.length_scales_eval
+            self.timescale_eval = all_models[0].timescale.evaluate()
+
+        if hasattr(all_models[0], "length_scales_eval"):
+            self.length_scales_eval = all_models[0].length_scales_eval
         else:
             self.length_scales_eval = {
                 domain: scale.evaluate()
-                for domain, scale in model.length_scales.items()
+                for domain, scale in all_models[0].length_scales.items()
             }
 
         self.set_up_time = None
@@ -129,15 +131,29 @@ def y(self):
             return self._y
 
     def set_y(self):
-        if isinstance(self.all_ys[0], (casadi.DM, casadi.MX)):
-            self._y = casadi.horzcat(*self.all_ys)
-        else:
-            self._y = np.hstack(self.all_ys)
+        try:
+            if isinstance(self.all_ys[0], (casadi.DM, casadi.MX)):
+                self._y = casadi.horzcat(*self.all_ys)
+            else:
+                self._y = np.hstack(self.all_ys)
+        except ValueError:
+            raise pybamm.SolverError(
+                "The solution is made up from different models, so `y` cannot be "
+                "computed explicitly."
+            )
+
+    @property
+    def all_ts(self):
+        return self._all_ts
 
     @property
-    def model(self):
-        """Model used for solution"""
-        return self._model
+    def all_ys(self):
+        return self._all_ys
+
+    @property
+    def all_models(self):
+        """Model(s) used for solution"""
+        return self._all_models
 
     @property
     def all_inputs_casadi(self):
@@ -179,6 +195,35 @@ def termination(self, value):
         """Updates the reason for termination"""
         self._termination = value
 
+    @property
+    def last_state(self):
+        """
+        A Solution object that only contains the final state. This is faster to evaluate
+        than the full solution when only the final state is needed (e.g. to initialize
+        a model with the solution)
+        """
+        try:
+            return self._last_state
+        except AttributeError:
+            new_sol = Solution(
+                self.all_ts[-1][-1:],
+                self.all_ys[-1][:, -1:],
+                self.all_models[-1:],
+                self.all_inputs[-1:],
+                self.t_event,
+                self.y_event,
+                self.termination,
+            )
+            new_sol._all_inputs_casadi = self.all_inputs_casadi[-1:]
+            new_sol._sub_solutions = self.sub_solutions
+
+            new_sol.solve_time = 0
+            new_sol.integration_time = 0
+            new_sol.set_up_time = 0
+
+            self._last_state = new_sol
+            return self._last_state
+
     @property
     def total_time(self):
         return self.set_up_time + self.solve_time
@@ -194,39 +239,48 @@ def update(self, variables):
             # If there are symbolic inputs then we need to make a
             # ProcessedSymbolicVariable
             if self.has_symbolic_inputs is True:
-                var = pybamm.ProcessedSymbolicVariable(self.model.variables[key], self)
+                var = pybamm.ProcessedSymbolicVariable(
+                    self.all_models[0].variables[key], self
+                )
 
             # Otherwise a standard ProcessedVariable is ok
             else:
-                var_pybamm = self.model.variables[key]
-
-                if key in self.model._variables_casadi:
-                    var_casadi = self.model._variables_casadi[key]
-                else:
-                    self._t_MX = casadi.MX.sym("t")
-                    self._y_MX = casadi.MX.sym("y", self.all_ys[0].shape[0])
-                    self._symbolic_inputs_dict = {
-                        key: casadi.MX.sym("input", value.shape[0])
-                        for key, value in self.all_inputs[0].items()
-                    }
-                    self._symbolic_inputs = casadi.vertcat(
-                        *[p for p in self._symbolic_inputs_dict.values()]
-                    )
+                vars_pybamm = [model.variables[key] for model in self.all_models]
+
+                # Iterate through all models, some may be in the list several times and
+                # therefore only get set up once
+                vars_casadi = []
+                for model, ys, inputs, var_pybamm in zip(
+                    self.all_models, self.all_ys, self.all_inputs, vars_pybamm
+                ):
+                    if key in model._variables_casadi:
+                        var_casadi = model._variables_casadi[key]
+                    else:
+                        self._t_MX = casadi.MX.sym("t")
+                        self._y_MX = casadi.MX.sym("y", ys.shape[0])
+                        self._symbolic_inputs_dict = {
+                            key: casadi.MX.sym("input", value.shape[0])
+                            for key, value in inputs.items()
+                        }
+                        self._symbolic_inputs = casadi.vertcat(
+                            *[p for p in self._symbolic_inputs_dict.values()]
+                        )
 
-                    # Convert variable to casadi
-                    # Make all inputs symbolic first for converting to casadi
-                    var_sym = var_pybamm.to_casadi(
-                        self._t_MX, self._y_MX, inputs=self._symbolic_inputs_dict
-                    )
+                        # Convert variable to casadi
+                        # Make all inputs symbolic first for converting to casadi
+                        var_sym = var_pybamm.to_casadi(
+                            self._t_MX, self._y_MX, inputs=self._symbolic_inputs_dict
+                        )
 
-                    var_casadi = casadi.Function(
-                        "variable",
-                        [self._t_MX, self._y_MX, self._symbolic_inputs],
-                        [var_sym],
-                    )
-                    self.model._variables_casadi[key] = var_casadi
+                        var_casadi = casadi.Function(
+                            "variable",
+                            [self._t_MX, self._y_MX, self._symbolic_inputs],
+                            [var_sym],
+                        )
+                        model._variables_casadi[key] = var_casadi
+                    vars_casadi.append(var_casadi)
 
-                var = pybamm.ProcessedVariable(var_pybamm, var_casadi, self)
+                var = pybamm.ProcessedVariable(vars_pybamm, vars_casadi, self)
 
             # Save variable and data
             self._variables[key] = var
@@ -385,12 +439,17 @@ def save_data(self, filename, variables=None, to_format="pickle", short_names=No
 
     @property
     def sub_solutions(self):
-        """List of sub solutions that have been concatenated to form the full solution
+        """
+        List of sub solutions that have been concatenated to form the full solution
         """
         return self._sub_solutions
 
     def __add__(self, other):
         """ Adds two solutions together, e.g. when stepping """
+        if not isinstance(other, Solution):
+            raise pybamm.SolverError(
+                "Only a Solution or None can be added to a Solution"
+            )
         # Special case: new solution only has one timestep and it is already in the
         # existing solution. In this case, return a copy of the existing solution
         if (
@@ -412,7 +471,7 @@ def __add__(self, other):
         new_sol = Solution(
             all_ts,
             all_ys,
-            self.model,
+            self.all_models + other.all_models,
             self.all_inputs + other.all_inputs,
             self.t_event,
             self.y_event,
@@ -435,11 +494,22 @@ def __add__(self, other):
 
         return new_sol
 
+    def __radd__(self, other):
+        """
+        Function to deal with the case `None + Solution` (returns `Solution`)
+        """
+        if other is None:
+            return self.copy()
+        else:
+            raise pybamm.SolverError(
+                "Only a Solution or None can be added to a Solution"
+            )
+
     def copy(self):
         new_sol = Solution(
             self.all_ts,
             self.all_ys,
-            self.model,
+            self.all_models,
             self.all_inputs,
             self.t_event,
             self.y_event,
diff --git a/tests/integration/test_models/standard_output_comparison.py b/tests/integration/test_models/standard_output_comparison.py
index 73842db6ee..ec77cf0ca6 100644
--- a/tests/integration/test_models/standard_output_comparison.py
+++ b/tests/integration/test_models/standard_output_comparison.py
@@ -11,9 +11,9 @@ class StandardOutputComparison(object):
     def __init__(self, solutions):
         self.solutions = solutions
 
-        if isinstance(solutions[0].model, pybamm.lithium_ion.BaseModel):
+        if isinstance(solutions[0].all_models[0], pybamm.lithium_ion.BaseModel):
             self.chemistry = "Lithium-ion"
-        elif isinstance(solutions[0].model, pybamm.lead_acid.BaseModel):
+        elif isinstance(solutions[0].all_models[0], pybamm.lead_acid.BaseModel):
             self.chemistry = "Lead acid"
 
         self.t = self.get_output_times()
@@ -33,7 +33,7 @@ def get_output_times(self):
             np.testing.assert_array_equal(t_common, solution.t[:max_index])
 
         # Get timescale
-        timescale = self.solutions[0].model.timescale_eval
+        timescale = self.solutions[0].timescale_eval
 
         return t_common * timescale
 
diff --git a/tests/unit/test_experiments/test_simulation_with_experiment.py b/tests/unit/test_experiments/test_simulation_with_experiment.py
index 0bd4270845..b15ef04c73 100644
--- a/tests/unit/test_experiments/test_simulation_with_experiment.py
+++ b/tests/unit/test_experiments/test_simulation_with_experiment.py
@@ -52,17 +52,19 @@ def test_set_up(self):
             sim._experiment_times, [3600, 7 * 24 * 3600, 7 * 24 * 3600, 3600]
         )
 
+        model_I = sim.op_conds_to_model_and_param[(-1.0, "A")][0]
+        model_V = sim.op_conds_to_model_and_param[(4.1, "V")][0]
         self.assertIn(
             "Current cut-off (positive) [A] [experiment]",
-            [event.name for event in sim.model.events],
+            [event.name for event in model_V.events],
         )
         self.assertIn(
             "Current cut-off (negative) [A] [experiment]",
-            [event.name for event in sim.model.events],
+            [event.name for event in model_V.events],
         )
         self.assertIn(
             "Voltage cut-off [V] [experiment]",
-            [event.name for event in sim.model.events],
+            [event.name for event in model_I.events],
         )
 
         # fails if trying to set up with something that isn't an experiment
@@ -70,13 +72,38 @@ def test_set_up(self):
             pybamm.Simulation(model, experiment=0)
 
     def test_run_experiment(self):
+        experiment = pybamm.Experiment(
+            [
+                (
+                    "Discharge at C/20 for 1 hour",
+                    "Charge at 1 A until 4.1 V",
+                    "Hold at 4.1 V until C/2",
+                    "Discharge at 2 W for 1 hour",
+                )
+            ]
+        )
+        model = pybamm.lithium_ion.SPM()
+        sim = pybamm.Simulation(model, experiment=experiment)
+        sol = sim.solve()
+        self.assertEqual(sol.termination, "final time")
+        self.assertEqual(len(sol.cycles), 1)
+
+        # Solve again starting from solution
+        sol2 = sim.solve(starting_solution=sol)
+        self.assertEqual(sol2.termination, "final time")
+        self.assertGreater(sol2.t[-1], sol.t[-1])
+        self.assertEqual(sol2.cycles[0], sol.cycles[0])
+        self.assertEqual(len(sol2.cycles), 2)
+
+    def test_run_experiment_old_setup_type(self):
         experiment = pybamm.Experiment(
             [
                 "Discharge at C/20 for 1 hour",
                 "Charge at 1 A until 4.1 V",
                 "Hold at 4.1 V until C/2",
                 "Discharge at 2 W for 1 hour",
-            ]
+            ],
+            use_simulation_setup_type="old",
         )
         model = pybamm.lithium_ion.SPM()
         sim = pybamm.Simulation(model, experiment=experiment)
@@ -92,7 +119,7 @@ def test_run_experiment_breaks_early(self):
         t_eval = [0, 1]
         sim.solve(t_eval, solver=pybamm.CasadiSolver())
         pybamm.set_logging_level("WARNING")
-        self.assertIn("event", sim._solution.termination)
+        self.assertEqual(sim._solution, None)
 
     def test_inputs(self):
         experiment = pybamm.Experiment(
diff --git a/tests/unit/test_models/test_base_model.py b/tests/unit/test_models/test_base_model.py
index 121c30b9cd..e0f2c7b8a6 100644
--- a/tests/unit/test_models/test_base_model.py
+++ b/tests/unit/test_models/test_base_model.py
@@ -927,6 +927,21 @@ def test_set_initial_condition_errors(self):
         ):
             model.set_initial_conditions_from({"var_concat_neg": np.ones((5, 6, 7))})
 
+        # Inconsistent model and variable names
+        model = pybamm.BaseModel()
+        var = pybamm.Variable("var")
+        model.rhs = {var: -var}
+        model.initial_conditions = {var: pybamm.Scalar(1)}
+        with self.assertRaisesRegex(pybamm.ModelError, "must appear in the solution"):
+            model.set_initial_conditions_from({"wrong var": 2})
+        var = pybamm.Concatenation(
+            pybamm.Variable("var", "test"), pybamm.Variable("var2", "test2")
+        )
+        model.rhs = {var: -var}
+        model.initial_conditions = {var: pybamm.Scalar(1)}
+        with self.assertRaisesRegex(pybamm.ModelError, "must appear in the solution"):
+            model.set_initial_conditions_from({"wrong var": 2})
+
 
 class TestStandardBatteryBaseModel(unittest.TestCase):
     def test_default_solver(self):
diff --git a/tests/unit/test_simulation.py b/tests/unit/test_simulation.py
index 4c9f473f42..6dd6d3be76 100644
--- a/tests/unit/test_simulation.py
+++ b/tests/unit/test_simulation.py
@@ -71,13 +71,19 @@ def test_solve(self):
 
         # test solve without check
         sim = pybamm.Simulation(pybamm.lithium_ion.SPM())
-        sim.solve(t_eval=[0, 600], check_model=False)
+        sol = sim.solve(t_eval=[0, 600], check_model=False)
         for val in list(sim.built_model.rhs.values()):
             self.assertFalse(val.has_symbol_of_classes(pybamm.Parameter))
             # skip test for scalar variables (e.g. discharge capacity)
             if val.size > 1:
                 self.assertTrue(val.has_symbol_of_classes(pybamm.Matrix))
 
+        # Test options that are only available when simulating an experiment
+        with self.assertRaisesRegex(ValueError, "save_at_cycles"):
+            sim.solve(save_at_cycles=2)
+        with self.assertRaisesRegex(ValueError, "starting_solution"):
+            sim.solve(starting_solution=sol)
+
     def test_solve_non_battery_model(self):
 
         model = pybamm.BaseModel()
diff --git a/tests/unit/test_solvers/test_processed_variable.py b/tests/unit/test_solvers/test_processed_variable.py
index 0a760a2657..6d692da6ea 100644
--- a/tests/unit/test_solvers/test_processed_variable.py
+++ b/tests/unit/test_solvers/test_processed_variable.py
@@ -37,8 +37,8 @@ def test_processed_variable_0D(self):
         y_sol = np.array([np.linspace(0, 5)])
         var_casadi = to_casadi(var, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var,
-            var_casadi,
+            [var],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -51,8 +51,8 @@ def test_processed_variable_0D(self):
         y_sol = np.array([1])[:, np.newaxis]
         var_casadi = to_casadi(var, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var,
-            var_casadi,
+            [var],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -75,8 +75,8 @@ def test_processed_variable_1D(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -84,8 +84,8 @@ def test_processed_variable_1D(self):
         np.testing.assert_array_equal(processed_var(t_sol, x_sol), y_sol)
         eqn_casadi = to_casadi(eqn_sol, y_sol)
         processed_eqn = pybamm.ProcessedVariable(
-            eqn_sol,
-            eqn_casadi,
+            [eqn_sol],
+            [eqn_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -104,8 +104,8 @@ def test_processed_variable_1D(self):
         x_s_edge.mesh = disc.mesh["separator"]
         x_s_casadi = to_casadi(x_s_edge, y_sol)
         processed_x_s_edge = pybamm.ProcessedVariable(
-            x_s_edge,
-            x_s_casadi,
+            [x_s_edge],
+            [x_s_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -120,8 +120,8 @@ def test_processed_variable_1D(self):
         y_sol = np.ones_like(x_sol)[:, np.newaxis]
         eqn_casadi = to_casadi(eqn_sol, y_sol)
         processed_eqn2 = pybamm.ProcessedVariable(
-            eqn_sol,
-            eqn_casadi,
+            [eqn_sol],
+            [eqn_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -155,7 +155,7 @@ def test_processed_variable_1D_unknown_domain(self):
         c = pybamm.StateVector(slice(0, var_pts[x]), domain=["SEI layer"])
         c.mesh = mesh["SEI layer"]
         c_casadi = to_casadi(c, y_sol)
-        pybamm.ProcessedVariable(c, c_casadi, solution, warn=False)
+        pybamm.ProcessedVariable([c], [c_casadi], solution, warn=False)
 
     def test_processed_variable_2D_x_r(self):
         var = pybamm.Variable(
@@ -182,8 +182,8 @@ def test_processed_variable_2D_x_r(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -217,8 +217,8 @@ def test_processed_variable_2D_x_z(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -237,8 +237,8 @@ def test_processed_variable_2D_x_z(self):
         x_s_edge.secondary_mesh = disc.mesh["current collector"]
         x_s_casadi = to_casadi(x_s_edge, y_sol)
         processed_x_s_edge = pybamm.ProcessedVariable(
-            x_s_edge,
-            x_s_casadi,
+            [x_s_edge],
+            [x_s_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -271,8 +271,8 @@ def test_processed_variable_2D_space_only(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -295,8 +295,8 @@ def test_processed_variable_2D_scikit(self):
 
         var_casadi = to_casadi(var_sol, u_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -318,8 +318,8 @@ def test_processed_variable_2D_fixed_t_scikit(self):
 
         var_casadi = to_casadi(var_sol, u_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -340,8 +340,8 @@ def test_processed_var_0D_interpolation(self):
         y_sol = np.array([np.linspace(0, 5, 1000)])
         var_casadi = to_casadi(var, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var,
-            var_casadi,
+            [var],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -353,8 +353,8 @@ def test_processed_var_0D_interpolation(self):
 
         eqn_casadi = to_casadi(eqn, y_sol)
         processed_eqn = pybamm.ProcessedVariable(
-            eqn,
-            eqn_casadi,
+            [eqn],
+            [eqn_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -377,8 +377,8 @@ def test_processed_var_0D_fixed_t_interpolation(self):
         y_sol = np.array([[100]])
         eqn_casadi = to_casadi(eqn, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            eqn,
-            eqn_casadi,
+            [eqn],
+            [eqn_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -401,8 +401,8 @@ def test_processed_var_1D_interpolation(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -421,8 +421,8 @@ def test_processed_var_1D_interpolation(self):
         )
         eqn_casadi = to_casadi(eqn_sol, y_sol)
         processed_eqn = pybamm.ProcessedVariable(
-            eqn_sol,
-            eqn_casadi,
+            [eqn_sol],
+            [eqn_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -441,8 +441,8 @@ def test_processed_var_1D_interpolation(self):
         x_casadi = to_casadi(x_disc, y_sol)
 
         processed_x = pybamm.ProcessedVariable(
-            x_disc,
-            x_casadi,
+            [x_disc],
+            [x_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -455,8 +455,8 @@ def test_processed_var_1D_interpolation(self):
         r_n.mesh = disc.mesh["negative particle"]
         r_n_casadi = to_casadi(r_n, y_sol)
         processed_r_n = pybamm.ProcessedVariable(
-            r_n,
-            r_n_casadi,
+            [r_n],
+            [r_n_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -479,8 +479,8 @@ def test_processed_var_1D_fixed_t_interpolation(self):
 
         eqn_casadi = to_casadi(eqn_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            eqn_sol,
-            eqn_casadi,
+            [eqn_sol],
+            [eqn_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -517,8 +517,8 @@ def test_processed_var_2D_interpolation(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -565,8 +565,8 @@ def test_processed_var_2D_interpolation(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -600,8 +600,8 @@ def test_processed_var_2D_fixed_t_interpolation(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -629,8 +629,8 @@ def test_processed_var_2D_secondary_broadcast(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -668,8 +668,8 @@ def test_processed_var_2D_secondary_broadcast(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -692,8 +692,8 @@ def test_processed_var_2D_scikit_interpolation(self):
 
         var_casadi = to_casadi(var_sol, u_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -736,8 +736,8 @@ def test_processed_var_2D_fixed_t_scikit_interpolation(self):
 
         var_casadi = to_casadi(var_sol, u_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -813,8 +813,8 @@ def test_call_failure(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -836,8 +836,8 @@ def test_call_failure(self):
 
         var_casadi = to_casadi(var_sol, y_sol)
         processed_var = pybamm.ProcessedVariable(
-            var_sol,
-            var_casadi,
+            [var_sol],
+            [var_casadi],
             pybamm.Solution(t_sol, y_sol, pybamm.BaseModel(), {}),
             warn=False,
         )
@@ -866,8 +866,8 @@ def test_3D_raises_error(self):
 
         with self.assertRaisesRegex(NotImplementedError, "Shape not recognized"):
             pybamm.ProcessedVariable(
-                var_sol,
-                var_casadi,
+                [var_sol],
+                [var_casadi],
                 pybamm.Solution(t_sol, u_sol, pybamm.BaseModel(), {}),
                 warn=False,
             )
diff --git a/tests/unit/test_solvers/test_scipy_solver.py b/tests/unit/test_solvers/test_scipy_solver.py
index 983fcacff3..d947224c95 100644
--- a/tests/unit/test_solvers/test_scipy_solver.py
+++ b/tests/unit/test_solvers/test_scipy_solver.py
@@ -164,7 +164,6 @@ def test_model_step_python(self):
         var = pybamm.Variable("var", domain=domain)
         model.rhs = {var: 0.1 * var}
         model.initial_conditions = {var: 1}
-        # No need to set parameters; can use base discretisation (no spatial operators)
 
         # create discretisation
         mesh = get_mesh_for_testing()
@@ -194,6 +193,44 @@ def test_model_step_python(self):
         solution = solver.solve(model, t_eval)
         np.testing.assert_array_almost_equal(solution.y[0], step_sol.y[0])
 
+    def test_step_different_model(self):
+        disc = pybamm.Discretisation()
+
+        # Create and discretise model1
+        model1 = pybamm.BaseModel()
+        var = pybamm.Variable("var")
+        var2 = pybamm.Variable("var2")
+        model1.rhs = {var: 0.1 * var}
+        model1.initial_conditions = {var: 1}
+        model1.variables = {"var": var, "mul_var": 2 * var, "var2": var}
+        disc.process_model(model1)
+
+        # Create and discretise model2, which is slightly different
+        model2 = pybamm.BaseModel()
+        var = pybamm.Variable("var")
+        var2 = pybamm.Variable("var2")
+        model2.rhs = {var: 0.2 * var, var2: -0.5 * var2}
+        model2.initial_conditions = {var: 1, var2: 1}
+        model2.variables = {"var": var, "mul_var": 3 * var, "var2": var2}
+        disc.process_model(model2)
+
+        solver = pybamm.ScipySolver(rtol=1e-8, atol=1e-8, method="RK45")
+
+        # Step once
+        dt = 1
+        step_sol1 = solver.step(None, model1, dt)
+        np.testing.assert_array_equal(step_sol1.t, [0, dt])
+        np.testing.assert_array_almost_equal(step_sol1.y[0], np.exp(0.1 * step_sol1.t))
+
+        # Step again, the model has changed
+        step_sol2 = solver.step(step_sol1, model2, dt)
+        np.testing.assert_array_equal(step_sol2.t, [0, dt, 2 * dt])
+        np.testing.assert_array_almost_equal(
+            step_sol2.all_ys[0][0], np.exp(0.1 * step_sol1.t)
+        )
+        print(step_sol2.all_ys)
+        print(step_sol2["mul_var"].data)
+
     def test_model_solver_with_inputs(self):
         # Create model
         model = pybamm.BaseModel()
diff --git a/tests/unit/test_solvers/test_solution.py b/tests/unit/test_solvers/test_solution.py
index 2b03339a5b..40ba22a24d 100644
--- a/tests/unit/test_solvers/test_solution.py
+++ b/tests/unit/test_solvers/test_solution.py
@@ -20,7 +20,7 @@ def test_init(self):
         self.assertEqual(sol.y_event, None)
         self.assertEqual(sol.termination, "final time")
         self.assertEqual(sol.all_inputs, [{}])
-        self.assertIsInstance(sol.model, pybamm.BaseModel)
+        self.assertIsInstance(sol.all_models[0], pybamm.BaseModel)
 
     def test_errors(self):
         bad_ts = [np.array([1, 2, 3]), np.array([3, 4, 5])]
@@ -61,9 +61,10 @@ def test_add_solutions(self):
         self.assertEqual(len(sol_sum.sub_solutions), 2)
         np.testing.assert_array_equal(sol_sum.sub_solutions[0].t, t1)
         np.testing.assert_array_equal(sol_sum.sub_solutions[1].t, t2)
-        self.assertEqual(sol_sum.sub_solutions[0].model, sol_sum.model)
+        self.assertEqual(sol_sum.sub_solutions[0].all_models[0], sol_sum.all_models[0])
         np.testing.assert_array_equal(sol_sum.sub_solutions[0].all_inputs[0]["a"], 1)
-        self.assertEqual(sol_sum.sub_solutions[1].model, sol2.model)
+        self.assertEqual(sol_sum.sub_solutions[1].all_models[0], sol2.all_models[0])
+        self.assertEqual(sol_sum.all_models[1], sol2.all_models[0])
         np.testing.assert_array_equal(sol_sum.sub_solutions[1].all_inputs[0]["a"], 2)
 
         # Add solution already contained in existing solution
@@ -72,6 +73,43 @@ def test_add_solutions(self):
         sol3 = pybamm.Solution(t3, y3, pybamm.BaseModel(), {"a": 3})
         self.assertEqual((sol_sum + sol3).all_ts, sol_sum.copy().all_ts)
 
+        # radd
+        sol4 = None + sol3
+        self.assertEqual(sol3.all_ys, sol4.all_ys)
+
+        # radd failure
+        with self.assertRaisesRegex(
+            pybamm.SolverError, "Only a Solution or None can be added to a Solution"
+        ):
+            sol3 + 2
+        with self.assertRaisesRegex(
+            pybamm.SolverError, "Only a Solution or None can be added to a Solution"
+        ):
+            2 + sol3
+
+    def test_add_solutions_different_models(self):
+        # Set up first solution
+        t1 = np.linspace(0, 1)
+        y1 = np.tile(t1, (20, 1))
+        sol1 = pybamm.Solution(t1, y1, pybamm.BaseModel(), {"a": 1})
+        sol1.solve_time = 1.5
+        sol1.integration_time = 0.3
+
+        # Set up second solution
+        t2 = np.linspace(1, 2)
+        y2 = np.tile(t2, (10, 1))
+        sol2 = pybamm.Solution(t2, y2, pybamm.BaseModel(), {"a": 2})
+        sol2.solve_time = 1
+        sol2.integration_time = 0.5
+        sol_sum = sol1 + sol2
+
+        # Test
+        np.testing.assert_array_equal(sol_sum.t, np.concatenate([t1, t2[1:]]))
+        with self.assertRaisesRegex(
+            pybamm.SolverError, "The solution is made up from different models"
+        ):
+            sol_sum.y
+
     def test_copy(self):
         # Set up first solution
         t1 = [np.linspace(0, 1), np.linspace(1, 2, 5)]
@@ -91,6 +129,26 @@ def test_copy(self):
         self.assertEqual(sol_copy.solve_time, sol1.solve_time)
         self.assertEqual(sol_copy.integration_time, sol1.integration_time)
 
+    def test_last_state(self):
+        # Set up first solution
+        t1 = [np.linspace(0, 1), np.linspace(1, 2, 5)]
+        y1 = [np.tile(t1[0], (20, 1)), np.tile(t1[1], (20, 1))]
+        sol1 = pybamm.Solution(t1, y1, pybamm.BaseModel(), [{"a": 1}, {"a": 2}])
+
+        sol1.set_up_time = 0.5
+        sol1.solve_time = 1.5
+        sol1.integration_time = 0.3
+
+        sol_last_state = sol1.last_state
+        self.assertEqual(sol_last_state.all_ts[0], 2)
+        np.testing.assert_array_equal(sol_last_state.all_ys[0], 2)
+        self.assertEqual(sol_last_state.all_inputs, sol1.all_inputs[-1:])
+        self.assertEqual(sol_last_state.all_inputs_casadi, sol1.all_inputs_casadi[-1:])
+        self.assertEqual(sol_last_state.all_models, sol1.all_models[-1:])
+        self.assertEqual(sol_last_state.set_up_time, 0)
+        self.assertEqual(sol_last_state.solve_time, 0)
+        self.assertEqual(sol_last_state.integration_time, 0)
+
     def test_cycles(self):
         model = pybamm.lithium_ion.SPM()
         experiment = pybamm.Experiment(
@@ -109,8 +167,8 @@ def test_cycles(self):
         np.testing.assert_array_equal(sol.cycles[0].y, sol.y[:, :len_cycle_1])
 
         self.assertIsInstance(sol.cycles[1], pybamm.Solution)
-        np.testing.assert_array_equal(sol.cycles[1].t, sol.t[len_cycle_1:])
-        np.testing.assert_array_equal(sol.cycles[1].y, sol.y[:, len_cycle_1:])
+        np.testing.assert_array_equal(sol.cycles[1].t, sol.t[len_cycle_1 - 1 :])
+        np.testing.assert_array_equal(sol.cycles[1].y, sol.y[:, len_cycle_1 - 1 :])
 
     def test_total_time(self):
         sol = pybamm.Solution(np.array([0]), np.array([[1, 2]]), pybamm.BaseModel(), {})
@@ -216,7 +274,7 @@ def test_save(self):
         # test save whole solution
         solution.save("test.pickle")
         solution_load = pybamm.load("test.pickle")
-        self.assertEqual(solution.model.name, solution_load.model.name)
+        self.assertEqual(solution.all_models[0].name, solution_load.all_models[0].name)
         np.testing.assert_array_equal(solution["c"].entries, solution_load["c"].entries)
         np.testing.assert_array_equal(solution["d"].entries, solution_load["d"].entries)