Skip to content

Commit

Permalink
[gym_jiminy/common] Refactor 'DeltaQuantity' to improve performance f…
Browse files Browse the repository at this point in the history
…or long horizon.
  • Loading branch information
duburcqa committed Feb 6, 2025
1 parent 5b6293a commit e56a934
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 124 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections import OrderedDict
from typing import (
Dict, Any, Tuple, List, TypeVar, Generic, TypedDict, Optional, Callable,
Mapping, no_type_check, TYPE_CHECKING)
Mapping, SupportsFloat, no_type_check, TYPE_CHECKING)

import numpy as np
import numpy.typing as npt
Expand Down Expand Up @@ -39,7 +39,7 @@
InfoType = Dict[str, Any]

PolicyCallbackFun = Callable[
[Obs, Optional[Act], Optional[float], bool, bool, InfoType], Act]
[Obs, Optional[Act], Optional[SupportsFloat], bool, bool, InfoType], Act]


class EngineObsType(TypedDict):
Expand Down Expand Up @@ -321,7 +321,6 @@ def _observer_handle(self,
# that is supposed to be executed before `refresh_observation` is being
# called for the first time of an episode.
if not self.__is_observation_refreshed:
# Refresh observation
try:
self.refresh_observation(self.measurement)
except RuntimeError as e:
Expand Down Expand Up @@ -441,7 +440,7 @@ def evaluate(self,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
**kwargs: Any) -> Tuple[List[SupportsFloat], List[InfoType]]:
r"""Evaluate a policy on the environment over a complete episode.
.. warning::
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def evaluate(self,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
**kwargs: Any) -> Tuple[List[SupportsFloat], List[InfoType]]:
# Ensure that this layer is already declared as part of the pipeline
# environment. If not, update the pipeline manually, considering this
# layer as top-most. This would be the case if `reset` has never been
Expand Down
24 changes: 10 additions & 14 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,21 +261,17 @@ def __init__(self,
of the episode, during which the latter is bound
to continue whatever happens.
Optional: 0.0 by default.
:param op: Any callable taking as input argument either the complete
history of true or reference value of the quantity or only
the most recent and oldest value stored in the history (in
that exact order) depending on whether `bounds_only` is
False or True respectively, and returning its variation over
the whole span of the history. For instance, the difference
between the most recent and oldest values stored in the
history is is appropriate for position in Euclidean space,
but not for orientation as it is important to count turns.
:param op: Any callable taking as input argument the current and some
previous value of the quantity in that exact order, and
returning the signed difference between them. Typically,
the substraction operation is appropriate for position in
Euclidean space, but not for orientation as it is important
to count turns.
Optional: `sub` by default.
:param bounds_only: Whether to pass only the recent and oldest value
stored in the history as input argument of `op`
instead of the complete history (stacked as last
dimenstion).
Optional: True by default.
:param bounds_only: Whether to compute the total variation as the
difference between the most recent and oldest value
stored in the history, or the sum of differences
between successive timesteps.
:param is_truncation: Whether the episode should be considered
terminated or truncated whenever the termination
condition is triggered.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from .generic import (
TrackingQuantityReward, QuantityTermination,
DriftTrackingQuantityTermination, ShiftTrackingQuantityTermination)
from ..quantities.locomotion import angle_total
from ..quantities.locomotion import angle_difference
from .mixin import radial_basis_function


Expand Down Expand Up @@ -730,7 +730,7 @@ def __init__(self,
max_orientation_err,
horizon,
grace_period,
op=angle_total,
op=angle_difference,
bounds_only=False,
is_truncation=False,
training_only=training_only)
Expand Down
9 changes: 5 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/envs/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,7 @@ def evaluate(self,
horizon: Optional[float] = None,
enable_stats: bool = True,
enable_replay: Optional[bool] = None,
**kwargs: Any) -> Tuple[List[float], List[InfoType]]:
**kwargs: Any) -> Tuple[List[SupportsFloat], List[InfoType]]:
# Handling of default arguments
if enable_replay is None:
enable_replay = (
Expand All @@ -1046,11 +1046,12 @@ def evaluate(self,
self.eval()

# Initialize the simulation
reward: Optional[SupportsFloat]
obs, info = env.reset(seed=seed)
action, reward, terminated, truncated = None, None, False, False

# Run the simulation
reward_episode: List[float] = []
reward_episode: List[SupportsFloat] = []
info_episode = [info]
try:
while horizon is None or self.stepper_state.t < horizon:
Expand All @@ -1063,7 +1064,7 @@ def evaluate(self,
break
obs, reward, terminated, truncated, info = env.step(action)
info_episode.append(info)
reward_episode.append(float(reward))
reward_episode.append(reward)
except KeyboardInterrupt:
pass

Expand All @@ -1077,7 +1078,7 @@ def evaluate(self,
# Display some statistic if requested
if enable_stats:
print("env.num_steps:", self.num_steps)
print("cumulative reward:", sum(reward_episode))
print("cumulative reward:", sum(map(float, reward_episode)))

# Replay the result if requested
if enable_replay:
Expand Down
39 changes: 13 additions & 26 deletions python/gym_jiminy/common/gym_jiminy/common/quantities/locomotion.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,7 +1229,7 @@ def initialize(self) -> None:
if index_first is None:
if is_contact:
index_first = i
elif index_last is None: # type: ignore[unreachable]
elif index_last is None:
if not is_contact:
index_last = i
elif is_contact:
Expand Down Expand Up @@ -1600,44 +1600,31 @@ def refresh(self) -> ArrayOrScalar:


@nb.jit(nopython=True, cache=True, fastmath=True, inline='always')
def angle_difference(delta: ArrayOrScalar) -> ArrayOrScalar:
def angle_difference(angle_left: ArrayOrScalar,
angle_right: ArrayOrScalar) -> ArrayOrScalar:
"""Compute the signed element-wise difference (aka. oriented angle) between
two batches of angles.
The oriented angle is defined as the smallest angle in absolute value
between right and left angles (ignoring multi-turns), signed in accordance
with the angle going from right to left angles.
.. warning::
This method is fully compliant with angles restricted between
[-pi, pi], but it requires the "physical" distance between the two
angles to be smaller than pi.
.. seealso::
This proposed implementation is the most efficient one for batch size
of 1000. See this posts for reference about other implementations:
https://stackoverflow.com/a/7869457/4820605
:param delta: Pre-computed difference between left and right angles.
"""
return delta - np.floor((delta + np.pi) / (2 * np.pi)) * (2 * np.pi)


@nb.jit(nopython=True, cache=True, fastmath=True)
def angle_total(angles: np.ndarray) -> np.ndarray:
"""Compute the total signed multi-turn angle from start to end of
time-series of angles.
The method is fully compliant with individual angles restricted between
[-pi, pi], but it requires the distance between the angles at successive
timesteps to be smaller than pi.
.. seealso::
See `angle_difference` documentation for details.
:param angle: Temporal sequence of angles as a multi-dimensional array
whose last dimension gathers all the successive timesteps.
:param angle_left: Left-hand side angles.
:param angle_right: Right-hand side angles.
"""
# Note that `angle_difference` has been manually inlined as it results in
# about 50% speedup, which is surprising.
delta = angles[..., 1:] - angles[..., :-1]
delta = angle_left - angle_right
delta -= np.floor((delta + np.pi) / (2.0 * np.pi)) * (2 * np.pi)
return np.sum(delta, axis=-1)
return delta


@dataclass(unsafe_hash=True)
Expand Down Expand Up @@ -1695,7 +1682,7 @@ def __init__(self,
axis=0,
keys=(2,))),
horizon=horizon,
op=angle_total,
op=angle_difference,
bounds_only=False))),
auto_refresh=False)

Expand Down
141 changes: 70 additions & 71 deletions python/gym_jiminy/common/gym_jiminy/common/quantities/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -715,65 +715,44 @@ def refresh(self) -> ValueT:
class DeltaQuantity(InterfaceQuantity[ArrayOrScalar]):
"""Variation of a given quantity over the whole span of a horizon.
The value of the quantity is accumulated over a variable-length history
bounded by 'max_stack', which is basically a sliding window. This variation
is computed from the whole history. For Euclidean spaces, this variation is
If `bounds_only=False`, then the differences of the value of the quantity
between successive timesteps is accumulated over a variable-length history
bounded by 'max_stack', which is basically a sliding window. The total
variation over this horizon is defined as the sum of all the successive
differences stored in the history.
If `bounds_only=True`, then the value of the quantity is accumulated over
a variable-length history bounded by 'max_stack'. The total variation is
simply computed as the difference between most recent and oldest values
stored in the history.
"""

quantity_stack: InterfaceQuantity[
quantity: InterfaceQuantity[
Union[np.ndarray, Sequence[ArrayOrScalar]]]
"""Stacked quantity from which to compute the variation over the history.
"""Quantity from which to compute the total variation over the history.
"""

op: Union[
Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar],
Callable[[np.ndarray], np.ndarray]]
"""Any callable taking as input argument either the complete history of
values of the quantity or only the most recent and oldest value stored in
the history (in that exact order) depending on `bounds_only`, and returning
its variation over the whole history.
op: Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar]
"""Any callable taking as input argument the current and some previous
value of the quantity in that exact order, and returning the signed
difference between them.
"""

bounds_only: bool
"""Whether to pass only the recent and oldest value stored in the history
as input argument of `op` instead of the complete history (stacked as last
dimenstion).
"""Whether to compute the total variation as the difference between the
most recent and oldest value stored in the history, or the sum of
differences between successive timesteps.
"""

