Skip to content

Commit 9b96886

Browse files
[PJRT] Add support of passing per-compilation compile options (#19438)
As discussed in #19418 (comment), #19418 (review) and #19418 (comment), here we support to read `env_option_overrides` as IREE compile flags from `compile_options` passed by frontends like JAX in a per-compilation basis. Most of these code already exists but has been commented due to some problems: `compile_options` was not yet available in that time, but it's now introduced by #19369. A simple use case is shown below, also as a test case: https://github.com/iree-org/iree/blob/c37a80212dd4a541762fc9fdaaa615b6d0a62829/integrations/pjrt/test/test_compile_options.py#L9-L15 ci-exactly: build_packages, test_pjrt --------- Signed-off-by: PragmaTwice <twice@apache.org> Co-authored-by: Scott Todd <scott.todd0@gmail.com>
1 parent 6b686c7 commit 9b96886

File tree

8 files changed

+94
-67
lines changed

8 files changed

+94
-67
lines changed

.github/workflows/pkgci_test_pjrt.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ jobs:
6161
source ${VENV_DIR}/bin/activate
6262
python -m pip install -v --no-deps -e integrations/pjrt/python_packages/iree_${{ matrix.pjrt_platform }}_plugin
6363
# install
64-
python -m pip install jax==0.4.35
64+
python -m pip install jax==0.4.36
6565
- name: Run tests
6666
run: |
6767
source ${VENV_DIR}/bin/activate

build_tools/testing/run_jax_tests.sh

+7
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,13 @@ diff_jax_test test/test_add.py
4848
diff_jax_test test/test_degenerate.py
4949
diff_jax_test test/test_simple.py
5050

51+
# here we test if the compile options is passed to IREE PJRT plugin successfully.
52+
# we pass --iree-scheduling-dump-statistics-format=csv via jax.jit,
53+
# and see if there's statistics in the output
54+
compile_options_test_tmp_out=$(mktemp /tmp/jax_test_result_compile_options.XXXXXX)
55+
JAX_PLATFORMS=$actual_jax_platform python test/test_compile_options.py 2>&1 | tee $compile_options_test_tmp_out
56+
cat $compile_options_test_tmp_out | grep '@main_dispatch'
57+
5158

5259
# FIXME: we can also utilize the native test cases from JAX,
5360
# e.g. `tests/nn_test.py` from the JAX repo, as below,

integrations/pjrt/src/iree_pjrt/common/CMakeLists.txt

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ iree_cc_library(
5656
iree::compiler::bindings::c::loader
5757
iree_pjrt::partitioner_api
5858
iree_pjrt::partitioner_api::loader
59+
iree_pjrt_deps::protos
5960
PUBLIC
6061
)
6162

integrations/pjrt/src/iree_pjrt/common/api_impl.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -1487,8 +1487,8 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
14871487
}
14881488

