11
11
#include < pybind11/functional.h>
12
12
#include < pybind11/numpy.h>
13
13
#include < pybind11/pybind11.h>
14
+ #include < pybind11/stl_bind.h>
14
15
15
- #include < iostream>
16
+ // #include <iostream>
16
17
namespace py = pybind11;
17
18
18
- using residual_type = std::function<py::array_t <double >(
19
- double , py::array_t <double >, py::array_t <double >)>;
19
+
20
+ using np_array = py::array_t <realtype>;
21
+ PYBIND11_MAKE_OPAQUE (std::vector<np_array>);
22
+ using residual_type = std::function<np_array(realtype, np_array, np_array)>;
20
23
using sensitivities_type = std::function<void (
21
- py:: array_t <realtype> , realtype,
22
- py:: array_t <realtype>, py:: array_t <realtype> ,
23
- py:: array_t <realtype>, py:: array_t <realtype>
24
+ std::vector<np_array>& , realtype, const np_array& ,
25
+ const np_array&, const std::vector<np_array>& ,
26
+ const std::vector<np_array>&
24
27
)>;
25
- using jacobian_type =
26
- std::function<py::array_t <double >(double , py::array_t <double >, double )>;
27
-
28
+ using jacobian_type = std::function<np_array(realtype, np_array, realtype)>;
28
29
29
30
using event_type =
30
- std::function<py::array_t <double >(double , py::array_t <double >)>;
31
- using np_array = py::array_t <double >;
31
+ std::function<np_array(realtype, np_array)>;
32
32
33
33
using jac_get_type = std::function<np_array()>;
34
34
@@ -59,14 +59,12 @@ class PybammFunctions
59
59
py::array_t <double > operator ()(double t, py::array_t <double > y,
60
60
py::array_t <double > yp)
61
61
{
62
- std::cout << " calling res()" << std::endl;
63
62
return py_res (t, y, yp);
64
63
}
65
64
66
65
py::array_t <double > res (double t, py::array_t <double > y,
67
66
py::array_t <double > yp)
68
67
{
69
- std::cout << " calling res" << std::endl;
70
68
return py_res (t, y, yp);
71
69
}
72
70
@@ -79,10 +77,9 @@ class PybammFunctions
79
77
}
80
78
81
79
void sensitivities (
82
- py::array_t <realtype> resvalS,
83
- double t,
84
- py::array_t <realtype> y, py::array_t <realtype> yp,
85
- py::array_t <realtype> yS, py::array_t <realtype> ypS)
80
+ std::vector<np_array>& resvalS,
81
+ const double t, const np_array& y, const np_array& yp,
82
+ const std::vector<np_array>& yS, const std::vector<np_array>& ypS)
86
83
{
87
84
// this function evaluates the sensitivity equations required by IDAS,
88
85
// returning them in resvalS, which is preallocated as a numpy array
@@ -92,7 +89,6 @@ class PybammFunctions
92
89
// yS and ypS are also shape (np, n), y and yp are shape (n)
93
90
//
94
91
// dF/dy * s_i + dF/dyd * sd + dFdp_i for i in range(np)
95
- std::cout << " calling sensitivity" << std::endl;
96
92
py_sens (resvalS, t, y, yp, yS, ypS);
97
93
}
98
94
@@ -117,7 +113,6 @@ class PybammFunctions
117
113
int residual (realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
118
114
void *user_data)
119
115
{
120
- std::cout << " calling orignal res" <<std::endl;
121
116
PybammFunctions *python_functions_ptr =
122
117
static_cast <PybammFunctions *>(user_data);
123
118
PybammFunctions python_functions = *python_functions_ptr;
@@ -143,15 +138,13 @@ int residual(realtype tres, N_Vector yy, N_Vector yp, N_Vector rr,
143
138
{
144
139
rval[i] = r_np_ptr[i];
145
140
}
146
- std::cout << " back in original res" <<std::endl;
147
141
return 0 ;
148
142
}
149
143
150
144
int jacobian (realtype tt, realtype cj, N_Vector yy, N_Vector yp,
151
145
N_Vector resvec, SUNMatrix JJ, void *user_data, N_Vector tempv1,
152
146
N_Vector tempv2, N_Vector tempv3)
153
147
{
154
- std::cout << " calling orignal jacobian" <<std::endl;
155
148
realtype *yval;
156
149
yval = N_VGetArrayPointer (yy);
157
150
@@ -259,50 +252,59 @@ int sensitivities(int Ns, realtype t, N_Vector yy, N_Vector yp,
259
252
// occurred (in which case idas will attempt to correct),
260
253
// or a negative value if it failed unrecoverably (in which case the integration is halted and IDA SRES FAIL is returned)
261
254
//
262
- std::cout << " calling orignal sensitivity" <<std::endl;
263
255
PybammFunctions *python_functions_ptr =
264
256
static_cast <PybammFunctions *>(user_data);
265
257
PybammFunctions python_functions = *python_functions_ptr;
266
258
267
- realtype *yval = N_VGetArrayPointer (yy);
268
- realtype *ypval = N_VGetArrayPointer (yp);
269
- realtype *ySval = N_VGetArrayPointer (yS[0 ]);
270
- realtype *ypSval = N_VGetArrayPointer (ypS[0 ]);
271
- realtype *resvalSval = N_VGetArrayPointer (resvalS[0 ]);
272
-
273
259
int n = python_functions.number_of_states ;
274
260
int np = python_functions.number_of_parameters ;
275
261
276
- py::array_t <realtype> y_np = py::array_t <realtype>(n, yval);
277
- py::array_t <realtype> yp_np = py::array_t <realtype>(n, ypval);
278
- py::array_t <realtype> yS_np = py::array_t <realtype>(
279
- std::vector<ptrdiff_t >{np, n}, ySval
280
- );
281
- py::array_t <realtype> ypS_np = py::array_t <realtype>(
282
- std::vector<ptrdiff_t >{np, n}, ypSval
283
- );
284
- py::array_t <realtype> resvalS_np = py::array_t <realtype>(
285
- std::vector<ptrdiff_t >{np, n}, resvalSval
286
- );
287
-
288
- python_functions.sensitivities (
289
- resvalS_np, t, y_np, yp_np, yS_np, ypS_np
290
- );
262
+ // memory managed by sundials, so pass a destructor that does nothing
263
+ auto state_vector_shape = std::vector<ptrdiff_t >{n, 1 };
264
+ np_array y_np = np_array (state_vector_shape, N_VGetArrayPointer (yy),
265
+ py::capsule (&yy, [](void * p) {}));
266
+ np_array yp_np = np_array (state_vector_shape, N_VGetArrayPointer (yp),
267
+ py::capsule (&yp, [](void * p) {}));
268
+
269
+ std::vector<np_array> yS_np (np);
270
+ for (int i = 0 ; i < np; i++) {
271
+ auto capsule = py::capsule (yS + i, [](void * p) {});
272
+ yS_np[i] = np_array (state_vector_shape, N_VGetArrayPointer (yS[i]), capsule);
273
+ }
274
+
275
+ std::vector<np_array> ypS_np (np);
276
+ for (int i = 0 ; i < np; i++) {
277
+ auto capsule = py::capsule (ypS + i, [](void * p) {});
278
+ ypS_np[i] = np_array (state_vector_shape, N_VGetArrayPointer (ypS[i]), capsule);
279
+ }
280
+
281
+ std::vector<np_array> resvalS_np (np);
282
+ for (int i = 0 ; i < np; i++) {
283
+ auto capsule = py::capsule (resvalS + i, [](void * p) {});
284
+ resvalS_np[i] = np_array (state_vector_shape,
285
+ N_VGetArrayPointer (resvalS[i]), capsule);
286
+ }
287
+
288
+ realtype *ptr1 = static_cast <realtype *>(resvalS_np[0 ].request ().ptr );
289
+ const realtype* resvalSval = N_VGetArrayPointer (resvalS[0 ]);
290
+
291
+ python_functions.sensitivities (resvalS_np, t, y_np, yp_np, yS_np, ypS_np);
291
292
292
293
return 0 ;
293
294
}
294
295
295
296
class Solution
296
297
{
297
298
public:
298
- Solution (int retval, np_array t_np, np_array y_np)
299
- : flag(retval), t(t_np), y(y_np)
299
+ Solution (int retval, np_array t_np, np_array y_np, np_array yS_np )
300
+ : flag(retval), t(t_np), y(y_np), yS(yS_np)
300
301
{
301
302
}
302
303
303
304
int flag;
304
305
np_array t;
305
306
np_array y;
307
+ np_array yS;
306
308
};
307
309
308
310
/* main program */
@@ -324,7 +326,7 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
324
326
void *ida_mem; // pointer to memory
325
327
N_Vector yy, yp, avtol; // y, y', and absolute tolerance
326
328
N_Vector *yyS, *ypS; // y, y' for sensitivities
327
- realtype rtol, *yval, *ypval, *atval;
329
+ realtype rtol, *yval, *ypval, *atval, *ySval ;
328
330
int retval;
329
331
SUNMatrix J;
330
332
SUNLinearSolver LS;
@@ -341,6 +343,7 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
341
343
342
344
// set initial value
343
345
yval = N_VGetArrayPointer (yy);
346
+ ySval = N_VGetArrayPointer (yyS[0 ]);
344
347
ypval = N_VGetArrayPointer (yp);
345
348
atval = N_VGetArrayPointer (avtol);
346
349
int i;
@@ -351,11 +354,9 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
351
354
atval[i] = atol [i];
352
355
}
353
356
354
- if (number_of_parameters > 0 ) {
355
- for (int is = 0 ; is < number_of_parameters; is++) {
356
- N_VConst (NZERO, yyS[is]);
357
- N_VConst (NZERO, ypS[is]);
358
- }
357
+ for (int is = 0 ; is < number_of_parameters; is++) {
358
+ N_VConst (RCONST (0.0 ), yyS[is]);
359
+ N_VConst (RCONST (0.0 ), ypS[is]);
359
360
}
360
361
361
362
// allocate memory for solver
@@ -393,12 +394,9 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
393
394
394
395
if (number_of_parameters > 0 )
395
396
{
396
- std::cout << " running sensitivities with np = " << number_of_parameters << std::endl;
397
397
retval = IDASensInit (ida_mem, number_of_parameters,
398
398
IDA_SIMULTANEOUS, sensitivities, yyS, ypS);
399
- std::cout << " retval from IDASensInit is " << retval << std::endl;
400
399
retval = IDASensEEtolerances (ida_mem);
401
- std::cout << " retval from IDASensEEtolerances is " << retval << std::endl;
402
400
}
403
401
404
402
int t_i = 1 ;
@@ -409,16 +407,21 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
409
407
// set return vectors
410
408
std::vector<double > t_return (number_of_timesteps);
411
409
std::vector<double > y_return (number_of_timesteps * number_of_states);
410
+ std::vector<double > yS_return (number_of_parameters * number_of_timesteps * number_of_states);
412
411
413
412
t_return[0 ] = t (0 );
414
- int j;
415
- for (j = 0 ; j < number_of_states; j++)
413
+ for (int j = 0 ; j < number_of_states; j++)
416
414
{
417
415
y_return[j] = yval[j];
418
416
}
417
+ for (int j = 0 ; j < number_of_parameters; j++) {
418
+ const int base_index = j * number_of_timesteps * number_of_states;
419
+ for (int k = 0 ; k < number_of_states; k++) {
420
+ yS_return[base_index + k] = ySval[j * number_of_states + k];
421
+ }
422
+ }
419
423
420
424
// calculate consistent initial conditions
421
- std::cout << " calculating ICs" << std::endl;
422
425
N_Vector id;
423
426
auto id_np_val = rhs_alg_id.unchecked <1 >();
424
427
id = N_VNew_Serial (number_of_states);
@@ -433,33 +436,34 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
433
436
434
437
IDASetId (ida_mem, id);
435
438
IDACalcIC (ida_mem, IDA_YA_YDP_INIT, t (1 ));
436
- std::cout << " finished calculating ICs" << std::endl;
437
439
438
440
while (true )
439
441
{
440
442
t_next = t (t_i);
441
- std::cout << " next time step " <<t_next<<std::endl;
442
443
IDASetStopTime (ida_mem, t_next);
443
444
retval = IDASolve (ida_mem, t_final, &tret, yy, yp, IDA_NORMAL);
444
445
445
- if (retval == IDA_TSTOP_RETURN)
446
+ if (retval == IDA_TSTOP_RETURN || retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN )
446
447
{
448
+ IDAGetSens (ida_mem, &tret, yyS);
449
+
447
450
t_return[t_i] = tret;
448
- for (j = 0 ; j < number_of_states; j++)
451
+ for (int j = 0 ; j < number_of_states; j++)
449
452
{
450
453
y_return[t_i * number_of_states + j] = yval[j];
451
454
}
455
+ for (int j = 0 ; j < number_of_parameters; j++) {
456
+ const int base_index = j * number_of_timesteps * number_of_states
457
+ + t_i * number_of_states;
458
+ for (int k = 0 ; k < number_of_states; k++) {
459
+ yS_return[base_index + k] = ySval[j * number_of_states + k];
460
+ }
461
+ }
452
462
t_i += 1 ;
453
- }
454
-
455
- if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN)
456
- {
457
- t_return[t_i] = tret;
458
- for (j = 0 ; j < number_of_states; j++)
459
- {
460
- y_return[t_i * number_of_states + j] = yval[j];
463
+ if (retval == IDA_SUCCESS || retval == IDA_ROOT_RETURN) {
464
+ break ;
461
465
}
462
- break ;
466
+
463
467
}
464
468
}
465
469
@@ -472,12 +476,19 @@ Solution solve(np_array t_np, np_array y0_np, np_array yp0_np,
472
476
SUNMatDestroy (J);
473
477
N_VDestroy (avtol);
474
478
N_VDestroy (yp);
479
+ if (number_of_parameters > 0 ) {
480
+ N_VDestroyVectorArray (yyS, number_of_parameters);
481
+ N_VDestroyVectorArray (ypS, number_of_parameters);
482
+ }
475
483
476
- py::array_t <double > t_ret = py::array_t <double >((t_i + 1 ), &t_return[0 ]);
477
- py::array_t <double > y_ret =
478
- py::array_t <double >((t_i + 1 ) * number_of_states, &y_return[0 ]);
484
+ np_array t_ret = np_array (t_i, &t_return[0 ]);
485
+ np_array y_ret = np_array (t_i * number_of_states, &y_return[0 ]);
486
+ np_array yS_ret = np_array (
487
+ std::vector<ptrdiff_t >{number_of_parameters, t_i, number_of_states},
488
+ &yS_return[0 ]
489
+ );
479
490
480
- Solution sol (retval, t_ret, y_ret);
491
+ Solution sol (retval, t_ret, y_ret, yS_ret );
481
492
482
493
return sol;
483
494
}
@@ -486,6 +497,8 @@ PYBIND11_MODULE(idaklu, m)
486
497
{
487
498
m.doc () = " sundials solvers" ; // optional module docstring
488
499
500
+ py::bind_vector<std::vector<np_array>>(m, " VectorNdArray" );
501
+
489
502
m.def (" solve" , &solve, " The solve function" , py::arg (" t" ), py::arg (" y0" ),
490
503
py::arg (" yp0" ), py::arg (" res" ), py::arg (" jac" ), py::arg (" sens" ),
491
504
py::arg (" get_jac_data" ),
@@ -498,5 +511,6 @@ PYBIND11_MODULE(idaklu, m)
498
511
py::class_<Solution>(m, " solution" )
499
512
.def_readwrite (" t" , &Solution::t)
500
513
.def_readwrite (" y" , &Solution::y)
514
+ .def_readwrite (" yS" , &Solution::yS)
501
515
.def_readwrite (" flag" , &Solution::flag);
502
516
}
0 commit comments