Skip to content

Commit 1f3bc9d

Browse files
committed
#1477 idaklu sensitivities works and tested for python, casadi and jax
1 parent f10fdfc commit 1f3bc9d

File tree

4 files changed

+121
-103
lines changed

4 files changed

+121
-103
lines changed

pybamm/solvers/c_solvers/idaklu.cpp

+88-74
Original file line numberDiff line numberDiff line change
@@ -11,24 +11,24 @@
1111
#include <pybind11/functional.h>
1212
#include <pybind11/numpy.h>
1313
#include <pybind11/pybind11.h>
14+
#include <pybind11/stl_bind.h>
1415

15-
#include <iostream>
16+
//#include <iostream>
1617
namespace py = pybind11;
1718

18-
using residual_type = std::function<py::array_t<double>(
19-
double, py::array_t<double>, py::array_t<double>)>;
19+
20+
using np_array = py::array_t<realtype>;
21+
PYBIND11_MAKE_OPAQUE(std::vector<np_array>);
22+
using residual_type = std::function<np_array(realtype, np_array, np_array)>;
2023
using sensitivities_type = std::function<void(
21-
py::array_t<realtype>, realtype,
22-
py::array_t<realtype>, py::array_t<realtype>,
23-
py::array_t<realtype>, py::array_t<realtype>
24+
std::vector<np_array>&, realtype, const np_array&,
25+
const np_array&, const std::vector<np_array>&,
26+
const std::vector<np_array>&
2427
)>;
25-
using jacobian_type =
26-
std::function<py::array_t<double>(double, py::array_t<double>, double)>;
27-
28+
using jacobian_type = std::function<np_array(realtype, np_array, realtype)>;
2829

2930
using event_type =
30-
std::function<py::array_t<double>(double, py::array_t<double>)>;
31-
using np_array = py::array_t<double>;
31+
std::function<np_array(realtype, np_array)>;
3232

3333
using jac_get_type = std::function<np_array()>;
3434

@@ -59,14 +59,12 @@ class PybammFunctions
5959
py::array_t<double> operator()(double t, py::array_t<double> y,
6060
py::array_t<double> yp)
6161
{
62-
std::cout << "calling res()" << std::endl;
6362
return py_res(t, y, yp);
6463
}
6564

6665
py::array_t<double> res(double t, py::array_t<double> y,
6766
py::array_t<double> yp)
6867
{
69-
std::cout << "calling res" << std::endl;
7068
return py_res(t, y, yp);
7169
}
7270

@@ -79,10 +77,9 @@ class PybammFunctions
7977
}
8078

8179
void sensitivities(
82-
py::array_t<realtype> resvalS,
83-
double t,
84-
py::array_t<realtype> y, py::array_t<realtype> yp,
85-
py::array_t<realtype> yS, py::array_t<realtype> ypS)
80+
std::vector<np_array>& resvalS,
81+
const double t, const np_array& y, const np_array& yp,
82+
const std::vector<np_array>& yS, const std::vector<np_array>& ypS)
8683
{
8784
// this function evaluates the sensitivity equations required by IDAS,
8885
// returning them in resvalS, which is preallocated as a numpy array
@@ -92,7 +89,6 @@ class PybammFunctions
9289
// yS and ypS are also shape (np, n), y and yp are shape (n)
9390
//
9491
// dF/dy * s_i + dF/dyd * sd + dFdp_i for i in range(np)
95-
std::cout << "calling sensitivity" << std::endl;
9692
py_sens(resvalS, t, y, yp, yS, ypS);
9793
}
9894

