Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Support Multi Input and Output for InferShape #39870

Merged
merged 14 commits into from
Mar 1, 2022
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 44 additions & 25 deletions paddle/fluid/framework/infershape_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -308,22 +308,25 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
// TODO(chenweihang): support multiple inputs and outputs later
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这行注释我后面移除

phi::InferMetaContext infer_mete_context;
for (auto& in_name : input_names) {
if (ctx->HasInput(in_name)) {
infer_meta_context.EmplaceBackInput(std::make_shared<CompatMetaTensor>(
ctx->GetInputVarPtrs(in_name)[0], ctx->IsRuntime()));
if (ctx->HasInputs(in_name)) {
auto input_var = ctx->GetInputVarPtrs(in_name);
if (input_var.size() == 1) {
infer_meta_context.EmplaceBackInput(
std::make_shared<CompatMetaTensor>(input_var[0], ctx->IsRuntime()));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> inputs;
inputs.reserve(input_var.size());
for (const auto& in : input_var) {
inputs.push_back(
std::make_shared<CompatMetaTensor>(in, ctx->IsRuntime()));
}
infer_meta_context.EmplaceBackInputs(std::move(inputs));
}
} else {
infer_meta_context.EmplaceBackInput({nullptr});
}
}

for (auto& out_name : output_names) {
if (ctx->HasOutput(out_name)) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
ctx->GetOutputVarPtrs(out_name)[0], ctx->IsRuntime()));
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}
auto attr_reader = ctx->Attrs();
for (size_t i = 0; i < attr_names.size(); ++i) {
auto attr_name = attr_names[i];
Expand All @@ -348,30 +351,21 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
} else {
// If is not in runtime, we will set default value(-1) for ScalarArray
int64_t num_ele = 0;
std::vector<VarDesc*> vars;
vars.reserve(infershape_inputs.size());
for (size_t i = 0; i < infershape_inputs.size(); i++) {
for (size_t i = 0; i < infershape_inputs.size(); ++i) {
vars.push_back(BOOST_GET_CONST(VarDesc*, infershape_inputs[i]));
}

int64_t num_ele = 0;
if (vars.size() == 1) {
num_ele = 1;
const auto& tensor_dims = vars[0]->GetShape();
for (size_t i = 0; i < tensor_dims.size(); ++i) {
num_ele *= tensor_dims[i];
}
} else {
for (auto& var : vars) {
const auto& tensor_dims = var->GetShape();
PADDLE_ENFORCE_EQ(tensor_dims.size(), 1,
platform::errors::InvalidArgument(
"The shape is constructed by multi-tensor, "
"every tensor's dims should be 1. But your "
"shape has tensor that dims is %s.",
tensor_dims.size()));
num_ele += tensor_dims[0];
}
num_ele = vars.size();
}
phi::ScalarArray tensor_attr(std::vector<int32_t>(num_ele, -1));
tensor_attr.SetFromTensor(true);
Expand All @@ -383,10 +377,14 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
std::type_index(typeid(std::vector<int32_t>))) {
infer_meta_context.EmplaceBackAttr(std::move(
phi::ScalarArray(BOOST_GET_CONST(std::vector<int32_t>, attr))));
} else if (std::type_index(attr.type()) ==
std::type_index(typeid(int))) {
infer_meta_context.EmplaceBackAttr(
phi::ScalarArray({BOOST_GET_CONST(int, attr)}));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported cast op attribute `%s` to ScalarArray when "
"construct KernelContext.",
"construct InferMetaContext.",
attr_name));
}
}
Expand Down Expand Up @@ -414,7 +412,6 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
}
} else if (ctx->HasInput(attr_name)) {
const auto& infershape_input = ctx->GetInputVarPtrs(attr_name);

if (infershape_input.size() == 1) {
if (ctx->IsRuntime()) {
Variable* var = BOOST_GET_CONST(Variable*, infershape_input[0]);
Expand Down Expand Up @@ -490,6 +487,28 @@ phi::InferMetaContext BuildInferMetaContext(InferShapeContext* ctx,
"Unsupported attribute type is received when call "
"InferShapeFunctor."));
}
} else {
// do nothing
}
}

for (auto& out_name : output_names) {
if (ctx->HasOutputs(out_name)) {
auto output_var = ctx->GetOutputVarPtrs(out_name);
if (output_var.size() == 1) {
infer_meta_context.EmplaceBackOutput(std::make_shared<CompatMetaTensor>(
output_var[0], ctx->IsRuntime()));
} else {
paddle::SmallVector<std::shared_ptr<phi::MetaTensor>> outputs;
outputs.reserve(output_var.size());
for (const auto& out : output_var) {
outputs.emplace_back(
std::make_shared<CompatMetaTensor>(out, ctx->IsRuntime()));
}
infer_meta_context.EmplaceBackOutputs(std::move(outputs));
}
} else {
infer_meta_context.EmplaceBackOutput({nullptr});
}
}

