Skip to content

Commit

Permalink
fix bugs, test=develop
Browse files Browse the repository at this point in the history
  • Loading branch information
Shixiaowei02 committed Nov 2, 2020
1 parent 134c11d commit 159cd27
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 19 deletions.
36 changes: 21 additions & 15 deletions paddle/fluid/framework/op_version_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ class OpVersionRegistrar {
const std::unordered_map<std::string, OpVersion>& GetVersionMap() {
return op_version_map_;
}
bool Has(const std::string& op_type) const {
return op_version_map_.count(op_type);
}
uint32_t version_id(const std::string& op_type) const;

private:
Expand All @@ -202,21 +205,24 @@ class OpVersionComparator {
virtual ~OpVersionComparator() = default;
};

#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \
class OpVersion##cmp_name##Comparator : public OpVersionComparator { \
public: \
explicit OpVersion##cmp_name##Comparator(const std::string op_name, \
uint32_t target_version) \
: op_name_(op_name), target_version_(target_version) {} \
virtual bool operator()() { \
return OpVersionRegistrar::GetInstance().version_id(op_name_) \
cmp_math target_version_; \
} \
virtual ~OpVersion##cmp_name##Comparator() {} \
\
private: \
std::string op_name_; \
uint32_t target_version_; \
#define ADD_OP_VERSION_COMPARATOR(cmp_name, cmp_math) \
class OpVersion##cmp_name##Comparator : public OpVersionComparator { \
public: \
explicit OpVersion##cmp_name##Comparator(const std::string op_name, \
uint32_t target_version) \
: op_name_(op_name), target_version_(target_version) {} \
virtual bool operator()() { \
uint32_t version_id = 0; \
if (OpVersionRegistrar::GetInstance().Has(op_name_)) { \
version_id = OpVersionRegistrar::GetInstance().version_id(op_name_); \
} \
return version_id cmp_math target_version_; \
} \
virtual ~OpVersion##cmp_name##Comparator() {} \
\
private: \
std::string op_name_; \
uint32_t target_version_; \
};

ADD_OP_VERSION_COMPARATOR(LE, <=);
Expand Down
9 changes: 5 additions & 4 deletions paddle/fluid/framework/op_version_registry_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ TEST(test_operator_version, test_operator_version) {
}

TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
const std::string fake_op_name{"op_name__"};
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"no_bind_pass"));

Expand Down Expand Up @@ -90,31 +91,31 @@ TEST(test_pass_op_version_checker, test_pass_op_version_checker) {
REGISTER_PASS_CAPABILITY(test_pass4)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 5)
.GE(fake_op_name, 5)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass4"));

REGISTER_PASS_CAPABILITY(test_pass5)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.GE("test__", 4)
.GE(fake_op_name, 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass5"));

REGISTER_PASS_CAPABILITY(test_pass6)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("test__", 4)
.EQ(fake_op_name, 4)
.EQ("fc", 0));
ASSERT_TRUE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass6"));

REGISTER_PASS_CAPABILITY(test_pass7)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.NE("test__", 4)
.NE(fake_op_name, 4)
.EQ("fc", 0));
ASSERT_FALSE(PassVersionCheckerRegistrar::GetInstance().IsPassCompatible(
"test_pass7"));
Expand Down

0 comments on commit 159cd27

Please sign in to comment.