Skip to content

Commit

Permalink
enhance the op_version_registry, test=develop (#28347)
Browse files Browse the repository at this point in the history
* enhance the op_version_registry, test=develop

* add unittests, test=develop

* enhance the op_version_registry, test=develop

* fix bugs, test=develop

* revert pybind_boost_headers.h, test=develop

* fix a attribute bug, test=develop
  • Loading branch information
Shixiaowei02 authored Nov 4, 2020
1 parent c1c3e21 commit 21a63f6
Show file tree
Hide file tree
Showing 11 changed files with 518 additions and 122 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/framework/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ function(pass_library TARGET DEST)

cmake_parse_arguments(pass_library "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN})
if(pass_library_DIR)
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
cc_library(${TARGET} SRCS ${pass_library_DIR}/${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
else()
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base ${pass_library_DEPS})
cc_library(${TARGET} SRCS ${TARGET}.cc DEPS graph_pattern_detector pass fuse_pass_base op_version_registry ${pass_library_DEPS})
endif()

# add more DEST here, such as train, dist and collect USE_PASS into a file automatically.
Expand Down
72 changes: 72 additions & 0 deletions paddle/fluid/framework/op_version_registry.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,75 @@ See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/framework/op_version_registry.h"

namespace paddle {
namespace framework {
namespace compatible {

namespace {
template <OpUpdateType type__, typename InfoType>
OpUpdate<InfoType, type__>* new_update(InfoType&& info) {
return new OpUpdate<InfoType, type__>(info);
}
}

OpVersionDesc&& OpVersionDesc::ModifyAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kModifyAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}

OpVersionDesc&& OpVersionDesc::NewAttr(const std::string& name,
const std::string& remark,
const OpAttrVariantT& default_value) {
infos_.emplace_back(new_update<OpUpdateType::kNewAttr>(
OpAttrInfo(name, remark, default_value)));
return std::move(*this);
}

OpVersionDesc&& OpVersionDesc::NewInput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kNewInput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}

OpVersionDesc&& OpVersionDesc::NewOutput(const std::string& name,
const std::string& remark) {
infos_.emplace_back(
new_update<OpUpdateType::kNewOutput>(OpInputOutputInfo(name, remark)));
return std::move(*this);
}

OpVersionDesc&& OpVersionDesc::BugfixWithBehaviorChanged(
const std::string& remark) {
infos_.emplace_back(new_update<OpUpdateType::kBugfixWithBehaviorChanged>(
OpBugfixInfo(remark)));
return std::move(*this);
}

OpVersion& OpVersionRegistrar::Register(const std::string& op_type) {
PADDLE_ENFORCE_EQ(
op_version_map_.find(op_type), op_version_map_.end(),
platform::errors::AlreadyExists(
"'%s' is registered in operator version more than once.", op_type));
op_version_map_.insert(
std::pair<std::string, OpVersion>{op_type, OpVersion()});
return op_version_map_[op_type];
}
uint32_t OpVersionRegistrar::version_id(const std::string& op_type) const {
PADDLE_ENFORCE_NE(
op_version_map_.count(op_type), 0,
platform::errors::InvalidArgument(
"The version of operator type %s has not been registered.", op_type));
return op_version_map_.find(op_type)->second.version_id();
}

// Provide a fake registration item for pybind testing.
#include "paddle/fluid/framework/op_version_registry.inl"

} // namespace compatible
} // namespace framework
} // namespace paddle
Loading

0 comments on commit 21a63f6

Please sign in to comment.