diff --git a/paddle/cinn/frontend/cluster_ops/common_utils.cc b/paddle/cinn/frontend/cluster_ops/common_utils.cc index 0d4b9f604beabc..bea699f995566a 100644 --- a/paddle/cinn/frontend/cluster_ops/common_utils.cc +++ b/paddle/cinn/frontend/cluster_ops/common_utils.cc @@ -114,54 +114,4 @@ std::function MakePredicatorIsInThisFusionOp( }; } -std::function MakePredicatorIsInjectiveSource( - const OpTopo& op_topo) { - const auto& IsSource = [&](const pir::Operation* op) { - std::size_t num_inputs = 0; - op_topo.VisitInputOp(op, - [&](const pir::Operation* input) { ++num_inputs; }); - return num_inputs == 0; - }; - - const auto starts = [&] { - std::list starts; - for (const auto* op : *op_topo.ops) { - if (IsSource(op)) { - starts.push_back(op); - } else { - // do nothing. - } - } - return starts; - }(); - - std::unordered_map op_2_is_injective_source; - - auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) { - bool is_inputs_all_injective_source = true; - op_topo.VisitInputOp(op, [&](const pir::Operation* input) { - is_inputs_all_injective_source = (is_inputs_all_injective_source && - op_2_is_injective_source.at(input)); - }); - return is_inputs_all_injective_source; - }; - const auto VisitInput = [&](const pir::Operation* op, - const OpVisitor& DoEach) { - op_topo.VisitInputOp(op, DoEach); - }; - const auto VisitOutput = [&](const pir::Operation* op, - const OpVisitor& DoEach) { - op_topo.VisitOutputOp(op, DoEach); - }; - common::TopoWalker walker{VisitInput, VisitOutput}; - walker(starts.begin(), starts.end(), [&](const pir::Operation* op) { - op_2_is_injective_source[op] = - (IsGeneralInjective(op) && IsInputsAllInjectiveSource(op)); - }); - return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) { - const auto& iter = map.find(op); - CHECK(iter != map.end()); - return iter->second; - }; -} } // namespace cinn::frontend::cluster_ops \ No newline at end of file diff --git a/paddle/cinn/frontend/cluster_ops/common_utils.h b/paddle/cinn/frontend/cluster_ops/common_utils.h index 0851e12bc1f5a8..8eb08625857518 100644 --- a/paddle/cinn/frontend/cluster_ops/common_utils.h +++ b/paddle/cinn/frontend/cluster_ops/common_utils.h @@ -103,7 +103,4 @@ struct OpTopo { std::function MakePredicatorIsInThisFusionOp( const std::vector& ops); -std::function MakePredicatorIsInjectiveSource( - const OpTopo& op_topo); - } // namespace cinn::frontend::cluster_ops diff --git a/paddle/cinn/frontend/cluster_ops/pattern_utils.cc b/paddle/cinn/frontend/cluster_ops/pattern_utils.cc index 2e1b9993ba5c0e..e3e4dee354cbdc 100644 --- a/paddle/cinn/frontend/cluster_ops/pattern_utils.cc +++ b/paddle/cinn/frontend/cluster_ops/pattern_utils.cc @@ -180,4 +180,56 @@ auto VisitCachedOutput = [stmt2outputs](const auto* stmt, }; return common::TopoWalker(VisitCachedInput, VisitCachedOutput); + } + +std::function MakePredicatorIsInjectiveSource( + const OpTopo& op_topo) { + const auto& IsSource = [&](const pir::Operation* op) { + std::size_t num_inputs = 0; + op_topo.VisitInputOp(op, + [&](const pir::Operation* input) { ++num_inputs; }); + return num_inputs == 0; + }; + + const auto starts = [&] { + std::list starts; + for (const auto* op : *op_topo.ops) { + if (IsSource(op)) { + starts.push_back(op); + } else { + // do nothing. + } + } + return starts; + }(); + + std::unordered_map op_2_is_injective_source; + + auto IsInputsAllInjectiveSource = [&](const pir::Operation* op) { + bool is_inputs_all_injective_source = true; + op_topo.VisitInputOp(op, [&](const pir::Operation* input) { + is_inputs_all_injective_source = (is_inputs_all_injective_source && + op_2_is_injective_source.at(input)); + }); + return is_inputs_all_injective_source; + }; + const auto VisitInput = [&](const pir::Operation* op, + const OpVisitor& DoEach) { + op_topo.VisitInputOp(op, DoEach); + }; + const auto VisitOutput = [&](const pir::Operation* op, + const OpVisitor& DoEach) { + op_topo.VisitOutputOp(op, DoEach); + }; + common::TopoWalker walker{VisitInput, VisitOutput}; + walker(starts.begin(), starts.end(), [&](const pir::Operation* op) { + op_2_is_injective_source[op] = + (IsGeneralInjective(op) && IsInputsAllInjectiveSource(op)); + }); + return [map = std::move(op_2_is_injective_source)](const pir::Operation* op) { + const auto& iter = map.find(op); + CHECK(iter != map.end()); + return iter->second; + }; +} \ No newline at end of file diff --git a/paddle/cinn/frontend/cluster_ops/pattern_utils.h b/paddle/cinn/frontend/cluster_ops/pattern_utils.h index 5a09ac0dc99206..4f6dc96d361ca9 100644 --- a/paddle/cinn/frontend/cluster_ops/pattern_utils.h +++ b/paddle/cinn/frontend/cluster_ops/pattern_utils.h @@ -1,2 +1,5 @@ common::TopoWalker MakeTopoWalker( const OpTopo& op_topo, const std::vector& stmt_patterns); + + std::function MakePredicatorIsInjectiveSource( + const OpTopo& op_topo);