Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cbench] Fix return type of Benchmark.validate(). #604

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions compiler_gym/bin/validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ def progress_message(i):
progress_message(len(states))
result_dicts = []

def dump_result_dicst_to_json():
def dump_result_dicts_to_json():
with open(FLAGS.validation_logfile, "w") as f:
json.dump(result_dicts, f)

Expand All @@ -223,9 +223,9 @@ def dump_result_dicst_to_json():
walltimes.append(result.state.walltime)

if not i % 10:
dump_result_dicst_to_json()
dump_result_dicts_to_json()

dump_result_dicst_to_json()
dump_result_dicts_to_json()

# Print a summary footer.
intermediate_print("\r\033[K----", "-" * name_col_width, "-----------", sep="")
Expand Down
25 changes: 19 additions & 6 deletions compiler_gym/envs/llvm/datasets/cbench.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
from collections import defaultdict
from pathlib import Path
from threading import Lock
from typing import Callable, Dict, List, NamedTuple, Optional
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional

import fasteners

from compiler_gym.datasets import Benchmark, TarDatasetWithManifest
from compiler_gym.datasets.benchmark import ValidationCallback
from compiler_gym.datasets.uri import BenchmarkUri
from compiler_gym.envs.llvm import llvm_benchmark
from compiler_gym.service.proto import BenchmarkDynamicConfig, Command
Expand Down Expand Up @@ -269,7 +270,7 @@ def _make_cBench_validator(
pre_execution_callback: Optional[Callable[[Path], None]] = None,
sanitizer: Optional[LlvmSanitizer] = None,
flakiness: int = 5,
) -> Callable[["LlvmEnv"], Optional[ValidationError]]: # noqa: F821
) -> ValidationCallback:
"""Construct a validation callback for a cBench benchmark. See validator() for usage."""
input_files = input_files or []
output_files = output_files or []
Expand Down Expand Up @@ -407,6 +408,11 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
# Timeout errors can be raised by the environment in case of a
# slow step / observation, and should be retried.
pass

# No point in repeating compilation failures as they are not flaky.
if error.type == "Compilation failed":
return error

logger.warning(
"Validation callback failed (%s), attempt=%d/%d",
error.type,
Expand All @@ -415,7 +421,16 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
)
return error

return flaky_wrapped_cb
# The flaky_wrapped_cb() function takes an environment and produces a single
# error. We need the validator to produce an iterable of errors.
def adapt_validator_return_type(
env: "LlvmEnv", # noqa: F821
) -> Iterable[ValidationError]:
error = flaky_wrapped_cb(env)
if error:
yield error

return adapt_validator_return_type


def validator(
Expand Down Expand Up @@ -658,9 +673,7 @@ def __init__(self, site_data_base: Path):


# A map from benchmark name to validation callbacks.
VALIDATORS: Dict[
str, List[Callable[["LlvmEnv"], Optional[str]]] # noqa: F821
] = defaultdict(list)
VALIDATORS: Dict[str, List[ValidationCallback]] = defaultdict(list)


# A map from cBench benchmark path to a list of BenchmarkDynamicConfig messages,
Expand Down
4 changes: 4 additions & 0 deletions tests/llvm/datasets/cbench_validate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,16 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""Test for cBench semantics validation."""
import pytest

from compiler_gym import ValidationResult
from compiler_gym.envs.llvm import LlvmEnv
from tests.test_main import main

pytest_plugins = ["tests.pytest_plugins.llvm"]


@pytest.mark.timeout(600)
def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str):
"""Run the validation routine on all benchmarks."""
env.reward_space = "IrInstructionCount"
Expand All @@ -29,6 +32,7 @@ def test_validate_benchmark_semantics(env: LlvmEnv, validatable_cbench_uri: str)
assert result.okay()


@pytest.mark.timeout(600)
def test_non_validatable_benchmark_validate(
env: LlvmEnv, non_validatable_cbench_uri: str
):
Expand Down