14
14
import examples .loop_optimizations_service as loop_optimizations_service
15
15
from compiler_gym .envs import CompilerEnv
16
16
from compiler_gym .service import SessionNotFound
17
- from compiler_gym .spaces import Box , NamedDiscrete , Scalar , Sequence
17
+ from compiler_gym .spaces import Dict , NamedDiscrete , Scalar , Sequence
18
+ from compiler_gym .third_party .autophase import AUTOPHASE_FEATURE_NAMES
18
19
from tests .test_main import main
19
20
20
21
@@ -83,14 +84,41 @@ def test_action_space(env: CompilerEnv):
83
84
def test_observation_spaces (env : CompilerEnv ):
84
85
"""Test that the environment reports the service's observation spaces."""
85
86
env .reset ()
86
- assert env .observation .spaces .keys () == {"ir" , "features" , "runtime" , "size" }
87
+ assert env .observation .spaces .keys () == {
88
+ "ir" ,
89
+ "Inst2vec" ,
90
+ "Autophase" ,
91
+ "AutophaseDict" ,
92
+ "Programl" ,
93
+ "runtime" ,
94
+ "size" ,
95
+ }
87
96
assert env .observation .spaces ["ir" ].space == Sequence (
88
97
name = "ir" ,
89
98
size_range = (0 , np .iinfo (int ).max ),
90
99
dtype = str ,
91
100
)
92
- assert env .observation .spaces ["features" ].space == Box (
93
- name = "features" , shape = (3 ,), low = 0 , high = 1e5 , dtype = int
101
+ assert env .observation .spaces ["Inst2vec" ].space == Sequence (
102
+ name = "Inst2vec" ,
103
+ size_range = (0 , np .iinfo (int ).max ),
104
+ dtype = int ,
105
+ )
106
+ assert env .observation .spaces ["Autophase" ].space == Sequence (
107
+ name = "Autophase" ,
108
+ size_range = (len (AUTOPHASE_FEATURE_NAMES ), len (AUTOPHASE_FEATURE_NAMES )),
109
+ dtype = int ,
110
+ )
111
+ assert env .observation .spaces ["AutophaseDict" ].space == Dict (
112
+ name = "AutophaseDict" ,
113
+ spaces = {
114
+ name : Scalar (name = "" , min = 0 , max = np .iinfo (np .int64 ).max , dtype = np .int64 )
115
+ for name in AUTOPHASE_FEATURE_NAMES
116
+ },
117
+ )
118
+ assert env .observation .spaces ["Programl" ].space == Sequence (
119
+ name = "Programl" ,
120
+ size_range = (0 , np .iinfo (int ).max ),
121
+ dtype = str ,
94
122
)
95
123
assert env .observation .spaces ["runtime" ].space == Scalar (
96
124
name = "runtime" , min = 0 , max = np .inf , dtype = float
@@ -160,7 +188,7 @@ def test_Step_out_of_range(env: CompilerEnv):
160
188
161
189
162
190
def test_default_ir_observation (env : CompilerEnv ):
163
- """Test default observation space."""
191
+ """Test default IR observation space."""
164
192
env .observation_space = "ir"
165
193
observation = env .reset ()
166
194
assert len (observation ) > 0
@@ -171,16 +199,48 @@ def test_default_ir_observation(env: CompilerEnv):
171
199
assert reward is None
172
200
173
201
174
- def test_default_features_observation (env : CompilerEnv ):
175
- """Test default observation space."""
176
- env .observation_space = "features"
202
+ def test_default_inst2vec_observation (env : CompilerEnv ):
203
+ """Test default inst2vec observation space."""
204
+ env .observation_space = "Inst2vec"
205
+ observation = env .reset ()
206
+ assert isinstance (observation , np .ndarray )
207
+ assert len (observation ) >= 0
208
+ assert observation .dtype == np .int64
209
+ assert all (obs >= 0 for obs in observation .tolist ())
210
+
211
+
212
+ def test_default_autophase_observation (env : CompilerEnv ):
213
+ """Test default autophase observation space."""
214
+ env .observation_space = "Autophase"
177
215
observation = env .reset ()
178
216
assert isinstance (observation , np .ndarray )
179
- assert observation .shape == (3 ,)
217
+ assert observation .shape == (len ( AUTOPHASE_FEATURE_NAMES ) ,)
180
218
assert observation .dtype == np .int64
181
219
assert all (obs >= 0 for obs in observation .tolist ())
182
220
183
221
222
+ def test_default_autophase_dict_observation (env : CompilerEnv ):
223
+ """Test default autophase dict observation space."""
224
+ env .observation_space = "AutophaseDict"
225
+ observation = env .reset ()
226
+ assert isinstance (observation , dict )
227
+ assert observation .keys () == AUTOPHASE_FEATURE_NAMES
228
+ assert len (observation .values ()) == len (AUTOPHASE_FEATURE_NAMES )
229
+ assert all (obs >= 0 for obs in observation .values ())
230
+
231
+
232
+ def test_default_programl_observation (env : CompilerEnv ):
233
+ """Test default observation space."""
234
+ env .observation_space = "Programl"
235
+ observation = env .reset ()
236
+ assert len (observation ) > 0
237
+
238
+ observation , reward , done , info = env .step (0 )
239
+ assert not done , info
240
+ assert len (observation ) > 0
241
+ assert reward is None
242
+
243
+
184
244
def test_default_reward (env : CompilerEnv ):
185
245
"""Test default reward space."""
186
246
env .reward_space = "runtime"
@@ -195,7 +255,9 @@ def test_observations(env: CompilerEnv):
195
255
"""Test observation spaces."""
196
256
env .reset ()
197
257
assert len (env .observation ["ir" ]) > 0
198
- np .testing .assert_array_less ([- 1 , - 1 , - 1 ], env .observation ["features" ])
258
+ assert all (env .observation ["Inst2vec" ] >= 0 )
259
+ assert all (env .observation ["Autophase" ] >= 0 )
260
+ assert len (env .observation ["Programl" ]) > 0
199
261
200
262
201
263
def test_rewards (env : CompilerEnv ):
0 commit comments