Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Pten] Add selected_rows kernel for Full #39465

Merged
merged 14 commits into from
Feb 19, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cmake/pten.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ function(append_op_util_declare TARGET)
string(REGEX MATCH "(PT_REGISTER_BASE_KERNEL_NAME|PT_REGISTER_ARG_MAPPING_FN)\\([ \t\r\n]*[a-z0-9_]*" util_registrar "${target_content}")
string(REPLACE "PT_REGISTER_ARG_MAPPING_FN" "PT_DECLARE_ARG_MAPPING_FN" util_declare "${util_registrar}")
string(REPLACE "PT_REGISTER_BASE_KERNEL_NAME" "PT_DECLARE_BASE_KERNEL_NAME" util_declare "${util_declare}")
string(APPEND util_declare ");")
string(APPEND util_declare ");\n")
file(APPEND ${op_utils_header} "${util_declare}")
endfunction()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"

USE_OP_ITSELF(elementwise_add);
USE_OP(fill_constant);
USE_OP_ITSELF(fill_constant);

namespace paddle {
namespace distributed {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

#include "paddle/fluid/framework/new_executor/standalone_executor.h"

USE_OP(fill_constant);
USE_OP_ITSELF(fill_constant);
USE_OP(uniform_random);
USE_OP(lookup_table);
USE_OP(transpose2);
Expand Down
10 changes: 0 additions & 10 deletions paddle/fluid/operators/fill_constant_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,6 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);

REGISTER_OP_CPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);

REGISTER_OP_VERSION(fill_constant)
.AddCheckpoint(
R"ROC(
Expand Down
25 changes: 0 additions & 25 deletions paddle/fluid/operators/fill_constant_op.cu.cc

This file was deleted.

2 changes: 1 addition & 1 deletion paddle/fluid/operators/math/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ if (WITH_ASCEND_CL)
else()
math_library(beam_search DEPS math_function)
endif()
math_library(fc DEPS blas)
math_library(fc DEPS blas jit_kernel_helper)
math_library(matrix_bit_code)

math_library(unpooling)
Expand Down
8 changes: 8 additions & 0 deletions paddle/pten/kernels/full_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "paddle/pten/common/scalar.h"
#include "paddle/pten/common/scalar_array.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/selected_rows.h"

#include "paddle/pten/infermeta/nullary.h"
#include "paddle/pten/kernels/empty_kernel.h"
Expand All @@ -30,6 +31,13 @@ void FullKernel(const Context& dev_ctx,
DataType dtype,
DenseTensor* out);

template <typename T, typename Context>
void FullSR(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
SelectedRows* out);

template <typename T, typename Context>
void FullLikeKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
70 changes: 70 additions & 0 deletions paddle/pten/kernels/selected_rows/full_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/kernels/full_kernel.h"

#include "paddle/pten/backends/cpu/cpu_context.h"
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/pten/backends/gpu/gpu_context.h"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里引用gpu_context.h 头文件,是否需要加 #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) 宏判断?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里可以不加宏应该也没有影响,严谨一些这里先加上宏判断,后面再看情况是否保留

#endif
#include "paddle/pten/core/kernel_registry.h"

#include "paddle/pten/common/bfloat16.h"
#include "paddle/pten/common/complex.h"

namespace pten {

template <typename T, typename Context>
void FullSR(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DataType dtype,
SelectedRows* out) {
pten::FullKernel<T>(dev_ctx, shape, val, dtype, out->mutable_value());
}

} // namespace pten

PT_REGISTER_KERNEL(full_sr,
CPU,
ALL_LAYOUT,
pten::FullSR,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
pten::dtype::float16,
pten::dtype::bfloat16,
pten::dtype::complex<float>,
pten::dtype::complex<double>) {}

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PT_REGISTER_KERNEL(full_sr,
GPU,
ALL_LAYOUT,
pten::FullSR,
float,
double,
uint8_t,
int16_t,
int,
int64_t,
bool,
pten::dtype::float16,
pten::dtype::complex<float>,
pten::dtype::complex<double>) {}
#endif
51 changes: 51 additions & 0 deletions paddle/pten/ops/compat/fill_constant_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,57 @@ KernelSignature FillConstantOpArgumentMapping(
}
}
}
} else if (ctx.IsSelectedRowsOutput("Out")) {
if (ctx.HasInput("ShapeTensor")) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature(
"full_sr", {}, {"ShapeTensor", "ValueTensor", "dtype"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
Comment on lines +75 to +76
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个有用到吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有用到,case比较特殊

if (str_value.empty()) {
return KernelSignature(
"full_sr", {}, {"ShapeTensor", "value", "dtype"}, {"Out"});
} else {
return KernelSignature(
"full_sr", {}, {"ShapeTensor", "str_value", "dtype"}, {"Out"});
}
}
} else if (ctx.InputSize("ShapeTensorList") > 0) {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature("full_sr",
{},
{"ShapeTensorList", "ValueTensor", "dtype"},
{"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature(
"full_sr", {}, {"ShapeTensorList", "value", "dtype"}, {"Out"});
} else {
return KernelSignature("full_sr",
{},
{"ShapeTensorList", "str_value", "dtype"},
{"Out"});
}
}
} else {
if (ctx.HasInput("ValueTensor")) {
return KernelSignature(
"full_sr", {}, {"shape", "ValueTensor", "dtype"}, {"Out"});
} else {
const auto& str_value =
paddle::any_cast<std::string>(ctx.Attr("str_value"));
if (str_value.empty()) {
return KernelSignature(
"full_sr", {}, {"shape", "value", "dtype"}, {"Out"});
} else {
return KernelSignature(
"full_sr", {}, {"shape", "str_value", "dtype"}, {"Out"});
}
}
}
}
return KernelSignature("unregistered", {}, {}, {});
}
Expand Down
1 change: 1 addition & 0 deletions paddle/pten/tests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ add_subdirectory(api)
add_subdirectory(common)
add_subdirectory(core)
add_subdirectory(kernels)
add_subdirectory(ops_signature)
1 change: 1 addition & 0 deletions paddle/pten/tests/ops_signature/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
cc_test(test_op_signature SRCS test_op_signature.cc DEPS op_utils)
118 changes: 118 additions & 0 deletions paddle/pten/tests/ops_signature/test_op_signature.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/pten/tests/ops_signature/test_op_signature.h"

