Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[gym_jiminy/rllib] Restore curriculum state from checkpoints. #878

Merged
merged 1 commit into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
213 changes: 161 additions & 52 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/curriculum.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,27 @@
""" TODO: Write documentation.
"""
import math
from functools import partial
from collections import defaultdict
from typing import (
List, Any, Dict, Tuple, Optional, Callable, DefaultDict, cast)
List, Any, Dict, Tuple, Optional, Callable, DefaultDict, Union, Collection,
cast)

import numpy as np
import gymnasium as gym

from ray.rllib.core.rl_module import RLModule
from ray.rllib.env.env_context import EnvContext
from ray.rllib.env.env_runner import EnvRunner
from ray.rllib.env.env_runner_group import EnvRunnerGroup
from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner
from ray.rllib.algorithms.algorithm import Algorithm
from ray.rllib.algorithms.callbacks import DefaultCallbacks
from ray.rllib.utils.checkpoints import Checkpointable
from ray.rllib.utils.metrics import ENV_RUNNER_RESULTS
from ray.rllib.utils.metrics.metrics_logger import MetricsLogger
from ray.rllib.utils.typing import ResultDict, EpisodeID, EpisodeType
from ray.rllib.utils.typing import (
ResultDict, EpisodeID, EpisodeType, StateDict)

from jiminy_py import tree
from gym_jiminy.common.bases import BasePipelineWrapper
Expand All @@ -30,7 +35,53 @@
) from e


class TaskSchedulingSamplingCallback(DefaultCallbacks):
def _update_proba_task_tree_from_runner(
env_runner: EnvRunner,
proba_task_tree: ProbaTaskTree) -> None:
"""Update the probability task tree of all the environments being managed
by a given runner.

:param env_runner: Environment runner to consider.
:param proba_task_tree:
Probability tree consistent with the task space of the underlying
environment, which must derive from
`gym_jiminy.toolbox.wrappers.meta_envs.BaseTaskSettableWrapper`.
"""
# FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used
# reliability in conjunction with `BasePipelineWrapper`.
# See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294
assert isinstance(env_runner, SingleAgentEnvRunner)
env = env_runner.env.unwrapped
assert isinstance(env, gym.vector.SyncVectorEnv)
for env in env.unwrapped.envs:
while not isinstance(env, BaseTaskSettableWrapper):
assert isinstance(
env, (gym.Wrapper, BasePipelineWrapper))
env = env.env
env.proba_task_tree = proba_task_tree


def _update_proba_task_tree_from_runner_group(
workers: EnvRunnerGroup,
proba_task_tree: ProbaTaskTree) -> None:
"""Update the probability tree for a group of environment runners.

:param workers: Group of environment runners to be updated.
:param proba_task_tree:
Probability tree consistent with the task space of the underlying
environment, which must derive from
`gym_jiminy.toolbox.wrappers.meta_envs.BaseTaskSettableWrapper`.
"""
workers.foreach_worker(partial(
_update_proba_task_tree_from_runner,
proba_task_tree=proba_task_tree))
# workers.foreach_worker(
# lambda worker: worker.env.unwrapped.set_attr(
# 'proba_task_tree',
# (proba_task_tree,) * worker.num_envs))


