Skip to content

Commit 6fa1b9b

Browse files
committed
[llvm] Fix the number of runtime observations in reward reset.
This adds a `resend_on_reset` flag to the `send_params()` method that enables parameters to be resent immediately after a service is reset. This is required by the LLVM environment to ensure that the runtime observation parameters are sent before the reward space is reset. Fixes #756.
1 parent ac1b52d commit 6fa1b9b

File tree

3 files changed

+48
-29
lines changed

3 files changed

+48
-29
lines changed

compiler_gym/envs/llvm/llvm_env.py

+20-21
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
CostFunctionReward,
2929
NormalizedReward,
3030
)
31-
from compiler_gym.errors import BenchmarkInitError
31+
from compiler_gym.errors import BenchmarkInitError, SessionNotFound
3232
from compiler_gym.service.client_service_compiler_env import ClientServiceCompilerEnv
3333
from compiler_gym.spaces import Box, Commandline
3434
from compiler_gym.spaces import Dict as DictSpace
@@ -363,7 +363,7 @@ def __init__(
363363

364364
def reset(self, *args, **kwargs):
365365
try:
366-
observation = super().reset(*args, **kwargs)
366+
return super().reset(*args, **kwargs)
367367
except ValueError as e:
368368
# Catch and re-raise some known benchmark initialization errors with
369369
# a more informative error type.
@@ -379,15 +379,6 @@ def reset(self, *args, **kwargs):
379379
raise BenchmarkInitError(str(e)) from e
380380
raise
381381

382-
# Resend the runtimes-per-observation session parameter, if it is a
383-
# non-default value.
384-
if self._runtimes_per_observation_count is not None:
385-
self.runtime_observation_count = self._runtimes_per_observation_count
386-
if self._runtimes_warmup_per_observation_count is not None:
387-
self.runtime_warmup_runs_count = self._runtimes_warmup_per_observation_count
388-
389-
return observation
390-
391382
def make_benchmark(
392383
self,
393384
inputs: Union[
@@ -612,11 +603,15 @@ def runtime_observation_count(self) -> int:
612603

613604
@runtime_observation_count.setter
614605
def runtime_observation_count(self, n: int) -> None:
615-
if self.in_episode:
616-
self.send_param("llvm.set_runtimes_per_observation_count", str(n))
617-
# NOTE(cummins): Keep this after the send_param() call because
618-
# send_param() will raise an error if the valid is invalid.
619-
self._runtimes_per_observation_count = n
606+
try:
607+
self.send_param(
608+
"llvm.set_runtimes_per_observation_count", str(n), resend_on_reset=True
609+
)
610+
# NOTE(cummins): Keep this after the send_param() call because
611+
# send_param() will raise an error if the valid is invalid.
612+
self._runtimes_per_observation_count = n
613+
except SessionNotFound:
614+
pass # Not in session yet, will be sent on reset().
620615

621616
@property
622617
def runtime_warmup_runs_count(self) -> int:
@@ -648,13 +643,17 @@ def runtime_warmup_runs_count(self) -> int:
648643

649644
@runtime_warmup_runs_count.setter
650645
def runtime_warmup_runs_count(self, n: int) -> None:
651-
if self.in_episode:
646+
try:
652647
self.send_param(
653-
"llvm.set_warmup_runs_count_per_runtime_observation", str(n)
648+
"llvm.set_warmup_runs_count_per_runtime_observation",
649+
str(n),
650+
resend_on_reset=True,
654651
)
655-
# NOTE(cummins): Keep this after the send_param() call because
656-
# send_param() will raise an error if the valid is invalid.
657-
self._runtimes_warmup_per_observation_count = n
652+
# NOTE(cummins): Keep this after the send_param() call because
653+
# send_param() will raise an error if the valid is invalid.
654+
self._runtimes_warmup_per_observation_count = n
655+
except SessionNotFound:
656+
pass # Not in session yet, will be sent on reset().
658657

659658
def fork(self):
660659
fkd = super().fork()

compiler_gym/service/client_service_compiler_env.py

+24-4
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ def __init__(
202202

203203
self._service_endpoint: Union[str, Path] = service
204204
self._connection_settings = connection_settings or ConnectionOpts()
205+
self._params_to_send_on_reset: List[SessionParameter] = []
205206

206207
self.service = service_connection or CompilerGymServiceConnection(
207208
endpoint=self._service_endpoint,
@@ -788,6 +789,12 @@ def _call_with_error(
788789
reply.new_action_space
789790
)
790791

792+
# Re-send any session parameters that we marked as needing to be
793+
# re-sent on reset(). Do this before any other initialization as they
794+
# may affect the behavior of subsequent service calls.
795+
if self._params_to_send_on_reset:
796+
self.send_params(*[(p.key, p.value) for p in self._params_to_send_on_reset])
797+
791798
self.reward.reset(benchmark=self.benchmark, observation_view=self.observation)
792799
if self.reward_space:
793800
self.episode_reward = 0.0
@@ -1236,7 +1243,7 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult
12361243
**validation,
12371244
)
12381245

1239-
def send_param(self, key: str, value: str) -> str:
1246+
def send_param(self, key: str, value: str, resend_on_reset: bool = False) -> str:
12401247
"""Send a single <key, value> parameter to the compiler service.
12411248
12421249
See :meth:`send_params() <compiler_gym.envs.ClientServiceCompilerEnv.send_params>`
@@ -1246,14 +1253,19 @@ def send_param(self, key: str, value: str) -> str:
12461253
12471254
:param value: The parameter value.
12481255
1256+
:param resend_on_reset: Whether to resend this parameter to the compiler
1257+
service on :code:`reset()`.
1258+
12491259
:return: The response from the compiler service.
12501260
12511261
:raises SessionNotFound: If called before :meth:`reset()
12521262
<compiler_gym.envs.ClientServiceCompilerEnv.reset>`.
12531263
"""
1254-
return self.send_params((key, value))[0]
1264+
return self.send_params((key, value), resend_on_reset=resend_on_reset)[0]
12551265

1256-
def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]:
1266+
def send_params(
1267+
self, *params: Iterable[Tuple[str, str]], resend_on_reset: bool = False
1268+
) -> List[str]:
12571269
"""Send a list of <key, value> parameters to the compiler service.
12581270
12591271
This provides a mechanism to send messages to the backend compilation
@@ -1270,17 +1282,25 @@ def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]:
12701282
:param params: A list of parameters, where each parameter is a
12711283
:code:`(key, value)` tuple.
12721284
1285+
:param resend_on_reset: Whether to resend this parameter to the compiler
1286+
service on :code:`reset()`.
1287+
12731288
:return: A list of string responses, one per parameter.
12741289
12751290
:raises SessionNotFound: If called before :meth:`reset()
12761291
<compiler_gym.envs.ClientServiceCompilerEnv.reset>`.
12771292
"""
1293+
params_to_send = [SessionParameter(key=k, value=v) for (k, v) in params]
1294+
1295+
if resend_on_reset:
1296+
self._params_to_send_on_reset += params_to_send
1297+
12781298
if not self.in_episode:
12791299
raise SessionNotFound("Must call reset() before send_params()")
12801300

12811301
request = SendSessionParameterRequest(
12821302
session_id=self._session_id,
1283-
parameter=[SessionParameter(key=k, value=v) for (k, v) in params],
1303+
parameter=params_to_send,
12841304
)
12851305
reply: SendSessionParameterReply = self.service(
12861306
self.service.stub.SendSessionParameter, request

tests/llvm/runtime_test.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -191,10 +191,10 @@ def test_correct_number_of_observations_during_reset(
191191

192192
# Check that the number of observations that you are receive during reset()
193193
# matches the amount that you asked for.
194-
# FIXME(github.com/facebookresearch/CompilerGym/issues/756): This is broken.
195-
# Only a single observation is received, irrespective of how many you ask
196-
# for.
197-
assert len(env.reward.spaces["runtimeseries"].last_runtime_observation) == 1
194+
assert (
195+
len(env.reward.spaces["runtimeseries"].last_runtime_observation)
196+
== runtime_observation_count
197+
)
198198

199199
# Check that the number of observations that you are receive during step()
200200
# matches the amount that you asked for.

0 commit comments

Comments
 (0)