Expand Down
44 changes: 8 additions & 36 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@ limitations under the License. */
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/infershape_utils.h"

#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/kernels/funcs/concat_funcs.h"

#ifdef PADDLE_WITH_MKLDNN
Expand All @@ -33,41 +35,6 @@ class ConcatOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
OP_INOUT_CHECK(ctx->HasInputs("X"), "Input", "X", "Concat");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "Concat");

auto inputs_dims = ctx->GetInputsDim("X");

const size_t inputs_num = inputs_dims.size();
PADDLE_ENFORCE_GT(
inputs_num, static_cast<size_t>(0),
platform::errors::InvalidArgument(
"The number of input tensors in concat op should > 0. But "
"received inputs' length is 0."));
if (inputs_num == 1) {
VLOG(3) << "Warning: concat op have only one input, may waste memory";
}

if (ctx->HasInput("AxisTensor")) {
auto out_dims =
phi::make_ddim(std::vector<int>(inputs_dims[0].size(), -1));
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
} else {
size_t axis =
ComputeAxis(static_cast<int64_t>(ctx->Attrs().Get<int>("axis")),
static_cast<int64_t>(inputs_dims[0].size()));
framework::DDim out_dims =
phi::funcs::ComputeAndCheckShape(ctx->IsRuntime(), inputs_dims, axis);
if (out_dims[axis] < 0) {
out_dims[axis] = -1;
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand Down Expand Up @@ -237,9 +204,14 @@ class ConcatDoubleGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace paddle

namespace ops = paddle::operators;

DELCARE_INFER_SHAPE_FUNCTOR(concat, ConcatInferShapeFunctor,
PT_INFER_META(phi::ConcatInferMeta));

REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
ops::ConcatGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatGradOpMaker<paddle::imperative::OpBase>);
ops::ConcatGradOpMaker<paddle::imperative::OpBase>,
ConcatInferShapeFunctor);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatDoubleGradOpMaker<paddle::framework::OpDesc>,
ops::ConcatDoubleGradOpMaker<paddle::imperative::OpBase>,
Expand Down
55 changes: 8 additions & 47 deletions paddle/fluid/operators/split_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/split_op.h"
#include <string>

#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/phi/infermeta/unary.h"

namespace paddle {
namespace operators {
using framework::Tensor;
Expand All @@ -23,52 +26,6 @@ class SplitOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true,
platform::errors::InvalidArgument(
"Input(X) of SplitOp should not be null."));
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
platform::errors::InvalidArgument(
"Outputs(Out) of SplitOp should not be empty."));
auto in_dims = ctx->GetInputDim("X");
auto outs_names = ctx->Outputs("Out");
size_t axis = static_cast<size_t>(ctx->Attrs().Get<int>("axis"));
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
std::vector<int> sections = static_cast<std::vector<int>>(
ctx->Attrs().Get<std::vector<int>>("sections"));
const size_t outs_number = outs_names.size();

if (sections.size() > 0) {
PADDLE_ENFORCE_EQ(
sections.size(), outs_number,
platform::errors::InvalidArgument("tensor split sections size "
"should be equal to output size."));
}

if (ctx->HasInput("AxisTensor")) {
auto out_dims = phi::make_ddim(std::vector<int>(in_dims.size(), -1));
std::vector<framework::DDim> outs_dims(outs_number, out_dims);
ctx->SetOutputsDim("Out", outs_dims);
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
return;
}

bool each_section_is_known =
(sections.size() > 0 && !ctx->HasInputs("SectionsTensorList"));

auto outs_dims = UpdateOutsDims(ctx->IsRuntime(), each_section_is_known,
in_dims, num, sections, axis, outs_number);
ctx->SetOutputsDim("Out", outs_dims);
if (axis != 0) {
// Only pass LoD when not spliting along the first dim.
for (size_t i = 0; i < outs_number; ++i) {
ctx->ShareLoD("X", "Out", 0, i);
}
}
}

protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
Expand Down Expand Up @@ -168,6 +125,10 @@ This operator splits the input tensor into multiple sub-tensors.

namespace ops = paddle::operators;

DELCARE_INFER_SHAPE_FUNCTOR(split, SplitInferShapeFunctor,
PT_INFER_META(phi::SplitInferMeta));

