Skip to content

Commit

Permalink
[gym_jiminy/common] Fix cache auto-refresh edge-case. (#880)
Browse files Browse the repository at this point in the history
* [gym_jiminy/common] Fix cache auto-refresh edge-case.
* [gym_jiminy/common] Fix stack length computation from horizon.
* [gym_jiminy/common] Refactor 'DeltaQuantity' to improve performance for long horizon.
* [misc] Fix RL tutorial notebook.
  • Loading branch information
duburcqa authored Feb 6, 2025
1 parent 70939e4 commit ec07134
Show file tree
Hide file tree
Showing 13 changed files with 144 additions and 144 deletions.
17 changes: 13 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,19 @@ def __new__(cls,
self = super(partial_hashable, cls).__new__(cls, func, *args, **kwargs)

# Pre-compute normalized arguments once and for all
sig = inspect.signature(self.func)
bound = sig.bind_partial(*self.args, **(self.keywords or {}))
bound.apply_defaults()
self._normalized_args = tuple(bound.arguments.values())
try:
sig = inspect.signature(self.func)
bound = sig.bind_partial(*self.args, **(self.keywords or {}))
bound.apply_defaults()
self._normalized_args = tuple(bound.arguments.values())
except ValueError as e:
# Impossible to get signature from Python bindings. Keyword-only is
# enforced to ensure that equality check can be implemented.
if self.args:
raise ValueError(
"Specifying position arguments is not supported for "
"methods whose signature cannot be inspected.") from e
_, self._normalized_args = zip(*sorted(self.keywords.items()))

return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from typing import (
Dict, Any, Tuple, List, TypeVar, Generic, TypedDict, Optional, Callable,
Mapping, no_type_check, TYPE_CHECKING)
Mapping, SupportsFloat, no_type_check, TYPE_CHECKING)

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -39,7 +39,7 @@
InfoType = Dict[str, Any]

PolicyCallbackFun = Callable[
[Obs, Optional[Act], Optional[float], bool, bool, InfoType], Act]
[Obs, Optional[Act], Optional[SupportsFloat], bool, bool, InfoType], Act]


