Skip to content

Commit 9db514a

Browse files
Tongfei-GuoGoogle-ML-Automation
authored andcommitted
[XLA:Test] Attach default device assignment in base test classes.
PiperOrigin-RevId: 736443782
1 parent 7351c47 commit 9db514a

9 files changed

+98
-44
lines changed

xla/hlo/testlib/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -66,9 +66,11 @@ cc_library(
6666
"//xla:xla_proto_cc",
6767
"//xla/hlo/ir:hlo",
6868
"//xla/hlo/ir:hlo_module_group",
69+
"//xla/hlo/parser:hlo_parser",
6970
"//xla/hlo/pass:hlo_pass",
7071
"//xla/hlo/utils:hlo_query",
7172
"//xla/service:computation_layout",
73+
"//xla/service:computation_placer_hdr",
7274
"//xla/service:hlo_module_config",
7375
"//xla/service:hlo_verifier",
7476
"//xla/tsl/platform:errors",

xla/hlo/testlib/hlo_hardware_independent_test_base.cc

+42-8
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,15 @@ limitations under the License.
3636
#include "absl/types/span.h"
3737
#include "xla/debug_options_flags.h"
3838
#include "xla/hlo/ir/hlo_instruction.h"
39+
#include "xla/hlo/ir/hlo_module.h"
3940
#include "xla/hlo/ir/hlo_module_group.h"
4041
#include "xla/hlo/ir/hlo_opcode.h"
42+
#include "xla/hlo/parser/hlo_parser.h"
4143
#include "xla/hlo/pass/hlo_pass_interface.h"
4244
#include "xla/hlo/testlib/filecheck.h"
4345
#include "xla/hlo/testlib/verified_hlo_module.h"
4446
#include "xla/hlo/utils/hlo_query.h"
47+
#include "xla/service/computation_placer.h"
4548
#include "xla/service/hlo_module_config.h"
4649
#include "xla/service/hlo_verifier.h"
4750
#include "xla/shape.h"
@@ -84,12 +87,24 @@ HloHardwareIndependentTestBase::CreateNewVerifiedModule(
8487
instruction_can_change_layout_func_);
8588
}
8689

90+
/* static */ DeviceAssignment
91+
HloHardwareIndependentTestBase::GetDefaultDeviceAssignment(
92+
int64_t replica_count, int64_t num_partitions) {
93+
DeviceAssignment device_assignment(replica_count, num_partitions);
94+
device_assignment.FillIota(0);
95+
return device_assignment;
96+
}
97+
8798
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
8899
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
89-
absl::string_view hlo_text, int64_t replica_count,
90-
int64_t num_partitions) const {
91-
return ParseAndReturnVerifiedModule(
92-
hlo_text, GetModuleConfigForTest(replica_count, num_partitions));
100+
absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions,
101+
std::optional<DeviceAssignment> device_assignment) const {
102+
HloModuleConfig config =
103+
GetModuleConfigForTest(replica_count, num_partitions);
104+
if (device_assignment.has_value()) {
105+
config.set_static_device_assignment(device_assignment.value());
106+
}
107+
return ParseAndReturnVerifiedModule(hlo_text, config);
93108
}
94109

95110
absl::Status HloHardwareIndependentTestBase::
@@ -117,12 +132,31 @@ absl::Status HloHardwareIndependentTestBase::
117132