@overload
def __init__(self: "DeltaQuantity",
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity],
quantity: QuantityCreator[ArrayOrScalar],
horizon: float,
*,
op: Callable[[np.ndarray], np.ndarray],
bounds_only: Literal[False]) -> None:
...

@overload
def __init__(self: "DeltaQuantity",
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity],
quantity: QuantityCreator[ArrayOrScalar],
horizon: float,
*,
op: Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar],
bounds_only: Literal[True]) -> None:
...

def __init__(self,
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity],
quantity: QuantityCreator[ArrayOrScalar],
horizon: float,
*,
op: Union[
Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar],
Callable[[np.ndarray], np.ndarray]] = sub,
bounds_only: bool = True) -> None:
def __init__(
self,
env: InterfaceJiminyEnv,
parent: Optional[InterfaceQuantity],
quantity: QuantityCreator[ArrayOrScalar],
horizon: Optional[float],
*,
op: Callable[[ArrayOrScalar, ArrayOrScalar], ArrayOrScalar] = sub,
bounds_only: bool = True) -> None:
"""
:param env: Base or wrapped jiminy environment.
:param parent: Higher-level quantity from which this quantity is a
Expand All @@ -782,46 +761,66 @@ def __init__(self,
to compute the variation, plus any keyword-arguments
of its constructor except 'env' and 'parent'.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the drift.
:param op: Any callable taking as input argument either the complete
history of true or reference value of the quantity or only
the most recent and oldest value stored in the history (in
that exact order) depending on whether `bounds_only` is
False or True respectively, and returning its variation over
the whole span of the history. For instance, the difference
between the most recent and oldest values stored in the
history is is appropriate for position in Euclidean space,
but not for orientation as it is important to count turns.
stacked before computing the drift. `None` to consider
only two successive timesteps.
:param op: Any callable taking as input argument the current and some
previous value of the quantity in that exact order, and
returning the signed difference between them. Typically,
the substraction operation is appropriate for position in
Euclidean space, but not for orientation as it is important
to count turns.
Optional: `sub` by default.
:param bounds_only: Whether to pass only the recent and oldest value
stored in the history as input argument of `op`
instead of the complete history (stacked as last
dimenstion).
:param bounds_only: Whether to compute the total variation as the
difference between the most recent and oldest value
stored in the history, or the sum of differences
between successive timesteps.
Optional: True by default.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1
if horizon is None:
max_stack = 2
else:
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.op = op
self.bounds_only = bounds_only

# Define the appropriate quantity
quantity_stack: QuantityCreator
if bounds_only:
quantity_stack = (StackedQuantity, dict(
quantity=quantity,
max_stack=max_stack,
is_wrapping=False,
as_array=False))
else:
quantity_stack = (StackedQuantity, dict(
quantity=(DeltaQuantity, dict(
quantity=quantity,
horizon=None,
bounds_only=True,
op=op)),
max_stack=(max_stack - 1),
is_wrapping=True,
as_array=True))

# Call base implementation
super().__init__(
env,
parent,
requirements=dict(
quantity_stack=(StackedQuantity, dict(
quantity=quantity,
max_stack=max_stack,
is_wrapping=False,
as_array=not bounds_only))),
quantity_stack=quantity_stack),
auto_refresh=False)

# Keep try of the underlying quantity for equality check
if bounds_only:
self.quantity = self.quantity_stack.quantity
else:
self.quantity = self.quantity_stack.quantity.quantity

def refresh(self) -> ArrayOrScalar:
quantity_stack = self.quantity_stack.get()
if self.bounds_only:
return self.op(
quantity_stack[-1], # type: ignore[call-arg, arg-type]
quantity_stack[0])
return self.op(quantity_stack) # type: ignore[call-arg, arg-type]
return self.op(quantity_stack[-1], quantity_stack[0])
return quantity_stack.sum(axis=-1)
Loading

0 comments on commit e56a934

Please sign in to comment.