@@ -368,16 +368,22 @@ def _get_weighted_function_inputs(penalty, penalty_idx, ocp, nlp, scaled):
368
368
369
369
if nlp :
370
370
x = PenaltyHelpers .states (
371
- penalty , penalty_idx , lambda p_idx , n_idx , sn_idx : _get_x (ocp , p_idx , n_idx , sn_idx , scaled )
371
+ penalty ,
372
+ penalty_idx ,
373
+ lambda p_idx , n_idx , sn_idx : _get_x (ocp , p_idx , n_idx , sn_idx , scaled , penalty ),
372
374
)
373
375
u = PenaltyHelpers .controls (
374
- penalty , penalty_idx , lambda p_idx , n_idx , sn_idx : _get_u (ocp , p_idx , n_idx , sn_idx , scaled )
376
+ penalty ,
377
+ penalty_idx ,
378
+ lambda p_idx , n_idx , sn_idx : _get_u (ocp , p_idx , n_idx , sn_idx , scaled , penalty ),
375
379
)
376
380
p = PenaltyHelpers .parameters (
377
381
penalty , penalty_idx , lambda p_idx , n_idx , sn_idx : _get_p (ocp , p_idx , n_idx , sn_idx , scaled )
378
382
)
379
383
a = PenaltyHelpers .states (
380
- penalty , penalty_idx , lambda p_idx , n_idx , sn_idx : _get_a (ocp , p_idx , n_idx , sn_idx , scaled )
384
+ penalty ,
385
+ penalty_idx ,
386
+ lambda p_idx , n_idx , sn_idx : _get_a (ocp , p_idx , n_idx , sn_idx , scaled , penalty ),
381
387
)
382
388
d = PenaltyHelpers .numerical_timeseries (
383
389
penalty ,
@@ -396,7 +402,9 @@ def _get_weighted_function_inputs(penalty, penalty_idx, ocp, nlp, scaled):
396
402
return t0 , x , u , p , a , d , weight , target
397
403
398
404
399
- def _get_x (ocp , phase_idx , node_idx , subnodes_idx , scaled ):
405
+ def _get_x (ocp , phase_idx , node_idx , subnodes_idx , scaled , penalty ):
406
+ idx = 0 if not penalty .is_multinode_penalty else penalty .nodes_phase .index (phase_idx )
407
+ subnodes_are_decision_states = penalty .subnodes_are_decision_states [idx ] and not penalty .is_transition
400
408
values = ocp .nlp [phase_idx ].X_scaled if scaled else ocp .nlp [phase_idx ].X
401
409
if subnodes_idx .stop == - 1 :
402
410
if subnodes_idx .start == 0 :
@@ -407,11 +415,16 @@ def _get_x(ocp, phase_idx, node_idx, subnodes_idx, scaled):
407
415
else :
408
416
raise RuntimeError ("only subnodes_idx.start == 0 is supported for subnodes_idx.stop == -1" )
409
417
else :
410
- x = values [node_idx ][:, subnodes_idx ] if node_idx < len (values ) else ocp .cx ()
418
+ if subnodes_are_decision_states :
419
+ x = values [node_idx ][:, subnodes_idx ] if node_idx < len (values ) else ocp .cx ()
420
+ else :
421
+ x = values [node_idx ][:, 0 ] if node_idx < len (values ) else ocp .cx ()
411
422
return x
412
423
413
424
414
- def _get_u (ocp , phase_idx , node_idx , subnodes_idx , scaled ):
425
+ def _get_u (ocp , phase_idx , node_idx , subnodes_idx , scaled , penalty ):
426
+ idx = 0 if not penalty .is_multinode_penalty else penalty .nodes_phase .index (phase_idx )
427
+ subnodes_are_decision_states = penalty .subnodes_are_decision_states [idx ] and not penalty .is_transition
415
428
values = ocp .nlp [phase_idx ].U_scaled if scaled else ocp .nlp [phase_idx ].U
416
429
if subnodes_idx .stop == - 1 :
417
430
if subnodes_idx .start == 0 :
@@ -422,15 +435,20 @@ def _get_u(ocp, phase_idx, node_idx, subnodes_idx, scaled):
422
435
else :
423
436
raise RuntimeError ("only subnodes_idx.start == 0 is supported for subnodes_idx.stop == -1" )
424
437
else :
425
- u = values [node_idx ][:, subnodes_idx ] if node_idx < len (values ) else ocp .cx ()
438
+ if subnodes_are_decision_states :
439
+ u = values [node_idx ][:, subnodes_idx ] if node_idx < len (values ) else ocp .cx ()
440
+ else :
441
+ u = values [node_idx ][:, 0 ] if node_idx < len (values ) else ocp .cx ()
426
442
return u
427
443
428
444
429
445
def _get_p (ocp , phase_idx , node_idx , subnodes_idx , scaled ):
430
446
return ocp .parameters .scaled .cx if scaled else ocp .parameters .scaled
431
447
432
448
433
- def _get_a (ocp , phase_idx , node_idx , subnodes_idx , scaled ):
449
+ def _get_a (ocp , phase_idx , node_idx , subnodes_idx , scaled , penalty ):
450
+ idx = 0 if not penalty .is_multinode_penalty else penalty .nodes_phase .index (phase_idx )
451
+ subnodes_are_decision_states = penalty .subnodes_are_decision_states [idx ] and not penalty .is_transition
434
452
values = ocp .nlp [phase_idx ].A_scaled if scaled else ocp .nlp [phase_idx ].A
435
453
if subnodes_idx .stop == - 1 :
436
454
if subnodes_idx .start == 0 :
@@ -441,7 +459,10 @@ def _get_a(ocp, phase_idx, node_idx, subnodes_idx, scaled):
441
459
else :
442
460
raise RuntimeError ("only subnodes_idx.start == 0 is supported for subnodes_idx.stop == -1" )
443
461
else :
444
- a = values [node_idx ][:, subnodes_idx ] if node_idx < len (values ) else ocp .cx ()
462
+ if subnodes_are_decision_states :
463
+ a = values [node_idx ][:, subnodes_idx ] if node_idx < len (values ) else ocp .cx ()
464
+ else :
465
+ a = values [node_idx ][:, 0 ] if node_idx < len (values ) else ocp .cx ()
445
466
return a
446
467
447
468
0 commit comments