118133
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
119134
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
120-
absl::string_view hlo_text, const HloModuleConfig& config) const {
135+
absl::string_view hlo_text, const HloModuleConfig& config,
136+
const HloParserOptions& parser_options) const {
137+
return ParseAndReturnVerifiedModule(hlo_text, config, parser_options,
138+
ShapeUtil::ByteSizeOfElements);
139+
}
140+
141+
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
142+
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
143+
absl::string_view hlo_text, const HloModuleConfig& config,
144+
const HloParserOptions& parser_options,
145+
std::function<int64_t(const xla::Shape&)> shape_size_fn) const {
146+
HloModuleConfig config_with_device_assignment = config;
147+
if (!config.has_static_device_assignment()) {
148+
default_device_assignment_ =
149+
std::make_unique<DeviceAssignment>(GetDefaultDeviceAssignment(
150+
config.replica_count(), config.num_partitions()));
151+
config_with_device_assignment.set_static_device_assignment(
152+
*default_device_assignment_);
153+
}
121154
auto module = std::make_unique<VerifiedHloModule>(
122-
TestName(), config, verifier_layout_sensitive_,
123-
allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements,
155+
TestName(), config_with_device_assignment, verifier_layout_sensitive_,
156+
allow_mixed_precision_in_hlo_verifier_, shape_size_fn,
124157
instruction_can_change_layout_func_);
125-
TF_RETURN_IF_ERROR(module->ParseHloStringAndVerifyModule(hlo_text));
158+
TF_RETURN_IF_ERROR(
159+
module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
126160
return module;
127161
}
128162

xla/hlo/testlib/hlo_hardware_independent_test_base.h

+31-7
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,12 @@ limitations under the License.
3535
#include "xla/hlo/ir/hlo_module.h"
3636
#include "xla/hlo/ir/hlo_module_group.h"
3737
#include "xla/hlo/ir/hlo_opcode.h"
38+
#include "xla/hlo/parser/hlo_parser.h"
3839
#include "xla/hlo/pass/hlo_pass_interface.h"
3940
#include "xla/hlo/testlib/verified_hlo_module.h"
4041
#include "xla/layout.h"
4142
#include "xla/service/computation_layout.h"
43+
#include "xla/service/computation_placer.h"
4244
#include "xla/service/hlo_module_config.h"
4345
#include "xla/service/hlo_verifier.h"
4446
#include "xla/shape_layout.h"
@@ -97,14 +99,26 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
9799
std::unique_ptr<VerifiedHloModule> CreateNewVerifiedModule(
98100
const std::string& name = TestName(), int64_t replica_count = 1) const;
99101

102+
// Returns a default device assignment for the given replica and partition
103+
// counts.
104+
static DeviceAssignment GetDefaultDeviceAssignment(int64_t replica_count,
105+
int64_t num_partitions);
106+
100107
// Parses the given string and returns module as a VerifiedHloModule.
101108
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
102-
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
103-
int64_t replica_count = 1,
104-
int64_t num_partitions = 1) const;
109+
ParseAndReturnVerifiedModule(
110+
absl::string_view hlo_text, int64_t replica_count = 1,
111+
int64_t num_partitions = 1,
112+
std::optional<DeviceAssignment> device_assignment = std::nullopt) const;
113+
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
114+
ParseAndReturnVerifiedModule(
115+
absl::string_view hlo_text, const HloModuleConfig& config,
116+
const HloParserOptions& parser_options = HloParserOptions()) const;
105117
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
106-
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
107-
const HloModuleConfig& config) const;
118+
ParseAndReturnVerifiedModule(
119+
absl::string_view hlo_text, const HloModuleConfig& config,
120+
const HloParserOptions& parser_options,
121+
std::function<int64_t(const xla::Shape&)> shape_size_fn) const;
108122

109123
// Runs the hlo_pass with the provided module and returns the result. This
110124
// function also verifies that the module remains unchanged when hlo_pass
@@ -194,13 +208,22 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
194208
// options (e.g. disabling additional passes).
195209
virtual DebugOptions GetDebugOptionsForTest() const;
196210

