Skip to content

Commit 67ed37a

Browse files
authored
Merge pull request #612 from ChrisCummins/fix/reward-name-compat
Add backwards compatibility for Reward.id.
2 parents 0b881cb + a20e405 commit 67ed37a

File tree

4 files changed

+81
-8
lines changed

4 files changed

+81
-8
lines changed

compiler_gym/spaces/reward.py

+36-8
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5+
import warnings
56
from typing import List, Optional, Tuple, Union
67

78
import numpy as np
@@ -38,7 +39,10 @@ class Reward(Scalar):
3839

3940
def __init__(
4041
self,
41-
name: str,
42+
# NOTE(github.com/facebookresearch/CompilerGym/issues/381): Once `id`
43+
# argument has been removed, the default value for `name` can be
44+
# removed.
45+
name: str = None,
4246
observation_spaces: Optional[List[str]] = None,
4347
default_value: RewardType = 0,
4448
min: Optional[RewardType] = None,
@@ -47,6 +51,10 @@ def __init__(
4751
success_threshold: Optional[RewardType] = None,
4852
deterministic: bool = False,
4953
platform_dependent: bool = True,
54+
# NOTE(github.com/facebookresearch/CompilerGym/issues/381): Backwards
55+
# compatability workaround for deprecated parameter, will be removed in
56+
# v0.2.4.
57+
id: Optional[str] = None,
5058
):
5159
"""Constructor.
5260
@@ -56,33 +64,53 @@ def __init__(
5664
(:class:`space.id <compiler_gym.views.ObservationSpaceSpec>` values)
5765
that are used to compute the reward. May be an empty list if no
5866
observations are requested. Requested observations will be provided
59-
to the :code:`observations` argument of
60-
:meth:`reward.update() <compiler_gym.spaces.Reward.update>`.
67+
to the :code:`observations` argument of :meth:`reward.update()
68+
<compiler_gym.spaces.Reward.update>`.
6169
:param default_value: A default reward. This value will be returned by
62-
:meth:`env.step() <compiler_gym.envs.CompilerEnv.step>` if
63-
the service terminates.
70+
:meth:`env.step() <compiler_gym.envs.CompilerEnv.step>` if the
71+
service terminates.
6472
:param min: The lower bound of the reward.
6573
:param max: The upper bound of the reward.
6674
:param default_negates_returns: If true, the default value will be
6775
offset by the sum of all rewards for the current episode. For
6876
example, given a default reward value of *-10.0* and an episode with
69-
prior rewards *[0.1, 0.3, -0.15]*, the default value is:
70-
*-10.0 - sum(0.1, 0.3, -0.15)*.
77+
prior rewards *[0.1, 0.3, -0.15]*, the default value is: *-10.0 -
78+
sum(0.1, 0.3, -0.15)*.
7179
:param success_threshold: The cumulative reward threshold before an
7280
episode is considered successful. For example, episodes where reward
7381
is scaled to an existing heuristic can be considered “successful”
7482
when the reward exceeds the existing heuristic.
7583
:param deterministic: Whether the reward space is deterministic.
7684
:param platform_dependent: Whether the reward values depend on the
7785
execution environment of the service.
86+
:param id: The name of the reward space.
87+
88+
.. deprecated:: 0.2.3
89+
Use :code:`name` instead.
7890
"""
7991
super().__init__(
8092
name=name,
8193
min=-np.inf if min is None else min,
8294
max=np.inf if max is None else max,
8395
dtype=np.float64,
8496
)
85-
self.name = name
97+
98+
# NOTE(github.com/facebookresearch/CompilerGym/issues/381): Backwards
99+
# compatability workaround for deprecated parameter, will be removed in
100+
# v0.2.4.
101+
if id is not None:
102+
warnings.warn(
103+
"The `id` argument of "
104+
"compiler_gym.spaces.Reward.__init__() "
105+
"has been renamed `name`. This will break in a future release, "
106+
"please update your code.",
107+
DeprecationWarning,
108+
)
109+
self.name = name or id
110+
self.id = self.name
111+
if not self.name:
112+
raise TypeError("No name given")
113+
86114
self.observation_spaces = observation_spaces or []
87115
self.default_value: RewardType = default_value
88116
self.default_negates_returns: bool = default_negates_returns

tests/spaces/BUILD

+10
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ py_test(
2424
],
2525
)
2626

27+
py_test(
28+
name = "reward_test",
29+
timeout = "short",
30+
srcs = ["reward_test.py"],
31+
deps = [
32+
"//compiler_gym/spaces",
33+
"//tests:test_main",
34+
],
35+
)
36+
2737
py_test(
2838
name = "scalar_test",
2939
timeout = "short",

tests/spaces/CMakeLists.txt

+10
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,16 @@ cg_py_test(
2525
tests::test_main
2626
)
2727

28+
cg_py_test(
29+
NAME
30+
reward_test
31+
SRCS
32+
"reward_test.py"
33+
DEPS
34+
compiler_gym::spaces::spaces
35+
tests::test_main
36+
)
37+
2838
cg_py_test(
2939
NAME
3040
scalar_test

tests/spaces/reward_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright (c) Facebook, Inc. and its affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
"""Unit tests for compiler_gym.spaces.Reward."""
6+
import pytest
7+
8+
from compiler_gym.spaces import Reward
9+
from tests.test_main import main
10+
11+
12+
def test_reward_id_backwards_compatibility():
13+
"""Test that Reward.id is backwards compatible with Reward.name.
14+
15+
See: github.com/facebookresearch/CompilerGym/issues/381
16+
"""
17+
with pytest.warns(DeprecationWarning, match="renamed `name`"):
18+
reward = Reward(id="foo")
19+
20+
assert reward.id == "foo"
21+
assert reward.name == "foo"
22+
23+
24+
if __name__ == "__main__":
25+
main()

0 commit comments

Comments
 (0)