|
12 | 12 | import os
|
13 | 13 | import subprocess
|
14 | 14 | import sys
|
| 15 | +import time |
15 | 16 | from pathlib import Path
|
16 | 17 | from typing import Any, Dict, List
|
17 | 18 |
|
|
38 | 39 | MODEL_SIZE_METRICS = "model_size_metrics"
|
39 | 40 | PERFORMANCE_METRICS = "performance_metrics"
|
40 | 41 |
|
| 42 | +NUM_RETRY_ON_CONNECTION_ERROR = 2 |
| 43 | +RETRY_TIMEOUT = 60 |
| 44 | + |
41 | 45 |
|
42 | 46 | def example_test_cases():
|
43 | 47 | example_scope = load_json(EXAMPLE_SCOPE_PATH)
|
44 | 48 | for example_name, example_params in example_scope.items():
|
45 | 49 | yield pytest.param(example_name, example_params, id=example_name)
|
46 | 50 |
|
47 | 51 |
|
| 52 | +def _is_connection_error(txt: str) -> bool: |
| 53 | + error_list = [ |
| 54 | + "ReadTimeoutError", |
| 55 | + "HTTPError", |
| 56 | + "URL fetch failure", |
| 57 | + ] |
| 58 | + for line in txt.split()[::-1]: |
| 59 | + if any(e in line for e in error_list): |
| 60 | + print("-------------------------------") |
| 61 | + print(f"Detect connection error: {line}") |
| 62 | + return True |
| 63 | + return False |
| 64 | + |
| 65 | + |
48 | 66 | @pytest.mark.parametrize("example_name, example_params", example_test_cases())
|
49 | 67 | def test_examples(
|
50 | 68 | tmp_path: Path,
|
@@ -93,8 +111,20 @@ def test_examples(
|
93 | 111 | run_cmd_line = f"{python_executable_with_venv} {run_example_py} --name {example_name} --output {metrics_file_path}"
|
94 | 112 | if data is not None:
|
95 | 113 | run_cmd_line += f" --data {data}"
|
96 |
| - cmd = Command(run_cmd_line, cwd=PROJECT_ROOT, env=env) |
97 |
| - cmd.run() |
| 114 | + |
| 115 | + retry_count = 0 |
| 116 | + while True: |
| 117 | + cmd = Command(run_cmd_line, cwd=PROJECT_ROOT, env=env) |
| 118 | + try: |
| 119 | + ret = cmd.run() |
| 120 | + if ret == 0: |
| 121 | + break |
| 122 | + except Exception as e: |
| 123 | + if retry_count >= NUM_RETRY_ON_CONNECTION_ERROR or not _is_connection_error(str(e)): |
| 124 | + raise e |
| 125 | + retry_count += 1 |
| 126 | + print(f"Retry {retry_count} after {RETRY_TIMEOUT} seconds") |
| 127 | + time.sleep(RETRY_TIMEOUT) |
98 | 128 |
|
99 | 129 | measured_metrics = load_json(metrics_file_path)
|
100 | 130 | print(measured_metrics)
|
|
0 commit comments