211+
void TearDown() override { default_device_assignment_.reset(); }
197212
// Gets an HloModuleConfig with options appropriate for tests.
198-
HloModuleConfig GetModuleConfigForTest(int64_t replica_count = 1,
199-
int64_t num_partitions = 1) const {
213+
HloModuleConfig GetModuleConfigForTest(
214+
int64_t replica_count = 1, int64_t num_partitions = 1,
215+
std::optional<DeviceAssignment> device_assignment = std::nullopt) const {
200216
HloModuleConfig config;
201217
config.set_debug_options(GetDebugOptionsForTest());
202218
config.set_replica_count(replica_count);
203219
config.set_num_partitions(num_partitions);
220+
if (device_assignment.has_value()) {
221+
config.set_static_device_assignment(*device_assignment);
222+
} else {
223+
default_device_assignment_ = std::make_unique<DeviceAssignment>(
224+
GetDefaultDeviceAssignment(replica_count, num_partitions));
225+
config.set_static_device_assignment(*default_device_assignment_);
226+
}
204227
return config;
205228
}
206229

@@ -282,6 +305,7 @@ class HloHardwareIndependentTestBase : public ::testing::Test {
282305
bool allow_mixed_precision_in_hlo_verifier_;
283306
HloPredicate instruction_can_change_layout_func_;
284307
std::unique_ptr<HloVerifier> hlo_verifier_;
308+
mutable std::unique_ptr<DeviceAssignment> default_device_assignment_;
285309
};
286310

287311
} // namespace xla

xla/service/gpu/transforms/BUILD

+2
Original file line numberDiff line numberDiff line change
@@ -1912,10 +1912,12 @@ xla_test(
19121912
"//xla/hlo/testlib:pattern_matcher_gmock",
19131913
"//xla/hlo/testlib:test",
19141914
"//xla/hlo/testlib:verified_hlo_module",
1915+
"//xla/service:computation_placer_hdr",
19151916
"//xla/service:hlo_module_config",
19161917
"//xla/service:pattern_matcher",
19171918
"//xla/stream_executor:device_description",
19181919
"//xla/stream_executor:semantic_version",
1920+
"//xla/tests:hlo_runner_agnostic_test_base",
19191921
"//xla/tsl/platform:statusor",
19201922
"@com_google_absl//absl/container:flat_hash_map",
19211923
"@com_google_absl//absl/status:statusor",

xla/service/gpu/transforms/gemm_rewriter_fp8_test.cc

+10-4
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,14 @@ limitations under the License.
3535
#include "xla/hlo/testlib/pattern_matcher_gmock.h"
3636
#include "xla/hlo/testlib/test.h"
3737
#include "xla/hlo/testlib/verified_hlo_module.h"
38+
#include "xla/service/computation_placer.h"
3839
#include "xla/service/gpu/transforms/gemm_rewriter.h"
3940
#include "xla/service/gpu/transforms/gemm_rewriter_test_lib.h"
4041
#include "xla/service/hlo_module_config.h"
4142
#include "xla/service/pattern_matcher.h"
4243
#include "xla/stream_executor/device_description.h"
4344
#include "xla/stream_executor/semantic_version.h"
45+
#include "xla/tests/hlo_runner_agnostic_test_base.h"
4446
#include "xla/tsl/platform/statusor.h"
4547
#include "xla/xla.pb.h"
4648

@@ -118,12 +120,16 @@ class ParameterizedFp8GemmRewriteTest
118120
}
119121
}
120122

123+
using ParameterizedGemmRewriteTestBase::ParseAndReturnVerifiedModule;
124+
121125
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
122-
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
123-
int64_t replica_count = 1,
124-
int64_t num_partitions = 1) {
126+
ParseAndReturnVerifiedModule(
127+
absl::string_view hlo_text, int64_t replica_count = 1,
128+
int64_t num_partitions = 1,
129+
std::optional<DeviceAssignment> device_assignment = std::nullopt) const {
125130
return GemmRewriteTestBase::ParseAndReturnVerifiedModule(
126-
absl::StrReplaceAll(hlo_text, replacements_));
131+
absl::StrReplaceAll(hlo_text, replacements_), replica_count,
132+
num_partitions, device_assignment);
127133
}
128134

129135
private:

xla/service/hlo_computation_test.cc

