@@ -202,6 +202,7 @@ def __init__(
202
202
203
203
self ._service_endpoint : Union [str , Path ] = service
204
204
self ._connection_settings = connection_settings or ConnectionOpts ()
205
+ self ._params_to_send_on_reset : List [SessionParameter ] = []
205
206
206
207
self .service = service_connection or CompilerGymServiceConnection (
207
208
endpoint = self ._service_endpoint ,
@@ -788,6 +789,12 @@ def _call_with_error(
788
789
reply .new_action_space
789
790
)
790
791
792
+ # Re-send any session parameters that we marked as needing to be
793
+ # re-sent on reset(). Do this before any other initialization as they
794
+ # may affect the behavior of subsequent service calls.
795
+ if self ._params_to_send_on_reset :
796
+ self .send_params (* [(p .key , p .value ) for p in self ._params_to_send_on_reset ])
797
+
791
798
self .reward .reset (benchmark = self .benchmark , observation_view = self .observation )
792
799
if self .reward_space :
793
800
self .episode_reward = 0.0
@@ -1236,7 +1243,7 @@ def validate(self, state: Optional[CompilerEnvState] = None) -> ValidationResult
1236
1243
** validation ,
1237
1244
)
1238
1245
1239
- def send_param (self , key : str , value : str ) -> str :
1246
+ def send_param (self , key : str , value : str , resend_on_reset : bool = False ) -> str :
1240
1247
"""Send a single <key, value> parameter to the compiler service.
1241
1248
1242
1249
See :meth:`send_params() <compiler_gym.envs.ClientServiceCompilerEnv.send_params>`
@@ -1246,14 +1253,19 @@ def send_param(self, key: str, value: str) -> str:
1246
1253
1247
1254
:param value: The parameter value.
1248
1255
1256
+ :param resend_on_reset: Whether to resend this parameter to the compiler
1257
+ service on :code:`reset()`.
1258
+
1249
1259
:return: The response from the compiler service.
1250
1260
1251
1261
:raises SessionNotFound: If called before :meth:`reset()
1252
1262
<compiler_gym.envs.ClientServiceCompilerEnv.reset>`.
1253
1263
"""
1254
- return self .send_params ((key , value ))[0 ]
1264
+ return self .send_params ((key , value ), resend_on_reset = resend_on_reset )[0 ]
1255
1265
1256
- def send_params (self , * params : Iterable [Tuple [str , str ]]) -> List [str ]:
1266
+ def send_params (
1267
+ self , * params : Iterable [Tuple [str , str ]], resend_on_reset : bool = False
1268
+ ) -> List [str ]:
1257
1269
"""Send a list of <key, value> parameters to the compiler service.
1258
1270
1259
1271
This provides a mechanism to send messages to the backend compilation
@@ -1270,17 +1282,25 @@ def send_params(self, *params: Iterable[Tuple[str, str]]) -> List[str]:
1270
1282
:param params: A list of parameters, where each parameter is a
1271
1283
:code:`(key, value)` tuple.
1272
1284
1285
+ :param resend_on_reset: Whether to resend this parameter to the compiler
1286
+ service on :code:`reset()`.
1287
+
1273
1288
:return: A list of string responses, one per parameter.
1274
1289
1275
1290
:raises SessionNotFound: If called before :meth:`reset()
1276
1291
<compiler_gym.envs.ClientServiceCompilerEnv.reset>`.
1277
1292
"""
1293
+ params_to_send = [SessionParameter (key = k , value = v ) for (k , v ) in params ]
1294
+
1295
+ if resend_on_reset :
1296
+ self ._params_to_send_on_reset += params_to_send
1297
+
1278
1298
if not self .in_episode :
1279
1299
raise SessionNotFound ("Must call reset() before send_params()" )
1280
1300
1281
1301
request = SendSessionParameterRequest (
1282
1302
session_id = self ._session_id ,
1283
- parameter = [ SessionParameter ( key = k , value = v ) for ( k , v ) in params ] ,
1303
+ parameter = params_to_send ,
1284
1304
)
1285
1305
reply : SendSessionParameterReply = self .service (
1286
1306
self .service .stub .SendSessionParameter , request
0 commit comments