Skip to content

Commit

Permalink
add pass registry macro
Browse files Browse the repository at this point in the history
  • Loading branch information
zhiqiu committed Aug 28, 2023
1 parent 603ce9a commit 7233626
Show file tree
Hide file tree
Showing 6 changed files with 23 additions and 54 deletions.
12 changes: 2 additions & 10 deletions paddle/fluid/pybind/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
#include "paddle/ir/core/value.h"
#include "paddle/ir/pass/pass.h"
#include "paddle/ir/pass/pass_manager.h"
#include "paddle/ir/pass/pass_registry.h"
#include "paddle/ir/transforms/dead_code_elimination_pass.h"
#include "paddle/phi/core/enforce.h"
#include "pybind11/stl.h"
Expand Down Expand Up @@ -488,15 +489,6 @@ void BindIrPass(pybind11::module *m) {
[](const Pass &self) { return self.pass_info().dependents; });
}

// TODO(zhiqiu): refine pass registry
std::unique_ptr<Pass> CreatePassByName(std::string name) {
if (name == "DeadCodeEliminationPass") {
return ir::CreateDeadCodeEliminationPass();
} else {
IR_THROW("The %s pass is not registed", name);
}
}

void BindPassManager(pybind11::module *m) {
py::class_<PassManager, std::shared_ptr<PassManager>> pass_manager(
*m,
Expand All @@ -514,7 +506,7 @@ void BindPassManager(pybind11::module *m) {
py::arg("opt_level") = 2)
.def("add_pass",
[](PassManager &self, std::string pass_name) {
self.AddPass(std::move(CreatePassByName(pass_name)));
self.AddPass(std::move(PassRegistry::Instance().Get(name);));
})
.def("passes",
[](PassManager &self) {
Expand Down
2 changes: 1 addition & 1 deletion paddle/ir/pass/pass.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#include "paddle/ir/core/enforce.h"
#include "paddle/ir/pass/analysis_manager.h"
#include "paddle/ir/pass/pass"
#include "paddle/ir/pass/pass_registry.h"
#include "paddle/phi/core/enforce.h"

namespace ir {
Expand Down
15 changes: 0 additions & 15 deletions paddle/ir/pass/pass_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <cstdint>
#include <iostream>
#include <memory>
#include <vector>

#include "paddle/ir/pass/pass_registry.h"

namespace ir {

class IrContext;
class Operation;
class Program;
class Pass;
class PassInstrumentation;
class PassInstrumentor;

PassRegistry &PassRegistry::Instance() {
static PassRegistry g_pass_info_map;
return g_pass_info_map;
Expand Down
43 changes: 16 additions & 27 deletions paddle/ir/pass/pass_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@

#pragma once

#include <cstdint>
#include <iostream>
#include <functional>
#include <memory>
#include <vector>
#include <unordered_map>

#include "paddle/ir/core/enforce.h"
#include "paddle/ir/core/macros.h"
#include "paddle/ir/pass/pass.h"

namespace ir {

Expand All @@ -30,39 +33,24 @@ class PassRegistry {
static PassRegistry &Instance();

bool Has(const std::string &pass_type) const {
return map_.find(pass_type) != map_.end();
return pass_map_.find(pass_type) != pass_map_.end();
}

void Insert(const std::string &pass_type, const PassCreator &pass_creator) {
PADDLE_ENFORCE_NE(Has(pass_type),
true,
platform::errors::AlreadyExists(
"Pass %s has been registered.", pass_type));
map_.insert({pass_type, pass_creator});
IR_ENFORCE(
Has(pass_type) != true, "Pass %s has been registered.", pass_type);
pass_map_.insert({pass_type, pass_creator});
}

std::unique_ptr<Pass> Get(const std::string &pass_type) const {
if (pass_type == "tensorrt_subgraph_pass") {
PADDLE_ENFORCE_EQ(Has(pass_type),
true,
platform::errors::InvalidArgument(
"Pass %s has not been registered. Please "
"use the paddle inference library "
"compiled with tensorrt or disable "
"the tensorrt engine in inference configuration! ",
pass_type));
} else {
PADDLE_ENFORCE_EQ(Has(pass_type),
true,
platform::errors::InvalidArgument(
"Pass %s has not been registered.", pass_type));
}
return map_.at(pass_type)();
IR_ENFORCE(
Has(pass_type) == true, "Pass %s has not been registered.", pass_type);
return pass_map_.at(pass_type)();
}

private:
PassRegistry() = default;
std::unordered_map<std::string, PassCreator> map_;
std::unordered_map<std::string, PassCreator> pass_map_;

DISABLE_COPY_AND_ASSIGN(PassRegistry);
};
Expand All @@ -80,7 +68,8 @@ class PassRegistrar {
// registrar variable won't be removed by the linker.
void Touch() {}
explicit PassRegistrar(const char *pass_type) {
PassRegistry::Instance().Insert(pass_type, std::make_unique<PassType>());
PassRegistry::Instance().Insert(
pass_type, []() { return std::make_unique<PassType>(); });
}
};

Expand Down
2 changes: 2 additions & 0 deletions paddle/ir/transforms/dead_code_elimination_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,5 @@ std::unique_ptr<Pass> CreateDeadCodeEliminationPass() {
}

} // namespace ir

REGISTER_PASS(dead_code_elimination, DeadCodeEliminationPass);
3 changes: 2 additions & 1 deletion test/ir/new_ir/test_pass_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def test_op(self):
self.assertTrue('pd.uniform' in op_names)
pm = ir.PassManager()
pm.add_pass(
'DeadCodeEliminationPass'
'dead_code_elimination'
) # apply pass to elimitate dead code
pm.run(new_program)
op_names = [op.name() for op in new_program.block().ops]
# print(op_names)
# TODO(zhiqiu): unify the name of pass
self.assertEqual(pm.passes(), ['DeadCodeEliminationPass'])
self.assertFalse(pm.empty())
self.assertTrue(
Expand Down

0 comments on commit 7233626

Please sign in to comment.