class TaskSchedulingSamplingCallback(DefaultCallbacks, Checkpointable):
r"""Scheduler that automatically adapt the probability distribution of the
tasks of a task-settable environment in order to maintain the same level of
performance among all the task, no matter if some of them are much harder
Expand Down Expand Up @@ -131,14 +182,40 @@ def __init__(self,
# Whether to clear all task metrics at the end of the next episode
self._must_clear_metrics = False

# Whether the internal state of this callback instance has been
# initialized.
# Note that initialization must be postponed because it requires having
# access to attributes of the environments, but they are not available
# at this point.
self._is_initialized = False

# Whether the internal state of this callback has just been restored
# but still must to be propagated to all managed environments.
self._is_restored = False

# Maximum number of steps of the episodes.
# This is used to standardize the return because 0.0 and 1.0 (assuming
# the reward is normalized), or at least to make it independent from
# the maximum episode duration.
self._max_num_steps_all: Tuple[int, ...] = ()

# Arbitrarily nested task space
self._task_space = gym.spaces.Tuple([])

# Flatten task space representation for efficiency
self._task_paths: Tuple[TaskPath, ...] = ()
self._task_names: Tuple[str, ...] = ()
self._proba_task_tree: ProbaTaskTree = ()

# Current probablity task tree
self.proba_task_tree: ProbaTaskTree = ()

# Flattened probablity task tree representation for efficiency
self._proba_task_tree_flat_map: Dict[str, int] = {}

# Use custom logger and aggregate stats
self.stats_logger = MetricsLogger()
self._buffer = MetricsLogger()

def on_environment_created(self,
*,
env_runner: EnvRunner,
Expand All @@ -150,14 +227,22 @@ def on_environment_created(self,
if self._is_initialized:
return

# Add stats proxy at worker-level to ease centralization later on.
# Note that the proxy must NOT be added if it was already done before.
# This would be the case when no remote workers are available, causing
# the local one is used for sample collection as a fallback.
if not hasattr(env_runner, "_task_stats_logger"):
env_runner.__dict__["_task_stats_logger"] = self._buffer

# Backup tree information
try:
self._task_space, *_ = env.unwrapped.get_attr("task_space")
self._proba_task_tree, *_ = (
env.unwrapped.get_attr("proba_task_tree"))
except AttributeError as e:
raise RuntimeError("Base environment must be wrapped with "
"`BaseTaskSettableWrapper`.") from e
if not self.proba_task_tree:
self.proba_task_tree, *_ = (
env.unwrapped.get_attr("proba_task_tree"))

# Get the maximum episode duration
self._max_num_steps_all = tuple(
Expand All @@ -177,7 +262,7 @@ def on_environment_created(self,
# Initialize proba task tree flat ordering map
self._proba_task_tree_flat_map = {
"/".join(map(str, path[::2])): i for i, (path, _) in enumerate(
tree.flatten_with_path(self._proba_task_tree))}
tree.flatten_with_path(self.proba_task_tree))}

# The callback is now fully initialized
self._is_initialized = True
Expand All @@ -191,6 +276,20 @@ def on_episode_start(self,
env_index: int,
rl_module: RLModule,
**kwargs: Any) -> None:
# Propagate task probability to the environments if not already done.
# FIXME: At this point, the environment managed by the runner at hand
# has already been reset. As a result, the tasks associated with the
# very first episodes to be collected after restoring the state of this
# callback instance would be sampled according to the old probability
# task tree. To address this issue, propagation of the probability tree
# at environment-level after restoring state should be moved in
# `on_episode_created`. However, this callback method is not available
# prior to `ray>=2.41`.
if self._is_restored:
_update_proba_task_tree_from_runner(
env_runner, self.proba_task_tree)
self._is_restored = False

# Drop all partial episodes associated with the environment at hand
# when starting a fresh new one since it will never be done anyway.
if env_index in self._ongoing_episodes:
Expand All @@ -207,10 +306,9 @@ def on_episode_end(self,
env_index: int,
rl_module: RLModule,
**kwargs: Any) -> None:
# Force clearing all custom metrics at the beginning of every
# sampling iteration. See `MonitorEpisodeCallback.on_episode_end`.
# Clear all custom metrics at the beginning of every sampling iteration
if self._must_clear_metrics:
metrics_logger.stats.pop("task_metrics", None)
self._buffer.reset()
self._must_clear_metrics = False

# Get all the chunks associated with the episode at hand
Expand All @@ -232,23 +330,22 @@ def on_episode_end(self,
task_path = self._task_paths[task_index]
for i in range(len(task_path)):
task_branch = "/".join(map(str, task_path[:(i + 1)]))
metrics_logger.log_value(
("task_metrics", "score", task_branch),
self._buffer.log_value(
("score", task_branch),
score,
reduce="mean",
window=self.history_length,
clear_on_reduce=False)
metrics_logger.log_value(
("task_metrics", "num", task_branch), 1, reduce="sum")
self._buffer.log_value(
("num", task_branch), 1, reduce="sum")

def on_sample_end(self,
*,
env_runner: EnvRunner,
metrics_logger: MetricsLogger,
samples: List[EpisodeType],
**kwargs: Any) -> None:
# Store all the partial episodes that did not reached done yet.
# See `MonitorEpisodeCallback.on_episode_end`.
# Store all the partial episodes that did not reached done yet
for episode in samples:
if episode.is_done:
continue
Expand All @@ -271,9 +368,25 @@ def on_train_result(self,
env=algorithm.env_runner.env,
env_context=EnvContext({}, worker_index=0))

# Extract from metrics mean task scores aggregated across runners
metrics = result[ENV_RUNNER_RESULTS]
task_metrics = metrics.setdefault("task_metrics", {})
# Centralized reduced task statistics across all remote workers.
# Note that it is necessary to also include the local worker, as it
# would be used for sample collection if there is no remote worker.
workers = algorithm.env_runner_group
assert workers is not None
task_stats_all = workers.foreach_worker(
lambda worker: worker._task_stats_logger.reduce())
self.stats_logger.merge_and_log_n_dicts(task_stats_all)
task_stats = self.stats_logger.reduce()

# Extract task metrics and aggregate them with the results
task_metrics = {
key: {
task_path: stat.peek()
for task_path, stat in task_stats_group.items()}
for key, task_stats_group in task_stats.items()}
result[ENV_RUNNER_RESULTS]["task_metrics"] = task_metrics

# Extract mean task scores aggregated across runners
score_task_metrics = task_metrics.get("score", {})

# Re-order flat task tree and complete missing data with nan
Expand All @@ -285,7 +398,7 @@ def on_train_result(self,

# Unflatten mean task score tree
score_task_tree = tree.unflatten_as(
self._proba_task_tree, score_task_tree_flat)
self.proba_task_tree, score_task_tree_flat)

# Compute the probability tree
proba_task_tree: ProbaTaskTree = []
Expand Down Expand Up @@ -328,37 +441,9 @@ def on_train_result(self,
score_and_proba_task_branches.append((
score_task_branch_, proba_task_branch_, space))

# Update the probability tree at runner-level.
# FIXME: `set_attr` is buggy on`gymnasium<=1.0` and cannot be used
# reliability in conjunction with `BasePipelineWrapper`.
# See PR: https://github.com/Farama-Foundation/Gymnasium/pull/1294
self._proba_task_tree = proba_task_tree
workers = algorithm.env_runner_group
assert workers is not None

def _update_runner_proba_task_tree(
env_runner: EnvRunner) -> None:
"""Update the probability task tree of all the environments
being managed by a given runner.