REGISTER_OPERATOR(split, ops::SplitOp, ops::SplitOpMaker,
ops::SplitGradMaker<paddle::framework::OpDesc>,
ops::SplitGradMaker<paddle::imperative::OpBase>);
ops::SplitGradMaker<paddle::imperative::OpBase>,
SplitInferShapeFunctor);
2 changes: 1 addition & 1 deletion paddle/infrt/host_context/value.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ using ValueVariantType =
std::vector<phi::DenseTensor>,
paddle::experimental::ScalarBase<phi::DenseTensor>,
paddle::experimental::ScalarArrayBase<phi::DenseTensor>,
std::vector<phi::MetaTensor>,
std::vector<phi::MetaTensor*>,
phi::MetaConfig,
paddle::experimental::Backend,
paddle::experimental::DataLayout,
Expand Down
6 changes: 5 additions & 1 deletion paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,12 +100,16 @@ std::vector<Tensor> split_impl(const Tensor& x,
std::vector<Tensor> out;
auto dense_outs = SetKernelOutput(out_number, kernel_backend, &out);
std::vector<phi::MetaTensor> meta_outs;
meta_outs.reserve(out_number);
std::vector<phi::MetaTensor*> meta_out_ptrs;
meta_out_ptrs.reserve(out_number);
for (size_t i = 0; i < out_number; ++i) {
meta_outs.push_back(dense_outs[i]);
meta_out_ptrs.push_back(&meta_outs.back());
}

phi::SplitInferMeta(
MakeMetaTensor(*dense_x), num_or_sections, axis, &meta_outs);
MakeMetaTensor(*dense_x), num_or_sections, axis, meta_out_ptrs);

using kernel_signature = void (*)(const platform::DeviceContext&,
const phi::DenseTensor&,
Expand Down
16 changes: 8 additions & 8 deletions paddle/phi/core/infermeta_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,13 +75,13 @@ paddle::optional<const phi::MetaTensor&> InferMetaContext::OptionalInputAt(
: paddle::optional<const phi::MetaTensor&>{paddle::none};
}

std::vector<MetaTensor> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor> result;
std::vector<MetaTensor*> InferMetaContext::InputsBetween(size_t start,
size_t end) const {
std::vector<MetaTensor*> result;
result.reserve(end - start);

for (size_t i = start; i < end; ++i) {
result.emplace_back(*inputs_.at(i));
result.push_back(inputs_.at(i).get());
}

return result;
Expand All @@ -91,12 +91,12 @@ MetaTensor* InferMetaContext::MutableOutputAt(size_t idx) {
return outputs_.at(idx).get();
}

std::vector<MetaTensor> InferMetaContext::MutableOutputBetween(size_t start,
size_t end) {
std::vector<MetaTensor> result;
std::vector<MetaTensor*> InferMetaContext::MutableOutputBetween(size_t start,
size_t end) {
std::vector<MetaTensor*> result;
result.reserve(end - start);
for (size_t i = start; i < end; ++i) {
result.emplace_back(*outputs_.at(i));
result.emplace_back(outputs_.at(i).get());
}
return result;
}
Expand Down
15 changes: 7 additions & 8 deletions paddle/phi/core/infermeta_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,13 @@ class InferMetaContext {
const std::pair<int, int>& OutputRangeAt(size_t idx) const;

const MetaConfig& GetMetaConfig() const;
const MetaTensor& InputAt(size_t idx) const;

const MetaTensor& InputAt(size_t idx) const;
paddle::optional<const phi::MetaTensor&> OptionalInputAt(size_t idx) const;
std::vector<MetaTensor*> InputsBetween(size_t start, size_t end) const;

std::vector<MetaTensor> InputsBetween(size_t start, size_t end) const;
MetaTensor* MutableOutputAt(size_t idx);
std::vector<MetaTensor> MutableOutputBetween(size_t start, size_t end);
std::vector<MetaTensor*> MutableOutputBetween(size_t start, size_t end);

template <typename AttrType>
AttrType AttrAt(size_t idx) {
Expand Down Expand Up @@ -157,15 +157,15 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};

template <typename... Tail>
struct InferMetaFnCallHelper<const std::vector<MetaTensor>&, Tail...> {
struct InferMetaFnCallHelper<const std::vector<MetaTensor*>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
static_assert(attr_idx == 0,
"InferMeta's Input should appear before Attributes.");
static_assert(out_idx == 0,
"InferMeta's Input should appear before Outputs.");
const std::pair<int, int> range = ctx->InputRangeAt(in_idx);
std::vector<MetaTensor> arg =
std::vector<MetaTensor*> arg =
ctx->InputsBetween(range.first, range.second);
InferMetaFnCallHelper<
Tail...>::template Call<in_idx + 1, attr_idx, out_idx>(ctx,
Expand Down Expand Up @@ -210,13 +210,12 @@ struct InferMetaFnImpl<Return (*)(Args...), infer_meta_fn> {
};

template <typename... Tail>
struct InferMetaFnCallHelper<std::vector<MetaTensor>*, Tail...> {
struct InferMetaFnCallHelper<std::vector<MetaTensor*>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Call(InferMetaContext* ctx, PreviousArgs&... pargs) {
const std::pair<int, int> range = ctx->OutputRangeAt(out_idx);
std::vector<MetaTensor> tmp =
std::vector<MetaTensor*> arg =
ctx->MutableOutputBetween(range.first, range.second);
std::vector<MetaTensor>* arg = &tmp;
InferMetaFnCallHelper<
Tail...>::template Call<in_idx, attr_idx, out_idx + 1>(ctx,
pargs...,
Expand Down
Loading