Skip to content

Commit

Permalink
[gym_jiminy/rllib] Fix eval runner connector states. (#870)
Browse files Browse the repository at this point in the history
* [gym_jiminy/common] Fix 'MechanicalPowerConsumption' not exposed. 
* [gym_jiminy/common] Add support of 'horizon=None' to MechanicalPowerConsumptionTermination.
* [gym_jiminy/rllib] Rename 'build_eval_runner_from_checkpoint' in 'build_runner_from_checkpoint'.
* [gym_jiminy/rllib] Fix eval runner connector stats not sync w/ eval. 
* [gym_jiminy/rllib] Fix eval runner connector updated during sample collection.
* [misc] Add pipeline benchmark script.
* [misc] Fix Sphinx doc (github-pages) broken refs.
  • Loading branch information
duburcqa authored Jan 23, 2025
1 parent fe51034 commit faea5bd
Show file tree
Hide file tree
Showing 12 changed files with 141 additions and 30 deletions.
2 changes: 0 additions & 2 deletions docs/api/gym_jiminy/common/compositions/generic.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,5 @@ Generic

.. automodule:: gym_jiminy.common.compositions.generic
:members:
:undoc-members:
:private-members:
:inherited-members:
:show-inheritance:
1 change: 0 additions & 1 deletion docs/api/gym_jiminy/common/compositions/locomotion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,5 @@ Locomotion

.. automodule:: gym_jiminy.common.compositions.locomotion
:members:
:undoc-members:
:private-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/api/gym_jiminy/common/quantities/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,6 @@ Quantities
:maxdepth: 1

generic
manager
transform
locomotion
2 changes: 0 additions & 2 deletions docs/api/gym_jiminy/common/quantities/locomotion.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,4 @@ Locomotion

.. automodule:: gym_jiminy.common.quantities.locomotion
:members:
:undoc-members:
:private-members:
:show-inheritance:
6 changes: 6 additions & 0 deletions docs/api/gym_jiminy/common/quantities/manager.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Manager
=======

.. automodule:: gym_jiminy.common.quantities.manager
:members:
:show-inheritance:
6 changes: 6 additions & 0 deletions docs/api/gym_jiminy/common/quantities/transform.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
Transform
=========

.. automodule:: gym_jiminy.common.quantities.transform
:members:
:show-inheritance:
2 changes: 1 addition & 1 deletion docs/api/gym_jiminy/common/wrappers/observation_layout.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Observation Layout Adaptation
=============================

.. automodule:: gym_jiminy.common.wrappers.observation_filter
.. automodule:: gym_jiminy.common.wrappers.observation_layout
:members:
:show-inheritance:
25 changes: 19 additions & 6 deletions python/gym_jiminy/common/gym_jiminy/common/compositions/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
from ..bases.compositions import ArrayOrScalar, ArrayLikeOrScalar, Number
from ..quantities import (
EnergyGenerationMode, StackedQuantity, UnaryOpQuantity, BinaryOpQuantity,
MultiActuatedJointKinematic, AverageMechanicalPowerConsumption)
MultiActuatedJointKinematic, MechanicalPowerConsumption,
AverageMechanicalPowerConsumption)

from .mixin import radial_basis_function

