Skip to content

Commit

Permalink
[gym_jiminy/rllib] Fix OOM errors for both RAM and disk space. (#877)
Browse files Browse the repository at this point in the history
* [gym_jiminy/common] Remove LRU cache for trajectory state getter as it never hits in practice.
* [gym_jiminy/rllib] Fix log paths not deleted during evaluation.
* [gym_jiminy/rllib] Add option to 'evaluate_from_runner' to delete log files automatically.
  • Loading branch information
duburcqa authored Feb 3, 2025
1 parent 954d583 commit 72feeaf
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 87 deletions.
57 changes: 38 additions & 19 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,22 +1298,22 @@ def evaluate_from_algo(algo: Algorithm,
step_dt = None
_pretty_print_episode_metrics(all_episodes, step_dt)

# Backup only the log file corresponding to the best and worst trial
# Backup only the log file corresponding to the best and worst trial, while
# deleting all the others.
all_returns = [
episode.get_return() for episode in all_episodes]
idx_worst, idx_best = np.argsort(all_returns)[[0, -1]]
log_labels, log_paths = [], []
for label, idx in (
("best", idx_best), ("worst", idx_worst))[:num_episodes]:
ext = Path(all_log_paths[idx]).suffix
for idx, log_path_orig in tuple(enumerate(all_log_paths))[::-1]:
if idx not in (idx_worst, idx_best):
os.remove(log_path_orig)
continue
ext = Path(log_path_orig).suffix
label = "best" if idx == idx_best else "worst"
log_path = f"{algo.logdir}/iter_{algo.iteration}-{label}{ext}"
try:
shutil.move(all_log_paths[idx], log_path)
except FileNotFoundError:
LOGGER.warning("Failed to save log file during evaluation.")
else:
log_paths.append(log_path)
log_labels.append(label)
shutil.move(log_path_orig, log_path)
log_paths.append(log_path)
log_labels.append(label)

# Replay and/or record a video of the best and worst trials if requested.
# Async to enable replaying and recording while training keeps going.
Expand Down Expand Up @@ -1378,13 +1378,15 @@ def evaluate_from_algo(algo: Algorithm,
return results


def evaluate_from_runner(env_runner: EnvRunner,
num_episodes: int = 1,
print_stats: Optional[bool] = None,
enable_replay: Optional[bool] = None,
block: bool = True,
**kwargs: Any
) -> Tuple[Sequence[EpisodeType], Sequence[str]]:
def evaluate_from_runner(
env_runner: EnvRunner,
num_episodes: int = 1,
print_stats: Optional[bool] = None,
enable_replay: Optional[bool] = None,
delete_log_files: bool = True,
block: bool = True,
**kwargs: Any
) -> Tuple[Sequence[EpisodeType], Optional[Sequence[str]]]:
"""Evaluates the performance of a given local worker.
This method is specifically tailored for Gym environments inheriting from
Expand All @@ -1400,20 +1402,31 @@ def evaluate_from_runner(env_runner: EnvRunner,
Optional: True by default if `record_video_path` is
not provided and the default/current backend supports
it, False otherwise.
:param delete_log_files: Whether to delete log files instead of returning
them. Note that this option is not supported if
`enable_replay=True` and `block=False`.
:param block: Whether calling this method should be blocking.
Optional: True by default.
:param kwargs: Extra keyword arguments to forward to the viewer if any.
:returns: Tuple gathering the sequences of episodes and log files.
:returns: Sequences of episodes, along with the sequence of corresponding
log files if `delete_log_files=False`, None otherwise.
"""
# Assert(s) for type checker
assert isinstance(env_runner, SingleAgentEnvRunner)

# Make sure that the input arguments are valid
if delete_log_files and enable_replay and not block:
raise ValueError(
"Specifying `delete_log_files=True` is not available in "
"conjunction with `enable_replay=True` and `block=True`.")

# Handling of default argument(s)
if print_stats is None:
print_stats = num_episodes >= 10

# Sample episodes
all_log_paths: Optional[Sequence[str]]
_, all_episodes, all_log_paths = (
sample_from_runner(env_runner, num_episodes))

Expand Down Expand Up @@ -1452,6 +1465,12 @@ def evaluate_from_runner(env_runner: EnvRunner,
ctypes.pythonapi.PyThreadState_SetAsyncExc(
ctypes.c_long(thread.ident), ctypes.py_object(SystemExit))

# Delete log files if requested
if delete_log_files:
for log_path in all_log_paths:
os.remove(log_path)
all_log_paths = None

# Return all collected data
return all_episodes, all_log_paths

Expand Down
108 changes: 40 additions & 68 deletions python/jiminy_py/src/jiminy_py/dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
# pylint: disable=invalid-name,no-member
import logging
from bisect import bisect_left
from functools import lru_cache
from dataclasses import dataclass, fields
from typing import (
List, Union, Optional, Tuple, Sequence, Dict, Callable, Literal)
from typing import List, Union, Optional, Tuple, Sequence, Callable, Literal

import numpy as np

Expand Down Expand Up @@ -233,9 +231,6 @@ def __init__(self,
fields_.append(field)
self._fields = tuple(fields_)

# Hacky way to enable argument-based function caching at instance-level
self.__dict__['_get'] = lru_cache(maxsize=None)(self._get)

@property
def has_data(self) -> bool:
"""Whether the trajectory has data, ie the state sequence is not empty.
Expand Down Expand Up @@ -296,17 +291,49 @@ def time_interval(self) -> Tuple[float, float]:
"State sequence is empty. Time interval undefined.")
return (self._times[0], self._times[-1])

def _get(self, t: float) -> Dict[str, np.ndarray]:
def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""Query the state at a given timestamp.
.. note::
This method is used internally by `get`. It is not meant to be
called manually.
Internally, the nearest neighbor states are linearly interpolated,
taking into account the corresponding Lie Group of all state attributes
that are available.
:param t: Time of the state to extract from the trajectory.
:param mode: Fallback strategy when the query time is not in the time
interval 'time_interval' of the trajectory. 'raise' raises
an exception if the query time is out-of-bound wrt the
underlying state sequence of the selected trajectory.
'clip' forces clipping of the query time before
interpolation of the state sequence. 'wrap' wraps around
the query time wrt the time span of the trajectory. This
is useful to store periodic trajectories as finite state
sequences.
"""
# pylint: disable=possibly-used-before-assignment

# Raise exception if state sequence is empty
if not self.has_data:
raise RuntimeError(
"State sequence is empty. Impossible to interpolate data.")

# Backup the original query time
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
elif mode == "wrap":
if t_end > t_start:
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
t = t_rel + t_start
else:
t = t_start
else:
t = min(max(t, t_start), t_end)

# Get nearest neighbors timesteps for linear interpolation.
# Note that the left and right data points may be associated with the
# same timestamp, corresponding respectively t- and t+. These values
Expand All @@ -330,7 +357,7 @@ def _get(self, t: float) -> Dict[str, np.ndarray]:
return_right = t_right - t < TRAJ_INTERP_TOL
alpha = (t - t_left) / (t_right - t_left)

# Interpolate state
# Interpolate state data
if return_left:
position = s_left.q.copy()
elif return_right:
Expand All @@ -350,69 +377,14 @@ def _get(self, t: float) -> Dict[str, np.ndarray]:
else:
data[field] = value_left + alpha * (value_right - value_left)

# Make sure that data are immutable.
# This is essential to make sure that cached values cannot be altered.
for arr in data.values():
arr.setflags(write=False)

return data

def get(self, t: float, mode: TrajectoryTimeMode = 'raise') -> State:
"""Query the state at a given timestamp.
Internally, the nearest neighbor states are linearly interpolated,
taking into account the corresponding Lie Group of all state attributes
that are available.
:param t: Time of the state to extract from the trajectory.
:param mode: Fallback strategy when the query time is not in the time
interval 'time_interval' of the trajectory. 'raise' raises
an exception if the query time is out-of-bound wrt the
underlying state sequence of the selected trajectory.
'clip' forces clipping of the query time before
interpolation of the state sequence. 'wrap' wraps around
the query time wrt the time span of the trajectory. This
is useful to store periodic trajectories as finite state
sequences.
"""
# Raise exception if state sequence is empty
if not self.has_data:
raise RuntimeError(
"State sequence is empty. Impossible to interpolate data.")

# Backup the original query time
t_orig = t

# Handling of the desired mode
n_steps = 0.0
t_start, t_end = self.time_interval
if mode == "raise":
if t - t_end > TRAJ_INTERP_TOL or t_start - t > TRAJ_INTERP_TOL:
raise RuntimeError("Time is out-of-range.")
elif mode == "wrap":
if t_end > t_start:
n_steps, t_rel = divmod(t - t_start, t_end - t_start)
t = t_rel + t_start
else:
t = t_start
else:
t = min(max(t, t_start), t_end)

# Rounding time to avoid cache miss issues
# Note that `int(x + 0.5)` is faster than `round(x)`.
t = int(t / TRAJ_INTERP_TOL + 0.5) * TRAJ_INTERP_TOL

# Interpolate state at the desired time
state = State(t=t_orig, **self._get(t))

# Perform odometry if time is wrapping
if self._stride_offset_log6 is not None and n_steps:
state.q = position = state.q.copy()
stride_offset = pin.exp6(n_steps * self._stride_offset_log6)
ff_xyzquat = stride_offset * pin.XYZQUATToSE3(position[:7])
position[:7] = pin.SE3ToXYZQUAT(ff_xyzquat)

return state
# Return state instances bundling all data
return State(t=t_orig, **data)


# #####################################################################
Expand Down

0 comments on commit 72feeaf

Please sign in to comment.