Skip to content

Commit

Permalink
[gym_jiminy/common] Fix cache auto-refresh (again). Fix 'DeltaQuantit…
Browse files Browse the repository at this point in the history
…y' equality.
  • Loading branch information
duburcqa committed Feb 8, 2025
1 parent cbc9f72 commit be7113f
Show file tree
Hide file tree
Showing 9 changed files with 137 additions and 69 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ jobs:
sudo apt update
sudo apt install -y gdb gnupg curl wget build-essential cmake doxygen graphviz texlive-latex-base
"${PYTHON_EXECUTABLE}" -m pip install setuptools wheel "pip>=20.3"
"${PYTHON_EXECUTABLE}" -m pip install setuptools wheel "pip>=20.3,<25.1"
git config --global advice.detachedHead false
- name: Install pre-compiled binaries for additional gym-jiminy dependencies
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/macos.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ jobs:
echo "JIMINY_PANDA3D_FORCE_TINYDISPLAY=" >> $GITHUB_ENV
fi
"${PYTHON_EXECUTABLE}" -m pip install setuptools wheel "pip>=20.3"
"${PYTHON_EXECUTABLE}" -m pip install setuptools wheel "pip>=20.3,<25.1"
"${PYTHON_EXECUTABLE}" -m pip install delocate twine
- name: Install pre-compiled binaries for additional gym-jiminy dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/manylinux.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
echo "RootDir=${GITHUB_WORKSPACE}" >> $GITHUB_ENV
echo "InstallDir=${GITHUB_WORKSPACE}/install" >> $GITHUB_ENV
"${PYTHON_EXECUTABLE}" -m pip install setuptools wheel "pip>=20.3"
"${PYTHON_EXECUTABLE}" -m pip install setuptools wheel "pip>=20.3,<25.1"
"${PYTHON_EXECUTABLE}" -m pip install twine cmake
- name: Install latest numpy version at build-time for run-time binary compatibility
run: |
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/win.yml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ jobs:
- name: Setup minimal build environment
run: |
git config --global advice.detachedHead false
python -m pip install setuptools wheel "pip>=20.3"
python -m pip install setuptools wheel "pip>=20.3,<25.1"
python -m pip install pefile machomachomangler
- name: Install pre-compiled binaries for additional gym-jiminy dependencies
run: |
Expand Down
2 changes: 1 addition & 1 deletion build_tools/easy_install_deps_ubuntu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ echo "-- Python writable site-packages: ${PYTHON_SITELIB}"
# Install Python 3 standard utilities
apt update && \
apt install -y python3-pip && \
${SUDO_CMD} python3 -m pip install setuptools wheel "pip>=20.3" && \
${SUDO_CMD} python3 -m pip install setuptools wheel "pip>=20.3,<25.1" && \
${SUDO_CMD} python3 -m pip install "numpy>=1.24" "numba>=0.54.0"

# Install standard linux utilities
Expand Down
49 changes: 34 additions & 15 deletions python/gym_jiminy/common/gym_jiminy/common/bases/quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,14 @@ def get(self) -> ValueT:

# Initialize quantity if not already done manually
if not owner._is_initialized:
owner.initialize()
try:
owner.initialize()
except:
# Revert initialization of this quantity as it failed
owner.reset(reset_tracking=False,
ignore_requirements=True,
ignore_others=True)
raise
assert owner._is_initialized

# Get first owning quantity systematically
Expand Down Expand Up @@ -577,7 +584,9 @@ def get(self) -> ValueT:

