Skip to content

Commit 4c6bf59

Browse files
#1100 starting to get SDAEs working with casadi
1 parent d438609 commit 4c6bf59

File tree

5 files changed

+1132
-815
lines changed

5 files changed

+1132
-815
lines changed

pybamm/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ def version(formatted=False):
108108
to_python,
109109
EvaluatorPython,
110110
)
111+
111112
if system() != "Windows":
112113
from .expression_tree.operations.evaluate import EvaluatorJax
113114

pybamm/solvers/base_solver.py

+34-5
Original file line numberDiff line numberDiff line change
@@ -287,19 +287,44 @@ def report(string):
287287
func = func.to_casadi(t_casadi, y_casadi, inputs=p_casadi)
288288
# Add sensitivity vectors to the rhs and algebraic equations
289289
if self.solve_sensitivity_equations is True:
290-
if name == "rhs":
290+
if name == "rhs" and model.len_rhs > 0:
291291
report("Creating sensitivity equations for rhs using CasADi")
292292
df_dx = casadi.jacobian(func, y_diff)
293293
df_dp = casadi.jacobian(func, p_casadi_stacked)
294294
S_x_mat = S_x.reshape(
295-
(model.len_rhs_and_alg, p_casadi_stacked.shape[0])
295+
(model.len_rhs, p_casadi_stacked.shape[0])
296296
)
297297
if model.len_alg == 0:
298298
S_rhs = (df_dx @ S_x_mat + df_dp).reshape((-1, 1))
299299
else:
300300
df_dz = casadi.jacobian(func, y_alg)
301-
S_rhs = df_dx @ S_x_mat + df_dz @ S_z + df_dp
301+
S_z_mat = S_z.reshape(
302+
(model.len_rhs, p_casadi_stacked.shape[0])
303+
)
304+
S_rhs = (df_dx @ S_x_mat + df_dz @ S_z_mat + df_dp).reshape(
305+
(-1, 1)
306+
)
302307
func = casadi.vertcat(func, S_rhs)
308+
if name == "algebraic" and model.len_alg > 0:
309+
report(
310+
"Creating sensitivity equations for algebraic using CasADi"
311+
)
312+
dg_dz = casadi.jacobian(func, y_alg)
313+
dg_dp = casadi.jacobian(func, p_casadi_stacked)
314+
S_z_mat = S_z.reshape(
315+
(model.len_rhs, p_casadi_stacked.shape[0])
316+
)
317+
if model.len_rhs == 0:
318+
S_alg = (dg_dz @ S_z_mat + dg_dp).reshape((-1, 1))
319+
else:
320+
dg_dx = casadi.jacobian(func, y_diff)
321+
S_x_mat = S_x.reshape(
322+
(model.len_rhs, p_casadi_stacked.shape[0])
323+
)
324+
S_alg = (dg_dx @ S_x_mat + dg_dz @ S_z_mat + dg_dp).reshape(
325+
(-1, 1)
326+
)
327+
func = casadi.vertcat(func, S_alg)
303328
elif name == "initial_conditions":
304329
if model.len_rhs == 0 or model.len_alg == 0:
305330
S_0 = casadi.jacobian(func, p_casadi_stacked).reshape(
@@ -309,8 +334,12 @@ def report(string):
309334
else:
310335
x0 = func[: model.len_rhs]
311336
z0 = func[model.len_rhs :]
312-
Sx_0 = casadi.jacobian(x0, p_casadi_stacked)
313-
Sz_0 = casadi.jacobian(z0, p_casadi_stacked)
337+
Sx_0 = casadi.jacobian(x0, p_casadi_stacked).reshape(
338+
(-1, 1)
339+
)
340+
Sz_0 = casadi.jacobian(z0, p_casadi_stacked).reshape(
341+
(-1, 1)
342+
)
314343
func = casadi.vertcat(x0, Sx_0, z0, Sz_0)
315344
if use_jacobian:
316345
report(f"Calculating jacobian for {name} using CasADi")

pybamm/solvers/casadi_algebraic_solver.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,11 @@ def _integrate(self, model, t_eval, inputs=None):
7575
y0_diff = casadi.DM()
7676
y0_alg = y0
7777
else:
78-
len_rhs = model.concatenated_rhs.size
78+
# Check y0 to see if it includes sensitivities
79+
if model.len_rhs_and_alg == y0.shape[0]:
80+
len_rhs = model.len_rhs
81+
else:
82+
len_rhs = model.len_rhs * (inputs.shape[0] + 1)
7983
y0_diff = y0[:len_rhs]
8084
y0_alg = y0[len_rhs:]
8185

pybamm/solvers/solution.py

+12-17
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,8 @@ def __init__(
6969
self.sensitivity = {}
7070
else:
7171
n_states = model.len_rhs_and_alg
72+
n_rhs = model.len_rhs
73+
n_alg = model.len_alg
7274
n_t = len(t)
7375
n_p = np.vstack(list(inputs.values())).size
7476
# Get the point where the algebraic equations start
@@ -97,26 +99,19 @@ def __init__(
9799
# tn_x1_p0, tn_x1_p1, ..., tn_x1_pn
98100
# ...
99101
# tn_xn_p0, tn_xn_p1, ..., tn_xn_pn
100-
# 1. Extract the relevant parts of y
101-
# This makes a (n_states * n_p, n_t) matrix
102-
full_sens_matrix = np.vstack(
103-
[
104-
y[model.len_rhs : len_rhs_and_sens, :],
105-
y[len_rhs_and_sens + model.len_alg :, :],
106-
]
102+
# 1, Extract rhs and alg sensitivities and reshape into 3D matrices
103+
# with shape (n_p, n_states, n_t)
104+
ode_sens = y[n_rhs:len_rhs_and_sens, :].reshape(n_p, n_rhs, n_t)
105+
alg_sens = y[len_rhs_and_sens + n_alg :, :].reshape(n_p, n_alg, n_t)
106+
# 2. Concatenate into a single 3D matrix with shape (n_p, n_states, n_t)
107+
# i.e. along first axis
108+
full_sens_matrix = np.concatenate([ode_sens, alg_sens], axis=1)
109+
# Transpose and reshape into a (n_states * n_t, n_p) matrix
110+
full_sens_matrix = full_sens_matrix.transpose(2, 1, 0).reshape(
111+
n_t * n_states, n_p
107112
)
108-
# 2. Transpose into a (n_t, n_states * n_p) matrix
109-
full_sens_matrix = full_sens_matrix.T
110-
# 3. Reshape into a (n_t, n_p, n_states) matrix,
111-
# then tranpose n_p and n_states to get (n_t, n_states, n_p) matrix
112-
full_sens_matrix = full_sens_matrix.reshape(n_t, n_p, n_states).transpose(
113-
0, 2, 1
114-
)
115-
# 3. Stack time and space to get a (n_t * n_states, n_p) matrix
116-
full_sens_matrix = full_sens_matrix.reshape(n_t * n_states, n_p)
117113

118114
# Save the full sensitivity matrix
119-
120115
sensitivity = {"all": full_sens_matrix}
121116
# also save the sensitivity wrt each parameter (read the columns of the
122117
# sensitivity matrix)

0 commit comments

Comments
 (0)