diff --git a/paddle/fluid/pybind/ir.cc b/paddle/fluid/pybind/ir.cc index 6c6957c3e00e08..c7f6c57af97500 100644 --- a/paddle/fluid/pybind/ir.cc +++ b/paddle/fluid/pybind/ir.cc @@ -39,6 +39,7 @@ #include "paddle/ir/core/value.h" #include "paddle/ir/pass/pass.h" #include "paddle/ir/pass/pass_manager.h" +#include "paddle/ir/pass/pass_registry.h" #include "paddle/ir/transforms/dead_code_elimination_pass.h" #include "paddle/phi/core/enforce.h" #include "pybind11/stl.h" @@ -488,15 +489,6 @@ void BindIrPass(pybind11::module *m) { [](const Pass &self) { return self.pass_info().dependents; }); } -// TODO(zhiqiu): refine pass registry -std::unique_ptr CreatePassByName(std::string name) { - if (name == "DeadCodeEliminationPass") { - return ir::CreateDeadCodeEliminationPass(); - } else { - IR_THROW("The %s pass is not registed", name); - } -} - void BindPassManager(pybind11::module *m) { py::class_> pass_manager( *m, @@ -514,7 +506,7 @@ void BindPassManager(pybind11::module *m) { py::arg("opt_level") = 2) .def("add_pass", [](PassManager &self, std::string pass_name) { - self.AddPass(std::move(CreatePassByName(pass_name))); + self.AddPass(std::move(PassRegistry::Instance().Get(name);)); }) .def("passes", [](PassManager &self) { diff --git a/paddle/ir/pass/pass.h b/paddle/ir/pass/pass.h index 7e01786d4738f6..4a4cbf629d678e 100644 --- a/paddle/ir/pass/pass.h +++ b/paddle/ir/pass/pass.h @@ -20,7 +20,7 @@ #include "paddle/ir/core/enforce.h" #include "paddle/ir/pass/analysis_manager.h" -#include "paddle/ir/pass/pass" +#include "paddle/ir/pass/pass_registry.h" #include "paddle/phi/core/enforce.h" namespace ir { diff --git a/paddle/ir/pass/pass_registry.cc b/paddle/ir/pass/pass_registry.cc index 413880ab6b6e31..a0239219a694d8 100644 --- a/paddle/ir/pass/pass_registry.cc +++ b/paddle/ir/pass/pass_registry.cc @@ -12,24 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -#pragma once - -#include -#include -#include -#include - #include "paddle/ir/pass/pass_registry.h" namespace ir { - -class IrContext; -class Operation; -class Program; -class Pass; -class PassInstrumentation; -class PassInstrumentor; - PassRegistry &PassRegistry::Instance() { static PassRegistry g_pass_info_map; return g_pass_info_map; diff --git a/paddle/ir/pass/pass_registry.h b/paddle/ir/pass/pass_registry.h index aeb174c1f5aea5..c35dc0ba90a302 100644 --- a/paddle/ir/pass/pass_registry.h +++ b/paddle/ir/pass/pass_registry.h @@ -14,10 +14,13 @@ #pragma once -#include -#include +#include #include -#include +#include + +#include "paddle/ir/core/enforce.h" +#include "paddle/ir/core/macros.h" +#include "paddle/ir/pass/pass.h" namespace ir { @@ -30,39 +33,24 @@ class PassRegistry { static PassRegistry &Instance(); bool Has(const std::string &pass_type) const { - return map_.find(pass_type) != map_.end(); + return pass_map_.find(pass_type) != pass_map_.end(); } void Insert(const std::string &pass_type, const PassCreator &pass_creator) { - PADDLE_ENFORCE_NE(Has(pass_type), - true, - platform::errors::AlreadyExists( - "Pass %s has been registered.", pass_type)); - map_.insert({pass_type, pass_creator}); + IR_ENFORCE( + Has(pass_type) != true, "Pass %s has been registered.", pass_type); + pass_map_.insert({pass_type, pass_creator}); } std::unique_ptr Get(const std::string &pass_type) const { - if (pass_type == "tensorrt_subgraph_pass") { - PADDLE_ENFORCE_EQ(Has(pass_type), - true, - platform::errors::InvalidArgument( - "Pass %s has not been registered. Please " - "use the paddle inference library " - "compiled with tensorrt or disable " - "the tensorrt engine in inference configuration! ", - pass_type)); - } else { - PADDLE_ENFORCE_EQ(Has(pass_type), - true, - platform::errors::InvalidArgument( - "Pass %s has not been registered.", pass_type)); - } - return map_.at(pass_type)(); + IR_ENFORCE( + Has(pass_type) == true, "Pass %s has not been registered.", pass_type); + return pass_map_.at(pass_type)(); } private: PassRegistry() = default; - std::unordered_map map_; + std::unordered_map pass_map_; DISABLE_COPY_AND_ASSIGN(PassRegistry); }; @@ -80,7 +68,8 @@ class PassRegistrar { // registrar variable won't be removed by the linker. void Touch() {} explicit PassRegistrar(const char *pass_type) { - PassRegistry::Instance().Insert(pass_type, std::make_unique()); + PassRegistry::Instance().Insert( + pass_type, []() { return std::make_unique(); }); } }; diff --git a/paddle/ir/transforms/dead_code_elimination_pass.cc b/paddle/ir/transforms/dead_code_elimination_pass.cc index c9278e904dadc4..d56b83b8446804 100644 --- a/paddle/ir/transforms/dead_code_elimination_pass.cc +++ b/paddle/ir/transforms/dead_code_elimination_pass.cc @@ -76,3 +76,5 @@ std::unique_ptr CreateDeadCodeEliminationPass() { } } // namespace ir + +REGISTER_PASS(dead_code_elimination, DeadCodeEliminationPass); diff --git a/test/ir/new_ir/test_pass_manager.py b/test/ir/new_ir/test_pass_manager.py index 580dea776772c4..2f31e945f31f40 100644 --- a/test/ir/new_ir/test_pass_manager.py +++ b/test/ir/new_ir/test_pass_manager.py @@ -51,11 +51,12 @@ def test_op(self): self.assertTrue('pd.uniform' in op_names) pm = ir.PassManager() pm.add_pass( - 'DeadCodeEliminationPass' + 'dead_code_elimination' ) # apply pass to elimitate dead code pm.run(new_program) op_names = [op.name() for op in new_program.block().ops] # print(op_names) + # TODO(zhiqiu): unify the name of pass self.assertEqual(pm.passes(), ['DeadCodeEliminationPass']) self.assertFalse(pm.empty()) self.assertTrue(