Skip to content

Commit f2b652c

Browse files
committed
Merge remote-tracking branch 'origin/cx_intermediates' into cx_intermediates
2 parents 8f23404 + 2b8364b commit f2b652c

17 files changed

+210
-185
lines changed

bioptim/examples/getting_started/example_external_forces.py

+11-11
Original file line numberDiff line numberDiff line change
@@ -149,45 +149,45 @@ def setup_external_forces(
149149
# Add appropriate forces based on method
150150
if external_force_method == "translational_force":
151151
external_force_set.add_translational_force(
152-
"contact0",
152+
"g",
153153
"Seg1",
154154
Seg1_force,
155155
point_of_application_in_local=Seg1_point_of_application if use_point_of_applications else None,
156156
)
157157
external_force_set.add_translational_force(
158-
"contact0",
158+
"h",
159159
"Test",
160160
Test_force,
161161
point_of_application_in_local=Test_point_of_application if use_point_of_applications else None,
162162
)
163163

164164
elif external_force_method == "translational_force_on_a_marker":
165-
external_force_set.add_translational_force("contact0","Test", Test_force, point_of_application_in_local="m0")
165+
external_force_set.add_translational_force("q","Test", Test_force, point_of_application_in_local="m0")
166166

167167
elif external_force_method == "in_global":
168168
external_force_set.add(
169-
"contact0",
169+
"i",
170170
"Seg1",
171171
np.concatenate((Seg1_force, Seg1_force), axis=0),
172172
point_of_application=Seg1_point_of_application if use_point_of_applications else None,
173173
)
174174
external_force_set.add(
175-
"contact0",
175+
"j",
176176
"Test",
177177
np.concatenate((Test_force, Test_force), axis=0),
178178
point_of_application=Test_point_of_application if use_point_of_applications else None,
179179
)
180180
elif external_force_method == "in_global_torque":
181-
external_force_set.add_torque("contact0", "Seg1", Seg1_force)
182-
external_force_set.add_torque("contact0", "Test", Test_force)
181+
external_force_set.add_torque("k", "Seg1", Seg1_force)
182+
external_force_set.add_torque("l", "Test", Test_force)
183183

184184
elif external_force_method == "in_segment_torque":
185-
external_force_set.add_torque_in_segment_frame("contact0", "Seg1", Seg1_force)
186-
external_force_set.add_torque_in_segment_frame("contact0", "Test", Test_force)
185+
external_force_set.add_torque_in_segment_frame("m", "Seg1", Seg1_force)
186+
external_force_set.add_torque_in_segment_frame("n", "Test", Test_force)
187187

188188
elif external_force_method == "in_segment":
189-
external_force_set.add_in_segment_frame("contact0", "Seg1", np.concatenate((Seg1_force, Seg1_force), axis=0))
190-
external_force_set.add_in_segment_frame("contact0", "Test", np.concatenate((Test_force, Test_force), axis=0))
189+
external_force_set.add_in_segment_frame("o", "Seg1", np.concatenate((Seg1_force, Seg1_force), axis=0))
190+
external_force_set.add_in_segment_frame("p", "Test", np.concatenate((Test_force, Test_force), axis=0))
191191

192192
return external_force_set
193193

bioptim/examples/muscle_driven_with_contact/contact_forces_inverse_dynamics_constraint_muscle.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,12 @@ def custom_configure(ocp: OptimalControlProgram, nlp: NonLinearProgram, numerica
4848

4949
# Implicit variables
5050
ConfigureProblem.configure_rigid_contact_forces(
51-
ocp, nlp, as_states=False, as_algebraic_states=True, as_controls=False
51+
ocp,
52+
nlp,
53+
as_states=False,
54+
as_algebraic_states=True,
55+
as_controls=False,
56+
as_states_dot=False,
5257
)
5358

5459
# Dynamics
@@ -72,7 +77,6 @@ def custom_dynamics(
7277
qdot = nlp.get_var_from_states_or_controls("qdot", states, controls)
7378
residual_tau = nlp.get_var_from_states_or_controls("tau", states, controls)
7479
mus_activations = nlp.get_var_from_states_or_controls("muscles", states, controls)
75-
rigid_contact_forces_derivatives = nlp.get_var_from_states_or_controls()
7680

7781
# Get external forces from the states
7882
rigid_contact_forces = nlp.get_external_forces(
@@ -158,7 +162,7 @@ def prepare_ocp(biorbd_model_path, phase_time, n_shooting, expand_dynamics=True)
158162
node=Node.ALL_SHOOTING,
159163
)
160164
multinode_constraints = MultinodeConstraintList()
161-
for i_node in range(n_shooting):
165+
for i_node in range(n_shooting - 1):
162166
multinode_constraints.add(
163167
MultinodeConstraintFcn.ALGEBRAIC_STATES_CONTINUITY,
164168
nodes_phase=(0, 0),

bioptim/interfaces/interface_utils.py

+30-9
Original file line numberDiff line numberDiff line change
@@ -368,16 +368,22 @@ def _get_weighted_function_inputs(penalty, penalty_idx, ocp, nlp, scaled):
368368

369369
if nlp:
370370
x = PenaltyHelpers.states(
371-
penalty, penalty_idx, lambda p_idx, n_idx, sn_idx: _get_x(ocp, p_idx, n_idx, sn_idx, scaled)
371+
penalty,
372+
penalty_idx,
373+
lambda p_idx, n_idx, sn_idx: _get_x(ocp, p_idx, n_idx, sn_idx, scaled, penalty),
372374
)
373375
u = PenaltyHelpers.controls(
374-
penalty, penalty_idx, lambda p_idx, n_idx, sn_idx: _get_u(ocp, p_idx, n_idx, sn_idx, scaled)
376+
penalty,
377+
penalty_idx,
378+
lambda p_idx, n_idx, sn_idx: _get_u(ocp, p_idx, n_idx, sn_idx, scaled, penalty),
375379
)
376380
p = PenaltyHelpers.parameters(
377381
penalty, penalty_idx, lambda p_idx, n_idx, sn_idx: _get_p(ocp, p_idx, n_idx, sn_idx, scaled)
378382
)
379383
a = PenaltyHelpers.states(
380-
penalty, penalty_idx, lambda p_idx, n_idx, sn_idx: _get_a(ocp, p_idx, n_idx, sn_idx, scaled)
384+
penalty,
385+
penalty_idx,
386+
lambda p_idx, n_idx, sn_idx: _get_a(ocp, p_idx, n_idx, sn_idx, scaled, penalty),
381387
)
382388
d = PenaltyHelpers.numerical_timeseries(
383389
penalty,
@@ -396,7 +402,9 @@ def _get_weighted_function_inputs(penalty, penalty_idx, ocp, nlp, scaled):
396402
return t0, x, u, p, a, d, weight, target
397403

398404

399-
def _get_x(ocp, phase_idx, node_idx, subnodes_idx, scaled):
405+
def _get_x(ocp, phase_idx, node_idx, subnodes_idx, scaled, penalty):
406+
idx = 0 if not penalty.is_multinode_penalty else penalty.nodes_phase.index(phase_idx)
407+
subnodes_are_decision_states = penalty.subnodes_are_decision_states[idx] and not penalty.is_transition
400408
values = ocp.nlp[phase_idx].X_scaled if scaled else ocp.nlp[phase_idx].X
401409
if subnodes_idx.stop == -1:
402410
if subnodes_idx.start == 0:
@@ -407,11 +415,16 @@ def _get_x(ocp, phase_idx, node_idx, subnodes_idx, scaled):
407415
else:
408416
raise RuntimeError("only subnodes_idx.start == 0 is supported for subnodes_idx.stop == -1")
409417
else:
410-
x = values[node_idx][:, subnodes_idx] if node_idx < len(values) else ocp.cx()
418+
if subnodes_are_decision_states:
419+
x = values[node_idx][:, subnodes_idx] if node_idx < len(values) else ocp.cx()
420+
else:
421+
x = values[node_idx][:, 0] if node_idx < len(values) else ocp.cx()
411422
return x
412423

413424

414-
def _get_u(ocp, phase_idx, node_idx, subnodes_idx, scaled):
425+
def _get_u(ocp, phase_idx, node_idx, subnodes_idx, scaled, penalty):
426+
idx = 0 if not penalty.is_multinode_penalty else penalty.nodes_phase.index(phase_idx)
427+
subnodes_are_decision_states = penalty.subnodes_are_decision_states[idx] and not penalty.is_transition
415428
values = ocp.nlp[phase_idx].U_scaled if scaled else ocp.nlp[phase_idx].U
416429
if subnodes_idx.stop == -1:
417430
if subnodes_idx.start == 0:
@@ -422,15 +435,20 @@ def _get_u(ocp, phase_idx, node_idx, subnodes_idx, scaled):
422435
else:
423436
raise RuntimeError("only subnodes_idx.start == 0 is supported for subnodes_idx.stop == -1")
424437
else:
425-
u = values[node_idx][:, subnodes_idx] if node_idx < len(values) else ocp.cx()
438+
if subnodes_are_decision_states:
439+
u = values[node_idx][:, subnodes_idx] if node_idx < len(values) else ocp.cx()
440+
else:
441+
u = values[node_idx][:, 0] if node_idx < len(values) else ocp.cx()
426442
return u
427443

428444

429445
def _get_p(ocp, phase_idx, node_idx, subnodes_idx, scaled):
430446
return ocp.parameters.scaled.cx if scaled else ocp.parameters.scaled
431447

432448

433-
def _get_a(ocp, phase_idx, node_idx, subnodes_idx, scaled):
449+
def _get_a(ocp, phase_idx, node_idx, subnodes_idx, scaled, penalty):
450+
idx = 0 if not penalty.is_multinode_penalty else penalty.nodes_phase.index(phase_idx)
451+
subnodes_are_decision_states = penalty.subnodes_are_decision_states[idx] and not penalty.is_transition
434452
values = ocp.nlp[phase_idx].A_scaled if scaled else ocp.nlp[phase_idx].A
435453
if subnodes_idx.stop == -1:
436454
if subnodes_idx.start == 0:
@@ -441,7 +459,10 @@ def _get_a(ocp, phase_idx, node_idx, subnodes_idx, scaled):
441459
else:
442460
raise RuntimeError("only subnodes_idx.start == 0 is supported for subnodes_idx.stop == -1")
443461
else:
444-
a = values[node_idx][:, subnodes_idx] if node_idx < len(values) else ocp.cx()
462+
if subnodes_are_decision_states:
463+
a = values[node_idx][:, subnodes_idx] if node_idx < len(values) else ocp.cx()
464+
else:
465+
a = values[node_idx][:, 0] if node_idx < len(values) else ocp.cx()
445466
return a
446467

447468

bioptim/limits/multinode_penalty.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,8 @@ def __init__(
6868
if len(nodes) != len(nodes_phase):
6969
raise ValueError("Each of the nodes must have a corresponding nodes_phase")
7070

71-
self.multinode_penalty = True
71+
self.is_multinode_penalty = True
72+
self.is_transition = False
7273

7374
self.nodes_phase = nodes_phase
7475
self.nodes = nodes

bioptim/limits/penalty_helpers.py

+41-17
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def t0(penalty, index, get_t0: Callable):
2929
This method returns the t0 of a penalty.
3030
"""
3131

32-
if penalty.multinode_penalty:
32+
if penalty.is_multinode_penalty:
3333
phases, nodes, _ = _get_multinode_indices(penalty, is_constructing_penalty=False)
3434
phase, node = phases[0], nodes[0]
3535
else:
@@ -68,20 +68,34 @@ def states(penalty, index, get_state_decision: Callable, is_constructing_penalty
6868

6969
node = penalty.node_idx[index]
7070

71-
if penalty.multinode_penalty:
71+
if penalty.is_multinode_penalty:
7272
x = []
7373
phases, nodes, subnodes = _get_multinode_indices(penalty, is_constructing_penalty)
7474
idx = 0
7575
for phase, node, sub in zip(phases, nodes, subnodes):
76-
if not is_constructing_penalty and node == penalty.ns[idx] and (penalty.control_types[idx] != ControlType.LINEAR_CONTINUOUS and penalty.control_types[idx] != ControlType.CONSTANT_WITH_LAST_NODE):
76+
if (
77+
not is_constructing_penalty
78+
and node == penalty.ns[idx]
79+
and (
80+
penalty.control_types[idx] != ControlType.LINEAR_CONTINUOUS
81+
and penalty.control_types[idx] != ControlType.CONSTANT_WITH_LAST_NODE
82+
)
83+
):
7784
x.append(_reshape_to_vector(get_state_decision(phase, node, range(0, 1))))
7885
else:
7986
x.append(_reshape_to_vector(get_state_decision(phase, node, sub)))
8087
idx += 1
8188
return _vertcat(x)
8289

8390
else:
84-
subnodes = slice(0, None if node < penalty.ns[0] and penalty.subnodes_are_decision_states[0] else 1)
91+
subnodes = slice(
92+
0,
93+
(
94+
None
95+
if node < penalty.ns[0] and penalty.subnodes_are_decision_states[0] and not penalty.is_transition
96+
else 1
97+
),
98+
)
8599
x0 = _reshape_to_vector(get_state_decision(penalty.phase, node, subnodes))
86100

87101
if is_constructing_penalty:
@@ -97,12 +111,23 @@ def states(penalty, index, get_state_decision: Callable, is_constructing_penalty
97111
def controls(penalty, index, get_control_decision: Callable, is_constructing_penalty: bool = False):
98112
node = penalty.node_idx[index]
99113

100-
if penalty.multinode_penalty:
114+
if penalty.is_multinode_penalty:
101115
u = []
102116
phases, nodes, subnodes = _get_multinode_indices(penalty, is_constructing_penalty)
117+
idx = 0
103118
for phase, node, sub in zip(phases, nodes, subnodes):
104-
# No need to test for control types as this is never integrated (so we only need the starting value)
105-
u.append(_reshape_to_vector(get_control_decision(phase, node, sub)))
119+
if (
120+
not is_constructing_penalty
121+
and node == penalty.ns[idx]
122+
and (
123+
penalty.control_types[idx] != ControlType.LINEAR_CONTINUOUS
124+
and penalty.control_types[idx] != ControlType.CONSTANT_WITH_LAST_NODE
125+
)
126+
):
127+
u.append(_reshape_to_vector(get_control_decision(phase, node, range(0, 1))))
128+
else:
129+
u.append(_reshape_to_vector(get_control_decision(phase, node, sub)))
130+
idx += 1
106131
return _vertcat(u)
107132

108133
if is_constructing_penalty:
@@ -136,7 +161,7 @@ def parameters(penalty, index, get_parameter_decision: Callable):
136161
@staticmethod
137162
def numerical_timeseries(penalty, index, get_numerical_timeseries: Callable):
138163
node = penalty.node_idx[index]
139-
if penalty.multinode_penalty:
164+
if penalty.is_multinode_penalty:
140165
# numerical timeseries are expected to be provided only at the shooting node.
141166
for i_phase in penalty.nodes_phase:
142167
d = get_numerical_timeseries(i_phase, node, 0) # cx_start
@@ -219,27 +244,26 @@ def get_multinode_penalty_subnodes_starting_index(p):
219244

220245

221246
def _get_multinode_indices(penalty, is_constructing_penalty: bool):
222-
if not penalty.multinode_penalty:
247+
if not penalty.is_multinode_penalty:
223248
raise RuntimeError("This function should only be called for multinode penalties")
224249

225-
if not (all(penalty.subnodes_are_decision_states) or sum(penalty.subnodes_are_decision_states) == 0):
226-
# This check allows to test only for penalty.subnodes_are_decision_states[0] below
227-
raise NotImplementedError(
228-
"All controllers must be of the same type (either all or none should have subnodes_are_decision_states)"
229-
)
230-
231250
phases = penalty.nodes_phase
232251
nodes = penalty.multinode_idx
233252

234253
startings = PenaltyHelpers.get_multinode_penalty_subnodes_starting_index(penalty)
235254
subnodes = []
236255
for i_starting, starting in enumerate(startings):
237-
if starting < 0 or starting == 2: # The last cx accessible
256+
if starting < 0: # The last cx accessible (cx_end)
238257
if is_constructing_penalty:
239258
subnodes.append(slice(-1, None))
240259
else:
241260
subnodes.append(slice(0, 1))
242-
elif penalty.subnodes_are_decision_states[0]:
261+
elif starting == 2: # Also the last cx accessible (cx_end) since there are only 3 cx available
262+
if is_constructing_penalty:
263+
subnodes.append(slice(2, 3))
264+
else:
265+
subnodes.append(slice(0, 1))
266+
elif penalty.subnodes_are_decision_states[i_starting] and not penalty.is_transition:
243267
if nodes[i_starting] >= penalty.ns[i_starting]:
244268
subnodes.append(slice(0, 1))
245269
else:

bioptim/limits/penalty_option.py

+6-5
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def __init__(
196196
self.weighted_function: list[Function | None] = []
197197
self.weighted_function_non_threaded: list[Function | None] = []
198198

199-
self.multinode_penalty = False
199+
self.is_multinode_penalty = False
200+
self.is_transition = False
200201
self.nodes_phase = None # This is relevant for multinodes
201202
self.nodes = None # This is relevant for multinodes
202203
if self.derivative and self.explicit_derivative:
@@ -606,9 +607,9 @@ def _set_penalty_function(self, controllers: list[PenaltyController], fcn: MX |
606607
self.weighted_function[node] = self.weighted_function[node].expand()
607608

608609
def _check_sanity_of_penalty_interactions(self, controller):
609-
if self.multinode_penalty and self.explicit_derivative:
610+
if self.is_multinode_penalty and self.explicit_derivative:
610611
raise ValueError("multinode_penalty and explicit_derivative cannot be true simultaneously")
611-
if self.multinode_penalty and self.derivative:
612+
if self.is_multinode_penalty and self.derivative:
612613
raise ValueError("multinode_penalty and derivative cannot be true simultaneously")
613614
if self.derivative and self.explicit_derivative:
614615
raise ValueError("derivative and explicit_derivative cannot be true simultaneously")
@@ -625,7 +626,7 @@ def _check_sanity_of_penalty_interactions(self, controller):
625626
)
626627

627628
def get_variable_inputs(self, controllers: list[PenaltyController]):
628-
if self.multinode_penalty:
629+
if self.is_multinode_penalty:
629630
controller = controllers[0] # Recast controller as a normal variable (instead of a list)
630631
self.node_idx[0] = controller.node_index
631632

@@ -733,7 +734,7 @@ def vertcat_cx_end():
733734
# performing some kind of integration or derivative and this last node does not exist
734735
if nlp.control_type in (ControlType.CONSTANT_WITH_LAST_NODE,):
735736
return vertcat(u, controls.scaled.cx_end)
736-
if self.integrate or self.derivative or self.explicit_derivative or self.multinode_penalty:
737+
if self.integrate or self.derivative or self.explicit_derivative or self.is_multinode_penalty:
737738
return u
738739
else:
739740
return vertcat(u, controls.scaled.cx_end)

bioptim/limits/phase_transition.py

+1
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ def __init__(
7171
self.bounds = Bounds("phase_transition", interpolation=InterpolationType.CONSTANT)
7272
self.node = Node.TRANSITION
7373
self.quadratic = True
74+
self.is_transition = True
7475

7576
def add_or_replace_to_penalty_pool(self, ocp, nlp):
7677
super(PhaseTransition, self).add_or_replace_to_penalty_pool(ocp, nlp)

bioptim/models/biorbd/external_forces.py

+1-6
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,6 @@
44

55
class ExternalForceSet:
66

7-
@property
8-
def nb_external_forces(self) -> int:
9-
attributes = ["in_global", "torque_in_global", "translational_in_global", "in_local", "torque_in_local"]
10-
return sum([len(values) for attr in attributes for values in getattr(self, attr).values()])
11-
127
@property
138
def nb_external_forces_components(self) -> int:
149
"""Return the number of vertical components of the external forces if concatenated in a unique vector"""
@@ -38,7 +33,7 @@ def check_segment_names(self, segment_names: tuple[str, ...]) -> None:
3833
for attr in attributes:
3934
for force_name , force in getattr(self, attr).items():
4035
if force["segment"] not in segment_names:
41-
wrong_segments.append(force[0]["segment"])
36+
wrong_segments.append(force["segment"])
4237

4338
if wrong_segments:
4439
raise ValueError(

0 commit comments

Comments
 (0)