Skip to content

Commit d7db851

Browse files
authored
Parameterize ONNX model tests. (#65)
Progress on #6. See how this is used downstream in iree-org/iree#19524. ## Overview This replaces hardcoded flags like ```python iree_compile_flags = [ "--iree-hal-target-backends=llvm-cpu", "--iree-llvmcpu-target-cpu=host", ] iree_run_module_flags = [ "--device=local-task", ] ``` and inlined marks like ```python @pytest.mark.xfail(raises=IreeCompileException) def test_foo(): ``` with a JSON config file passed to the test runner via the `--test-config-file` option or the `IREE_TEST_CONFIG_FILE` environment variable. During test case collection, each test case name is looked up in the config file to determine what the expected outcome is, from `["skip" (special option), "pass", "fail-import", "fail-compile", "fail-run"]`. By default, all tests are skipped. This design allows for out of tree testing to be performed using explicit test lists (encoded in a file, unlike the [`-k` option](https://docs.pytest.org/en/latest/example/markers.html#using-k-expr-to-select-tests-based-on-their-name)), custom flags, and custom test expectations. ## Design details Compare this implementation with these others: * https://github.com/iree-org/iree-test-suites/tree/main/onnx_ops also uses config files, but with separate lists for `skip_compile_tests`, `skip_run_tests`, `expected_compile_failures`, and `expected_run_failures`. All tests are run by default. * https://github.com/nod-ai/SHARK-TestSuite/blob/main/alt_e2eshark/run.py uses `--device=`, `--backend=`, `--target-chip=`, and `--test-filter=` arguments. Arbitrary flags are not supported, and test expectations are also not supported, so there is no way to directly signal if tests are unexpectedly passing or failing. A utility script can be used to diff the results of two test reports: https://github.com/nod-ai/SHARK-TestSuite/blob/main/alt_e2eshark/utils/check_regressions.py. * https://github.com/iree-org/iree-test-suites/blob/main/sharktank_models/llama3.1/test_llama.py parameterizes test cases using `@pytest.fixture([params=[...]])` with `pytest.mark.target_hip` and other custom marks. This is more standard pytest and supports fluent ways to express other test configurations, but it makes annotating large numbers of tests pretty verbose and doesn't allow for out of tree configuration. I'm imagining a few usage styles: * Nightly testing in this repository, running all test cases and tracking the current test results in a checked in config file. * We could also go with an approach like https://github.com/nod-ai/SHARK-TestSuite/blob/main/alt_e2eshark/utils/check_regressions.py to diff test results but this encodes the test results in the config files rather than in external reports. I see pros and cons to both approaches. * Presubmit testing in https://github.com/iree-org/iree, running a subset of test cases that pass, ensuring that they do not start failing. We could also run with XFAIL to get early signal for tests that start to pass. * If we don't run with XFAIL then we don't need the generalized `tests_and_expected_outcomes`, we could just limit testing to only models that are passing. * Developer testing with arbitrary flags. ## Follow-up tasks - [ ] Add job matrix to workflow (needs runners in this repo with GPUs) - [ ] Add an easy way to update the list of XFAILs (maybe switch to https://github.com/gsnedders/pytest-expect and use its `--update-xfail`?) - [ ] Triage some of the failures (e.g. can adjust tolerances on Vulkan) - [ ] Adjust file downloading / caching behavior to avoid redownloading and using significant bandwidth when used together with persistent self-hosted runners or github actions caches
1 parent 7e175a3 commit d7db851

11 files changed

+344
-327
lines changed

onnx_models/README.md

+25
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,31 @@ graph LR
5353

5454
See https://docs.pytest.org/en/stable/how-to/usage.html for other options.
5555

56+
## Customizing compile and run configurations
57+
58+
By default, the
59+
[`onnx_models_cpu_llvm_task.json`](./configs/onnx_models_cpu_llvm_task.json)
60+
config is used, which runs the tests on IREE's CPU backend and sets some
61+
pass/fail test expectations. To change this, run pytest with the
62+
`--test-config-file=` option:
63+
64+
```bash
65+
pytest \
66+
-rA \
67+
--log-cli-level=info \
68+
--test-config-file=./onnx_models/configs/onnx_models_gpu_vulkan.json \
69+
--durations=0
70+
```
71+
72+
Note that these config files can be tracked independently from the
73+
iree-test-suites repository so you can, for example:
74+
75+
* Run the tests from [iree-org/iree](https://github.com/iree-org/iree) at a
76+
specific commit that impacts tests outcomes and updates the config file to
77+
match the new results
78+
* Run the tests from another repository using a custom backend
79+
* Add custom flags to see if the test outcomes change
80+
5681
## Advanced pytest usage
5782
5883
* The `log-cli-level` level can also be set to `debug`, `warning`, or `error`.
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"config_name": "cpu_llvm_task",
3+
"iree_compile_flags": [
4+
"--iree-hal-target-backends=llvm-cpu",
5+
"--iree-llvmcpu-target-cpu=host"
6+
],
7+
"iree_run_module_flags": [
8+
"--device=local-task"
9+
],
10+
"tests_and_expected_outcomes": {
11+
"default": "pass",
12+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/inception_v1/model/inception-v1-12.onnx]": "fail-compile",
13+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[faster-rcnn/model/FasterRCNN-12.onnx]": "fail-compile",
14+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[fcn/model/fcn-resnet50-12.onnx]": "fail-run",
15+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[mask-rcnn/model/MaskRCNN-12.onnx]": "fail-compile",
16+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[retinanet/model/retinanet-9.onnx]": "fail-run",
17+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[ssd/model/ssd-12.onnx]": "fail-compile",
18+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[yolov4/model/yolov4.onnx]": "fail-run",
19+
"tests/model_zoo/validated/vision/style_transfer_models_test.py::test_models[fast_neural_style/model/mosaic-9.onnx]": "fail-compile"
20+
}
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
{
2+
"config_name": "gpu_rocm_rdna3",
3+
"iree_compile_flags": [
4+
"--iree-hal-target-backends=rocm",
5+
"--iree-hip-target=gfx1100"
6+
],
7+
"iree_run_module_flags": [
8+
"--device=hip"
9+
],
10+
"tests_and_expected_outcomes": {
11+
"default": "pass",
12+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/inception_v1/model/inception-v1-12.onnx]": "fail-compile",
13+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[faster-rcnn/model/FasterRCNN-12.onnx]": "fail-compile",
14+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[fcn/model/fcn-resnet50-12.onnx]": "fail-run",
15+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[mask-rcnn/model/MaskRCNN-12.onnx]": "fail-compile",
16+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[retinanet/model/retinanet-9.onnx]": "fail-run",
17+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[ssd/model/ssd-12.onnx]": "fail-compile",
18+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[yolov4/model/yolov4.onnx]": "fail-run",
19+
"tests/model_zoo/validated/vision/style_transfer_models_test.py::test_models[fast_neural_style/model/mosaic-9.onnx]": "fail-compile"
20+
}
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
{
2+
"config_name": "gpu_vulkan",
3+
"iree_compile_flags": [
4+
"--iree-hal-target-backends=vulkan-spirv"
5+
],
6+
"iree_run_module_flags": [
7+
"--device=vulkan"
8+
],
9+
"tests_and_expected_outcomes": {
10+
"default": "pass",
11+
"tests/model_zoo/validated/vision/body_analysis_models_test.py::test_models[age_gender/models/age_googlenet.onnx]": "fail-run",
12+
"tests/model_zoo/validated/vision/body_analysis_models_test.py::test_models[age_gender/models/gender_googlenet.onnx]": "fail-run",
13+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[densenet-121/model/densenet-12.onnx]": "fail-compile",
14+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[efficientnet-lite4/model/efficientnet-lite4-11.onnx]": "fail-run",
15+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/googlenet/model/googlenet-12.onnx]": "fail-run",
16+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/inception_v1/model/inception-v1-12.onnx]": "fail-compile",
17+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[inception_and_googlenet/inception_v2/model/inception-v2-9.onnx]": "fail-compile",
18+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[shufflenet/model/shufflenet-9.onnx]": "fail-compile",
19+
"tests/model_zoo/validated/vision/classification_models_test.py::test_models[shufflenet/model/shufflenet-v2-12.onnx]": "fail-compile",
20+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[faster-rcnn/model/FasterRCNN-12.onnx]": "fail-compile",
21+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[fcn/model/fcn-resnet50-12.onnx]": "fail-compile",
22+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[mask-rcnn/model/MaskRCNN-12.onnx]": "fail-compile",
23+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[retinanet/model/retinanet-9.onnx]": "fail-compile",
24+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[ssd/model/ssd-12.onnx]": "fail-compile",
25+
"tests/model_zoo/validated/vision/object_detection_segmentation_models_test.py::test_models[yolov4/model/yolov4.onnx]": "fail-compile",
26+
"tests/model_zoo/validated/vision/style_transfer_models_test.py::test_models[fast_neural_style/model/mosaic-9.onnx]": "fail-compile",
27+
"tests/model_zoo/validated/vision/super_resolution_models_test.py::test_models[sub_pixel_cnn_2016/model/super-resolution-10.onnx]": "fail-run"
28+
}
29+
}

