Skip to content

Commit 4e37f18

Browse files
jreiffersGoogle-ML-Automation
authored andcommitted
PR #23681: Add an API to get a computation's caller(s).
Imported from GitHub PR #23681 This should replace CallGraph in most cases, and adds an alternative to the deprecated .*CallInstruction functions. Copybara import of the project: -- 4168abb by Johannes Reifferscheid <jreiffers@nvidia.com>: Add an API to get a computation's caller(s). This should replace CallGraph in most cases, and adds an alternative to the deprecated .*CallInstruction functions. Merging this change closes #23681 FUTURE_COPYBARA_INTEGRATE_REVIEW=#23681 from jreiffers:caller-instructions 4168abb PiperOrigin-RevId: 736448470
1 parent 49c16af commit 4e37f18

File tree

4 files changed

+171
-22
lines changed

4 files changed

+171
-22
lines changed

xla/hlo/ir/hlo_computation.cc

+64-5
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,9 @@ HloComputation::~HloComputation() {
204204
}
205205
CHECK(caller_computations_.empty());
206206

207+
// Delete the map from caller instructions to count, if it exists.
208+
delete GetCallersMap();
209+
207210
for (const auto& i : instructions_) {
208211
delete i.inst();
209212
}
@@ -277,8 +280,6 @@ static void IncrementCount(
277280
++map[key];
278281
}
279282

280-
// Returns true if the callee was present and its count was decremented; returns
281-
// false if the callee was not present.
282283
static void DecrementCount(
283284
absl::btree_map<HloComputation*, int, HloComputation::UniqueIdComparator>&
284285
map,
@@ -292,17 +293,75 @@ static void DecrementCount(
292293
}
293294
}
294295

