Skip to content

Commit

Permalink
[gym_jiminy/rllib] Restore curriculum state from checkpoints.
Browse files Browse the repository at this point in the history
* [gym_jiminy/rllib] Make callbacks checkpointable.
* [gym_jiminy/rllib] Restore curriculum state from checkpoints.
  • Loading branch information
duburcqa authored Feb 4, 2025
1 parent 72feeaf commit 70939e4
Show file tree
Hide file tree
Showing 2 changed files with 290 additions and 124 deletions.
213 changes: 161 additions & 52 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@
""" TODO: Write documentation.
"""
import math
from functools import partial
from collections import defaultdict
from typing import (
List, Any, Dict, Tuple, Optional, Callable, DefaultDict, cast)
List, Any, Dict, Tuple, Optional, Callable, DefaultDict, Union, Collection,
cast)

import numpy as np
import gymnasium as gym

from ray.rllib.core.rl_module import RLModule
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.env_runner_group import EnvRunnerGroup
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.typing import ResultDict, EpisodeID, EpisodeType
from ray.rllib.utils.typing import (
ResultDict, EpisodeID, EpisodeType, StateDict)

from jiminy_py import tree
from gym_jiminy.common.bases import BasePipelineWrapper
Expand All @@ -30,7 +35,53 @@
) from e


class TaskSchedulingSamplingCallback(DefaultCallbacks):
def _update_proba_task_tree_from_runner(
env_runner: EnvRunner,
proba_task_tree: ProbaTaskTree) -> None:
"""Update the probability task tree of all the environments being managed
by a given runner.
:param env_runner: Environment runner to consider.
:param proba_task_tree:
Probability tree consistent with the task space of the underlying
environment, which must derive from
`gym_jiminy.toolbox.wrappers.meta_envs.BaseTaskSettableWrapper`.
"""
# FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used
# reliability in conjunction with `BasePipelineWrapper`.
# See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294
assert isinstance(env_runner, SingleAgentEnvRunner)
env = env_runner.env.unwrapped
assert isinstance(env, gym.vector.SyncVectorEnv)
for env in env.unwrapped.envs:
while not isinstance(env, BaseTaskSettableWrapper):
assert isinstance(
env, (gym.Wrapper, BasePipelineWrapper))
env = env.env
env.proba_task_tree = proba_task_tree


def _update_proba_task_tree_from_runner_group(
workers: EnvRunnerGroup,
proba_task_tree: ProbaTaskTree) -> None:
"""Update the probability tree for a group of environment runners.
:param workers: Group of environment runners to be updated.
:param proba_task_tree:
Probability tree consistent with the task space of the underlying
environment, which must derive from
`gym_jiminy.toolbox.wrappers.meta_envs.BaseTaskSettableWrapper`.
"""
workers.foreach_worker(partial(
_update_proba_task_tree_from_runner,
proba_task_tree=proba_task_tree))
# workers.foreach_worker(
# lambda worker: worker.env.unwrapped.set_attr(
# 'proba_task_tree',
# (proba_task_tree,) * worker.num_envs))