@@ -117,7 +113,6 @@ class PybammFunctions
117113
int residual(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
118114
void *user_data)
119115
{
120-
std::cout << "calling orignal res" <<std::endl;
121116
PybammFunctions *python_functions_ptr =
122117
static_cast<PybammFunctions *>(user_data);
123118
PybammFunctions python_functions = *python_functions_ptr;
@@ -143,15 +138,13 @@ int residual(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
143138
{
144139
rval[i] = r_np_ptr[i];
145140
}
146-
std::cout << "back in original res" <<std::endl;
147141
return 0;
148142
}
149143

150144
int jacobian(realtype tt, realtype cj, N_Vector yy, N_Vector yp,
151145
N_Vector resvec, SUNMatrix JJ, void *user_data, N_Vector tempv1,
152146
N_Vector tempv2, N_Vector tempv3)
153147
{
154-
std::cout << "calling orignal jacobian" <<std::endl;
155148
realtype *yval;
156149
yval = N_VGetArrayPointer(yy);
157150

@@ -259,50 +252,59 @@ int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp,
259252
// occurred (in which case idas will attempt to correct),
260253
// or a negative value if it failed unrecoverably (in which case the integration is halted and IDA SRES FAIL is returned)
261254
//
262-
std::cout << "calling orignal sensitivity" <<std::endl;
263255
PybammFunctions *python_functions_ptr =
264256
static_cast<PybammFunctions *>(user_data);
265257
PybammFunctions python_functions = *python_functions_ptr;
266258

267-
realtype *yval = N_VGetArrayPointer(yy);
268-
realtype *ypval = N_VGetArrayPointer(yp);
269-
realtype *ySval = N_VGetArrayPointer(yS[0]);
270-
realtype *ypSval = N_VGetArrayPointer(ypS[0]);
271-
realtype *resvalSval = N_VGetArrayPointer(resvalS[0]);
272-
273259
int n = python_functions.number_of_states;
274260
int np = python_functions.number_of_parameters;
275261

276-
py::array_t<realtype> y_np = py::array_t<realtype>(n, yval);
277-
py::array_t<realtype> yp_np = py::array_t<realtype>(n, ypval);
278-
py::array_t<realtype> yS_np = py::array_t<realtype>(
279-
std::vector<ptrdiff_t>{np, n}, ySval
280-
);
281-
py::array_t<realtype> ypS_np = py::array_t<realtype>(
282-
std::vector<ptrdiff_t>{np, n}, ypSval
283-
);
284-
py::array_t<realtype> resvalS_np = py::array_t<realtype>(
285-
std::vector<ptrdiff_t>{np, n}, resvalSval
286-
);
287-
288-
python_functions.sensitivities(
289-
resvalS_np, t, y_np, yp_np, yS_np, ypS_np
290-
);
262+
// memory managed by sundials, so pass a destructor that does nothing
263+
auto state_vector_shape = std::vector<ptrdiff_t>{n, 1};
264+
np_array y_np = np_array(state_vector_shape, N_VGetArrayPointer(yy),
265+
py::capsule(&yy, [](void* p) {}));
266+
np_array yp_np = np_array(state_vector_shape, N_VGetArrayPointer(yp),
267+
py::capsule(&yp, [](void* p) {}));
268+
269+
std::vector<np_array> yS_np(np);
270+
for (int i = 0; i < np; i++) {
271+
auto capsule = py::capsule(yS + i, [](void* p) {});
272+
yS_np[i] = np_array(state_vector_shape, N_VGetArrayPointer(yS[i]), capsule);
273+
}
274+
275+
std::vector<np_array> ypS_np(np);
276+
for (int i = 0; i < np; i++) {
277+
auto capsule = py::capsule(ypS + i, [](void* p) {});
278+
ypS_np[i] = np_array(state_vector_shape, N_VGetArrayPointer(ypS[i]), capsule);
279+
}
280+
281+
std::vector<np_array> resvalS_np(np);
282+
for (int i = 0; i < np; i++) {
283+
auto capsule = py::capsule(resvalS + i, [](void* p) {});
284+
resvalS_np[i] = np_array(state_vector_shape,
285+
N_VGetArrayPointer(resvalS[i]), capsule);
286+
}
287+
288+
realtype *ptr1 = static_cast<realtype *>(resvalS_np[0].request().ptr);
289+
const realtype* resvalSval = N_VGetArrayPointer(resvalS[0]);
290+
291+
python_functions.sensitivities(resvalS_np, t, y_np, yp_np, yS_np, ypS_np);
291292

292293
return 0;
293294
}
294295

295296
class Solution
296297
{
297298
public:
298-
Solution(int retval, np_array t_np, np_array y_np)
299-
: flag(retval), t(t_np), y(y_np)
299+
Solution(int retval, np_array t_np, np_array y_np, np_array yS_np)
300+
: flag(retval), t(t_np), y(y_np), yS(yS_np)
300301
{
301302
}
302303

303304
int flag;
304305
np_array t;
305306
np_array y;
307+
np_array yS;
306308
};
307309

308310
/* main program */
@@ -324,7 +326,7 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
324326
void *ida_mem; // pointer to memory
325327
N_Vector yy, yp, avtol; // y, y', and absolute tolerance
326328
N_Vector *yyS, *ypS; // y, y' for sensitivities
327-
realtype rtol, *yval, *ypval, *atval;
329+
realtype rtol, *yval, *ypval, *atval, *ySval;
328330
int retval;
329331
SUNMatrix J;
330332
SUNLinearSolver LS;
@@ -341,6 +343,7 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
341343

342344
// set initial value
343345
yval = N_VGetArrayPointer(yy);
346+
ySval = N_VGetArrayPointer(yyS[0]);
344347
ypval = N_VGetArrayPointer(yp);
345348
atval = N_VGetArrayPointer(avtol);
346349
int i;
@@ -351,11 +354,9 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
351354
atval[i] = atol[i];
352355
}
353356