onnx_models/conftest.py

+146-52
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,15 @@
44
# See https://llvm.org/LICENSE.txt for license information.
55
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66

7+
import json
78
import logging
9+
import os
10+
import pyjson5
811
import pytest
912
import subprocess
1013
import urllib.request
1114
from dataclasses import dataclass
12-
from onnxruntime import InferenceSession
15+
from onnxruntime import InferenceSession, SessionOptions
1316
from pathlib import Path
1417

1518
from .utils import *
@@ -20,6 +23,96 @@
2023
ARTIFACTS_ROOT = THIS_DIR / "artifacts"
2124

2225

26+
###############################################################################
27+
# Configuration
28+
###############################################################################
29+
30+
31+
def pytest_addoption(parser):
32+
# List of configuration files following this schema:
33+
# {
34+
# "config_name": str,
35+
# "iree_compile_flags": list of str,
36+
# "iree_run_module_flags": list of str,
37+
# "skip_compile_tests": list of str,
38+
# "skip_run_tests": list of str,
39+
# "tests_and_expected_outcomes": dict
40+
# }
41+
#
42+
# For example, to run some tests on CPU with the `llvm-cpu` backend and
43+
# `local-task` device:
44+
# {
45+
# "config_name": "cpu_llvm_task",
46+
# "iree_compile_flags": ["--iree-hal-target-backends=llvm-cpu"],
47+
# "iree_run_module_flags": ["--device=local-task"],
48+
# "tests_and_expected_outcomes": {
49+
# "default": "skip",
50+
# "tests/foo/bar/baz.py::test_a": "pass",
51+
# "tests/foo/bar/baz.py::test_b[params/x]": "fail-import",
52+
# "tests/foo/bar/baz.py::test_b[params/y]": "fail-import",
53+
# "tests/foo/bar/baz.py::test_b[params/z]": "fail-import",
54+
# "tests/foo/bar/baz.py::test_c": "fail-compile",
55+
# "tests/foo/bar/baz.py::test_d": "fail-run"
56+
# }
57+
# }
58+
#
59+
# The file can be specified in (by order of preference):
60+
# 1. The `--config-file` argument
61+
# e.g. `pytest ... --config-file foo.json`
62+
# 2. The `IREE_TEST_CONFIG_FILE` environment variable
63+
# e.g. `export IREE_TEST_CONFIG_FILE=foo.json`
64+
# 3. A default config file used for testing the test suite itself
65+
default_config_file = os.getenv(
66+
"IREE_TEST_CONFIG_FILE", THIS_DIR / "configs" / "onnx_models_cpu_llvm_task.json"
67+
)
68+
parser.addoption(
69+
"--test-config-file",
70+
type=Path,
71+
default=default_config_file,
72+
help="Config JSON file used to parameterize test cases",
73+
)
74+
75+
76+
def pytest_sessionstart(session):
77+
config_file_path = session.config.getoption("test_config_file")
78+
with open(config_file_path) as config_file:
79+
test_config = pyjson5.load(config_file)
80+
session.config.iree_test_config = test_config
81+
82+
83+
def pytest_collection_modifyitems(session, config, items):
84+
logger.debug(f"pytest_collection_modifyitems with {len(items)} items:")
85+
86+
tests_and_expected_outcomes = config.iree_test_config["tests_and_expected_outcomes"]
87+
default_outcome = tests_and_expected_outcomes.get("default", "skip")
88+
89+
for item in items:
90+
# Build a test name from the test item location, matching how the test
91+
# appears in logs, e.g.
92+
# "tests/model_zoo/validated/vision/classification_models_test.py::test_alexnet"
93+
# https://docs.pytest.org/en/stable/reference/reference.html#pytest.Item
94+
standardized_location_0 = item.location[0].replace("\\", "/")
95+
item_path = f"{standardized_location_0}::{item.location[2]}"
96+
97+
expected_outcome = tests_and_expected_outcomes.get(item_path, default_outcome)
98+
logger.debug(f"Expected outcome for {item_path} is {expected_outcome}")
99+
100+
if expected_outcome == "skip":
101+
mark = pytest.mark.skip(reason="Test not included in config")
102+
item.add_marker(mark)
103+
elif expected_outcome == "pass":
104+
pass
105+
elif expected_outcome == "fail-import":
106+
mark = pytest.mark.xfail(raises=IreeImportOnnxException)
107+
item.add_marker(mark)
108+
elif expected_outcome == "fail-compile":
109+
mark = pytest.mark.xfail(raises=IreeCompileException)
110+
item.add_marker(mark)
111+
elif expected_outcome == "fail-run":
112+
mark = pytest.mark.xfail(raises=IreeRunException)
113+
item.add_marker(mark)
114+
115+
23116
###############################################################################
24117
# ONNX loading, running, import, etc.
25118
###############################################################################
@@ -60,7 +153,9 @@ def get_onnx_model_metadata(onnx_path: Path) -> OnnxModelMetadata:
60153
# C) Get metadata on demand from the InferenceSession using 'onnxruntime'
61154
# This is option C.
62155