class EngineObsType(TypedDict):
Expand Down Expand Up @@ -321,7 +321,6 @@ def _observer_handle(self,
# that is supposed to be executed before `refresh_observation` is being
# called for the first time of an episode.
if not self.__is_observation_refreshed:
# Refresh observation
try:
self.refresh_observation(self.measurement)
except RuntimeError as e:
Expand Down Expand Up @@ -441,7 +440,7 @@ def evaluate(self,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
**kwargs: Any) -> Tuple[List[SupportsFloat], List[InfoType]]:
r"""Evaluate a policy on the environment over a complete episode.
.. warning::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def evaluate(self,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
**kwargs: Any) -> Tuple[List[SupportsFloat], List[InfoType]]:
# Ensure that this layer is already declared as part of the pipeline
# environment. If not, update the pipeline manually, considering this
# layer as top-most. This would be the case if `reset` has never been
Expand Down
27 changes: 12 additions & 15 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,21 +261,17 @@ def __init__(self,
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param op: Any callable taking as input argument either the complete
history of true or reference value of the quantity or only
the most recent and oldest value stored in the history (in
that exact order) depending on whether `bounds_only` is
False or True respectively, and returning its variation over
the whole span of the history. For instance, the difference
between the most recent and oldest values stored in the
history is is appropriate for position in Euclidean space,
but not for orientation as it is important to count turns.
:param op: Any callable taking as input argument the current and some
previous value of the quantity in that exact order, and
returning the signed difference between them. Typically,
the substraction operation is appropriate for position in
Euclidean space, but not for orientation as it is important
to count turns.
Optional: `sub` by default.
:param bounds_only: Whether to pass only the recent and oldest value
stored in the history as input argument of `op`
instead of the complete history (stacked as last
dimenstion).
Optional: True by default.
:param bounds_only: Whether to compute the total variation as the
difference between the most recent and oldest value
stored in the history, or the sum of differences
between successive timesteps.
:param is_truncation: Whether the episode should be considered
terminated or truncated whenever the termination
condition is triggered.
Expand Down Expand Up @@ -430,7 +426,8 @@ def __init__(self,
self.op = op

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
assert horizon >= env.step_dt
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Define drift of quantity
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .generic import (
TrackingQuantityReward, QuantityTermination,
DriftTrackingQuantityTermination, ShiftTrackingQuantityTermination)
from ..quantities.locomotion import angle_total
from ..quantities.locomotion import angle_difference
from .mixin import radial_basis_function


Expand Down Expand Up @@ -730,7 +730,7 @@ def __init__(self,
max_orientation_err,
horizon,
grace_period,
op=angle_total,
op=angle_difference,
bounds_only=False,
is_truncation=False,
training_only=training_only)
Expand Down
9 changes: 5 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def evaluate(self,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
**kwargs: Any) -> Tuple[List[SupportsFloat], List[InfoType]]:
# Handling of default arguments
if enable_replay is None:
enable_replay = (
Expand All @@ -1046,11 +1046,12 @@ def evaluate(self,
self.eval()

# Initialize the simulation
reward: Optional[SupportsFloat]
obs, info = env.reset(seed=seed)
action, reward, terminated, truncated = None, None, False, False

# Run the simulation
reward_episode: List[float] = []
reward_episode: List[SupportsFloat] = []
info_episode = [info]
try:
while horizon is None or self.stepper_state.t < horizon:
Expand All @@ -1063,7 +1064,7 @@ def evaluate(self,
break
obs, reward, terminated, truncated, info = env.step(action)
info_episode.append(info)
reward_episode.append(float(reward))
reward_episode.append(reward)
except KeyboardInterrupt:
pass

Expand All @@ -1077,7 +1078,7 @@ def evaluate(self,
# Display some statistic if requested
if enable_stats:
print("env.num_steps:", self.num_steps)
print("cumulative reward:", sum(reward_episode))
print("cumulative reward:", sum(map(float, reward_episode)))

# Replay the result if requested
if enable_replay:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def __init__(
:param mode: Desired mode of evaluation for this quantity.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.max_stack = max_stack
Expand Down
45 changes: 17 additions & 28 deletions python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ def initialize(self) -> None:
if index_first is None:
if is_contact:
index_first = i
elif index_last is None: # type: ignore[unreachable]
elif index_last is None:
if not is_contact:
index_last = i
elif is_contact:
Expand Down Expand Up @@ -1569,7 +1569,8 @@ def __init__(self,
Optional: 'QuantityEvalMode.TRUE' by default.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
assert horizon >= env.step_dt
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.max_stack = max_stack
Expand Down Expand Up @@ -1599,44 +1600,31 @@ def refresh(self) -> ArrayOrScalar:


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def angle_difference(delta: ArrayOrScalar) -> ArrayOrScalar:
def angle_difference(angle_left: ArrayOrScalar,
angle_right: ArrayOrScalar) -> ArrayOrScalar:
"""Compute the signed element-wise difference (aka. oriented angle) between
two batches of angles.
The oriented angle is defined as the smallest angle in absolute value
between right and left angles (ignoring multi-turns), signed in accordance
with the angle going from right to left angles.
.. warning::
This method is fully compliant with angles restricted between
[-pi, pi], but it requires the "physical" distance between the two
angles to be smaller than pi.
.. seealso::
This proposed implementation is the most efficient one for batch size
of 1000. See this posts for reference about other implementations:
https://stackoverflow.com/a/7869457/4820605
:param delta: Pre-computed difference between left and right angles.
"""
return delta - np.floor((delta + np.pi) / (2 * np.pi)) * (2 * np.pi)


@nb.jit(nopython=True, cache=True, fastmath=True)
def angle_total(angles: np.ndarray) -> np.ndarray:
"""Compute the total signed multi-turn angle from start to end of
time-series of angles.
The method is fully compliant with individual angles restricted between
[-pi, pi], but it requires the distance between the angles at successive
timesteps to be smaller than pi.
.. seealso::
See `angle_difference` documentation for details.
:param angle: Temporal sequence of angles as a multi-dimensional array
whose last dimension gathers all the successive timesteps.
:param angle_left: Left-hand side angles.
:param angle_right: Right-hand side angles.
"""
# Note that `angle_difference` has been manually inlined as it results in
# about 50% speedup, which is surprising.
delta = angles[..., 1:] - angles[..., :-1]
delta = angle_left - angle_right
delta -= np.floor((delta + np.pi) / (2.0 * np.pi)) * (2 * np.pi)
return np.sum(delta, axis=-1)
return delta


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -1675,7 +1663,8 @@ def __init__(self,
Optional: 'QuantityEvalMode.TRUE' by default.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
assert horizon >= env.step_dt
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.max_stack = max_stack
Expand All @@ -1693,7 +1682,7 @@ def __init__(self,
axis=0,
keys=(2,))),
horizon=horizon,
op=angle_total,
op=angle_difference,
bounds_only=False))),
auto_refresh=False)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,9 +156,12 @@ def _build_quantity(
# The objective is to avoid resetting multiple times the same
# quantity because of the auto-refresh mechanism.
cache = SharedCache()
is_key_found = False
for i, (cache_key, _) in enumerate(self._caches):
if key[0] == cache_key[0]:
self._caches.insert(i + 1, (key, cache))
is_key_found = True
elif is_key_found:
self._caches.insert(i, (key, cache))
break
else:
self._caches.append((key, cache))
Expand Down
Loading

0 comments on commit ec07134

Please sign in to comment.