Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python/dynamics] Fix bug making 'Trajectory.get' extremely inefficient. #881

Merged
merged 1 commit into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/gym_jiminy/common/gym_jiminy/common/utils/spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,6 @@ def _array_clip(value: np.ndarray,
:param low: Optional lower bound.
:param high: Optional upper bound.
"""
# Note that in-place clipping is actually slower than out-of-place in
# Numba when 'fastmath' compilation flag is set.

# Short circuit if there is neither low or high bounds
if low is None and high is None:
return value.copy()
Expand Down
39 changes: 20 additions & 19 deletions python/jiminy_py/src/jiminy_py/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
import logging
from bisect import bisect_left
from dataclasses import dataclass, fields
from typing import List, Union, Optional, Tuple, Sequence, Callable, Literal
from typing import (
List, Union, Optional, Tuple, Sequence, Callable, Dict, Literal)

import numpy as np

Expand Down Expand Up @@ -215,7 +216,7 @@ def __init__(self,
self._t_prev = 0.0
self._index_prev = 1

# List of optional state fields that are provided
# List of optional state fields that have been specified.
# Note that looking for keys in such a small set is not worth the
# hassle of using Python `set`, which breaks ordering and index access.
fields_: List[str] = []
Expand All @@ -227,7 +228,7 @@ def __init__(self,
raise ValueError(
"The state information being set must be the same "
"for all the timesteps of a given trajectory.")
else:
elif field not in fields_:
fields_.append(field)
self._fields = tuple(fields_)

Expand Down Expand Up @@ -286,10 +287,11 @@ def time_interval(self) -> Tuple[float, float]:

It raises an exception if no data is available.
"""
if not self.has_data:
raise RuntimeError(
try:
return (self._times[0], self._times[-1])
except IndexError:
raise RuntimeError( # pylint: disable=raise-missing-from
"State sequence is empty. Time interval undefined.")
return (self._times[0], self._times[-1])

def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""Query the state at a given timestamp.
Expand All @@ -311,20 +313,19 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""
# pylint: disable=possibly-used-before-assignment

# Raise exception if state sequence is empty
if not self.has_data:
raise RuntimeError(
"State sequence is empty. Impossible to interpolate data.")

# Backup the original query time
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
try:
t_start, t_end = self._times[0], self._times[-1]
except IndexError:
raise RuntimeError( # pylint: disable=raise-missing-from
"State sequence is empty. Impossible to interpolate data.")
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
raise RuntimeError("Query time out-of-range.")
elif mode == "wrap":
if t_end > t_start:
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
Expand Down Expand Up @@ -365,26 +366,26 @@ def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
else:
position = pin.interpolate(
self._pinocchio_model, s_left.q, s_right.q, alpha)
data = {"q": position}
state: Dict[str, Union[float, np.ndarray]] = dict(t=t_orig, q=position)
for field in self._fields:
value_left = getattr(s_left, field)
if return_left:
data[field] = value_left.copy()
state[field] = value_left.copy()
continue
value_right = getattr(s_right, field)
if return_right:
data[field] = value_right.copy()
state[field] = value_right.copy()
else:
data[field] = value_left + alpha * (value_right - value_left)
state[field] = value_left + alpha * (value_right - value_left)

# Perform odometry if time is wrapping
if self._stride_offset_log6 is not None and n_steps:
stride_offset = pin.exp6(n_steps * self._stride_offset_log6)
ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7])
position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat)

# Return state instances bundling all data
return State(t=t_orig, **data)
# Return a State object
return State(**state) # type: ignore[arg-type]


# #####################################################################
Expand Down
Loading