#include <gtest/gtest.h>
#include <memory>
#include <unordered_set>

#include "paddle/pten/core/compat/op_utils.h"
#include "paddle/pten/ops/compat/signatures.h"

namespace pten {
namespace tests {

// The unittests in this file are just order to pass the CI-Coverage,
// so it isn't necessary to check the all cases.

TEST(ARG_MAP, fill_constant) {
TestArgumentMappingContext arg_case1(
{"ShapeTensor", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature1 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case1);
ASSERT_EQ(signature1.name, "full_sr");

TestArgumentMappingContext arg_case2(
{"ShapeTensor"},
{},
{{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature2 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case2);
ASSERT_EQ(signature2.name, "full_sr");

TestArgumentMappingContext arg_case3(
{"ShapeTensor"},
{},
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature3 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case3);
ASSERT_EQ(signature3.name, "full_sr");

TestArgumentMappingContext arg_case4(
{"ShapeTensorList", "ValueTensor"}, {}, {}, {}, {"Out"});
auto signature4 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case4);
ASSERT_EQ(signature4.name, "full_sr");

TestArgumentMappingContext arg_case5(
{"ShapeTensorList"},
{},
{{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature5 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case5);
ASSERT_EQ(signature5.name, "full_sr");

TestArgumentMappingContext arg_case6(
{"ShapeTensorList"},
{},
{{"value", paddle::any{0}}, {"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature6 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case6);
ASSERT_EQ(signature6.name, "full_sr");

TestArgumentMappingContext arg_case7(
{"ValueTensor"},
{},
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}}},
{},
{"Out"});
auto signature7 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case7);
ASSERT_EQ(signature7.name, "full_sr");

TestArgumentMappingContext arg_case8(
{},
{},
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}},
{"value", paddle::any{0}},
{"str_value", paddle::any{std::string{""}}}},
{},
{"Out"});
auto signature8 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case8);
ASSERT_EQ(signature8.name, "full_sr");

TestArgumentMappingContext arg_case9(
{},
{},
{{"shape", paddle::any{std::vector<int64_t>{2, 3}}},
{"str_value", paddle::any{std::string{"10"}}}},
{},
{"Out"});
auto signature9 =
OpUtilsMap::Instance().GetArgumentMappingFn("fill_constant")(arg_case9);
ASSERT_EQ(signature9.name, "full_sr");
}

} // namespace tests
} // namespace pten
Loading