From 7b44c72c9acdf4df5d8854e027ab2f734d8aadda Mon Sep 17 00:00:00 2001 From: Alexis Duburcq Date: Mon, 3 Feb 2025 18:04:53 +0100 Subject: [PATCH] [gym_jiminy/rllib] Make callbacks checkpointable. Restore curriculum state from checkpoints. --- .../rllib/gym_jiminy/rllib/curriculum.py | 213 +++++++++++++----- .../rllib/gym_jiminy/rllib/utilities.py | 201 +++++++++++------ 2 files changed, 290 insertions(+), 124 deletions(-) diff --git a/python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py b/python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py index 0602564e7..de92d1512 100644 --- a/python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py +++ b/python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py @@ -2,9 +2,11 @@ """ TODO: Write documentation. """ import math +from functools import partial from collections import defaultdict from typing import ( - List, Any, Dict, Tuple, Optional, Callable, DefaultDict, cast) + List, Any, Dict, Tuple, Optional, Callable, DefaultDict, Union, Collection, + cast) import numpy as np import gymnasium as gym @@ -12,12 +14,15 @@ from ray.rllib.core.rl_module import RLModule from ray.rllib.env.env_context import EnvContext from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.env_runner_group import EnvRunnerGroup from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner from ray.rllib.algorithms.algorithm import Algorithm from ray.rllib.algorithms.callbacks import DefaultCallbacks +from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS from ray.rllib.utils.metrics.metrics_logger import MetricsLogger -from ray.rllib.utils.typing import ResultDict, EpisodeID, EpisodeType +from ray.rllib.utils.typing import ( + ResultDict, EpisodeID, EpisodeType, StateDict) from jiminy_py import tree from gym_jiminy.common.bases import BasePipelineWrapper @@ -30,7 +35,53 @@ ) from e -class TaskSchedulingSamplingCallback(DefaultCallbacks): +def _update_proba_task_tree_from_runner( + env_runner: EnvRunner, + proba_task_tree: ProbaTaskTree) -> None: + """Update the probability task tree of all the environments being managed + by a given runner. + + :param env_runner: Environment runner to consider. + :param proba_task_tree: + Probability tree consistent with the task space of the underlying + environment, which must derive from + `gym_jiminy.toolbox.wrappers.meta_envs.BaseTaskSettableWrapper`. + """ + # FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used + # reliability in conjunction with `BasePipelineWrapper`. + # See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294 + assert isinstance(env_runner, SingleAgentEnvRunner) + env = env_runner.env.unwrapped + assert isinstance(env, gym.vector.SyncVectorEnv) + for env in env.unwrapped.envs: + while not isinstance(env, BaseTaskSettableWrapper): + assert isinstance( + env, (gym.Wrapper, BasePipelineWrapper)) + env = env.env + env.proba_task_tree = proba_task_tree + + +def _update_proba_task_tree_from_runner_group( + workers: EnvRunnerGroup, + proba_task_tree: ProbaTaskTree) -> None: + """Update the probability tree for a group of environment runners. + + :param workers: Group of environment runners to be updated. + :param proba_task_tree: + Probability tree consistent with the task space of the underlying + environment, which must derive from + `gym_jiminy.toolbox.wrappers.meta_envs.BaseTaskSettableWrapper`. + """ + workers.foreach_worker(partial( + _update_proba_task_tree_from_runner, + proba_task_tree=proba_task_tree)) + # workers.foreach_worker( + # lambda worker: worker.env.unwrapped.set_attr( + # 'proba_task_tree', + # (proba_task_tree,) * worker.num_envs)) + + +class TaskSchedulingSamplingCallback(DefaultCallbacks, Checkpointable): r"""Scheduler that automatically adapt the probability distribution of the tasks of a task-settable environment in order to maintain the same level of performance among all the task, no matter if some of them are much harder @@ -131,14 +182,40 @@ def __init__(self, # Whether to clear all task metrics at the end of the next episode self._must_clear_metrics = False + # Whether the internal state of this callback instance has been + # initialized. + # Note that initialization must be postponed because it requires having + # access to attributes of the environments, but they are not available + # at this point. self._is_initialized = False + + # Whether the internal state of this callback has just been restored + # but still must to be propagated to all managed environments. + self._is_restored = False + + # Maximum number of steps of the episodes. + # This is used to standardize the return because 0.0 and 1.0 (assuming + # the reward is normalized), or at least to make it independent from + # the maximum episode duration. self._max_num_steps_all: Tuple[int, ...] = () + + # Arbitrarily nested task space self._task_space = gym.spaces.Tuple([]) + + # Flatten task space representation for efficiency self._task_paths: Tuple[TaskPath, ...] = () self._task_names: Tuple[str, ...] = () - self._proba_task_tree: ProbaTaskTree = () + + # Current probablity task tree + self.proba_task_tree: ProbaTaskTree = () + + # Flattened probablity task tree representation for efficiency self._proba_task_tree_flat_map: Dict[str, int] = {} + # Use custom logger and aggregate stats + self.stats_logger = MetricsLogger() + self._buffer = MetricsLogger() + def on_environment_created(self, *, env_runner: EnvRunner, @@ -150,14 +227,22 @@ def on_environment_created(self, if self._is_initialized: return + # Add stats proxy at worker-level to ease centralization later on. + # Note that the proxy must NOT be added if it was already done before. + # This would be the case when no remote workers are available, causing + # the local one is used for sample collection as a fallback. + if not hasattr(env_runner, "_task_stats_logger"): + env_runner.__dict__["_task_stats_logger"] = self._buffer + # Backup tree information try: self._task_space, *_ = env.unwrapped.get_attr("task_space") - self._proba_task_tree, *_ = ( - env.unwrapped.get_attr("proba_task_tree")) except AttributeError as e: raise RuntimeError("Base environment must be wrapped with " "`BaseTaskSettableWrapper`.") from e + if not self.proba_task_tree: + self.proba_task_tree, *_ = ( + env.unwrapped.get_attr("proba_task_tree")) # Get the maximum episode duration self._max_num_steps_all = tuple( @@ -177,7 +262,7 @@ def on_environment_created(self, # Initialize proba task tree flat ordering map self._proba_task_tree_flat_map = { "/".join(map(str, path[::2])): i for i, (path, _) in enumerate( - tree.flatten_with_path(self._proba_task_tree))} + tree.flatten_with_path(self.proba_task_tree))} # The callback is now fully initialized self._is_initialized = True @@ -191,6 +276,20 @@ def on_episode_start(self, env_index: int, rl_module: RLModule, **kwargs: Any) -> None: + # Propagate task probability to the environments if not already done. + # FIXME: At this point, the environment managed by the runner at hand + # has already been reset. As a result, the tasks associated with the + # very first episodes to be collected after restoring the state of this + # callback instance would be sampled according to the old probability + # task tree. To address this issue, propagation of the probability tree + # at environment-level after restoring state should be moved in + # `on_episode_created`. However, this callback method is not available + # prior to `ray>=2.41`. + if self._is_restored: + _update_proba_task_tree_from_runner( + env_runner, self.proba_task_tree) + self._is_restored = False + # Drop all partial episodes associated with the environment at hand # when starting a fresh new one since it will never be done anyway. if env_index in self._ongoing_episodes: @@ -207,10 +306,9 @@ def on_episode_end(self, env_index: int, rl_module: RLModule, **kwargs: Any) -> None: - # Force clearing all custom metrics at the beginning of every - # sampling iteration. See `MonitorEpisodeCallback.on_episode_end`. + # Clear all custom metrics at the beginning of every sampling iteration if self._must_clear_metrics: - metrics_logger.stats.pop("task_metrics", None) + self._buffer.reset() self._must_clear_metrics = False # Get all the chunks associated with the episode at hand @@ -232,14 +330,14 @@ def on_episode_end(self, task_path = self._task_paths[task_index] for i in range(len(task_path)): task_branch = "/".join(map(str, task_path[:(i + 1)])) - metrics_logger.log_value( - ("task_metrics", "score", task_branch), + self._buffer.log_value( + ("score", task_branch), score, reduce="mean", window=self.history_length, clear_on_reduce=False) - metrics_logger.log_value( - ("task_metrics", "num", task_branch), 1, reduce="sum") + self._buffer.log_value( + ("num", task_branch), 1, reduce="sum") def on_sample_end(self, *, @@ -247,8 +345,7 @@ def on_sample_end(self, metrics_logger: MetricsLogger, samples: List[EpisodeType], **kwargs: Any) -> None: - # Store all the partial episodes that did not reached done yet. - # See `MonitorEpisodeCallback.on_episode_end`. + # Store all the partial episodes that did not reached done yet for episode in samples: if episode.is_done: continue @@ -271,9 +368,25 @@ def on_train_result(self, env=algorithm.env_runner.env, env_context=EnvContext({}, worker_index=0)) - # Extract from metrics mean task scores aggregated across runners - metrics = result[ENV_RUNNER_RESULTS] - task_metrics = metrics.setdefault("task_metrics", {}) + # Centralized reduced task statistics across all remote workers. + # Note that it is necessary to also include the local worker, as it + # would be used for sample collection if there is no remote worker. + workers = algorithm.env_runner_group + assert workers is not None + task_stats_all = workers.foreach_worker( + lambda worker: worker._task_stats_logger.reduce()) + self.stats_logger.merge_and_log_n_dicts(task_stats_all) + task_stats = self.stats_logger.reduce() + + # Extract task metrics and aggregate them with the results + task_metrics = { + key: { + task_path: stat.peek() + for task_path, stat in task_stats_group.items()} + for key, task_stats_group in task_stats.items()} + result[ENV_RUNNER_RESULTS]["task_metrics"] = task_metrics + + # Extract mean task scores aggregated across runners score_task_metrics = task_metrics.get("score", {}) # Re-order flat task tree and complete missing data with nan @@ -285,7 +398,7 @@ def on_train_result(self, # Unflatten mean task score tree score_task_tree = tree.unflatten_as( - self._proba_task_tree, score_task_tree_flat) + self.proba_task_tree, score_task_tree_flat) # Compute the probability tree proba_task_tree: ProbaTaskTree = [] @@ -328,37 +441,9 @@ def on_train_result(self, score_and_proba_task_branches.append(( score_task_branch_, proba_task_branch_, space)) - # Update the probability tree at runner-level. - # FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used - # reliability in conjunction with `BasePipelineWrapper`. - # See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294 - self._proba_task_tree = proba_task_tree - workers = algorithm.env_runner_group - assert workers is not None - - def _update_runner_proba_task_tree( - env_runner: EnvRunner) -> None: - """Update the probability task tree of all the environments - being managed by a given runner. - - :param env_runner: Environment runner to consider. - """ - nonlocal proba_task_tree - assert isinstance(env_runner, SingleAgentEnvRunner) - env = env_runner.env.unwrapped - assert isinstance(env, gym.vector.SyncVectorEnv) - for env in env.unwrapped.envs: - while not isinstance(env, BaseTaskSettableWrapper): - assert isinstance( - env, (gym.Wrapper, BasePipelineWrapper)) - env = env.env - env.proba_task_tree = proba_task_tree - - workers.foreach_worker(_update_runner_proba_task_tree) - # workers.foreach_worker( - # lambda worker: worker.env.unwrapped.set_attr( - # 'proba_task_tree', - # (proba_task_tree,) * worker.num_envs)) + # Update the probability tree at runner-level + _update_proba_task_tree_from_runner_group(workers, proba_task_tree) + self.proba_task_tree = proba_task_tree # Compute flattened probability tree proba_task_tree_flat: List[float] = [] @@ -385,12 +470,36 @@ def _update_runner_proba_task_tree( for path in self._task_paths: num_task_metrics.setdefault("/".join(map(str, path)), 0) - # Filter out all non-leaf metrics to avoid cluttering plots + # Filter out all non-leaf task metrics to avoid cluttering plots for data in (score_task_metrics, num_task_metrics): for key in tuple(data.keys()): if key not in self._task_names: del data[key] + def get_state(self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs: Any) -> StateDict: + return dict( + proba_task_tree=self.proba_task_tree, + stats_logger=self.stats_logger.get_state()) + + def set_state(self, state: StateDict) -> None: + self.proba_task_tree = state["proba_task_tree"] + self.stats_logger.set_state(state["stats_logger"]) + self._is_restored = True + + def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: + return ( + (), # *args + dict( # **kwargs + history_length=self.history_length, + softmin_beta=self.softmin_beta, + score_fn=self.score_fn + ) + ) + __all__ = [ "TaskSchedulingSamplingCallback", diff --git a/python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py b/python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py index 9bb808dea..fdcf0f5bf 100644 --- a/python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py +++ b/python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py @@ -25,8 +25,8 @@ from tempfile import mkdtemp from traceback import TracebackException from typing import ( - Optional, Any, Union, Sequence, Tuple, List, Literal, Dict, Set, Type, - DefaultDict, Iterable, overload, cast) + Optional, Any, Union, Sequence, Tuple, List, Literal, Dict, Set, Callable, + DefaultDict, Collection, Iterable, overload, cast) import tree import numpy as np @@ -63,13 +63,15 @@ from ray.rllib.env.env_runner import EnvRunner from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner from ray.rllib.env.env_runner_group import EnvRunnerGroup +from ray.rllib.utils.checkpoints import Checkpointable from ray.rllib.utils.filter import MeanStdFilter as _MeanStdFilter, RunningStat from ray.rllib.utils.metrics import ( NUM_ENV_STEPS_SAMPLED_LIFETIME, NUM_AGENT_STEPS_SAMPLED_LIFETIME, NUM_EPISODES_LIFETIME, EPISODE_RETURN_MEAN, EPISODE_RETURN_MAX, EPISODE_LEN_MEAN, EVALUATION_RESULTS, ENV_RUNNER_RESULTS, NUM_EPISODES) from ray.rllib.utils.metrics.metrics_logger import MetricsLogger -from ray.rllib.utils.typing import AgentID, EpisodeID, ResultDict, EpisodeType +from ray.rllib.utils.typing import ( + AgentID, EpisodeID, ResultDict, EpisodeType, StateDict) from jiminy_py.viewer import async_play_and_record_logs_files from gym_jiminy.common.bases import Obs, Act @@ -445,75 +447,118 @@ def initialize(num_cpus: int, return log_path -def make_multi_callbacks( - callback_class_list: Sequence[Type[DefaultCallbacks]] - ) -> Type[DefaultCallbacks]: - """Allows combining multiple sub-callbacks into one new callbacks class. +class MultiCallbacks(DefaultCallbacks, Checkpointable): + """Wrapper to combine multiple callback classes as one to fit with the + standard RLlib API. - The resulting DefaultCallbacks will call all the sub-callbacks' callbacks - when called. - - .. warning:: - This wrapper only supports the new API, unlike the original helper - `ray.rllib.algorithms.callbacks.make_multi_callbacks`. - - :param callback_class_list: The list of sub-classes of DefaultCallbacks to - be baked into the to-be-returned class. All of - these sub-classes' implemented methods will be - called in the given order. + .. note:: + Based on `ray.rllib.algorithms.callbacks.make_multi_callbacks`, which + has been extended to support stateful callbacks and the so-called new + API. """ - class MultiCallbacks(DefaultCallbacks): - """A DefaultCallbacks subclass that combines all the given sub-classes. + def __init__(self, + callbacks_list: Tuple[Callable[[], DefaultCallbacks], ...] + ) -> None: """ - def __init__(self) -> None: - self._callback_list = [ - callback_class() for callback_class in callback_class_list] - - def on_algorithm_init(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_algorithm_init(**kwargs) - - def on_workers_recreated(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_workers_recreated(**kwargs) - - def on_checkpoint_loaded(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_checkpoint_loaded(**kwargs) - - def on_environment_created(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_environment_created(**kwargs) - - def on_episode_start(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_episode_start(**kwargs) - - def on_episode_step(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_episode_step(**kwargs) - - def on_episode_end(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_episode_end(**kwargs) - - def on_evaluate_start(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_evaluate_start(**kwargs) - - def on_evaluate_end(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_evaluate_end(**kwargs) - - def on_sample_end(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_sample_end(**kwargs) - - def on_train_result(self, **kwargs: Any) -> None: - for callback in self._callback_list: - callback.on_train_result(**kwargs) + :param callbacks_list: The list of sub-classes of DefaultCallbacks to + be baked into the to-be-returned class. All of + these sub-classes' implemented methods will be + called in the given order. + """ + self._ctor_kwargs = dict(callbacks_list=callbacks_list) + self._callbacks_list = tuple( + callback_class() for callback_class in callbacks_list) + + def on_algorithm_init(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_algorithm_init(**kwargs) + + def on_workers_recreated(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_workers_recreated(**kwargs) + + def on_checkpoint_loaded(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_checkpoint_loaded(**kwargs) + + def on_environment_created(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_environment_created(**kwargs) + + def on_episode_start(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_episode_start(**kwargs) + + def on_episode_step(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_episode_step(**kwargs) + + def on_episode_end(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_episode_end(**kwargs) + + def on_evaluate_start(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_evaluate_start(**kwargs) + + def on_evaluate_end(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_evaluate_end(**kwargs) + + def on_sample_end(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_sample_end(**kwargs) + + def on_train_result(self, **kwargs: Any) -> None: + for callbacks in self._callbacks_list: + callbacks.on_train_result(**kwargs) + + def get_state(self, + components: Optional[Union[str, Collection[str]]] = None, + *, + not_components: Optional[Union[str, Collection[str]]] = None, + **kwargs: Any) -> StateDict: + # Sanitize input argument(s) + if isinstance(components, str): + components = (components,) + if isinstance(not_components, str): + not_components = (not_components,) + + # Aggregate sequentially states of all the wrapped callbacks if any. + # Note that the wrapper itself is stateless. + state = {} + for i, callbacks in enumerate(self._callbacks_list): + # Skip individual callbacks are not requested + key = str(i) + if components is not None and key not in components: + continue + if not_components is not None and key in not_components: + continue - return MultiCallbacks + # Append the state of the individual callback + if isinstance(callbacks, Checkpointable): + state[key] = callbacks.get_state() + return state + + def set_state(self, state: StateDict) -> None: + for i, callbacks in enumerate(self._callbacks_list): + key = str(i) + state_i = state.get(key, None) + if state_i: + assert isinstance(callbacks, Checkpointable) + callbacks.set_state(state_i) + + def get_checkpointable_components( + self) -> List[Tuple[str, Checkpointable]]: + return [(str(i), callbacks) + for i, callbacks in enumerate(self._callbacks_list) + if isinstance(callbacks, Checkpointable)] + + def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]: + return ( + (), # *args + self._ctor_kwargs, # **kwargs + ) def train(algo_config: AlgorithmConfig, @@ -615,8 +660,8 @@ def train(algo_config: AlgorithmConfig, if algo_config.callbacks_class is DefaultCallbacks: algo_config.callbacks(MonitorEpisodeCallback) else: - algo_config.callbacks(make_multi_callbacks( - [algo_config.callbacks_class, MonitorEpisodeCallback])) + algo_config.callbacks(partial(MultiCallbacks, ( + algo_config.callbacks_class, MonitorEpisodeCallback))) # Configure evaluation algo_config.evaluation( @@ -743,7 +788,14 @@ def train(algo_config: AlgorithmConfig, str(path) for path in Path(logdir).iterdir() if path.is_dir() and path.name.startswith("checkpoint_")]) if checkpoints_paths: - algo.restore(checkpoints_paths[-1]) + checkpoint_dir = checkpoints_paths[-1] + algo.restore(checkpoint_dir) + if isinstance(algo.callbacks, Checkpointable): + algo.callbacks.restore_from_path( + os.path.join(checkpoint_dir, "callbacks")) + state_callbacks = algo.callbacks.get_state() + algo.env_runner_group.foreach_worker( + lambda worker: worker._callbacks.set_state(state_callbacks)) # Synchronize connectors of training and evaluation remote workers with the # local training runner. This is necessary if a checkpoint has just been @@ -921,7 +973,12 @@ def disable_update_connectors(env_runner: EnvRunner) -> None: # Backup the policy iter_num = result[TRAINING_ITERATION] if checkpoint_interval > 0 and iter_num % checkpoint_interval == 0: - algo.save(os.path.join(logdir, f"checkpoint_{iter_num:06d}")) + checkpoint_dir = os.path.join( + logdir, f"checkpoint_{iter_num:06d}") + algo.save(checkpoint_dir) + if isinstance(algo.callbacks, Checkpointable): + algo.callbacks.save_to_path( + os.path.join(checkpoint_dir, "callbacks")) # Check terminal conditions num_timesteps = result[NUM_ENV_STEPS_SAMPLED_LIFETIME]