Skip to content

Commit 0e4e245

Browse files
authoredOct 29, 2024··
jax-toolbox-triage: minor usability/doc improvements (#1125)
- Print the stdout/stderr of the first execution of the test case, which is supposed to fail, at INFO level along with a message encouraging the user to check that it is the correct failure. - Print the path to the DEBUG log file at INFO level and, therefore, to the console. - Expand the documentation. - Add `--passing-container` and `--failing-container` arguments, which allow the container-level search to be skipped and non-dated containers to be triaged.
1 parent bde47a4 commit 0e4e245

File tree

6 files changed

+226
-33
lines changed

6 files changed

+226
-33
lines changed
 

‎.github/triage/jax_toolbox_triage/args.py

+46-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import tempfile
77

88

9-
def parse_args():
9+
def parse_args(args=None):
1010
parser = argparse.ArgumentParser(
1111
description="""
1212
Triage failures in JAX/XLA-related tests. The expectation is that the given
@@ -37,7 +37,6 @@ def parse_args():
3737
help="""
3838
Container to use. Example: jax, pax, triton. Used to construct the URLs of
3939
nightly containers, like ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD.""",
40-
required=True,
4140
)
4241
parser.add_argument(
4342
"--output-prefix",
@@ -67,6 +66,15 @@ def parse_args():
6766
Command to execute inside the container. This should be as targeted as
6867
possible.""",
6968
)
69+
container_search_args.add_argument(
70+
"--failing-container",
71+
help="""
72+
Skip the container-level search and pass this container to the commit-level
73+
search. If this is passed, --passing-container must be too, but --container
74+
is not required. This can be used to apply the commit-level bisection
75+
search to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD
76+
series, although they must have a similar structure.""",
77+
)
7078
container_search_args.add_argument(
7179
"--end-date",
7280
help="""
@@ -76,6 +84,15 @@ def parse_args():
7684
test case fails on this date.""",
7785
type=lambda s: datetime.date.fromisoformat(s),
7886
)
87+
container_search_args.add_argument(
88+
"--passing-container",
89+
help="""
90+
Skip the container-level search and pass this container to the commit-level
91+
search. If this is passed, --failing-container must be too, but --container is
92+
not required. This can be used to apply the commit-level bisection search
93+
to containers not from the ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD series,
94+
although they must have a similar structure.""",
95+
)
7996
container_search_args.add_argument(
8097
"--start-date",
8198
help="""
@@ -109,4 +126,30 @@ def parse_args():
109126
significantly speed up the commit-level search. By default, uses a temporary
110127
directory including the name of the current user.""",
111128
)
112-
return parser.parse_args()
129+
args = parser.parse_args(args=args)
130+
num_explicit_containers = (args.passing_container is not None) + (
131+
args.failing_container is not None
132+
)
133+
if num_explicit_containers == 1:
134+
raise Exception(
135+
"--passing-container and --failing-container must both be passed if either is"
136+
)
137+
if num_explicit_containers == 2:
138+
# Explicit mode, --container, --start-date and --end-date are all ignored
139+
if args.container:
140+
raise Exception(
141+
"--container must not be passed if --passing-container and --failing-container are"
142+
)
143+
if args.start_date:
144+
raise Exception(
145+
"--start-date must not be passed if --passing-container and --failing-container are"
146+
)
147+
if args.end_date:
148+
raise Exception(
149+
"--end-date must not be passed if --passing-container and --failing-container are"
150+
)
151+
elif num_explicit_containers == 0 and args.container is None:
152+
raise Exception(
153+
"--container must be passed if --passing-container and --failing-container are not"
154+
)
155+
return args

‎.github/triage/jax_toolbox_triage/logic.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
1+
from dataclasses import dataclass
12
import datetime
23
import functools
34
import logging
45
import typing
56

67

8+
@dataclass
9+
class TestResult:
10+
"""
11+
Hold the result/stdout/stderr of a test execution
12+
"""
13+
14+
__test__ = False # stop pytest gathering this
15+
result: bool
16+
stdout: typing.Optional[str] = None
17+
stderr: typing.Optional[str] = None
18+
19+
720
def as_datetime(date: datetime.date) -> datetime.datetime:
821
return datetime.datetime.combine(date, datetime.time())
922

@@ -59,7 +72,7 @@ def adjust_date(
5972
def container_search(
6073
*,
6174
container_exists: typing.Callable[[datetime.date], bool],
62-
container_passes: typing.Callable[[datetime.date], bool],
75+
container_passes: typing.Callable[[datetime.date], TestResult],
6376
start_date: typing.Optional[datetime.date],
6477
end_date: typing.Optional[datetime.date],
6578
logger: logging.Logger,
@@ -88,8 +101,17 @@ def container_search(
88101
logger.info(f"Skipping check for end-of-range failure in {end_date}")
89102
else:
90103
logger.info(f"Checking end-of-range failure in {end_date}")
91-
if container_passes(end_date):
104+
test_end_date = container_passes(end_date)
105+
logger.info(f"stdout: {test_end_date.stdout}")
106+
logger.info(f"stderr: {test_end_date.stderr}")
107+
if test_end_date.result:
92108
raise Exception(f"Could not reproduce failure in {end_date}")
109+
logger.info(
110+
"IMPORTANT: you should check that the test output above shows the "
111+
f"*expected* failure of your test case in the {end_date} container. It is "
112+
"very easy to accidentally provide a test case that fails for the wrong "
113+
"reason, which will not triage the correct issue!"
114+
)
93115

94116
# Start the coarse, container-level, search for a starting point to the bisection range
95117
earliest_failure = end_date
@@ -127,7 +149,7 @@ def container_search(
127149
logger.info(f"Skipping check that the test passes on start_date={start_date}")
128150
else:
129151
# While condition prints an info message
130-
while not container_passes(search_date):
152+
while not container_passes(search_date).result:
131153
# Test failed on `search_date`, go further into the past
132154
earliest_failure = search_date
133155
new_search_date = adjust(
@@ -155,7 +177,7 @@ def container_search(
155177
if range_mid is None:
156178
# It wasn't possible to refine further.
157179
break
158-
result = container_passes(range_mid)
180+
result = container_passes(range_mid).result
159181
if result:
160182
range_start = range_mid
161183
else:

‎.github/triage/jax_toolbox_triage/main.py

+30-18
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from .args import parse_args
1010
from .docker import DockerContainer
11-
from .logic import commit_search, container_search
11+
from .logic import commit_search, container_search, TestResult
1212
from .utils import (
1313
container_exists as container_exists_base,
1414
container_url as container_url_base,
@@ -21,6 +21,10 @@ def main():
2121
args = parse_args()
2222
bazel_cache_mounts = prepare_bazel_cache_mounts(args.bazel_cache)
2323
logger = get_logger(args.output_prefix)
24+
logger.info(
25+
"Verbose output, including stdout/err of triage commands, will be written to "
26+
f'{(args.output_prefix / "debug.log").resolve()}'
27+
)
2428
container_url = functools.partial(container_url_base, container=args.container)
2529
container_exists = functools.partial(
2630
container_exists_base, container=args.container, logger=logger
@@ -75,7 +79,7 @@ def get_commit(container: DockerContainer, repo: str) -> typing.Tuple[str, str]:
7579
f"Could not extract commit of {repo} from {args.container} container {container}"
7680
)
7781

78-
def check_container(date: datetime.date) -> bool:
82+
def check_container(date: datetime.date) -> TestResult:
7983
"""
8084
See if the test passes in the given container.
8185
"""
@@ -100,37 +104,45 @@ def check_container(date: datetime.date) -> bool:
100104
"xla": xla_commit,
101105
},
102106
)
103-
return test_pass
104-
105-
# Search through the published containers, narrowing down to a pair of dates with
106-
# the property that the test passed on `range_start` and fails on `range_end`.
107-
range_start, range_end = container_search(
108-
container_exists=container_exists,
109-
container_passes=check_container,
110-
start_date=args.start_date,
111-
end_date=args.end_date,
112-
logger=logger,
113-
skip_precondition_checks=args.skip_precondition_checks,
114-
threshold_days=args.threshold_days,
115-
)
107+
return TestResult(result=test_pass, stdout=result.stdout, stderr=result.stderr)
108+
109+
if args.passing_container is not None:
110+
assert args.failing_container is not None
111+
# Skip the container-level search because explicit end points were given
112+
passing_url = args.passing_container
113+
failing_url = args.failing_container
114+
else:
115+
# Search through the published containers, narrowing down to a pair of dates with
116+
# the property that the test passed on `range_start` and fails on `range_end`.
117+
range_start, range_end = container_search(
118+
container_exists=container_exists,
119+
container_passes=check_container,
120+
start_date=args.start_date,
121+
end_date=args.end_date,
122+
logger=logger,
123+
skip_precondition_checks=args.skip_precondition_checks,
124+
threshold_days=args.threshold_days,
125+
)
126+
passing_url = container_url(range_start)
127+
failing_url = container_url(range_end)
116128

117129
# Container-level search is now complete. Triage proceeds inside the `range_end``
118130
# container. First, we check that rewinding JAX and XLA inside the `range_end``
119131
# container to the commits used in the `range_start` container passes, whereas
120132
# using the `range_end` commits reproduces the failure.
121133

122-
with Container(container_url(range_start)) as worker:
134+
with Container(passing_url) as worker:
123135
start_jax_commit, _ = get_commit(worker, "jax")
124136
start_xla_commit, _ = get_commit(worker, "xla")
125137

126138
# Fire up the container that will be used for the fine search.
127-
with Container(container_url(range_end)) as worker:
139+
with Container(failing_url) as worker:
128140
end_jax_commit, jax_dir = get_commit(worker, "jax")
129141
end_xla_commit, xla_dir = get_commit(worker, "xla")
130142
logger.info(
131143
(
132144
f"Bisecting JAX [{start_jax_commit}, {end_jax_commit}] and "
133-
f"XLA [{start_xla_commit}, {end_xla_commit}] using {container_url(range_end)}"
145+
f"XLA [{start_xla_commit}, {end_xla_commit}] using {failing_url}"
134146
)
135147
)
136148

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import pytest
2+
from jax_toolbox_triage.args import parse_args
3+
4+
test_command = ["my-test-command"]
5+
valid_start_end_container = [
6+
"--passing-container",
7+
"passing-url",
8+
"--failing-container",
9+
"failing-url",
10+
]
11+
valid_start_end_date_args = [
12+
["--container", "jax"],
13+
["--container", "jax", "--start-date", "2024-10-02"],
14+
["--container", "jax", "--end-date", "2024-10-02"],
15+
["--container", "jax", "--start-date", "2024-10-01", "--end-date", "2024-10-02"],
16+
]
17+
18+
19+
@pytest.mark.parametrize(
20+
"good_args", [valid_start_end_container] + valid_start_end_date_args
21+
)
22+
def test_good_container_args(good_args):
23+
args = parse_args(good_args + test_command)
24+
assert args.test_command == test_command
25+
26+
27+
@pytest.mark.parametrize("date_args", valid_start_end_date_args)
28+
def test_bad_container_arg_combinations_across_groups(date_args):
29+
# Can't combine --{start,end}-container with --container/--{start,end}-date
30+
with pytest.raises(Exception):
31+
parse_args(valid_start_end_container + date_args + test_command)
32+
33+
34+
@pytest.mark.parametrize(
35+
"container_args",
36+
[
37+
# Need --container
38+
[],
39+
["--start-date", "2024-10-01"],
40+
["--end-date", "2024-10-02"],
41+
["--start-date", "2024-10-01", "--end-date", "2024-10-02"],
42+
# Need both if either is passed
43+
["--passing-container", "passing-url"],
44+
["--failing-container", "failing-url"],
45+
],
46+
)
47+
def test_bad_container_arg_combinations_within_groups(container_args):
48+
with pytest.raises(Exception):
49+
parse_args(container_args + test_command)
50+
51+
52+
@pytest.mark.parametrize(
53+
"container_args",
54+
[
55+
# Need valid ISO dates
56+
["--container", "jax", "--start-date", "a-blue-moon-ago"],
57+
["--container", "jax", "--end-date", "a-year-ago-last-thursday"],
58+
],
59+
)
60+
def test_unparsable_container_args(container_args):
61+
with pytest.raises(SystemExit):
62+
parse_args(container_args + test_command)

‎.github/triage/tests/test_triage_logic.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import logging
44
import pytest
55
import random
6-
from jax_toolbox_triage.logic import commit_search, container_search
6+
from jax_toolbox_triage.logic import commit_search, container_search, TestResult
77

88

99
def wrap(b):
@@ -306,7 +306,7 @@ def test_container_search_limits(
306306
with pytest.raises(Exception, match=match_string):
307307
container_search(
308308
container_exists=lambda dt: dt in dates_that_exist,
309-
container_passes=lambda dt: False,
309+
container_passes=lambda dt: TestResult(result=False),
310310
start_date=start_date,
311311
end_date=end_date,
312312
logger=logger,
@@ -353,7 +353,7 @@ def test_container_search_checks(
353353
with pytest.raises(Exception, match=match_string):
354354
container_search(
355355
container_exists=lambda dt: True,
356-
container_passes=lambda dt: dt in dates_that_pass,
356+
container_passes=lambda dt: TestResult(result=dt in dates_that_pass),
357357
start_date=start_date,
358358
end_date=end_date,
359359
logger=logger,
@@ -374,7 +374,7 @@ def test_container_search(logger, start_date, days_of_failure, threshold_days):
374374
assert start_date is None or threshold_date >= start_date
375375
good_date, bad_date = container_search(
376376
container_exists=lambda dt: True,
377-
container_passes=lambda dt: dt < threshold_date,
377+
container_passes=lambda dt: TestResult(result=dt < threshold_date),
378378
start_date=start_date,
379379
end_date=end_date,
380380
logger=logger,

‎docs/triage-tool.md

+58-4
Original file line numberDiff line numberDiff line change
@@ -20,33 +20,65 @@ The tool follows a three-step process:
2020
failing, and a reference commit of XLA (JAX) that can be used to reproduce the
2121
regression.
2222

23+
The third step can also be used on its own, via the `--passing-container` and
24+
`--failing-container` options, which allows it to be used between private container
25+
tags, without the dependency on the `ghcr.io/nvidia/jax` registry. This assumes that
26+
the given containers are closely related to those from JAX-Toolbox
27+
(`ghcr.io/nvidia/jax:XXX`):
28+
* JAX and XLA sources at `/opt/{jax,xla}[-source]`
29+
* `build-jax.sh` script from JAX-Toolbox available in the container
30+
2331
## Installation
2432

2533
The triage tool can be installed using `pip`:
2634
```bash
2735
pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
2836
```
2937
or directly from a checkout of the JAX-Toolbox repository.
38+
39+
You should make sure `pip` is up to date, for example with `pip install -U pip`. The
40+
versions of `pip` installed on cluster head/compute nodes can be quite old. The
41+
recommended installation method, using `virtualenv`, should take care of this for you.
42+
3043
Because the tool needs to orchestrate running commands in multiple containers, it is
3144
most convenient to install it in a virtual environment on the host system, rather than
3245
attempting to install it inside a container.
3346

47+
The recommended installation method is to install `virtualenv` natively on the host
48+
system, and then use that to create an isolated environment on the host system for the
49+
triage tool, *i.e.*:
50+
```bash
51+
virtualenv triage-venv
52+
./triage-venv/bin/pip install git+https://github.com/NVIDIA/JAX-Toolbox.git#subdirectory=.github/triage
53+
./triage-venv/bin/jax-toolbox-triage ...
54+
```
55+
3456
The tool should be invoked on a machine with `docker` available and whatever GPUs are
3557
needed to execute the test case.
3658

3759
## Usage
3860

39-
To use the tool, there are two compulsory arguments:
40-
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
41-
families to execute the test command in. Example: `jax` for a JAX unit test
42-
failure, `maxtext` for a MaxText model execution failure
61+
To use the tool, there are two compulsory inputs:
4362
* A test command to triage.
63+
* A specification of which containers to triage in. There are two choices here:
64+
* `--container`: which of the `ghcr.io/nvidia/jax:CONTAINER-YYYY-MM-DD` container
65+
families to execute the test command in. Example: `jax` for a JAX unit test
66+
failure, `maxtext` for a MaxText model execution failure. The `--start-date` and
67+
`--end-date` options can be combined with `--container` to tune the search; see
68+
below for more details.
69+
* `--passing-container` and `--failing-container`: a pair of URLs to containers to
70+
use in the commit-level search; if these are passed then no container-level
71+
search is performed.
4472

4573
The test command will be executed directly in the container, not inside a shell, so be
4674
sure not to add excessive quotation marks (*i.e.* run
4775
`jax-toolbox-triage --container=jax test-jax.sh foo` not
4876
`jax-toolbox-triage --container=jax "test-jax.sh foo"`), and you should aim to make it
4977
as fast and targeted as possible.
78+
79+
If you want to run multiple commands, you might want to use something like
80+
`jax-toolbox-triage --container=jax sh -c "command1 && command2"`.
81+
5082
The expectation is that the test case will be executed successfully several times as
5183
part of the triage, so you may want to tune some parameters to reduce the execution
5284
time in the successful case.
@@ -55,6 +87,28 @@ probably reduce `--steps` to optimise execution time in the successful case.
5587

5688
A JSON status file and both info-level and debug-level logfiles are written to the
5789
directory given by `--output-prefix`.
90+
Info-level output is also written to the console, and includes the path to the debug
91+
log file.
92+
93+
You should pay attention to the first execution of your test case, to make sure it is
94+
failing for the correct reason. For example:
95+
```console
96+
$ jax-toolbox-triage --container jax command-you-forgot-to-install
97+
```
98+
will not immediately abort, because the tool is **expecting** the command to fail in
99+
the early stages of the triage:
100+
```
101+
[INFO] 2024-10-29 01:49:01 Verbose output, including stdout/err of triage commands, will be written to /home/olupton/JAX-Toolbox/triage-2024-10-29-01-49-01/debug.log
102+
[INFO] 2024-10-29 01:49:05 Checking end-of-range failure in 2024-10-27
103+
[INFO] 2024-10-29 01:49:05 Ran test case in 2024-10-27 in 0.4s, pass=False
104+
[INFO] 2024-10-29 01:49:05 stdout: OCI runtime exec failed: exec failed: unable to start container process: exec: "command-you-forgot-to-install": executable file not found in $PATH: unknown
105+
106+
[INFO] 2024-10-29 01:49:05 stderr:
107+
[INFO] 2024-10-29 01:49:05 IMPORTANT: you should check that the test output above shows the *expected* failure of your test case in the 2024-10-27 container. It is very easy to accidentally provide a test case that fails for the wrong reason, which will not triage the correct issue!
108+
[INFO] 2024-10-29 01:49:06 Starting coarse search with 2024-10-26 based on end_date=2024-10-27
109+
[INFO] 2024-10-29 01:49:06 Ran test case in 2024-10-26 in 0.4s, pass=False
110+
```
111+
where, notably, the triage search is continuing.
58112

59113
### Optimising container-level search performance
60114

0 commit comments

Comments
 (0)
Please sign in to comment.