From da70bd13ed8fde2d65c528575a052893118b03d2 Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Fri, 7 Feb 2025 14:16:29 +0100 Subject: [PATCH] [python/dynamics] Fix bug making 'Trajectory.get' extremely inefficient. --- .../common/gym_jiminy/common/utils/spaces.py | 3 -- python/jiminy_py/src/jiminy_py/dynamics.py | 36 +++++++++---------- 2 files changed, 18 insertions(+), 21 deletions(-) diff --git a/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py b/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py index 325629da2..ae2b0a840 100644 --- a/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py +++ b/python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py @@ -52,9 +52,6 @@ def _array_clip(value: np.ndarray, :param low: Optional lower bound. :param high: Optional upper bound. """ - # Note that in-place clipping is actually slower than out-of-place in - # Numba when 'fastmath' compilation flag is set. - # Short circuit if there is neither low or high bounds if low is None and high is None: return value.copy() diff --git a/python/jiminy_py/src/jiminy_py/dynamics.py b/python/jiminy_py/src/jiminy_py/dynamics.py index bcf8083d0..0813916ad 100644 --- a/python/jiminy_py/src/jiminy_py/dynamics.py +++ b/python/jiminy_py/src/jiminy_py/dynamics.py @@ -215,7 +215,7 @@ def __init__(self, self._t_prev = 0.0 self._index_prev = 1 - # List of optional state fields that are provided + # List of optional state fields that have been specified. # Note that looking for keys in such a small set is not worth the # hassle of using Python `set`, which breaks ordering and index access. fields_: List[str] = [] @@ -227,7 +227,7 @@ def __init__(self, raise ValueError( "The state information being set must be the same " "for all the timesteps of a given trajectory.") - else: + elif field not in fields_: fields_.append(field) self._fields = tuple(fields_) @@ -286,10 +286,11 @@ def time_interval(self) -> Tuple[float, float]: It raises an exception if no data is available. """ - if not self.has_data: + try: + return (self._times[0], self._times[-1]) + except IndexError as e: raise RuntimeError( - "State sequence is empty. Time interval undefined.") - return (self._times[0], self._times[-1]) + "State sequence is empty. Time interval undefined.") from e def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State: """Query the state at a given timestamp. @@ -311,20 +312,19 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State: """ # pylint: disable=possibly-used-before-assignment - # Raise exception if state sequence is empty - if not self.has_data: - raise RuntimeError( - "State sequence is empty. Impossible to interpolate data.") - # Backup the original query time t_orig = t # Handling of the desired mode n_steps = 0.0 - t_start, t_end = self.time_interval + try: + t_start, t_end = self._times[0], self._times[-1] + except IndexError as e: + raise RuntimeError( + "State sequence is empty. Impossible to interpolate data.") if mode == "raise": if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL: - raise RuntimeError("Time is out-of-range.") + raise RuntimeError("Query time out-of-range.") elif mode == "wrap": if t_end > t_start: n_steps, t_rel = divmod(t - t_start, t_end - t_start) @@ -365,17 +365,17 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State: else: position = pin.interpolate( self._pinocchio_model, s_left.q, s_right.q, alpha) - data = {"q": position} + state = dict(t=t_orig, q=position) for field in self._fields: value_left = getattr(s_left, field) if return_left: - data[field] = value_left.copy() + state[field] = value_left.copy() continue value_right = getattr(s_right, field) if return_right: - data[field] = value_right.copy() + state[field] = value_right.copy() else: - data[field] = value_left + alpha * (value_right - value_left) + state[field] = value_left + alpha * (value_right - value_left) # Perform odometry if time is wrapping if self._stride_offset_log6 is not None and n_steps: @@ -383,8 +383,8 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State: ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7]) position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat) - # Return state instances bundling all data - return State(t=t_orig, **data) + # Return State object + return State(**state) # #####################################################################