Skip to content

Commit

Permalink
spliting
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangbaizhou committed Mar 20, 2024
1 parent 6a01a19 commit d86e15e
Show file tree
Hide file tree
Showing 4 changed files with 55 additions and 53 deletions.
50 changes: 0 additions & 50 deletions paddle/cinn/frontend/cluster_ops/common_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,54 +114,4 @@ std::function<bool(const pir::Operation*)> MakePredicatorIsInThisFusionOp(
};
}

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo) {
const auto& IsSource = [&](const pir::Operation* op) {
std::size_t num_inputs = 0;
op_topo.VisitInputOp(op,
[&](const pir::Operation* input) { ++num_inputs; });
return num_inputs == 0;
};

const auto starts = [&] {
std::list<const pir::Operation*> starts;
for (const auto* op : *op_topo.ops) {
if (IsSource(op)) {
starts.push_back(op);
} else {
// do nothing.
}
}
return starts;
}();

std::unordered_map<const pir::Operation*, bool> op_2_is_injective_source;

auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) {
bool is_inputs_all_injective_source = true;
op_topo.VisitInputOp(op, [&](const pir::Operation* input) {
is_inputs_all_injective_source = (is_inputs_all_injective_source &&
op_2_is_injective_source.at(input));
});
return is_inputs_all_injective_source;
};
const auto VisitInput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitInputOp(op, DoEach);
};
const auto VisitOutput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitOutputOp(op, DoEach);
};
common::TopoWalker<const pir::Operation*> walker{VisitInput, VisitOutput};
walker(starts.begin(), starts.end(), [&](const pir::Operation* op) {
op_2_is_injective_source[op] =
(IsGeneralInjective(op) && IsInputsAllInjectiveSource(op));
});
return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) {
const auto& iter = map.find(op);
CHECK(iter != map.end());
return iter->second;
};
}
} // namespace cinn::frontend::cluster_ops
3 changes: 0 additions & 3 deletions paddle/cinn/frontend/cluster_ops/common_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,4 @@ struct OpTopo {
std::function<bool(const pir::Operation*)> MakePredicatorIsInThisFusionOp(
const std::vector<const pir::Operation*>& ops);

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo);

} // namespace cinn::frontend::cluster_ops
52 changes: 52 additions & 0 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -180,4 +180,56 @@ auto VisitCachedOutput = [stmt2outputs](const auto* stmt,
};
return common::TopoWalker<const StmtPattern*>(VisitCachedInput,
VisitCachedOutput);

}

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo) {
const auto& IsSource = [&](const pir::Operation* op) {
std::size_t num_inputs = 0;
op_topo.VisitInputOp(op,
[&](const pir::Operation* input) { ++num_inputs; });
return num_inputs == 0;
};

const auto starts = [&] {
std::list<const pir::Operation*> starts;
for (const auto* op : *op_topo.ops) {
if (IsSource(op)) {
starts.push_back(op);
} else {
// do nothing.
}
}
return starts;
}();

std::unordered_map<const pir::Operation*, bool> op_2_is_injective_source;

auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) {
bool is_inputs_all_injective_source = true;
op_topo.VisitInputOp(op, [&](const pir::Operation* input) {
is_inputs_all_injective_source = (is_inputs_all_injective_source &&
op_2_is_injective_source.at(input));
});
return is_inputs_all_injective_source;
};
const auto VisitInput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitInputOp(op, DoEach);
};
const auto VisitOutput = [&](const pir::Operation* op,
const OpVisitor& DoEach) {
op_topo.VisitOutputOp(op, DoEach);
};
common::TopoWalker<const pir::Operation*> walker{VisitInput, VisitOutput};
walker(starts.begin(), starts.end(), [&](const pir::Operation* op) {
op_2_is_injective_source[op] =
(IsGeneralInjective(op) && IsInputsAllInjectiveSource(op));
});
return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) {
const auto& iter = map.find(op);
CHECK(iter != map.end());
return iter->second;
};
}
3 changes: 3 additions & 0 deletions paddle/cinn/frontend/cluster_ops/pattern_utils.h
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
common::TopoWalker<const StmtPattern*> MakeTopoWalker(
const OpTopo& op_topo, const std::vector<StmtPattern>& stmt_patterns);

std::function<bool(const pir::Operation*)> MakePredicatorIsInjectiveSource(
const OpTopo& op_topo);

0 comments on commit d86e15e

Please sign in to comment.