|
4 | 4 | # See https://llvm.org/LICENSE.txt for license information.
|
5 | 5 | # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
6 | 6 |
|
| 7 | +import json |
7 | 8 | import logging
|
| 9 | +import os |
| 10 | +import pyjson5 |
8 | 11 | import pytest
|
9 | 12 | import subprocess
|
10 | 13 | import urllib.request
|
11 | 14 | from dataclasses import dataclass
|
12 |
| -from onnxruntime import InferenceSession |
| 15 | +from onnxruntime import InferenceSession, SessionOptions |
13 | 16 | from pathlib import Path
|
14 | 17 |
|
15 | 18 | from .utils import *
|
|
20 | 23 | ARTIFACTS_ROOT = THIS_DIR / "artifacts"
|
21 | 24 |
|
22 | 25 |
|
| 26 | +############################################################################### |
| 27 | +# Configuration |
| 28 | +############################################################################### |
| 29 | + |
| 30 | + |
| 31 | +def pytest_addoption(parser): |
| 32 | + # List of configuration files following this schema: |
| 33 | + # { |
| 34 | + # "config_name": str, |
| 35 | + # "iree_compile_flags": list of str, |
| 36 | + # "iree_run_module_flags": list of str, |
| 37 | + # "skip_compile_tests": list of str, |
| 38 | + # "skip_run_tests": list of str, |
| 39 | + # "tests_and_expected_outcomes": dict |
| 40 | + # } |
| 41 | + # |
| 42 | + # For example, to run some tests on CPU with the `llvm-cpu` backend and |
| 43 | + # `local-task` device: |
| 44 | + # { |
| 45 | + # "config_name": "cpu_llvm_task", |
| 46 | + # "iree_compile_flags": ["--iree-hal-target-backends=llvm-cpu"], |
| 47 | + # "iree_run_module_flags": ["--device=local-task"], |
| 48 | + # "tests_and_expected_outcomes": { |
| 49 | + # "default": "skip", |
| 50 | + # "tests/foo/bar/baz.py::test_a": "pass", |
| 51 | + # "tests/foo/bar/baz.py::test_b[params/x]": "fail-import", |
| 52 | + # "tests/foo/bar/baz.py::test_b[params/y]": "fail-import", |
| 53 | + # "tests/foo/bar/baz.py::test_b[params/z]": "fail-import", |
| 54 | + # "tests/foo/bar/baz.py::test_c": "fail-compile", |
| 55 | + # "tests/foo/bar/baz.py::test_d": "fail-run" |
| 56 | + # } |
| 57 | + # } |
| 58 | + # |
| 59 | + # The file can be specified in (by order of preference): |
| 60 | + # 1. The `--config-file` argument |
| 61 | + # e.g. `pytest ... --config-file foo.json` |
| 62 | + # 2. The `IREE_TEST_CONFIG_FILE` environment variable |
| 63 | + # e.g. `export IREE_TEST_CONFIG_FILE=foo.json` |
| 64 | + # 3. A default config file used for testing the test suite itself |
| 65 | + default_config_file = os.getenv( |
| 66 | + "IREE_TEST_CONFIG_FILE", THIS_DIR / "configs" / "onnx_models_cpu_llvm_task.json" |
| 67 | + ) |
| 68 | + parser.addoption( |
| 69 | + "--test-config-file", |
| 70 | + type=Path, |
| 71 | + default=default_config_file, |
| 72 | + help="Config JSON file used to parameterize test cases", |
| 73 | + ) |
| 74 | + |
| 75 | + |
| 76 | +def pytest_sessionstart(session): |
| 77 | + config_file_path = session.config.getoption("test_config_file") |
| 78 | + with open(config_file_path) as config_file: |
| 79 | + test_config = pyjson5.load(config_file) |
| 80 | + session.config.iree_test_config = test_config |
| 81 | + |
| 82 | + |
| 83 | +def pytest_collection_modifyitems(session, config, items): |
| 84 | + logger.debug(f"pytest_collection_modifyitems with {len(items)} items:") |
| 85 | + |
| 86 | + tests_and_expected_outcomes = config.iree_test_config["tests_and_expected_outcomes"] |
| 87 | + default_outcome = tests_and_expected_outcomes.get("default", "skip") |
| 88 | + |
| 89 | + for item in items: |
| 90 | + # Build a test name from the test item location, matching how the test |
| 91 | + # appears in logs, e.g. |
| 92 | + # "tests/model_zoo/validated/vision/classification_models_test.py::test_alexnet" |
| 93 | + # https://docs.pytest.org/en/stable/reference/reference.html#pytest.Item |
| 94 | + standardized_location_0 = item.location[0].replace("\\", "/") |
| 95 | + item_path = f"{standardized_location_0}::{item.location[2]}" |
| 96 | + |
| 97 | + expected_outcome = tests_and_expected_outcomes.get(item_path, default_outcome) |
| 98 | + logger.debug(f"Expected outcome for {item_path} is {expected_outcome}") |
| 99 | + |
| 100 | + if expected_outcome == "skip": |
| 101 | + mark = pytest.mark.skip(reason="Test not included in config") |
| 102 | + item.add_marker(mark) |
| 103 | + elif expected_outcome == "pass": |
| 104 | + pass |
| 105 | + elif expected_outcome == "fail-import": |
| 106 | + mark = pytest.mark.xfail(raises=IreeImportOnnxException) |
| 107 | + item.add_marker(mark) |
| 108 | + elif expected_outcome == "fail-compile": |
| 109 | + mark = pytest.mark.xfail(raises=IreeCompileException) |
| 110 | + item.add_marker(mark) |
| 111 | + elif expected_outcome == "fail-run": |
| 112 | + mark = pytest.mark.xfail(raises=IreeRunException) |
| 113 | + item.add_marker(mark) |
| 114 | + |
| 115 | + |
23 | 116 | ###############################################################################
|
24 | 117 | # ONNX loading, running, import, etc.
|
25 | 118 | ###############################################################################
|
@@ -60,7 +153,9 @@ def get_onnx_model_metadata(onnx_path: Path) -> OnnxModelMetadata:
|
60 | 153 | # C) Get metadata on demand from the InferenceSession using 'onnxruntime'
|
61 | 154 | # This is option C.
|
62 | 155 |
|
63 |
| - onnx_session = InferenceSession(onnx_path) |
| 156 | + so = SessionOptions() |
| 157 | + so.log_severity_level = 3 # ignore warnings |
| 158 | + onnx_session = InferenceSession(onnx_path, so) |
64 | 159 | logger.info(f"Getting model metadata for '{onnx_path.relative_to(THIS_DIR)}'")
|
65 | 160 | inputs = []
|
66 | 161 | onnx_inputs = {}
|
@@ -161,56 +256,55 @@ def run_iree_module(iree_module_path: Path, run_flags: list[str]):
|
161 | 256 | raise IreeRunException(f" '{iree_module_path.name}' run failed")
|
162 | 257 |
|
163 | 258 |
|
164 |
| -def compare_between_iree_and_onnxruntime_fn(model_url: str, artifacts_subdir=""): |
165 |
| - test_artifacts_dir = ARTIFACTS_ROOT / artifacts_subdir |
166 |
| - if not test_artifacts_dir.is_dir(): |
167 |
| - test_artifacts_dir.mkdir(parents=True) |
168 |
| - |
169 |
| - # Extract path and file components from the model URL. |
170 |
| - # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" |
171 |
| - model_file_name = model_url.rsplit("/", 1)[-1] |
172 |
| - # "mobilenetv2-12.onnx" --> "mobilenetv2-12" |
173 |
| - model_name = model_file_name.rsplit(".", 1)[0] |
174 |
| - |
175 |
| - # Download the model as needed. |
176 |
| - # TODO(scotttodd): move to fixture with cache / download on demand |
177 |
| - # TODO(scotttodd): overwrite if already existing? check SHA? |
178 |
| - onnx_path = test_artifacts_dir / f"{model_name}.onnx" |
179 |
| - if not onnx_path.exists(): |
180 |
| - urllib.request.urlretrieve(model_url, onnx_path) |
181 |
| - |
182 |
| - # TODO(scotttodd): cache ONNX metadata and runtime results (pickle?) |
183 |
| - onnx_model_metadata = get_onnx_model_metadata(onnx_path) |
184 |
| - logger.debug(onnx_model_metadata) |
185 |
| - |
186 |
| - # Prepare inputs and expected outputs for running through IREE. |
187 |
| - run_module_args = [] |
188 |
| - for input in onnx_model_metadata.inputs: |
189 |
| - run_module_args.append( |
190 |
| - f"--input={input.type}=@{input.data_file.relative_to(THIS_DIR)}" |
191 |
| - ) |
192 |
| - for output in onnx_model_metadata.outputs: |
193 |
| - run_module_args.append( |
194 |
| - f"--expected_output={output.type}=@{output.data_file.relative_to(THIS_DIR)}" |
195 |
| - ) |
196 |
| - |
197 |
| - # Import, compile, then run with IREE. |
198 |
| - imported_mlir_path = import_onnx_model_to_mlir(onnx_path) |
199 |
| - iree_module_path = compile_mlir_with_iree( |
200 |
| - imported_mlir_path, |
201 |
| - "cpu", |
202 |
| - [ |
203 |
| - "--iree-hal-target-backends=llvm-cpu", |
204 |
| - "--iree-llvmcpu-target-cpu=host", |
205 |
| - ], |
206 |
| - ) |
207 |
| - # Note: could load the output into memory here and compare using numpy |
208 |
| - # if the pass/fail criteria is difficult to model in the native tooling. |
209 |
| - run_flags = ["--device=local-task"] |
210 |
| - run_flags.extend(run_module_args) |
211 |
| - run_iree_module(iree_module_path, run_flags) |
| 259 | +@pytest.fixture |
| 260 | +def compare_between_iree_and_onnxruntime(pytestconfig): |
| 261 | + config_name = pytestconfig.iree_test_config["config_name"] |
| 262 | + iree_compile_flags = pytestconfig.iree_test_config["iree_compile_flags"] |
| 263 | + iree_run_module_flags = pytestconfig.iree_test_config["iree_run_module_flags"] |
| 264 | + |
| 265 | + def compare_between_iree_and_onnxruntime_fn(model_url: str, artifacts_subdir=""): |
| 266 | + test_artifacts_dir = ARTIFACTS_ROOT / artifacts_subdir |
| 267 | + if not test_artifacts_dir.is_dir(): |
| 268 | + test_artifacts_dir.mkdir(parents=True) |
| 269 | + |
| 270 | + # Extract path and file components from the model URL. |
| 271 | + # "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx" |
| 272 | + model_file_name = model_url.rsplit("/", 1)[-1] |
| 273 | + # "mobilenetv2-12.onnx" --> "mobilenetv2-12" |
| 274 | + model_name = model_file_name.rsplit(".", 1)[0] |
| 275 | + |
| 276 | + # Download the model as needed. |
| 277 | + # TODO(scotttodd): move to fixture with cache / download on demand |
| 278 | + # TODO(scotttodd): overwrite if already existing? check SHA? |
| 279 | + # TODO(scotttodd): redownload if file is corrupted (e.g. partial download) |
| 280 | + onnx_path = test_artifacts_dir / f"{model_name}.onnx" |
| 281 | + if not onnx_path.exists(): |
| 282 | + urllib.request.urlretrieve(model_url, onnx_path) |
| 283 | + |
| 284 | + # TODO(scotttodd): cache ONNX metadata and runtime results (pickle?) |
| 285 | + onnx_model_metadata = get_onnx_model_metadata(onnx_path) |
| 286 | + logger.debug(onnx_model_metadata) |
| 287 | + |
| 288 | + # Prepare inputs and expected outputs for running through IREE. |
| 289 | + run_module_args = [] |
| 290 | + for input in onnx_model_metadata.inputs: |
| 291 | + run_module_args.append( |
| 292 | + f"--input={input.type}=@{input.data_file.relative_to(THIS_DIR)}" |
| 293 | + ) |
| 294 | + for output in onnx_model_metadata.outputs: |
| 295 | + run_module_args.append( |
| 296 | + f"--expected_output={output.type}=@{output.data_file.relative_to(THIS_DIR)}" |
| 297 | + ) |
212 | 298 |
|
| 299 | + # Import, compile, then run with IREE. |
| 300 | + imported_mlir_path = import_onnx_model_to_mlir(onnx_path) |
| 301 | + iree_module_path = compile_mlir_with_iree( |
| 302 | + imported_mlir_path, config_name, iree_compile_flags.copy() |
| 303 | + ) |
| 304 | + # Note: could load the output into memory here and compare using numpy |
| 305 | + # if the pass/fail criteria is difficult to model in the native tooling. |
| 306 | + run_flags = iree_run_module_flags.copy() |
| 307 | + run_flags.extend(run_module_args) |
| 308 | + run_iree_module(iree_module_path, run_flags) |
213 | 309 |
|
214 |
| -@pytest.fixture |
215 |
| -def compare_between_iree_and_onnxruntime(): |
216 | 310 | return compare_between_iree_and_onnxruntime_fn
|
0 commit comments