Skip to content

Commit 15dcc39

Browse files
authored
Merge pull request #604 from ChrisCummins/fix/cbench-validate-return-type
[cbench] Fix return type of Benchmark.validate().
2 parents 67ed37a + cd838bb commit 15dcc39

File tree

3 files changed

+26
-9
lines changed

3 files changed

+26
-9
lines changed

compiler_gym/bin/validate.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def progress_message(i):
207207
progress_message(len(states))
208208
result_dicts = []
209209

210-
def dump_result_dicst_to_json():
210+
def dump_result_dicts_to_json():
211211
with open(FLAGS.validation_logfile, "w") as f:
212212
json.dump(result_dicts, f)
213213

@@ -223,9 +223,9 @@ def dump_result_dicst_to_json():
223223
walltimes.append(result.state.walltime)
224224

225225
if not i % 10:
226-
dump_result_dicst_to_json()
226+
dump_result_dicts_to_json()
227227

228-
dump_result_dicst_to_json()
228+
dump_result_dicts_to_json()
229229

230230
# Print a summary footer.
231231
intermediate_print("\r\033[K----", "-" * name_col_width, "-----------", sep="")

compiler_gym/envs/llvm/datasets/cbench.py

+19-6
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@
1515
from collections import defaultdict
1616
from pathlib import Path
1717
from threading import Lock
18-
from typing import Callable, Dict, List, NamedTuple, Optional
18+
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional
1919

2020
import fasteners
2121

2222
from compiler_gym.datasets import Benchmark, TarDatasetWithManifest
23+
from compiler_gym.datasets.benchmark import ValidationCallback
2324
from compiler_gym.datasets.uri import BenchmarkUri
2425
from compiler_gym.envs.llvm import llvm_benchmark
2526
from compiler_gym.service.proto import BenchmarkDynamicConfig, Command
@@ -269,7 +270,7 @@ def _make_cBench_validator(
269270
pre_execution_callback: Optional[Callable[[Path], None]] = None,
270271
sanitizer: Optional[LlvmSanitizer] = None,
271272
flakiness: int = 5,
272-
) -> Callable[["LlvmEnv"], Optional[ValidationError]]: # noqa: F821
273+
) -> ValidationCallback:
273274
"""Construct a validation callback for a cBench benchmark. See validator() for usage."""
274275
input_files = input_files or []
275276
output_files = output_files or []
@@ -407,6 +408,11 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
407408
# Timeout errors can be raised by the environment in case of a
408409
# slow step / observation, and should be retried.
409410
pass
411+
412+
# No point in repeating compilation failures as they are not flaky.
413+
if error.type == "Compilation failed":
414+
return error
415+
410416
logger.warning(
411417
"Validation callback failed (%s), attempt=%d/%d",
412418
error.type,
@@ -415,7 +421,16 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
415421
)
416422
return error
417423

418-
return flaky_wrapped_cb
424+
# The flaky_wrapped_cb() function takes an environment and produces a single
425+
# error. We need the validator to produce an iterable of errors.
426+
def adapt_validator_return_type(
427+
env: "LlvmEnv", # noqa: F821
428+
) -> Iterable[ValidationError]:
429+
error = flaky_wrapped_cb(env)
430+
if error:
431+
yield error
432+
433+
return adapt_validator_return_type
419434

420435

421436
def validator(
@@ -658,9 +673,7 @@ def __init__(self, site_data_base: Path):
658673

659674

660675
# A map from benchmark name to validation callbacks.
661-
VALIDATORS: Dict[
662-
str, List[Callable[["LlvmEnv"], Optional[str]]] # noqa: F821
663-
] = defaultdict(list)
676+
VALIDATORS: Dict[str, List[ValidationCallback]] = defaultdict(list)
664677

665678

666679
# A map from cBench benchmark path to a list of BenchmarkDynamicConfig messages,

tests/llvm/datasets/cbench_validate_test.py

+4
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,16 @@
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
55
"""Test for cBench semantics validation."""
6+
import pytest
7+
68
from compiler_gym import ValidationResult
79
from compiler_gym.envs.llvm import LlvmEnv
810
from tests.test_main import main
911

1012
pytest_plugins = ["tests.pytest_plugins.llvm"]
1113

1214

15+
@pytest.mark.timeout(600)
1316
def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str):
1417
"""Run the validation routine on all benchmarks."""
1518
env.reward_space = "IrInstructionCount"
@@ -29,6 +32,7 @@ def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str)
2932
assert result.okay()
3033

3134

35+
@pytest.mark.timeout(600)
3236
def test_non_validatable_benchmark_validate(
3337
env: LlvmEnv, non_validatable_cbench_uri: str
3438
):

0 commit comments

Comments
 (0)