Expand Down Expand Up @@ -591,7 +592,7 @@ def __init__(
self,
env: InterfaceJiminyEnv,
max_power: float,
horizon: float,
horizon: Optional[float] = None,
generator_mode: EnergyGenerationMode = EnergyGenerationMode.CHARGE,
grace_period: float = 0.0,
*,
Expand All @@ -606,7 +607,9 @@ def __init__(
to continue whatever happens.
Optional: 0.0 by default.
:param horizon: Horizon over which values of the quantity will be
stacked before computing the average.
stacked before computing the average. `None` to
consider the instantaneous power consumption.
Optional: `None` by default.
:param training_only: Whether the termination condition should be
completely by-passed if the environment is in
evaluation mode.
Expand All @@ -617,13 +620,23 @@ def __init__(
self.horizon = horizon
self.generator_mode = generator_mode

# Pick the right quantity creator depending on the horizon
quantity_creator: QuantityCreator
if horizon is None:
quantity_creator = (AverageMechanicalPowerConsumption, dict(
horizon=self.horizon,
generator_mode=self.generator_mode,
mode=QuantityEvalMode.TRUE))
else:
quantity_creator = (MechanicalPowerConsumption, dict(
generator_mode=self.generator_mode,
mode=QuantityEvalMode.TRUE))

# Call base implementation
super().__init__(
env,
"termination_power_consumption",
(AverageMechanicalPowerConsumption, dict( # type: ignore[arg-type]
horizon=self.horizon,
generator_mode=self.generator_mode)),
quantity_creator,
None,
self.max_power,
grace_period,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
FrameOrientation,
FramePosition,
FrameXYZQuat,
MechanicalPowerConsumption,
MultiFrameOrientation,
MultiFramePosition,
MultiFrameXYZQuat,
Expand Down Expand Up @@ -54,6 +55,7 @@
'FramePosition',
'FrameXYZQuat',
'ReferencePositionWithTrueOdometryPose',
'MechanicalPowerConsumption',
'MultiFramePosition',
'MultiFrameOrientation',
'MultiFrameXYZQuat',
Expand Down
52 changes: 52 additions & 0 deletions python/gym_jiminy/examples/pipeline_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
import os

os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["NUMBA_NUM_THREADS"] = "1"

import time
from functools import reduce, partial

from gym_jiminy.envs import AtlasPDControlJiminyEnv
from gym_jiminy.common.wrappers import (
FilterObservation,
NormalizeAction,
NormalizeObservation,
FlattenAction,
FlattenObservation)

env = reduce(
lambda env, wrapper: wrapper(env), (
FlattenObservation,
FlattenAction,
partial(NormalizeObservation, ignore_unbounded=True),
NormalizeAction
), FilterObservation(
AtlasPDControlJiminyEnv(
# std_ratio={'disturbance': 4.0},
debug=False
),
nested_filter_keys=(
# 't',
# ('states', 'agent', 'q'),
# ('states', 'agent', 'v'),
("states", "pd_controller"),
# ('states', 'mahony_filter'),
# ('measurements', 'ImuSensor'),
# ('measurements', 'ForceSensor'),
("measurements", "EncoderSensor"),
("features", "mahony_filter"),
),
),
)

# Run in 30.7s on jiminy==1.8.11 (29.7s with PGO, 28.4s with eigen-dev)
env.reset()
action = env.action
time_start = time.time()
for _ in range(100000):
env.step(action)
print("time elapsed:", time.time() - time_start)
4 changes: 2 additions & 2 deletions python/gym_jiminy/examples/rllib/acrobot_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from gym_jiminy.rllib.utilities import (initialize,
train,
evaluate_from_runner,
build_eval_runner_from_checkpoint,
build_runner_from_checkpoint,
build_module_from_checkpoint,
build_module_wrapper)

Expand Down Expand Up @@ -267,7 +267,7 @@

# Build a standalone local evaluation worker (not requiring ray backend)
register_env("env", env_creator)
env_runner = build_eval_runner_from_checkpoint(checkpoint_path)
env_runner = build_runner_from_checkpoint(checkpoint_path)
evaluate_from_runner(env_runner,
num_episodes=1,
close_backend=True,
Expand Down
67 changes: 51 additions & 16 deletions python/gym_jiminy/rllib/gym_jiminy/rllib/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,6 +744,23 @@ def train(algo_config: AlgorithmConfig,
sort_keys=True,
cls=SafeFallbackEncoder)

# Disable connector update for evaluation runner
def disable_update_connectors(env_runner: EnvRunner) -> None:
"""Internal helper to disable automatic update of statistics (mean,
std) when collecting samples used by MeanStdFilter to empirically
normalized the observation.
:param env_runner: Environment runner to consider.
"""
assert isinstance(env_runner, SingleAgentEnvRunner)
for connector in env_runner._env_to_module:
if isinstance(connector, MeanStdFilter):
connector._update_stats = False
break

if algo.eval_env_runner_group is not None:
algo.eval_env_runner_group.foreach_worker(disable_update_connectors)

# Monitor memory allocations to detect leaks if any
if debug:
tracemalloc.start(10)
Expand All @@ -756,6 +773,24 @@ def train(algo_config: AlgorithmConfig,
# Perform one iteration of training the policy
result = algo.train()

# Synchronize evaluation connectors with training connectors
if algo.eval_env_runner_group is not None:
def sync_connectors(state_connectors: Dict[str, Any],
env_runner: EnvRunner) -> None:
"""Internal helper to synchronise all the env-to-module
connectors of a given runner with a given state.
:param state_connectors: Expected state of the connectors
after synchronization.
:param env_runner: Environment runner to consider.
"""
assert isinstance(env_runner, SingleAgentEnvRunner)
env_runner._env_to_module.set_state(state_connectors)

algo.eval_env_runner_group.foreach_worker(partial(
sync_connectors,
algo.env_runner._env_to_module.get_state()))

# Log results
num_timesteps = result[NUM_ENV_STEPS_SAMPLED_LIFETIME]
if file_writer is not None:
Expand Down Expand Up @@ -1406,8 +1441,9 @@ def evaluate_from_runner(env_runner: EnvRunner,
return all_episodes, all_log_paths


def build_eval_runner_from_checkpoint(checkpoint_path: str) -> EnvRunner:
"""Build a local evaluation runner from a checkpoint generated by calling
def build_runner_from_checkpoint(checkpoint_path: str,
is_eval_runner: bool = True) -> EnvRunner:
"""Build a local runner from a checkpoint generated by calling
`algo.save()` during training of the policy.
This local env runner can then be passed to `evaluate_from_runner` for
Expand All @@ -1421,27 +1457,26 @@ def build_eval_runner_from_checkpoint(checkpoint_path: str) -> EnvRunner:
prior to calling this method, otherwise it will raise an exception.
:param checkpoint_path: Checkpoint directory to be restored.
:param is_eval_runner: Whether to restore the evaluation runner in place
of the training one.
Optional: True by default.
"""
# Restore evaluation runner
eval_runner_checkpoint_path = Path(checkpoint_path) / "eval_env_runner"
# Restore runner
env_runner_checkpoint_path = Path(checkpoint_path) / (
"eval_env_runner" if is_eval_runner else "env_runner")
class_and_ctor_args_fullpath = (
eval_runner_checkpoint_path / "class_and_ctor_args.pkl")
env_runner_checkpoint_path / "class_and_ctor_args.pkl")
with open(class_and_ctor_args_fullpath, "rb") as f:
ctor_info = pickle.load(f)
ctor = ctor_info["class"]
env_runner = ctor.from_checkpoint(eval_runner_checkpoint_path)

# Restore trained RLModule
rl_module = RLModule.from_checkpoint(
Path(checkpoint_path) / COMPONENT_LEARNER_GROUP / COMPONENT_LEARNER /
COMPONENT_RL_MODULE / DEFAULT_MODULE_ID)
env_runner = ctor.from_checkpoint(env_runner_checkpoint_path)

# Sync the weights from the learner to the evaluation runner.
# Sync the weights from the learner to the runner.
# Note that it is necessary to load the learner module because weights are
# not up-to-date at runner-level.
rl_module = RLModule.from_checkpoint(
Path(checkpoint_path) / COMPONENT_LEARNER_GROUP / COMPONENT_LEARNER /
COMPONENT_RL_MODULE / DEFAULT_MODULE_ID)
Path(checkpoint_path) / COMPONENT_LEARNER_GROUP /
COMPONENT_LEARNER / COMPONENT_RL_MODULE / DEFAULT_MODULE_ID)
env_runner.set_state({COMPONENT_RL_MODULE: rl_module.get_state()})

return env_runner
Expand Down Expand Up @@ -1472,7 +1507,7 @@ def build_module_from_checkpoint(checkpoint_path: str) -> RLModule:
"""
# Restore a complete runner instead of just the policy, in order to perform
# checks regarding the pre- and post- processing of the policy.
env_runner = build_eval_runner_from_checkpoint(checkpoint_path)
env_runner = build_runner_from_checkpoint(checkpoint_path)
config = env_runner.config

# Assert(s) for type checker
Expand Down Expand Up @@ -1627,7 +1662,7 @@ def forward(obs: Obs,
"sample_from_runner_group",
"evaluate_from_algo",
"evaluate_from_runner",
"build_eval_runner_from_checkpoint",
"build_runner_from_checkpoint",
"build_module_from_checkpoint",
"build_module_wrapper",
]

0 comments on commit faea5bd

Please sign in to comment.