63-
onnx_session = InferenceSession(onnx_path)
156+
so = SessionOptions()
157+
so.log_severity_level = 3 # ignore warnings
158+
onnx_session = InferenceSession(onnx_path, so)
64159
logger.info(f"Getting model metadata for '{onnx_path.relative_to(THIS_DIR)}'")
65160
inputs = []
66161
onnx_inputs = {}
@@ -161,56 +256,55 @@ def run_iree_module(iree_module_path: Path, run_flags: list[str]):
161256
raise IreeRunException(f" '{iree_module_path.name}' run failed")
162257

163258

164-
def compare_between_iree_and_onnxruntime_fn(model_url: str, artifacts_subdir=""):
165-
test_artifacts_dir = ARTIFACTS_ROOT / artifacts_subdir
166-
if not test_artifacts_dir.is_dir():
167-
test_artifacts_dir.mkdir(parents=True)
168-
169-
# Extract path and file components from the model URL.
170-
# "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx"
171-
model_file_name = model_url.rsplit("/", 1)[-1]
172-
# "mobilenetv2-12.onnx" --> "mobilenetv2-12"
173-
model_name = model_file_name.rsplit(".", 1)[0]
174-
175-
# Download the model as needed.
176-
# TODO(scotttodd): move to fixture with cache / download on demand
177-
# TODO(scotttodd): overwrite if already existing? check SHA?
178-
onnx_path = test_artifacts_dir / f"{model_name}.onnx"
179-
if not onnx_path.exists():
180-
urllib.request.urlretrieve(model_url, onnx_path)
181-
182-
# TODO(scotttodd): cache ONNX metadata and runtime results (pickle?)
183-
onnx_model_metadata = get_onnx_model_metadata(onnx_path)
184-
logger.debug(onnx_model_metadata)
185-
186-
# Prepare inputs and expected outputs for running through IREE.
187-
run_module_args = []
188-
for input in onnx_model_metadata.inputs:
189-
run_module_args.append(
190-
f"--input={input.type}=@{input.data_file.relative_to(THIS_DIR)}"
191-
)
192-
for output in onnx_model_metadata.outputs:
193-
run_module_args.append(
194-
f"--expected_output={output.type}=@{output.data_file.relative_to(THIS_DIR)}"
195-
)
196-
197-
# Import, compile, then run with IREE.
198-
imported_mlir_path = import_onnx_model_to_mlir(onnx_path)
199-
iree_module_path = compile_mlir_with_iree(
200-
imported_mlir_path,
201-
"cpu",
202-
[
203-
"--iree-hal-target-backends=llvm-cpu",
204-
"--iree-llvmcpu-target-cpu=host",
205-
],
206-
)
207-
# Note: could load the output into memory here and compare using numpy
208-
# if the pass/fail criteria is difficult to model in the native tooling.
209-
run_flags = ["--device=local-task"]
210-
run_flags.extend(run_module_args)
211-
run_iree_module(iree_module_path, run_flags)
259+
@pytest.fixture
260+
def compare_between_iree_and_onnxruntime(pytestconfig):
261+
config_name = pytestconfig.iree_test_config["config_name"]
262+
iree_compile_flags = pytestconfig.iree_test_config["iree_compile_flags"]
263+
iree_run_module_flags = pytestconfig.iree_test_config["iree_run_module_flags"]
264+
265+
def compare_between_iree_and_onnxruntime_fn(model_url: str, artifacts_subdir=""):
266+
test_artifacts_dir = ARTIFACTS_ROOT / artifacts_subdir
267+
if not test_artifacts_dir.is_dir():
268+
test_artifacts_dir.mkdir(parents=True)
269+
270+
# Extract path and file components from the model URL.
271+
# "https://github.com/.../mobilenetv2-12.onnx" --> "mobilenetv2-12.onnx"
272+
model_file_name = model_url.rsplit("/", 1)[-1]
273+
# "mobilenetv2-12.onnx" --> "mobilenetv2-12"
274+
model_name = model_file_name.rsplit(".", 1)[0]
275+
276+
# Download the model as needed.
277+
# TODO(scotttodd): move to fixture with cache / download on demand
278+
# TODO(scotttodd): overwrite if already existing? check SHA?
279+
# TODO(scotttodd): redownload if file is corrupted (e.g. partial download)
280+
onnx_path = test_artifacts_dir / f"{model_name}.onnx"
281+
if not onnx_path.exists():
282+
urllib.request.urlretrieve(model_url, onnx_path)
283+
284+
# TODO(scotttodd): cache ONNX metadata and runtime results (pickle?)
285+
onnx_model_metadata = get_onnx_model_metadata(onnx_path)
286+
logger.debug(onnx_model_metadata)
287+
288+
# Prepare inputs and expected outputs for running through IREE.
289+
run_module_args = []
290+
for input in onnx_model_metadata.inputs:
291+
run_module_args.append(
292+
f"--input={input.type}=@{input.data_file.relative_to(THIS_DIR)}"
293+
)
294+
for output in onnx_model_metadata.outputs:
295+
run_module_args.append(
296+
f"--expected_output={output.type}=@{output.data_file.relative_to(THIS_DIR)}"
297+
)
212298

299+
# Import, compile, then run with IREE.
300+
imported_mlir_path = import_onnx_model_to_mlir(onnx_path)
301+
iree_module_path = compile_mlir_with_iree(
302+
imported_mlir_path, config_name, iree_compile_flags.copy()
303+
)
304+
# Note: could load the output into memory here and compare using numpy
305+
# if the pass/fail criteria is difficult to model in the native tooling.
306+
run_flags = iree_run_module_flags.copy()
307+
run_flags.extend(run_module_args)
308+
run_iree_module(iree_module_path, run_flags)
213309

214-
@pytest.fixture
215-
def compare_between_iree_and_onnxruntime():
216310
return compare_between_iree_and_onnxruntime_fn

onnx_models/requirements.txt

+2
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@
33

44
onnx
55
onnxruntime
6+
7+
pyjson5
68
pytest
79
pytest-html
810
pytest-reportlog

0 commit comments

Comments
 (0)