14891489
// Set flags.
1490-
// TODO: Plumb CompileOptions through.
1491-
// if (!job->SetFlags(options)) return MakeCompilerError(*job);
1490+
if (!job->SetFlags(options)) return MakeCompilerError(*job);
1491+
14921492
if (artifact_tx) {
14931493
artifact_tx->WriteArtifact(
14941494
/*label=*/"partitioner_flags", /*extension=*/"txt", /*index=*/-1,
@@ -1538,8 +1538,8 @@ PJRT_Error* ClientInstance::Compile(const PJRT_Program* program,
15381538
if (!SetDefaultCompilerFlags(job.get())) {
15391539
return MakeCompilerError(*job);
15401540
}
1541-
// TODO: Plumb CompileOptions through.
1542-
// if (!job->SetFlags(options)) return MakeCompilerError(*job);
1541+
if (!job->SetFlags(options)) return MakeCompilerError(*job);
1542+
15431543
if (artifact_tx) {
15441544
artifact_tx->WriteArtifact(
15451545
/*label=*/"flags", /*extension=*/"txt", /*index=*/-1,

integrations/pjrt/src/iree_pjrt/common/compiler.h

+2-4
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,7 @@
1111
#include <string>
1212

1313
#include "iree_pjrt/common/debugging.h"
14-
// TODO: Excise.
15-
// #include "xla/pjrt/pjrt_executable.h"
14+
#include "xla/pjrt/compile_options.pb.h"
1615

1716
namespace iree::pjrt {
1817

@@ -37,8 +36,7 @@ class CompilerJob {
3736
// setup of a job (or if the underlying session will not be re-used).
3837
// Returns false on failure.
3938
virtual bool SetFlag(const char* flag) = 0;
40-
// TODO: Excise.
41-
// virtual bool SetFlags(xla::CompileOptions options) = 0;
39+
virtual bool SetFlags(xla::CompileOptionsProto options) = 0;
4240

4341
// Gets all flags as a string. This is intended for debug printing a plausible
4442
// command line to reproduce compilation.

integrations/pjrt/src/iree_pjrt/common/hlo_partitioner.cc

+38-35
Original file line numberDiff line numberDiff line change
@@ -97,41 +97,44 @@ class OpenXLAPartitionerJob : public CompilerJob {
9797
return true;
9898
}
9999

100-
// TODO: Find another way to deal with this.
101-
// bool SetFlags(xla::CompileOptions options) override {
102-
// int num_partitions = options.executable_build_options.num_partitions();
103-
// int num_replicas = options.executable_build_options.num_replicas();
104-
// bool use_spmd_partitioning =
105-
// options.executable_build_options.use_spmd_partitioning();
106-
// auto allow_spmd_sharding_propagation_to_output =
107-
// options.executable_build_options
108-
// .allow_spmd_sharding_propagation_to_output();
109-
// if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-num-partitions=",
110-
// num_partitions)
111-
// .c_str())) {
112-
// return false;
113-
// }
114-
// if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-replica-count=",
115-
// num_replicas)
116-
// .c_str())) {
117-
// return false;
118-
// }
119-
// if (!SetFlag(
120-
// absl::StrCat("--openxla-partitioner-gspmd-use-spmd-partitioning=",
121-
// use_spmd_partitioning)
122-
// .c_str())) {
123-
// return false;
124-
// }
125-
// if (!SetFlag(
126-
// absl::StrCat(
127-
// "--openxla-partitioner-gspmd-allow-spmd-"
128-
// "sharding-propagation-to-output=",
129-
// absl::StrJoin(allow_spmd_sharding_propagation_to_output,
130-
// ",")) .c_str())) {
131-
// return false;
132-
// }
133-
// return true;
134-
// }
100+
bool SetFlags(xla::CompileOptionsProto options) override {
101+
int num_partitions = options.executable_build_options().num_partitions();
102+
int num_replicas = options.executable_build_options().num_replicas();
103+
bool use_spmd_partitioning =
104+
options.executable_build_options().use_spmd_partitioning();
105+
auto allow_spmd_sharding_propagation_to_output =
106+
options.executable_build_options()
107+
.allow_spmd_sharding_propagation_to_output();
108+
if (!SetFlag(("--openxla-partitioner-gspmd-num-partitions=" +
109+
std::to_string(num_partitions))
110+
.c_str())) {
111+
return false;
112+
}
113+
if (!SetFlag(("--openxla-partitioner-gspmd-replica-count=" +
114+
std::to_string(num_replicas))
115+
.c_str())) {
116+
return false;
117+
}
118+
if (!SetFlag(("--openxla-partitioner-gspmd-use-spmd-partitioning=" +
119+
std::to_string(use_spmd_partitioning))
120+
.c_str())) {
121+
return false;
122+
}
123+
std::string allow_spmd_sharding_propagation_to_output_str;
124+
for (size_t i = 0; i < allow_spmd_sharding_propagation_to_output.size();
125+
++i) {
126+
if (i != 0) allow_spmd_sharding_propagation_to_output_str += ",";
127+
allow_spmd_sharding_propagation_to_output_str +=
128+
std::to_string(allow_spmd_sharding_propagation_to_output[i]);
129+
}
130+
if (!SetFlag(("--openxla-partitioner-gspmd-allow-spmd-"
131+
"sharding-propagation-to-output=" +
132+
allow_spmd_sharding_propagation_to_output_str)
133+
.c_str())) {
134+
return false;
135+
}
136+
return true;
137+
}
135138

136139
std::string GetFlags() override {
137140
std::string flags;

integrations/pjrt/src/iree_pjrt/common/iree_compiler.cc

+22-23
Original file line numberDiff line numberDiff line change
@@ -95,29 +95,28 @@ class IREECompilerJob : public CompilerJob {
9595
return true;
9696
}
9797

98-
// TODO: Excise: Cannot dep on an internal XLA structure.
99-
// bool SetFlags(xla::CompileOptions options) override {
100-
// // Set extra options, overriding env variables if appropriate.
101-
// for (auto [option, option_override] : options.env_option_overrides) {
102-
// std::string override_string;
103-
// if (auto override_val = std::get_if<std::string>(&option_override)) {
104-
// override_string = *override_val;
105-
// } else if (auto override_val = std::get_if<bool>(&option_override)) {
106-
// override_string = *override_val ? "true" : "false";
107-
// } else if (auto override_val = std::get_if<int64_t>(&option_override))
108-
// {
109-
// override_string = std::to_string(*override_val);
110-
// } else {
111-
// assert(false &&
112-
// "option value should be of type string, bool, or int64");
113-
// }
114-
// if (!SetFlag(absl::StrCat("--", option, "=", override_string).c_str()))
115-
// {
116-
// return false;
117-
// }
118-
// }
119-
// return true;
120-
// }
98+
bool SetFlags(xla::CompileOptionsProto options) override {
99+
// Set extra options, overriding env variables if appropriate.
100+
for (auto [option, option_override] : options.env_option_overrides()) {
101+
std::string override_string;
102+
if (option_override.has_string_field()) {
103+
override_string = option_override.string_field();
104+
} else if (option_override.has_bool_field()) {
105+
override_string = option_override.bool_field() ? "true" : "false";
106+
} else if (option_override.has_int_field()) {
107+
override_string = std::to_string(option_override.int_field());
108+
} else if (option_override.has_double_field()) {
109+
override_string = std::to_string(option_override.double_field());
110+
} else {
111+
assert(false &&
112+
"option value should be of type string, bool, int, or double");
113+
}
114+
if (!SetFlag(("--" + option + "=" + override_string).c_str())) {
115+
return false;
116+
}
117+
}
118+
return true;
119+
}
121120

122121
std::string GetFlags() override {
123122
std::string flags;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2024 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
from functools import partial
8+
import jax.numpy as jnp
9+
from jax import jit
10+
11+
a = jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8, 9])
12+
13+
14+
@partial(jit, compiler_options={"iree-scheduling-dump-statistics-format": "csv"})
15+
def f(a, b):
16+
return a + b
17+
18+
19+
print(f(a, a))

0 commit comments

Comments
 (0)