15
15
from collections import defaultdict
16
16
from pathlib import Path
17
17
from threading import Lock
18
- from typing import Callable , Dict , List , NamedTuple , Optional
18
+ from typing import Callable , Dict , Iterable , List , NamedTuple , Optional
19
19
20
20
import fasteners
21
21
22
22
from compiler_gym .datasets import Benchmark , TarDatasetWithManifest
23
+ from compiler_gym .datasets .benchmark import ValidationCallback
23
24
from compiler_gym .datasets .uri import BenchmarkUri
24
25
from compiler_gym .envs .llvm import llvm_benchmark
25
26
from compiler_gym .service .proto import BenchmarkDynamicConfig , Command
@@ -269,7 +270,7 @@ def _make_cBench_validator(
269
270
pre_execution_callback : Optional [Callable [[Path ], None ]] = None ,
270
271
sanitizer : Optional [LlvmSanitizer ] = None ,
271
272
flakiness : int = 5 ,
272
- ) -> Callable [[ "LlvmEnv" ], Optional [ ValidationError ]]: # noqa: F821
273
+ ) -> ValidationCallback :
273
274
"""Construct a validation callback for a cBench benchmark. See validator() for usage."""
274
275
input_files = input_files or []
275
276
output_files = output_files or []
@@ -407,6 +408,11 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
407
408
# Timeout errors can be raised by the environment in case of a
408
409
# slow step / observation, and should be retried.
409
410
pass
411
+
412
+ # No point in repeating compilation failures as they are not flaky.
413
+ if error .type == "Compilation failed" :
414
+ return error
415
+
410
416
logger .warning (
411
417
"Validation callback failed (%s), attempt=%d/%d" ,
412
418
error .type ,
@@ -415,7 +421,16 @@ def flaky_wrapped_cb(env: "LlvmEnv") -> Optional[ValidationError]: # noqa: F821
415
421
)
416
422
return error
417
423
418
- return flaky_wrapped_cb
424
+ # The flaky_wrapped_cb() function takes an environment and produces a single
425
+ # error. We need the validator to produce an iterable of errors.
426
+ def adapt_validator_return_type (
427
+ env : "LlvmEnv" , # noqa: F821
428
+ ) -> Iterable [ValidationError ]:
429
+ error = flaky_wrapped_cb (env )
430
+ if error :
431
+ yield error
432
+
433
+ return adapt_validator_return_type
419
434
420
435
421
436
def validator (
@@ -658,9 +673,7 @@ def __init__(self, site_data_base: Path):
658
673
659
674
660
675
# A map from benchmark name to validation callbacks.
661
- VALIDATORS : Dict [
662
- str , List [Callable [["LlvmEnv" ], Optional [str ]]] # noqa: F821
663
- ] = defaultdict (list )
676
+ VALIDATORS : Dict [str , List [ValidationCallback ]] = defaultdict (list )
664
677
665
678
666
679
# A map from cBench benchmark path to a list of BenchmarkDynamicConfig messages,
0 commit comments