295-
void HloComputation::AddCallee(HloComputation* callee) {
296+
void HloComputation::AddCallee(const HloInstruction* caller,
297+
HloComputation* callee) {
296298
IncrementCount(callee_computations_, callee);
297299
IncrementCount(callee->caller_computations_, this);
300+
301+
if (auto* map = callee->GetCallersMap()) {
302+
++(*map)[caller];
303+
} else if (callee->callers_ == 0) {
304+
callee->callers_ = reinterpret_cast<uintptr_t>(caller);
305+
} else {
306+
// Convert the single instruction to a map.
307+
auto* current_caller = reinterpret_cast<const HloInstruction*>(
308+
callee->callers_ & ~kCallerTypeMask);
309+
auto* map = new absl::flat_hash_map<const HloInstruction*, int>();
310+
(*map)[current_caller] = 1;
311+
++(*map)[caller];
312+
callee->callers_ = reinterpret_cast<uintptr_t>(map) |
313+
static_cast<uintptr_t>(CallersType::kCallerCountHashMap);
314+
}
315+
298316
if (parent() != nullptr && callee->parent() == parent()) {
299317
parent()->topological_sort_.AddEdge(this, callee);
300318
}
301319
}
302320

303-
void HloComputation::RemoveCallee(HloComputation* callee) {
321+
void HloComputation::RemoveCallee(const HloInstruction* caller,
322+
HloComputation* callee) {
323+
CHECK(caller);
324+
CHECK(callee);
304325
DecrementCount(callee_computations_, callee);
305326
DecrementCount(callee->caller_computations_, this);
327+
328+
if (callee->callers_ == reinterpret_cast<uintptr_t>(caller)) {
329+
// The callee had just this single caller, so we reset it to 0 (no caller).
330+
callee->callers_ = 0;
331+
} else {
332+
auto* map = callee->GetCallersMap();
333+
CHECK(map) << "Attempted to remove a caller " << caller->name()
334+
<< " that did not call the computation " << name() << "."
335+
<< callee->callers_;
336+
auto it = map->find(caller);
337+
CHECK(it != map->end())
338+
<< "Attempted to remove a caller " << caller->name()
339+
<< " that did not call the computation " << name() << ".";
340+
--it->second;
341+
// We don't convert back to the inline representation, since this case
342+
// should be rare.
343+
}
344+
}
345+
346+
absl::flat_hash_map<const HloInstruction*, int>*
347+
HloComputation::GetCallersMap() {
348+
if (static_cast<CallersType>(callers_ & kCallerTypeMask) ==
349+
CallersType::kCallerCountHashMap) {
350+
return reinterpret_cast<absl::flat_hash_map<const HloInstruction*, int>*>(
351+
callers_ & ~kCallerTypeMask);
352+
}
353+
return nullptr;
354+
}
355+
356+
absl::flat_hash_map<const HloInstruction*, int>* const
357+
HloComputation::GetCallersMap() const {
358+
if (static_cast<CallersType>(callers_ & kCallerTypeMask) ==
359+
CallersType::kCallerCountHashMap) {
360+
return reinterpret_cast<
361+
absl::flat_hash_map<const HloInstruction*, int>* const>(
362+
callers_ & ~kCallerTypeMask);
363+
}
364+
return nullptr;
306365
}
307366

308367
HloInstruction* HloComputation::AddInstructionInternal(
@@ -329,7 +388,7 @@ HloInstruction* HloComputation::AddInstructionInternal(
329388
CHECK(parent() == nullptr || called_computation->parent() == parent())
330389
<< "Called computation " << called_computation->name()
331390
<< " is not in the same module as " << name();
332-
AddCallee(called_computation);
391+
AddCallee(pinst, called_computation);
333392
}
334393
return pinst;
335394
}

xla/hlo/ir/hlo_computation.h

+36-2
Original file line numberDiff line numberDiff line change
@@ -982,11 +982,34 @@ class HloComputation {
982982
return caller_computations_;
983983
}
984984

985+
// The returned callers are in no particular order.
986+
absl::InlinedVector<const HloInstruction*, 1> caller_instructions() const {
987+
if (const auto* map = GetCallersMap()) {
988+
absl::InlinedVector<const HloInstruction*, 1> result;
989+
for (const auto& [instr, _] : *map) {
990+
result.push_back(instr);
991+
}
992+
return result;
993+
}
994+
995+
if (callers_ == 0) {
996+
return {};
997+
}
998+
return {
999+
reinterpret_cast<const HloInstruction*>(callers_ & ~kCallerTypeMask)};
1000+
}
1001+
9851002
void ClearCalledComputations();
9861003

9871004
private:
9881005
friend class HloModule;
9891006

1007+
enum class CallersType : uint8_t {
1008+
kHloInstruction = 0,
1009+
kCallerCountHashMap = 1,
1010+
};
1011+
static constexpr uintptr_t kCallerTypeMask = 0b1;
1012+
9901013
explicit HloComputation(
9911014
const std::string& name, int parameter_count,
9921015
std::vector<std::unique_ptr<HloInstruction>>* instructions,
@@ -1051,8 +1074,14 @@ class HloComputation {
10511074
void SetUniqueIdHelper(int64_t id);
10521075

10531076
friend class HloInstruction;
1054-
void AddCallee(HloComputation* callee);
1055-
void RemoveCallee(HloComputation* callee);
1077+
// Add/remove call from `caller`, which must be in this computation, to
1078+
// `callee`.
1079+
void AddCallee(const HloInstruction* caller, HloComputation* callee);
1080+
void RemoveCallee(const HloInstruction* caller, HloComputation* callee);
1081+
1082+
// Returns nullptr if `callers_` is not a map.
1083+
absl::flat_hash_map<const HloInstruction*, int>* GetCallersMap();
1084+
absl::flat_hash_map<const HloInstruction*, int>* const GetCallersMap() const;
10561085

10571086
// Unique ID of this computation.
10581087
// This is set to -1 if the computation is not in a module. Should only be
@@ -1067,6 +1096,11 @@ class HloComputation {
10671096
// The respective type in the least significant three bits.
10681097
uintptr_t instruction_and_type_ = 0;
10691098

1099+
// Contains an HloInstruction* or an absl::flat_hash_map<HloInstruction*,
1100+
// /*count=*/int> in the high bits and a CallersType in the least significant
1101+
// bit.
1102+
uintptr_t callers_ = 0;
1103+
10701104
// If this computation is an async computation, this field points to the
10711105
// first async instruction (async-start) in the asynchronous op chain that
10721106
// calls this computation.

xla/hlo/ir/hlo_instruction.cc

+4-4
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ void HloInstruction::AppendComputation(HloComputation* computation) {
220220
// of T and hlo_instruction.h does not include hlo_computation.h.
221221
mutable_rare()->called_computations.push_back(computation);
222222
if (parent()) {
223-
parent()->AddCallee(computation);
223+
parent()->AddCallee(this, computation);
224224
}
225225
}
226226

@@ -235,10 +235,10 @@ void HloInstruction::set_called_computation(int index,
235235
std::swap(old_computation, mutable_rare()->called_computations[index]);
236236
if (parent()) {
237237
if (old_computation) {
238-
parent()->RemoveCallee(old_computation);
238+
parent()->RemoveCallee(this, old_computation);
239239
}
240240
if (computation) {
241-
parent()->AddCallee(computation);
241+
parent()->AddCallee(this, computation);
242242
}
243243
}
244244
}
@@ -255,7 +255,7 @@ void HloInstruction::ClearCalledComputations() {
255255
if (parent()) {
256256
for (HloComputation* computation : called_computations()) {
257257
if (computation) {
258-
parent()->RemoveCallee(computation);
258+
parent()->RemoveCallee(this, computation);
259259
}
260260
}
261261
}

xla/hlo/ir/hlo_module_test.cc

+67-11
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ namespace xla {
4343
namespace {
4444

4545
using ::testing::ElementsAre;
46+
using ::testing::IsEmpty;
47+
using ::testing::UnorderedElementsAre;
4648

4749
TEST(HloModuleTest, AbslHashValue) {
4850
HloModule module1("temp_module", HloModuleConfig());
@@ -477,41 +479,95 @@ TEST(HloModuleTest, CheckToStringHonorsDebugOptions) {
477479
}
478480

479481
TEST(HloModuleTest, TestCallersAndCallees) {
480-
// Check that the debug options xla_dump_large_constants,
481-
// xla_syntax_sugar_async_ops are honored.
482482
const char* hlo = R"(
483483
HloModule jit_h
484484
485485
f {
486-
Arg_0.3 = f32[] parameter(0)
487-
ROOT sine.4 = f32[] sine(Arg_0.3)
486+
p0 = f32[] parameter(0)
487+
ROOT sine.4 = f32[] sine(p0)
488488
}
489489
490490
g {
491-
Arg_0.13 = f32[] parameter(0)
492-
call.14 = f32[] call(Arg_0.13), to_apply=f
493-
ROOT call.15 = f32[] call(call.14), to_apply=f
491+
p0 = f32[] parameter(0)
492+
call.f.0 = f32[] call(p0), to_apply=f
493+
ROOT call.f.1 = f32[] call(call.f.0), to_apply=f
494+
}
495+
496+
h {
497+
ROOT p0 = f32[] parameter(0)
498+
}
499+
500+
uncalled {
501+
p0 = f32[] parameter(0)
502+
ROOT call.h = f32[] call(p0), to_apply=h
494503
}
495504
496505
ENTRY main {
497506
Arg_0.1 = f32[] parameter(0)
498-
call.5 = f32[] call(Arg_0.1), to_apply=f
499-
call.16 = f32[] call(call.5), to_apply=g
500-
ROOT call.27 = f32[] call(call.16), to_apply=g
507+
call.f.2 = f32[] call(Arg_0.1), to_apply=f
508+
call.g.0 = f32[] call(call.f.2), to_apply=g
509+
ROOT call.g.1 = f32[] call(call.g.0), to_apply=g
501510
})";
502511
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
503512
ParseAndReturnUnverifiedModule(hlo));
504-
EXPECT_EQ(module->computation_count(), 3);
513+
EXPECT_EQ(module->computation_count(), 5);
505514
HloComputation* main = module->GetComputationWithName("main");
506515
HloComputation* f = module->GetComputationWithName("f");
507516
HloComputation* g = module->GetComputationWithName("g");
517+
HloComputation* h = module->GetComputationWithName("h");
518+
HloComputation* uncalled = module->GetComputationWithName("uncalled");
508519
EXPECT_THAT(main->callee_computations(),
509520
ElementsAre(std::make_pair(f, 1), std::make_pair(g, 2)));
510521
EXPECT_THAT(f->callee_computations(), ElementsAre());
511522
EXPECT_THAT(g->callee_computations(), ElementsAre(std::make_pair(f, 2)));
512523
EXPECT_THAT(f->caller_computations(),
513524
ElementsAre(std::make_pair(g, 2), std::make_pair(main, 1)));
514525
EXPECT_THAT(g->caller_computations(), ElementsAre(std::make_pair(main, 2)));
526+
527+
HloInstruction* call_f_0 = g->GetInstructionWithName("call.f.0");
528+
HloInstruction* call_f_1 = g->GetInstructionWithName("call.f.1");
529+
HloInstruction* call_f_2 = main->GetInstructionWithName("call.f.2");
530+
HloInstruction* call_g_0 = main->GetInstructionWithName("call.g.0");
531+
HloInstruction* call_g_1 = main->GetInstructionWithName("call.g.1");
532+
HloInstruction* call_h = uncalled->GetInstructionWithName("call.h");
533+
534+
EXPECT_THAT(f->caller_instructions(),
535+
UnorderedElementsAre(call_f_0, call_f_1, call_f_2));
536+
EXPECT_THAT(g->caller_instructions(),
537+
UnorderedElementsAre(call_g_0, call_g_1));
538+
EXPECT_THAT(h->caller_instructions(), ElementsAre(call_h));
539+
EXPECT_THAT(uncalled->caller_instructions(), IsEmpty());
540+
}
541+
542+
TEST(HloModuleTest, MultipleCallsFromOneInstruction) {
543+
const char* hlo = R"(
544+
f {
545+
tparam = f32[4] parameter(0)
546+
ROOT tuple = (f32[4]) tuple(tparam)
547+
}
548+
549+
g {
550+
fparam = f32[4] parameter(0)
551+
ROOT tuple = (f32[4]) tuple(fparam)
552+
}
553+
554+
ENTRY main {
555+
p0 = f32[4] parameter(0)
556+
b0 = s32[] parameter(1)
557+
ROOT conditional = (f32[4]) conditional(b0, p0, p0, p0),
558+
branch_computations={f, f, g}
559+
})";
560+
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
561+
ParseAndReturnUnverifiedModule(hlo));
562+
EXPECT_EQ(module->computation_count(), 3);
563+
HloComputation* main = module->GetComputationWithName("main");
564+
HloComputation* f = module->GetComputationWithName("f");
565+
HloComputation* g = module->GetComputationWithName("g");
566+
567+
HloInstruction* conditional = main->GetInstructionWithName("conditional");
568+
569+
EXPECT_THAT(f->caller_instructions(), ElementsAre(conditional));
570+
EXPECT_THAT(g->caller_instructions(), ElementsAre(conditional));
515571
}
516572

517573
} // namespace

0 commit comments

Comments
 (0)