+2-2
Original file line numberDiff line numberDiff line change
@@ -843,8 +843,8 @@ ENTRY entry {
843843
TF_ASSERT_OK_AND_ASSIGN(auto module,
844844
ParseAndReturnVerifiedModule(hlo_string));
845845
EXPECT_THAT(module->entry_computation()->MakeInstructionPostOrder(),
846-
ElementsAre(op::Parameter(), op::AllReduce(), op::AllReduce(),
847-
op::Add(), op::Tuple()));
846+
ElementsAre(op::Parameter(), op::AllReduce(), op::Add(),
847+
op::AllReduce(), op::Tuple()));
848848
}
849849

850850
TEST_F(HloComputationTest, ComparisonWithCustomComparator) {

xla/service/hlo_module_config.h

+3
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,9 @@ class HloModuleConfig {
287287
void set_static_device_assignment(const DeviceAssignment& device_assignment) {
288288
static_device_assignment_ = device_assignment;
289289
}
290+
void reset_static_device_assignment() {
291+
static_device_assignment_ = std::nullopt;
292+
}
290293

291294
// Checks if this config has a simulated device assignment.
292295
bool has_pre_simulation_device_assignment() const {

xla/tests/hlo_runner_agnostic_test_base.cc

+3-16
Original file line numberDiff line numberDiff line change
@@ -76,25 +76,12 @@ HloRunnerAgnosticTestBase::CreateNewVerifiedModule(
7676
instruction_can_change_layout_func());
7777
}
7878

79-
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
80-
HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(
81-
absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions) {
82-
return ParseAndReturnVerifiedModule(
83-
hlo_text, GetModuleConfigForTest(replica_count, num_partitions));
84-
}
85-
8679
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
8780
HloRunnerAgnosticTestBase::ParseAndReturnVerifiedModule(
8881
absl::string_view hlo_text, const HloModuleConfig& config,
89-
const HloParserOptions& parser_options) {
90-
auto module = std::make_unique<VerifiedHloModule>(
91-
TestName(), config, verifier_layout_sensitive(),
92-
allow_mixed_precision_in_hlo_verifier(),
93-
test_runner_->device_shape_size_fn(),
94-
instruction_can_change_layout_func());
95-
TF_RETURN_IF_ERROR(
96-
module->ParseHloStringAndVerifyModule(hlo_text, parser_options));
97-
return std::move(module);
82+
const HloParserOptions& parser_options) const {
83+
return HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule(
84+
hlo_text, config, parser_options, test_runner_->device_shape_size_fn());
9885
}
9986

10087
HloComputation*

xla/tests/hlo_runner_agnostic_test_base.h

+3-7
Original file line numberDiff line numberDiff line change
@@ -101,20 +101,16 @@ class HloRunnerAgnosticTestBase : public HloHardwareIndependentTestBase {
101101
const std::string& name = TestName(), int64_t replica_count = 1);
102102

103103
// Parses the given string and returns module as a VerifiedHloModule.
104-
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
105-
ParseAndReturnVerifiedModule(absl::string_view hlo_text,
106-
int64_t replica_count = 1,
107-
int64_t num_partitions = 1);
108-
// Parses the given string and returns module as a VerifiedHloModule.
109-
//
104+
using HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule;
105+
110106
// To obtain a HloModuleConfig with a specific replica and partition count and
111107
// no further customization, either use the overload above or use
112108
// GetModuleConfigForTest. The latter option may be useful if you want to pass
113109
// custom HloParserOptions as well.
114110
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
115111
ParseAndReturnVerifiedModule(
116112
absl::string_view hlo_text, const HloModuleConfig& config,
117-
const HloParserOptions& parser_options = HloParserOptions());
113+
const HloParserOptions& parser_options = HloParserOptions()) const;
118114

119115
HloComputation* AddEntryComputationAndUpdateEntryComputationLayout(
120116
HloModule*, std::unique_ptr<HloComputation> computation);

0 commit comments

Comments
 (0)