Skip to content

Commit

Permalink
[gym_jiminy/common] Fix stack length computation from horizon.
Browse files Browse the repository at this point in the history
  • Loading branch information
duburcqa committed Feb 5, 2025
1 parent 24fb968 commit d57a26b
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 14 deletions.
17 changes: 13 additions & 4 deletions python/gym_jiminy/common/gym_jiminy/common/bases/compositions.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,19 @@ def __new__(cls,
self = super(partial_hashable, cls).__new__(cls, func, *args, **kwargs)

# Pre-compute normalized arguments once and for all
sig = inspect.signature(self.func)
bound = sig.bind_partial(*self.args, **(self.keywords or {}))
bound.apply_defaults()
self._normalized_args = tuple(bound.arguments.values())
try:
sig = inspect.signature(self.func)
bound = sig.bind_partial(*self.args, **(self.keywords or {}))
bound.apply_defaults()
self._normalized_args = tuple(bound.arguments.values())
except ValueError as e:
# Impossible to get signature from Python bindings. Keyword-only is
# enforced to ensure that equality check can be implemented.
if self.args:
raise ValueError(
"Specifying position arguments is not supported for "
"methods whose signature cannot be inspected.") from e
_, self._normalized_args = zip(*sorted(self.keywords.items()))

return self

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,8 @@ def __init__(self,
self.op = op

# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
assert horizon >= env.step_dt
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Define drift of quantity
stack_creator = lambda mode: (StackedQuantity, dict( # noqa: E731
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1862,7 +1862,7 @@ def __init__(
:param mode: Desired mode of evaluation for this quantity.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.max_stack = max_stack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1569,7 +1569,8 @@ def __init__(self,
Optional: 'QuantityEvalMode.TRUE' by default.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
assert horizon >= env.step_dt
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.max_stack = max_stack
Expand Down Expand Up @@ -1675,7 +1676,8 @@ def __init__(self,
Optional: 'QuantityEvalMode.TRUE' by default.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
assert horizon >= env.step_dt
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.max_stack = max_stack
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,8 @@ def refresh(self) -> OtherValueT:
raise RuntimeError(
"Previous step missing in the stack. Please reset the "
"environment after adding this quantity.")
else:
must_refresh = False

# Extract contiguous slice of (future) available data if necessary
if self.as_array:
Expand All @@ -227,18 +229,18 @@ def refresh(self) -> OtherValueT:
if num_stack < self.max_stack:
data = self._data[..., :num_stack]

# Get current index if wrapping around
if self.is_wrapping:
index = num_steps % self.max_stack

# Append current value of the quantity to the history buffer or update
# aggregated continuous array directly if necessary.
is_stack_full = num_steps >= self.max_stack
if must_refresh:
# Get the current value of the quantity
value = self.quantity.get()

# Get current index if wrapping around
if self.is_wrapping:
index = num_steps % self.max_stack

# Append value to the history or aggregate data directly
is_stack_full = num_steps >= self.max_stack
if self.as_array:
if self.is_wrapping:
array_copyto(data[..., index], value)
Expand Down Expand Up @@ -798,7 +800,7 @@ def __init__(self,
Optional: True by default.
"""
# Convert horizon in stack length, assuming constant env timestep
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1)
max_stack = max(int(np.ceil(horizon / env.step_dt)), 1) + 1

# Backup some of the user-arguments
self.op = op
Expand Down

0 comments on commit d57a26b

Please sign in to comment.