@@ -36,12 +36,15 @@ limitations under the License.
36
36
#include " absl/types/span.h"
37
37
#include " xla/debug_options_flags.h"
38
38
#include " xla/hlo/ir/hlo_instruction.h"
39
+ #include " xla/hlo/ir/hlo_module.h"
39
40
#include " xla/hlo/ir/hlo_module_group.h"
40
41
#include " xla/hlo/ir/hlo_opcode.h"
42
+ #include " xla/hlo/parser/hlo_parser.h"
41
43
#include " xla/hlo/pass/hlo_pass_interface.h"
42
44
#include " xla/hlo/testlib/filecheck.h"
43
45
#include " xla/hlo/testlib/verified_hlo_module.h"
44
46
#include " xla/hlo/utils/hlo_query.h"
47
+ #include " xla/service/computation_placer.h"
45
48
#include " xla/service/hlo_module_config.h"
46
49
#include " xla/service/hlo_verifier.h"
47
50
#include " xla/shape.h"
@@ -84,12 +87,24 @@ HloHardwareIndependentTestBase::CreateNewVerifiedModule(
84
87
instruction_can_change_layout_func_);
85
88
}
86
89
90
+ /* static */ DeviceAssignment
91
+ HloHardwareIndependentTestBase::GetDefaultDeviceAssignment (
92
+ int64_t replica_count, int64_t num_partitions) {
93
+ DeviceAssignment device_assignment (replica_count, num_partitions);
94
+ device_assignment.FillIota (0 );
95
+ return device_assignment;
96
+ }
97
+
87
98
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
88
99
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule (
89
- absl::string_view hlo_text, int64_t replica_count,
90
- int64_t num_partitions) const {
91
- return ParseAndReturnVerifiedModule (
92
- hlo_text, GetModuleConfigForTest (replica_count, num_partitions));
100
+ absl::string_view hlo_text, int64_t replica_count, int64_t num_partitions,
101
+ std::optional<DeviceAssignment> device_assignment) const {
102
+ HloModuleConfig config =
103
+ GetModuleConfigForTest (replica_count, num_partitions);
104
+ if (device_assignment.has_value ()) {
105
+ config.set_static_device_assignment (device_assignment.value ());
106
+ }
107
+ return ParseAndReturnVerifiedModule (hlo_text, config);
93
108
}
94
109
95
110
absl::Status HloHardwareIndependentTestBase::
@@ -117,12 +132,31 @@ absl::Status HloHardwareIndependentTestBase::
117
132
118
133
absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
119
134
HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule (
120
- absl::string_view hlo_text, const HloModuleConfig& config) const {
135
+ absl::string_view hlo_text, const HloModuleConfig& config,
136
+ const HloParserOptions& parser_options) const {
137
+ return ParseAndReturnVerifiedModule (hlo_text, config, parser_options,
138
+ ShapeUtil::ByteSizeOfElements);
139
+ }
140
+
141
+ absl::StatusOr<std::unique_ptr<VerifiedHloModule>>
142
+ HloHardwareIndependentTestBase::ParseAndReturnVerifiedModule (
143
+ absl::string_view hlo_text, const HloModuleConfig& config,
144
+ const HloParserOptions& parser_options,
145
+ std::function<int64_t (const xla::Shape&)> shape_size_fn) const {
146
+ HloModuleConfig config_with_device_assignment = config;
147
+ if (!config.has_static_device_assignment ()) {
148
+ default_device_assignment_ =
149
+ std::make_unique<DeviceAssignment>(GetDefaultDeviceAssignment (
150
+ config.replica_count (), config.num_partitions ()));
151
+ config_with_device_assignment.set_static_device_assignment (
152
+ *default_device_assignment_);
153
+ }
121
154
auto module = std::make_unique<VerifiedHloModule>(
122
- TestName (), config , verifier_layout_sensitive_,
123
- allow_mixed_precision_in_hlo_verifier_, ShapeUtil::ByteSizeOfElements ,
155
+ TestName (), config_with_device_assignment , verifier_layout_sensitive_,
156
+ allow_mixed_precision_in_hlo_verifier_, shape_size_fn ,
124
157
instruction_can_change_layout_func_);
125
- TF_RETURN_IF_ERROR (module->ParseHloStringAndVerifyModule (hlo_text));
158
+ TF_RETURN_IF_ERROR (
159
+ module->ParseHloStringAndVerifyModule (hlo_text, parser_options));
126
160
return module;
127
161
}
128
162
0 commit comments