diff --git a/compiler_gym/bin/validate.py b/compiler_gym/bin/validate.py index 94612324c..4e9fe83ed 100644 --- a/compiler_gym/bin/validate.py +++ b/compiler_gym/bin/validate.py @@ -207,7 +207,7 @@ def progress_message(i): progress_message(len(states)) result_dicts = [] - def dump_result_dicst_to_json(): + def dump_result_dicts_to_json(): with open(FLAGS.validation_logfile, "w") as f: json.dump(result_dicts, f) @@ -223,9 +223,9 @@ def dump_result_dicst_to_json(): walltimes.append(result.state.walltime) if not i % 10: - dump_result_dicst_to_json() + dump_result_dicts_to_json() - dump_result_dicst_to_json() + dump_result_dicts_to_json() # Print a summary footer. intermediate_print("\r\033[K----", "-" * name_col_width, "-----------", sep="") diff --git a/compiler_gym/envs/llvm/datasets/cbench.py b/compiler_gym/envs/llvm/datasets/cbench.py index fd001255c..fb48e5c67 100644 --- a/compiler_gym/envs/llvm/datasets/cbench.py +++ b/compiler_gym/envs/llvm/datasets/cbench.py @@ -15,11 +15,12 @@ from collections import defaultdict from pathlib import Path from threading import Lock -from typing import Callable, Dict, List, NamedTuple, Optional +from typing import Callable, Dict, Iterable, List, NamedTuple, Optional import fasteners from compiler_gym.datasets import Benchmark, TarDatasetWithManifest +from compiler_gym.datasets.benchmark import ValidationCallback from compiler_gym.datasets.uri import BenchmarkUri from compiler_gym.envs.llvm import llvm_benchmark from compiler_gym.service.proto import BenchmarkDynamicConfig, Command @@ -269,7 +270,7 @@ def _make_cBench_validator( pre_execution_callback: Optional[Callable[[Path], None]] = None, sanitizer: Optional[LlvmSanitizer] = None, flakiness: int = 5, -) -> Callable[["LlvmEnv"], Optional[ValidationError]]: # noqa: F821 +) -> ValidationCallback: """Construct a validation callback for a cBench benchmark. See validator() for usage.""" input_files = input_files or [] output_files = output_files or [] @@ -407,6 +408,11 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821 # Timeout errors can be raised by the environment in case of a # slow step / observation, and should be retried. pass + + # No point in repeating compilation failures as they are not flaky. + if error.type == "Compilation failed": + return error + logger.warning( "Validation callback failed (%s), attempt=%d/%d", error.type, @@ -415,7 +421,16 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821 ) return error - return flaky_wrapped_cb + # The flaky_wrapped_cb() function takes an environment and produces a single + # error. We need the validator to produce an iterable of errors. + def adapt_validator_return_type( + env: "LlvmEnv", # noqa: F821 + ) -> Iterable[ValidationError]: + error = flaky_wrapped_cb(env) + if error: + yield error + + return adapt_validator_return_type def validator( @@ -658,9 +673,7 @@ def __init__(self, site_data_base: Path): # A map from benchmark name to validation callbacks. -VALIDATORS: Dict[ - str, List[Callable[["LlvmEnv"], Optional[str]]] # noqa: F821 -] = defaultdict(list) +VALIDATORS: Dict[str, List[ValidationCallback]] = defaultdict(list) # A map from cBench benchmark path to a list of BenchmarkDynamicConfig messages, diff --git a/tests/llvm/datasets/cbench_validate_test.py b/tests/llvm/datasets/cbench_validate_test.py index 156e47b46..377f720d3 100644 --- a/tests/llvm/datasets/cbench_validate_test.py +++ b/tests/llvm/datasets/cbench_validate_test.py @@ -3,6 +3,8 @@ # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. """Test for cBench semantics validation.""" +import pytest + from compiler_gym import ValidationResult from compiler_gym.envs.llvm import LlvmEnv from tests.test_main import main @@ -10,6 +12,7 @@ pytest_plugins = ["tests.pytest_plugins.llvm"] +@pytest.mark.timeout(600) def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str): """Run the validation routine on all benchmarks.""" env.reward_space = "IrInstructionCount" @@ -29,6 +32,7 @@ def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str) assert result.okay() +@pytest.mark.timeout(600) def test_non_validatable_benchmark_validate( env: LlvmEnv, non_validatable_cbench_uri: str ):