@@ -97,41 +97,44 @@ class OpenXLAPartitionerJob : public CompilerJob {
97
97
return true ;
98
98
}
99
99
100
- // TODO: Find another way to deal with this.
101
- // bool SetFlags(xla::CompileOptions options) override {
102
- // int num_partitions = options.executable_build_options.num_partitions();
103
- // int num_replicas = options.executable_build_options.num_replicas();
104
- // bool use_spmd_partitioning =
105
- // options.executable_build_options.use_spmd_partitioning();
106
- // auto allow_spmd_sharding_propagation_to_output =
107
- // options.executable_build_options
108
- // .allow_spmd_sharding_propagation_to_output();
109
- // if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-num-partitions=",
110
- // num_partitions)
111
- // .c_str())) {
112
- // return false;
113
- // }
114
- // if (!SetFlag(absl::StrCat("--openxla-partitioner-gspmd-replica-count=",
115
- // num_replicas)
116
- // .c_str())) {
117
- // return false;
118
- // }
119
- // if (!SetFlag(
120
- // absl::StrCat("--openxla-partitioner-gspmd-use-spmd-partitioning=",
121
- // use_spmd_partitioning)
122
- // .c_str())) {
123
- // return false;
124
- // }
125
- // if (!SetFlag(
126
- // absl::StrCat(
127
- // "--openxla-partitioner-gspmd-allow-spmd-"
128
- // "sharding-propagation-to-output=",
129
- // absl::StrJoin(allow_spmd_sharding_propagation_to_output,
130
- // ",")) .c_str())) {
131
- // return false;
132
- // }
133
- // return true;
134
- // }
100
+ bool SetFlags (xla::CompileOptionsProto options) override {
101
+ int num_partitions = options.executable_build_options ().num_partitions ();
102
+ int num_replicas = options.executable_build_options ().num_replicas ();
103
+ bool use_spmd_partitioning =
104
+ options.executable_build_options ().use_spmd_partitioning ();
105
+ auto allow_spmd_sharding_propagation_to_output =
106
+ options.executable_build_options ()
107
+ .allow_spmd_sharding_propagation_to_output ();
108
+ if (!SetFlag ((" --openxla-partitioner-gspmd-num-partitions=" +
109
+ std::to_string (num_partitions))
110
+ .c_str ())) {
111
+ return false ;
112
+ }
113
+ if (!SetFlag ((" --openxla-partitioner-gspmd-replica-count=" +
114
+ std::to_string (num_replicas))
115
+ .c_str ())) {
116
+ return false ;
117
+ }
118
+ if (!SetFlag ((" --openxla-partitioner-gspmd-use-spmd-partitioning=" +
119
+ std::to_string (use_spmd_partitioning))
120
+ .c_str ())) {
121
+ return false ;
122
+ }
123
+ std::string allow_spmd_sharding_propagation_to_output_str;
124
+ for (size_t i = 0 ; i < allow_spmd_sharding_propagation_to_output.size ();
125
+ ++i) {
126
+ if (i != 0 ) allow_spmd_sharding_propagation_to_output_str += " ," ;
127
+ allow_spmd_sharding_propagation_to_output_str +=
128
+ std::to_string (allow_spmd_sharding_propagation_to_output[i]);
129
+ }
130
+ if (!SetFlag ((" --openxla-partitioner-gspmd-allow-spmd-"
131
+ " sharding-propagation-to-output=" +
132
+ allow_spmd_sharding_propagation_to_output_str)
133
+ .c_str ())) {
134
+ return false ;
135
+ }
136
+ return true ;
137
+ }
135
138
136
139
std::string GetFlags () override {
137
140
std::string flags;
0 commit comments