From e76178293661e294b29a832a07db95331ffecba6 Mon Sep 17 00:00:00 2001 From: co63oc Date: Tue, 7 May 2024 19:19:50 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90Hackathon=206th=20Fundable=20Projects?= =?UTF-8?q?=203=20No.82=E3=80=91fluid=20operator=20cudnn=5Flstm=20(#63936)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix * Fix * Fix * Fix --- paddle/fluid/eager/utils.cc | 12 + paddle/fluid/eager/utils.h | 2 + paddle/fluid/operators/cudnn_lstm_op.cc | 285 ------------------ .../operators/ops_signature/cudnn_lstm_sig.cc | 59 ---- .../op_generator/op_infermeta_func_gen.py | 5 +- paddle/phi/api/yaml/backward.yaml | 12 + paddle/phi/api/yaml/op_compat.yaml | 8 + paddle/phi/api/yaml/ops.yaml | 12 + paddle/phi/infermeta/backward.cc | 23 ++ paddle/phi/infermeta/backward.h | 10 + 10 files changed, 83 insertions(+), 345 deletions(-) delete mode 100644 paddle/fluid/operators/cudnn_lstm_op.cc delete mode 100644 paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc diff --git a/paddle/fluid/eager/utils.cc b/paddle/fluid/eager/utils.cc index 1659430d6216fc..4fa6480372739f 100644 --- a/paddle/fluid/eager/utils.cc +++ b/paddle/fluid/eager/utils.cc @@ -118,6 +118,18 @@ std::vector EagerUtils::nullable_autograd_meta( return metas; } +std::vector EagerUtils::nullable_autograd_meta( + const paddle::optional>& targets) { + std::vector metas; + if (targets.get_ptr() != nullptr) { + metas.reserve(targets.get_ptr()->size()); + for (const paddle::Tensor& t : (*(targets.get_ptr()))) { + metas.emplace_back(nullable_autograd_meta(t)); + } + } + return metas; +} + std::vector EagerUtils::nullable_autograd_meta( const std::vector& targets) { std::vector metas; diff --git a/paddle/fluid/eager/utils.h b/paddle/fluid/eager/utils.h index 147f7377508a7b..aa9c972d7fa200 100644 --- a/paddle/fluid/eager/utils.h +++ b/paddle/fluid/eager/utils.h @@ -148,6 +148,8 @@ class TEST_API EagerUtils { const paddle::optional& target); static std::vector nullable_autograd_meta( const std::vector& targets); + static std::vector nullable_autograd_meta( + const paddle::optional>& targets); static std::vector nullable_autograd_meta( const std::vector& targets); static AutogradMeta* unsafe_autograd_meta(const paddle::Tensor& target); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc deleted file mode 100644 index a082dbbcb8bcb5..00000000000000 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ /dev/null @@ -1,285 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include -#include - -#include "paddle/fluid/framework/infershape_utils.h" -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/op_version_registry.h" -#include "paddle/phi/core/infermeta_utils.h" - -#include "paddle/phi/infermeta/multiary.h" - -namespace paddle { -namespace operators { - -class CudnnLSTMOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - protected: - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType(ctx, "Input"), - ctx.device_context().GetPlace()); - } -}; - -class CudnnLSTMOpMaker : public framework::OpProtoAndCheckerMaker { - public: - void Make() override { - AddInput( - "Input", - "(Tensor) RNN input tensor, which support variable-time length input " - "sequence." - "The shape of the Tensor MUST be ( seq_len * batch_size * input_size)" - "seq_len is the total time step in this mini-batch (CAN be change in " - "different batch)" - "batch_size is the instance number of this batch" - "input_size is the hidden size of the input." - "input_size and the hidden_size in the next may not be same"); - AddInput("InitH", - "(Tensor) the initial hidden state of the LSTM" - "input. This is a tensor with shape (num_layers x batch_size x " - "hidden_size)" - "and When is_bidirec is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size)"); - AddInput("InitC", - "(Tensor) the initial cell state of the LSTm " - "input. This is a tensor with shape (num_layers x batch_size x " - "hidden_size)" - "and When is_bidirec is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size)"); - AddInput("W", - "(Tensor) the learnable hidden-hidden weights." - " The shape is (N), where N is total weight size of the LSTM. " - " cudnn concatenate all the weight to one Tensor") - .AsDispensable(); - AddInput("WeightList", - "(vector), stores weight and bias data when the weight " - "use the list format. ") - .AsDispensable() - .AsDuplicable(); - AddInput("SequenceLength", - "(Tensor) When the input data is padding, " - "set this parameter. This parameter represents " - "the variable sequence lengths in a batch. " - "The size of the vector has to equal the batch_size.") - .AsDispensable(); - AddOutput("Reserve", - "(Tensor, a temporary output Tensor to store the reserve_data " - "of cudnn kernel.") - .AsIntermediate(); - AddOutput("StateOut", - "Share memory with State. " - "Store the global drop state when training"); - AddOutput("Out", - "(Tensor) the hidden state of LSTM operator. " - "The shape is ( seq_len x batch_size x hidden_size) if " - "is_bidirec is False" - "and When is_bidirec is True, the shape will be ( seq_len x " - "batch_size x hidden_size * 2) "); - AddOutput("LastH", - "(Tensor) the hidden state of the last step. " - "The shape is ( num_layers x batch_size x hidden_size) if " - "is_bidirec is False" - "and When is_bidirec is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size)"); - AddOutput("LastC", - "(Tensor) the cell state of the last step" - "The shape is ( num_layers x batch_size x hidden_size) if " - "is_bidirec is False" - "and When is_bidirect is True, the shape will be (num_layers*2 x " - "batch_size x hidden_size*2)"); - AddAttr( - "dropout_prob", - "dropout prob of the dropout op" - "the dropout ONLY work between lstm layers, not between time steps" - "There is no dropout work on the Out tensor") - .SetDefault(0.0); - AddAttr("is_bidirec", - "is_bidirec" - "if it is bidirectional rnn" - "The will affect the shape of the Out, LastH, and LastC") - .SetDefault(false); - AddAttr("input_size", "input size ot the Input Tensor").SetDefault(10); - AddAttr("hidden_size", "hidden size of the LSTM").SetDefault(100); - AddAttr("num_layers", "the total layer number of the LSTM") - .SetDefault(1); - AddAttr("is_test", "True if in test phase.").SetDefault(false); - AddAttr("seed", "seed to used if fix_seed is True").SetDefault(0); - AddComment(R"DOC( -CUDNN LSTM implementation - -A four-gate Long Short-Term Memory network with no peephole connections. -In the forward pass the output ht and cell output ct for a given iteration can be computed from the recurrent input ht-1, -the cell input ct-1 and the previous layer input xt given matrices W, R and biases bW, bR from the following equations: - -$$ i_t = sigmoid(W_{ix}x_{t} + W_{ih}h_{t-1} + bx_i + bh_i) $$ - -$$ f_t = sigmoid(W_{fx}x_{t} + W_{fh}h_{t-1} + bx_f + bh_f) $$ - -$$ o_t = sigmoid(W_{ox}x_{t} + W_{oh}h_{t-1} + bx_o + bh_o) $$ - -$$ \\tilde{c_t} = tanh(W_{cx}x_t + W_{ch}h_{t-1} + bx_c + bh_c) $$ - -$$ c_t = f_t \\odot c_{t-1} + i_t \\odot \\tilde{c_t} $$ - -$$ h_t = o_t \\odot tanh(c_t) $$ - -- W terms denote weight matrices (e.g. $W_{ix}$ is the matrix - of weights from the input gate to the input) -- The b terms denote bias vectors ($bx_i$ and $bh_i$ are the input gate bias vector). -- sigmoid is the logistic sigmoid function. -- $i, f, o$ and $c$ are the input gate, forget gate, output gate, - and cell activation vectors, respectively, all of which have the same size as - the cell output activation vector $h$. -- The $\odot$ is the element-wise product of the vectors. -- `tanh` is the activation functions. -- $\tilde{c_t}$ is also called candidate hidden state, - which is computed based on the current input and the previous hidden state. - -Where sigmoid is the sigmoid operator: sigmoid(x) = 1 / (1 + e^-x), * represents a point-wise multiplication, -X represents a matrix multiplication - - -)DOC"); - } -}; - -class CudnnLSTMGradOp : public framework::OperatorWithKernel { - public: - using framework::OperatorWithKernel::OperatorWithKernel; - - void InferShape(framework::InferShapeContext* ctx) const override { - OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTMGrad"); - OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTMGrad"); - OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTMGrad"); - - auto SetOutGradDim = [&ctx](const std::string& name) { - auto g_name = framework::GradVarName(name); - if (ctx->HasOutput(g_name)) { - ctx->SetOutputDim(g_name, ctx->GetInputDim(name)); - } - }; - - SetOutGradDim("Input"); - if (ctx->HasInputs("WeightList")) { - ctx->SetOutputsDim(framework::GradVarName("WeightList"), - ctx->GetInputsDim("WeightList")); - } - SetOutGradDim("InitH"); - SetOutGradDim("InitC"); - } - phi::KernelKey GetExpectedKernelType( - const framework::ExecutionContext& ctx) const override { - return phi::KernelKey(OperatorWithKernel::IndicateVarDataType( - ctx, framework::GradVarName("Out")), - ctx.device_context().GetPlace()); - } -}; - -template -class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker { - public: - using framework::SingleGradOpMaker::SingleGradOpMaker; - - protected: - void Apply(GradOpPtr op) const override { - op->SetType("cudnn_lstm_grad"); - op->SetInput("Input", this->Input("Input")); - op->SetInput("InitH", this->Input("InitH")); - op->SetInput("InitC", this->Input("InitC")); - if (this->HasInput("WeightList")) { - op->SetInput("WeightList", this->Input("WeightList")); - } - if (this->HasInput("SequenceLength")) { - op->SetInput("SequenceLength", this->Input("SequenceLength")); - } - op->SetInput("Reserve", this->Output("Reserve")); - op->SetInput("StateOut", this->Output("StateOut")); - op->SetInput("Out", this->Output("Out")); - op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); - op->SetInput(framework::GradVarName("LastC"), this->OutputGrad("LastC")); - op->SetInput(framework::GradVarName("LastH"), this->OutputGrad("LastH")); - - if (this->HasInput("WeightList")) { - op->SetOutput(framework::GradVarName("WeightList"), - this->InputGrad("WeightList", false)); - } - - op->SetOutput(framework::GradVarName("Input"), this->InputGrad("Input")); - op->SetOutput(framework::GradVarName("InitH"), this->InputGrad("InitH")); - op->SetOutput(framework::GradVarName("InitC"), this->InputGrad("InitC")); - op->SetAttrMap(this->Attrs()); - } -}; - -} // namespace operators -} // namespace paddle - -DECLARE_INFER_SHAPE_FUNCTOR(cudnn_lstm, - CudnnLSTMInferShapeFunctor, - PD_INFER_META(phi::CudnnLSTMInferMeta)); - -namespace ops = paddle::operators; -REGISTER_OPERATOR(cudnn_lstm, - ops::CudnnLSTMOp, - ops::CudnnLSTMOpMaker, - ops::CudnnLSTMGradOpMaker, - ops::CudnnLSTMGradOpMaker, - CudnnLSTMInferShapeFunctor); - -REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); - -// TODO(Shixiaowei02) Add ModifyInput support -REGISTER_OP_VERSION(cudnn_lstm) - .AddCheckpoint( - R"ROC( - Upgrade cudnn_lstm add new inputs [WeightList, SequenceLength], modify the input [W] to dispensable, delete the input [Cache]. - Upgrade cudnn_lstm add new outputs [StateOut, Reserve, LastC, LastH], delete output [last_c, last_h]. - Upgrade cudnn_lstm modify the attr [seed] default value to 0, delete the attr [max_len].)ROC", - paddle::framework::compatible::OpVersionDesc() - .NewInput( - "WeightList", - "The WeightList stores weight and bias data. WeightList is " - "dispensable.") - .NewInput("SequenceLength", - "When the input data is padding, set this parameter. " - "SequenceLength is dispensable.") - .ModifyInput("W", - "The new LSTM use WeightList instead of W. The W " - "concatenate all the weight to one Tensor.") - .DeleteInput("Cache", - "The new LSTM use the Reserve Output to store the " - "data of dropout.") - .NewOutput("StateOut", "Store the global drop state when training") - .NewOutput("Reserve", - "A temporary output Tensor to store the reserve_data") - .DeleteOutput( - "last_c", - "Modify the name of the output from 'last_c' to 'LastC'.") - .NewOutput("LastC", "The cell state of the last step.") - .DeleteOutput( - "last_h", - "Modify the name of the output from 'last_h' to 'LastH'.") - .NewOutput("LastH", "The hidden state of the last step.") - .ModifyAttr("seed", - "Set the default value of seed from '-1' to '0'.", - 0) - .DeleteAttr("max_len", - "The length of Inputs is achieved form the input data " - "which is difficult to know the information in " - "advance.")); diff --git a/paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc b/paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc deleted file mode 100644 index 83e61b396ee537..00000000000000 --- a/paddle/fluid/operators/ops_signature/cudnn_lstm_sig.cc +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#include "paddle/phi/core/compat/op_utils.h" - -namespace phi { - -KernelSignature CudnnLSTMOpArgumentMapping(const ArgumentMappingContext& ctx) { - return KernelSignature( - "cudnn_lstm", - {"Input", "InitH", "InitC", "W", "WeightList", "SequenceLength"}, - {"dropout_prob", - "is_bidirec", - "hidden_size", - "num_layers", - "is_test", - "seed"}, - {"Out", "LastH", "LastC", "Reserve", "StateOut"}); -} - -KernelSignature CudnnLSTMGradOpArgumentMapping( - const ArgumentMappingContext& ctx) { - return KernelSignature( - "cudnn_lstm_grad", - {"Input", - "InitH", - "InitC", - "WeightList", - "SequenceLength", - "Out", - "Reserve", - "StateOut", - "Out@GRAD", - "LastH@GRAD", - "LastC@GRAD"}, - {"dropout_prob", - "is_bidirec", - "hidden_size", - "num_layers", - "is_test", - "seed"}, - {"Input@GRAD", "InitH@GRAD", "InitC@GRAD", "WeightList@GRAD"}); -} - -} // namespace phi - -PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm, phi::CudnnLSTMOpArgumentMapping); -PD_REGISTER_ARG_MAPPING_FN(cudnn_lstm_grad, - phi::CudnnLSTMGradOpArgumentMapping); diff --git a/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py b/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py index 31078476b23e23..caa5a4387f63ea 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_infermeta_func_gen.py @@ -281,9 +281,12 @@ def GenBuildOutputsPart2( }} """ + # In cudnn_lstm operator, the output weight_list_grad requires the use of optional input weight_list, + # so "pir::VectorType {name}" outside the "if" block. CREATE_OPTIONAL_INPUT_VEC_METATENSOR_TEMPLATE = """ std::vector vec_ir_tensor_{name}; + pir::VectorType {name}; if ({name}_.impl() != nullptr) {{ - pir::VectorType {name} = {name}_.type().dyn_cast(); + {name} = {name}_.type().dyn_cast(); for (size_t i=0; i < static_cast({name}.size()); i++) {{ if({name}[i].isa()) {{ auto {name}_type = {name}[i].dyn_cast(); diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index e94aff346e0a86..38a243d6eefbe2 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -519,6 +519,18 @@ func : cross_grad data_type : out_grad +- backward_op : cudnn_lstm_grad + forward: cudnn_lstm (Tensor x, Tensor init_h, Tensor init_c, Tensor w, Tensor[] weight_list, Tensor sequence_length, float dropout_prob = 0.0, bool is_bidirec = false, int hidden_size = 100, int num_layers = 1, bool is_test = false, int seed = 0) -> Tensor (out), Tensor (last_h), Tensor (last_c), Tensor (reserve), Tensor (state_out) + args: (Tensor x, Tensor init_h, Tensor init_c, Tensor[] weight_list, Tensor sequence_length, Tensor out, Tensor reserve, Tensor state_out, Tensor out_grad, Tensor last_h_grad, Tensor last_c_grad, float dropout_prob = 0.0, bool is_bidirec = false, int hidden_size = 100, int num_layers = 1, bool is_test = false, int seed = 0) + output: Tensor (x_grad), Tensor (init_h_grad), Tensor (init_c_grad), Tensor[](weight_list_grad){weight_list.size()} + infer_meta: + func: CudnnLSTMGradInferMeta + param : [x, init_h, init_c, weight_list] + kernel: + func: cudnn_lstm_grad + data_type : out_grad + optional: weight_list, sequence_length, weight_list_grad + - backward_op : cummax_grad forward : cummax(Tensor x, int axis=-1, DataType dtype = DataType::INT64) -> Tensor(out), Tensor(indices) args : (Tensor x, Tensor indices, Tensor out_grad, int axis, DataType dtype) diff --git a/paddle/phi/api/yaml/op_compat.yaml b/paddle/phi/api/yaml/op_compat.yaml index fc635cbb6cfd27..67979903724c22 100755 --- a/paddle/phi/api/yaml/op_compat.yaml +++ b/paddle/phi/api/yaml/op_compat.yaml @@ -3825,6 +3825,14 @@ outputs: {out: Out} +- op: cudnn_lstm + backward: cudnn_lstm_grad + inputs: + {x: Input, init_h: InitH, init_c: InitC, w: W, weight_list: WeightList, sequence_length: SequenceLength} + outputs: + {reserve: Reserve, state_out: StateOut, out: Out, last_h: LastH, last_c: LastC} + drop_empty_grad : [weight_list_grad] + - op: decayed_adagrad inputs: {param : Param, grad : Grad, moment : Moment, learning_rate : LearningRate} diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index 59661f0be04c02..5c9a314435476d 100755 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -708,6 +708,18 @@ data_type : input backward : cross_entropy_with_softmax_grad +- op : cudnn_lstm + args: (Tensor x, Tensor init_h, Tensor init_c, Tensor w, Tensor[] weight_list, Tensor sequence_length, float dropout_prob = 0.0, bool is_bidirec = false, int hidden_size = 100, int num_layers = 1, bool is_test = false, int seed = 0) + output: Tensor (out), Tensor (last_h), Tensor (last_c), Tensor (reserve), Tensor (state_out) + infer_meta: + func: CudnnLSTMInferMeta + kernel: + func: cudnn_lstm + data_type: x + optional: w, weight_list, sequence_length + intermediate: reserve + backward: cudnn_lstm_grad + - op : cummax args : (Tensor x, int axis=-1, DataType dtype = DataType::INT64) output : Tensor(out), Tensor(indices) diff --git a/paddle/phi/infermeta/backward.cc b/paddle/phi/infermeta/backward.cc index 261b99512a0ffe..9b8e94fc1380ce 100644 --- a/paddle/phi/infermeta/backward.cc +++ b/paddle/phi/infermeta/backward.cc @@ -320,6 +320,29 @@ void CrossEntropyWithSoftmaxGradInferMeta(const MetaTensor& label, logits_grad->set_dtype(softmax.dtype()); } +void CudnnLSTMGradInferMeta( + const MetaTensor& x, + const MetaTensor& init_h, + const MetaTensor& init_c, + const paddle::optional>& weight_list, + MetaTensor* x_grad, + MetaTensor* init_h_grad, + MetaTensor* init_c_grad, + std::vector weight_list_grad) { + if (x_grad) { + x_grad->share_meta(x); + } + if (init_h_grad) { + init_h_grad->share_meta(init_h); + } + if (init_c_grad) { + init_c_grad->share_meta(init_c); + } + if (!weight_list_grad.empty()) { + UnchangedMultiInferMeta(weight_list.get(), weight_list_grad); + } +} + void DeformableConvGradInferMeta(const MetaTensor& x, const MetaTensor& offset, const MetaTensor& filter, diff --git a/paddle/phi/infermeta/backward.h b/paddle/phi/infermeta/backward.h index 88aea8f18181b6..789419a54fde6f 100644 --- a/paddle/phi/infermeta/backward.h +++ b/paddle/phi/infermeta/backward.h @@ -141,6 +141,16 @@ void CSoftmaxWithCrossEntropyGradInferMeta(const MetaTensor& softmax, MetaTensor* logits_grad, MetaConfig config = MetaConfig()); +void CudnnLSTMGradInferMeta( + const MetaTensor& x, + const MetaTensor& init_h, + const MetaTensor& init_c, + const paddle::optional>& weight_list, + MetaTensor* x_grad, + MetaTensor* init_h_grad, + MetaTensor* init_c_grad, + std::vector weight_list_grad); + void DeformableConvGradInferMeta(const MetaTensor& x, const MetaTensor& offset, const MetaTensor& filter,