def reset(self,
reset_tracking: bool = False,
*, ignore_other_instances: bool = False) -> None:
*,
ignore_requirements: bool = False,
ignore_others: bool = False) -> None:
"""Consider that the quantity must be re-initialized before being
evaluated once again.
Expand All @@ -595,9 +604,11 @@ def reset(self,
:param reset_tracking: Do not consider this quantity as active anymore
until the `get` method gets called once again.
Optional: False by default.
:param ignore_other_instances:
Whether to skip reset of intermediary quantities as well as any
shared cache co-owner quantity instances.
:param ignore_requirements:
Whether to skip reset reset of intermediary quantities.
Optional: False by default.
:param ignore_others:
Whether to ignore any shared cache co-owner quantity instances.
Optional: False by default.
"""
# Make sure that auto-refresh can be honored
Expand All @@ -609,9 +620,14 @@ def reset(self,
# Reset all requirements first.
# This is necessary to avoid auto-refreshing quantities with deprecated
# cache if enabled.
if not ignore_other_instances:
if not ignore_requirements:
for quantity in self.requirements.values():
quantity.reset(reset_tracking, ignore_other_instances=False)
quantity.reset(reset_tracking,
ignore_requirements=False,
ignore_others=ignore_others)

# No longer consider this exact instance as initialized
self._is_initialized = False

# Skip reset if dynamic computation graph update is not allowed
if self.env.is_simulation_running and not self.allow_update_graph:
Expand All @@ -621,27 +637,25 @@ def reset(self,
if reset_tracking:
self._is_active = False

# No longer consider this exact instance as initialized
self._is_initialized = False

# More work must to be done if this quantity has a shared cache that
# has not been completely reset yet.
if self.has_cache and self.cache.sm_state is not _IS_RESET:
# Reset shared cache state machine first, to avoid triggering reset
# propagation to all identical quantities.
self.cache.reset(
ignore_auto_refresh=True, reset_state_machine=True)
self.cache.reset(ignore_auto_refresh=True,
reset_state_machine=True)

# Reset all identical quantities except itself since already done
for owner in self.cache.owners:
if owner is not self:
owner.reset(reset_tracking=reset_tracking,
ignore_other_instances=True)
ignore_requirements=True,
ignore_others=True)

# Reset shared cache afterward with auto-refresh enabled if needed
if self.env.is_simulation_running:
self.cache.reset(
ignore_auto_refresh=False, reset_state_machine=False)
self.cache.reset(ignore_auto_refresh=False,
reset_state_machine=False)

def initialize(self) -> None:
"""Initialize internal buffers.
Expand Down Expand Up @@ -793,6 +807,11 @@ def initialize(self) -> None:
try:
self.state.initialize()
except RuntimeError:
# Revert state initialization
self.state.reset(reset_tracking=False,
ignore_requirements=True,
ignore_others=True)

# It may have failed because no simulation running, which may be
# problematic but not blocking at this point. Just checking that
# the pinocchio model has been properly initialized.
Expand Down
139 changes: 91 additions & 48 deletions python/gym_jiminy/common/gym_jiminy/common/quantities/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def __init__(self, env: InterfaceJiminyEnv) -> None:
# Backup user argument(s)
self.env = env

# List of instantiated quantities to manager
# List of managed top-level quantities
self._registry: Dict[str, InterfaceQuantity] = {}

# Initialize shared caches for all managed quantities.
Expand All @@ -64,45 +64,62 @@ def __init__(self, env: InterfaceJiminyEnv) -> None:
# using `hash(dataclasses.astuple(quantity))`. This is clearly not
# unique, as all it requires to be the same is being built from the
# same nested ordered arguments. To get around this issue, we need to
# store (key, value) pairs in a list.
self._caches: List[Tuple[
Tuple[Type[InterfaceQuantity], int], SharedCache]] = []
# store keys in a list.
self._cache_keys: List[Tuple[Type[InterfaceQuantity], int]] = []
self._caches: List[SharedCache] = []

# Instantiate trajectory database.
# Note that this quantity is not added to the global registry to avoid
# exposing directly to the user. This way, it cannot be deleted.
# exposing it directly to the user. This way, it cannot be deleted.
self.trajectory_dataset = cast(
DatasetTrajectoryQuantity, self._build_quantity(
(DatasetTrajectoryQuantity, {})))

def reset(self, reset_tracking: bool = False) -> None:
"""Consider that all managed quantity must be re-initialized before
being able to evaluate them once again.
.. note::
The cache is cleared automatically by the quantities themselves.
.. note::
This method is supposed to be called before starting a simulation.
# Ordered list of all managed quantities including dependencies.
# Note that quantities are ordered from highest to lowest-level
# depencency. This way, intermediate quantities are guaranteed to be
# reset before their parent without having to resort on downsteream
# reset propagation at quantity-level. This avoids resetting the same
# quantity multiple times if it is a dependency of multiple quantities
# for which auto-refresh is enabled.
self._quantity_chain = self._get_managed_quantities()

:param reset_tracking: Do not consider any quantity as active anymore.
Optional: False by default.
"""
for quantity in self._registry.values():
quantity.reset(reset_tracking)

def clear(self) -> None:
"""Clear internal cache of quantities to force re-evaluating them the
next time their value is fetched.
def _get_managed_quantities(self) -> Tuple[InterfaceQuantity, ...]:
"""Get the list of all managed quantities including dependencies.
.. note::
This method is supposed to be called every time the state of the
environment has changed (ie either the agent or world itself),
thereby invalidating the value currently stored in cache if any.
This method is not meant to be called manually. It is used
internally to determine in which order quantities should be reset
for optimal efficiency.
"""
ignore_auto_refresh = not self.env.is_simulation_running
for _, cache in self._caches:
cache.reset(ignore_auto_refresh=ignore_auto_refresh)
# Get all dependency branches, sorted from highest to lowest level
quantity_paths = []
quantity_stack: List[Tuple[InterfaceQuantity, ...]] = [
(quantity,) for quantity in (
self.trajectory_dataset, *self._registry.values())]
while quantity_stack:
quantity_path = quantity_stack.pop()
quantities = quantity_path[-1].requirements.values()
if quantities:
for quantity in quantities:
quantity_stack.append((*quantity_path, quantity))
else:
quantity_paths.append(quantity_path)

# Merge each ordered dependencies list in a single ordered chain
quantities_sorted: List[InterfaceQuantity] = []
for quantity_path in quantity_paths:
parent_index = len(quantities_sorted)
for quantity in quantity_path:
for i, quantity_ in tuple(enumerate(
quantities_sorted))[:parent_index][::-1]:
if quantity == quantity_:
parent_index = i
break
else:
assert not quantity in quantities_sorted[parent_index:]
quantities_sorted.insert(parent_index, quantity)
return tuple(quantities_sorted)

def _build_quantity(
self, quantity_creator: QuantityCreator) -> InterfaceQuantity:
Expand Down Expand Up @@ -130,41 +147,31 @@ def _build_quantity(
top_quantity = quantity_cls(self.env, None, **(quantity_kwargs or {}))

# Get the list of all quantities involved in computations of the top
# level quantity, sorted from highest level to lowest level.
quantities_all, quantities_sorted_all = [top_quantity], [top_quantity]
# level quantity, sorted from highest to lowest level.
quantities_all, quantity_path = [top_quantity], [top_quantity]
while quantities_all:
quantities = quantities_all.pop().requirements.values()
quantities_all += quantities
quantities_sorted_all += quantities
quantity_path += quantities

# Set a shared cache entry for all quantities involved in computations.
# Make sure that the cache associated with requirements precedes their
# parents in global cache registry. This is essential for automatic
# refresh, to ensure that cached values of all the intermediary
# quantities have been cleared before refresh.
for quantity in quantities_sorted_all[::-1]:
for quantity in quantity_path[::-1]:
# Get already available cache entry if any, otherwise create it
key = (type(quantity), hash(quantity))
for cache_key, cache in self._caches:
for cache_key, cache in zip(self._cache_keys, self._caches):
if key == cache_key:
owner, *_ = cache.owners
if quantity == owner:
break
else:
# Partially sort cache entries to reset all quantity instances
# of the same class at once.
# The objective is to avoid resetting multiple times the same
# quantity because of the auto-refresh mechanism.
# Create new cache entry
cache = SharedCache()
is_key_found = False
for i, (cache_key, _) in enumerate(self._caches):
if key[0] == cache_key[0]:
is_key_found = True
elif is_key_found:
self._caches.insert(i, (key, cache))
break
else:
self._caches.append((key, cache))
self._cache_keys.append(key)
self._caches.append(cache)

# Set shared cache of the quantity
quantity.cache = cache
Expand Down Expand Up @@ -197,6 +204,9 @@ def add(self,
# Add it to the global registry of already managed quantities
self._registry[name] = quantity

# Backup the updated sequence of managed quantities
self._quantity_chain = self._get_managed_quantities()

return quantity

def discard(self, name: str) -> None:
Expand All @@ -223,12 +233,45 @@ def discard(self, name: str) -> None:
cache = quantity.cache
quantity.cache = None # type: ignore[assignment]
if len(cache.owners) == 0:
for i, (_, _cache) in enumerate(self._caches):
for i, _cache in enumerate(self._caches):
if cache is _cache:
del self._cache_keys[i]
del self._caches[i]
break
quantities_all += quantity.requirements.values()

# Update global quantity chain
self._quantity_chain = self._get_managed_quantities()

def reset(self, reset_tracking: bool = False) -> None:
"""Consider that all managed quantity must be re-initialized before
being able to evaluate them once again.
.. note::
The cache is cleared automatically by the quantities themselves.
.. note::
This method is supposed to be called before starting a simulation.
:param reset_tracking: Do not consider any quantity as active anymore.
Optional: False by default.
"""
for quantity in self._quantity_chain:
quantity.reset(reset_tracking, ignore_requirements=True)

def clear(self) -> None:
"""Clear internal cache of quantities to force re-evaluating them the
next time their value is fetched.
.. note::
This method is supposed to be called every time the state of the
environment has changed (ie either the agent or world itself),
thereby invalidating the value currently stored in cache if any.
"""
ignore_auto_refresh = not self.env.is_simulation_running
for cache in self._caches:
cache.reset(ignore_auto_refresh=ignore_auto_refresh)

def get(self, name: str) -> Any:
"""Fetch the value of a given quantity.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -738,6 +738,10 @@ class DeltaQuantity(InterfaceQuantity[ArrayOrScalar]):
difference between them.
"""

max_stack: int
"""Time horizon over which to compute the variation.
"""

bounds_only: bool
"""Whether to compute the total variation as the difference between the
most recent and oldest value stored in the history, or the sum of
Expand Down Expand Up @@ -784,6 +788,7 @@ def __init__(

# Backup some of the user-arguments
self.op = op
self.max_stack = max_stack
self.bounds_only = bounds_only

# Define the appropriate quantity
Expand Down
3 changes: 2 additions & 1 deletion python/gym_jiminy/unit_py/test_quantities.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,8 @@ def test_discard(self):
assert len(registry["rpy_2"].data.cache.owners) == 1

quantity_manager.discard("rpy_2")
for (cls, _), cache in quantity_manager._caches:
for (cls, _), cache in zip(
quantity_manager._cache_keys, quantity_manager._caches):
assert len(cache.owners) == (cls is DatasetTrajectoryQuantity)

def test_env(self):
Expand Down

0 comments on commit be7113f

Please sign in to comment.