class TaskSchedulingSamplingCallback(DefaultCallbacks, Checkpointable):
r"""Scheduler that automatically adapt the probability distribution of the
tasks of a task-settable environment in order to maintain the same level of
performance among all the task, no matter if some of them are much harder
Expand Down Expand Up @@ -131,14 +182,40 @@ def __init__(self,
# Whether to clear all task metrics at the end of the next episode
self._must_clear_metrics = False

# Whether the internal state of this callback instance has been
# initialized.
# Note that initialization must be postponed because it requires having
# access to attributes of the environments, but they are not available
# at this point.
self._is_initialized = False

# Whether the internal state of this callback has just been restored
# but still must to be propagated to all managed environments.
self._is_restored = False

# Maximum number of steps of the episodes.
# This is used to standardize the return because 0.0 and 1.0 (assuming
# the reward is normalized), or at least to make it independent from
# the maximum episode duration.
self._max_num_steps_all: Tuple[int, ...] = ()

# Arbitrarily nested task space
self._task_space = gym.spaces.Tuple([])

# Flatten task space representation for efficiency
self._task_paths: Tuple[TaskPath, ...] = ()
self._task_names: Tuple[str, ...] = ()
self._proba_task_tree: ProbaTaskTree = ()

# Current probablity task tree
self.proba_task_tree: ProbaTaskTree = ()

# Flattened probablity task tree representation for efficiency
self._proba_task_tree_flat_map: Dict[str, int] = {}

# Use custom logger and aggregate stats
self.stats_logger = MetricsLogger()
self._buffer = MetricsLogger()

def on_environment_created(self,
*,
env_runner: EnvRunner,
Expand All @@ -150,14 +227,22 @@ def on_environment_created(self,
if self._is_initialized:
return

# Add stats proxy at worker-level to ease centralization later on.
# Note that the proxy must NOT be added if it was already done before.
# This would be the case when no remote workers are available, causing
# the local one is used for sample collection as a fallback.
if not hasattr(env_runner, "_task_stats_logger"):
env_runner.__dict__["_task_stats_logger"] = self._buffer

# Backup tree information
try:
self._task_space, *_ = env.unwrapped.get_attr("task_space")
self._proba_task_tree, *_ = (
env.unwrapped.get_attr("proba_task_tree"))
except AttributeError as e:
raise RuntimeError("Base environment must be wrapped with "
"`BaseTaskSettableWrapper`.") from e
if not self.proba_task_tree:
self.proba_task_tree, *_ = (
env.unwrapped.get_attr("proba_task_tree"))

# Get the maximum episode duration
self._max_num_steps_all = tuple(
Expand All @@ -177,7 +262,7 @@ def on_environment_created(self,
# Initialize proba task tree flat ordering map
self._proba_task_tree_flat_map = {
"/".join(map(str, path[::2])): i for i, (path, _) in enumerate(
tree.flatten_with_path(self._proba_task_tree))}
tree.flatten_with_path(self.proba_task_tree))}

# The callback is now fully initialized
self._is_initialized = True
Expand All @@ -191,6 +276,20 @@ def on_episode_start(self,
env_index: int,
rl_module: RLModule,
**kwargs: Any) -> None:
# Propagate task probability to the environments if not already done.
# FIXME: At this point, the environment managed by the runner at hand
# has already been reset. As a result, the tasks associated with the
# very first episodes to be collected after restoring the state of this
# callback instance would be sampled according to the old probability
# task tree. To address this issue, propagation of the probability tree
# at environment-level after restoring state should be moved in
# `on_episode_created`. However, this callback method is not available
# prior to `ray>=2.41`.
if self._is_restored:
_update_proba_task_tree_from_runner(
env_runner, self.proba_task_tree)
self._is_restored = False

# Drop all partial episodes associated with the environment at hand
# when starting a fresh new one since it will never be done anyway.
if env_index in self._ongoing_episodes:
Expand All @@ -207,10 +306,9 @@ def on_episode_end(self,
env_index: int,
rl_module: RLModule,
**kwargs: Any) -> None:
# Force clearing all custom metrics at the beginning of every
# sampling iteration. See `MonitorEpisodeCallback.on_episode_end`.
# Clear all custom metrics at the beginning of every sampling iteration
if self._must_clear_metrics:
metrics_logger.stats.pop("task_metrics", None)
self._buffer.reset()
self._must_clear_metrics = False

# Get all the chunks associated with the episode at hand
Expand All @@ -232,23 +330,22 @@ def on_episode_end(self,
task_path = self._task_paths[task_index]
for i in range(len(task_path)):
task_branch = "/".join(map(str, task_path[:(i + 1)]))
metrics_logger.log_value(
("task_metrics", "score", task_branch),
self._buffer.log_value(
("score", task_branch),
score,
reduce="mean",
window=self.history_length,
clear_on_reduce=False)
metrics_logger.log_value(
("task_metrics", "num", task_branch), 1, reduce="sum")
self._buffer.log_value(
("num", task_branch), 1, reduce="sum")

def on_sample_end(self,
*,
env_runner: EnvRunner,
metrics_logger: MetricsLogger,
samples: List[EpisodeType],
**kwargs: Any) -> None:
# Store all the partial episodes that did not reached done yet.
# See `MonitorEpisodeCallback.on_episode_end`.
# Store all the partial episodes that did not reached done yet
for episode in samples:
if episode.is_done:
continue
Expand All @@ -271,9 +368,25 @@ def on_train_result(self,
env=algorithm.env_runner.env,
env_context=EnvContext({}, worker_index=0))

# Extract from metrics mean task scores aggregated across runners
metrics = result[ENV_RUNNER_RESULTS]
task_metrics = metrics.setdefault("task_metrics", {})
# Centralized reduced task statistics across all remote workers.
# Note that it is necessary to also include the local worker, as it
# would be used for sample collection if there is no remote worker.
workers = algorithm.env_runner_group
assert workers is not None
task_stats_all = workers.foreach_worker(
lambda worker: worker._task_stats_logger.reduce())
self.stats_logger.merge_and_log_n_dicts(task_stats_all)
task_stats = self.stats_logger.reduce()

# Extract task metrics and aggregate them with the results
task_metrics = {
key: {
task_path: stat.peek()
for task_path, stat in task_stats_group.items()}
for key, task_stats_group in task_stats.items()}
result[ENV_RUNNER_RESULTS]["task_metrics"] = task_metrics

# Extract mean task scores aggregated across runners
score_task_metrics = task_metrics.get("score", {})

# Re-order flat task tree and complete missing data with nan
Expand All @@ -285,7 +398,7 @@ def on_train_result(self,

# Unflatten mean task score tree
score_task_tree = tree.unflatten_as(
self._proba_task_tree, score_task_tree_flat)
self.proba_task_tree, score_task_tree_flat)

# Compute the probability tree
proba_task_tree: ProbaTaskTree = []
Expand Down Expand Up @@ -328,37 +441,9 @@ def on_train_result(self,
score_and_proba_task_branches.append((
score_task_branch_, proba_task_branch_, space))

# Update the probability tree at runner-level.
# FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used
# reliability in conjunction with `BasePipelineWrapper`.
# See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294
self._proba_task_tree = proba_task_tree
workers = algorithm.env_runner_group
assert workers is not None

def _update_runner_proba_task_tree(
env_runner: EnvRunner) -> None:
"""Update the probability task tree of all the environments
being managed by a given runner.
:param env_runner: Environment runner to consider.
"""
nonlocal proba_task_tree
assert isinstance(env_runner, SingleAgentEnvRunner)
env = env_runner.env.unwrapped
assert isinstance(env, gym.vector.SyncVectorEnv)
for env in env.unwrapped.envs:
while not isinstance(env, BaseTaskSettableWrapper):
assert isinstance(
env, (gym.Wrapper, BasePipelineWrapper))
env = env.env
env.proba_task_tree = proba_task_tree

workers.foreach_worker(_update_runner_proba_task_tree)
# workers.foreach_worker(
# lambda worker: worker.env.unwrapped.set_attr(
# 'proba_task_tree',
# (proba_task_tree,) * worker.num_envs))
# Update the probability tree at runner-level
_update_proba_task_tree_from_runner_group(workers, proba_task_tree)
self.proba_task_tree = proba_task_tree

# Compute flattened probability tree
proba_task_tree_flat: List[float] = []
Expand All @@ -385,12 +470,36 @@ def _update_runner_proba_task_tree(
for path in self._task_paths:
num_task_metrics.setdefault("/".join(map(str, path)), 0)

# Filter out all non-leaf metrics to avoid cluttering plots
# Filter out all non-leaf task metrics to avoid cluttering plots
for data in (score_task_metrics, num_task_metrics):
for key in tuple(data.keys()):
if key not in self._task_names:
del data[key]

def get_state(self,
components: Optional[Union[str, Collection[str]]] = None,
*,
not_components: Optional[Union[str, Collection[str]]] = None,
**kwargs: Any) -> StateDict:
return dict(
proba_task_tree=self.proba_task_tree,
stats_logger=self.stats_logger.get_state())

def set_state(self, state: StateDict) -> None:
self.proba_task_tree = state["proba_task_tree"]
self.stats_logger.set_state(state["stats_logger"])
self._is_restored = True

def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(), # *args
dict( # **kwargs
history_length=self.history_length,
softmin_beta=self.softmin_beta,
score_fn=self.score_fn
)
)


__all__ = [
"TaskSchedulingSamplingCallback",
Expand Down
Loading

0 comments on commit 70939e4

Please sign in to comment.