Skip to content

Commit

Permalink
[PTen] Add standard kernel suffix set (#39404)
Browse files Browse the repository at this point in the history
* add standard_suffix_set_and_remove_reshape_with_xshape

* revert reshape change

* polish reduce name
  • Loading branch information
chenwhql authored Feb 10, 2022
1 parent 63d2333 commit c7c1db3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
5 changes: 5 additions & 0 deletions paddle/pten/core/compat/op_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ limitations under the License. */

namespace pten {

const std::unordered_set<std::string> standard_kernel_suffixs({
"sr", // SelectedRows kernel
"raw" // fallback kernel of origfinal fluid op
});

/**
* Some fluid ops are no longer used under the corresponding official API
* system of 2.0. These names need to correspond to the official API names
Expand Down
8 changes: 4 additions & 4 deletions paddle/pten/ops/compat/elementwise_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,10 +66,10 @@ KernelSignature ElementwiseDivOpArgumentMapping(

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide_raw);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add, add);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub, subtract);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_mul, multiply);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_div, divide);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_add_grad, add_grad);
PT_REGISTER_BASE_KERNEL_NAME(elementwise_sub_grad, subtract_grad);

Expand Down
4 changes: 2 additions & 2 deletions paddle/pten/ops/compat/reduce_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ KernelSignature ReduceMeanOpArgumentMapping(const ArgumentMappingContext& ctx) {

} // namespace pten

PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum_raw);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean_raw);
PT_REGISTER_BASE_KERNEL_NAME(reduce_sum, sum);
PT_REGISTER_BASE_KERNEL_NAME(reduce_mean, mean);

PT_REGISTER_ARG_MAPPING_FN(reduce_sum, pten::ReduceSumOpArgumentMapping);
PT_REGISTER_ARG_MAPPING_FN(reduce_mean, pten::ReduceMeanOpArgumentMapping);

0 comments on commit c7c1db3

Please sign in to comment.