354-
if (number_of_parameters > 0) {
355-
for (int is = 0 ; is < number_of_parameters; is++) {
356-
N_VConst(NZERO, yyS[is]);
357-
N_VConst(NZERO, ypS[is]);
358-
}
357+
for (int is = 0 ; is < number_of_parameters; is++) {
358+
N_VConst(RCONST(0.0), yyS[is]);
359+
N_VConst(RCONST(0.0), ypS[is]);
359360
}
360361

361362
// allocate memory for solver
@@ -393,12 +394,9 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
393394

394395
if (number_of_parameters > 0)
395396
{
396-
std::cout << "running sensitivities with np = " << number_of_parameters << std::endl;
397397
retval = IDASensInit(ida_mem, number_of_parameters,
398398
IDA_SIMULTANEOUS, sensitivities, yyS, ypS);
399-
std::cout << "retval from IDASensInit is " << retval << std::endl;
400399
retval = IDASensEEtolerances(ida_mem);
401-
std::cout << "retval from IDASensEEtolerances is " << retval << std::endl;
402400
}
403401

404402
int t_i = 1;
@@ -409,16 +407,21 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
409407
// set return vectors
410408
std::vector<double> t_return(number_of_timesteps);
411409
std::vector<double> y_return(number_of_timesteps * number_of_states);
410+
std::vector<double> yS_return(number_of_parameters * number_of_timesteps * number_of_states);
412411

413412
t_return[0] = t(0);
414-
int j;
415-
for (j = 0; j < number_of_states; j++)
413+
for (int j = 0; j < number_of_states; j++)
416414
{
417415
y_return[j] = yval[j];
418416
}
417+
for (int j = 0; j < number_of_parameters; j++) {
418+
const int base_index = j * number_of_timesteps * number_of_states;
419+
for (int k = 0; k < number_of_states; k++) {
420+
yS_return[base_index + k] = ySval[j * number_of_states + k];
421+
}
422+
}
419423

420424
// calculate consistent initial conditions
421-
std::cout << "calculating ICs" << std::endl;
422425
N_Vector id;
423426
auto id_np_val = rhs_alg_id.unchecked<1>();
424427
id = N_VNew_Serial(number_of_states);
@@ -433,33 +436,34 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
433436

434437
IDASetId(ida_mem, id);
435438
IDACalcIC(ida_mem, IDA_YA_YDP_INIT, t(1));
436-
std::cout << "finished calculating ICs" << std::endl;
437439

438440
while (true)
439441
{
440442
t_next = t(t_i);
441-
std::cout << "next time step "<<t_next<<std::endl;
442443
IDASetStopTime(ida_mem, t_next);
443444
retval = IDASolve(ida_mem, t_final, &tret, yy, yp, IDA_NORMAL);
444445

445-
if (retval == IDA_TSTOP_RETURN)
446+
if (retval == IDA_TSTOP_RETURN || retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN)
446447
{
448+
IDAGetSens(ida_mem, &tret, yyS);
449+
447450
t_return[t_i] = tret;
448-
for (j = 0; j < number_of_states; j++)
451+
for (int j = 0; j < number_of_states; j++)
449452
{
450453
y_return[t_i * number_of_states + j] = yval[j];
451454
}
455+
for (int j = 0; j < number_of_parameters; j++) {
456+
const int base_index = j * number_of_timesteps * number_of_states
457+
+ t_i * number_of_states;
458+
for (int k = 0; k < number_of_states; k++) {
459+
yS_return[base_index + k] = ySval[j * number_of_states + k];
460+
}
461+
}
452462
t_i += 1;
453-
}
454-
455-
if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN)
456-
{
457-
t_return[t_i] = tret;
458-
for (j = 0; j < number_of_states; j++)
459-
{
460-
y_return[t_i * number_of_states + j] = yval[j];
463+
if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN) {
464+
break;
461465
}
462-
break;
466+
463467
}
464468
}
465469

@@ -472,12 +476,19 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
472476
SUNMatDestroy(J);
473477
N_VDestroy(avtol);
474478
N_VDestroy(yp);
479+
if (number_of_parameters > 0) {
480+
N_VDestroyVectorArray(yyS, number_of_parameters);
481+
N_VDestroyVectorArray(ypS, number_of_parameters);
482+
}
475483

476-
py::array_t<double> t_ret = py::array_t<double>((t_i + 1), &t_return[0]);
477-
py::array_t<double> y_ret =
478-
py::array_t<double>((t_i + 1) * number_of_states, &y_return[0]);
484+
np_array t_ret = np_array(t_i, &t_return[0]);
485+
np_array y_ret = np_array(t_i * number_of_states, &y_return[0]);
486+
np_array yS_ret = np_array(
487+
std::vector<ptrdiff_t>{number_of_parameters, t_i, number_of_states},
488+
&yS_return[0]
489+
);
479490

480-
Solution sol(retval, t_ret, y_ret);
491+
Solution sol(retval, t_ret, y_ret, yS_ret);
481492

482493
return sol;
483494
}
@@ -486,6 +497,8 @@ PYBIND11_MODULE(idaklu, m)
486497
{
487498
m.doc() = "sundials solvers"; // optional module docstring
488499

500+
py::bind_vector<std::vector<np_array>>(m, "VectorNdArray");
501+
489502
m.def("solve", &solve, "The solve function", py::arg("t"), py::arg("y0"),
490503
py::arg("yp0"), py::arg("res"), py::arg("jac"), py::arg("sens"),
491504
py::arg("get_jac_data"),
@@ -498,5 +511,6 @@ PYBIND11_MODULE(idaklu, m)
498511
py::class_<Solution>(m, "solution")
499512
.def_readwrite("t", &Solution::t)
500513
.def_readwrite("y", &Solution::y)
514+
.def_readwrite("yS", &Solution::yS)
501515
.def_readwrite("flag", &Solution::flag);
502516
}

0 commit comments

Comments
 (0)