:param env_runner: Environment runner to consider.
"""
nonlocal proba_task_tree
assert isinstance(env_runner, SingleAgentEnvRunner)
env = env_runner.env.unwrapped
assert isinstance(env, gym.vector.SyncVectorEnv)
for env in env.unwrapped.envs:
while not isinstance(env, BaseTaskSettableWrapper):
assert isinstance(
env, (gym.Wrapper, BasePipelineWrapper))
env = env.env
env.proba_task_tree = proba_task_tree

workers.foreach_worker(_update_runner_proba_task_tree)
# workers.foreach_worker(
# lambda worker: worker.env.unwrapped.set_attr(
# 'proba_task_tree',
# (proba_task_tree,) * worker.num_envs))
# Update the probability tree at runner-level
_update_proba_task_tree_from_runner_group(workers, proba_task_tree)
self.proba_task_tree = proba_task_tree

# Compute flattened probability tree
proba_task_tree_flat: List[float] = []
Expand All @@ -385,12 +470,36 @@ def _update_runner_proba_task_tree(
for path in self._task_paths:
num_task_metrics.setdefault("/".join(map(str, path)), 0)

# Filter out all non-leaf metrics to avoid cluttering plots
# Filter out all non-leaf task metrics to avoid cluttering plots
for data in (score_task_metrics, num_task_metrics):
for key in tuple(data.keys()):
if key not in self._task_names:
del data[key]

def get_state(self,
components: Optional[Union[str, Collection[str]]] = None,
*,
not_components: Optional[Union[str, Collection[str]]] = None,
**kwargs: Any) -> StateDict:
return dict(
proba_task_tree=self.proba_task_tree,
stats_logger=self.stats_logger.get_state())

def set_state(self, state: StateDict) -> None:
self.proba_task_tree = state["proba_task_tree"]
self.stats_logger.set_state(state["stats_logger"])
self._is_restored = True

def get_ctor_args_and_kwargs(self) -> Tuple[Tuple, Dict[str, Any]]:
return (
(), # *args
dict( # **kwargs
history_length=self.history_length,
softmin_beta=self.softmin_beta,
score_fn=self.score_fn
)
)


__all__ = [
"TaskSchedulingSamplingCallback",
Expand Down
Loading
Loading