Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lml/add prim ops #41201

Merged
merged 96 commits into from
Apr 13, 2022
Merged
Show file tree
Hide file tree
Changes from 93 commits
Commits
Show all changes
96 commits
Select commit Hold shift + click to select a range
80b884e
native commit for triple grad of sigmod
veyron95 Sep 22, 2021
d52b81c
Updated unittests files
veyron95 Sep 22, 2021
19d6b05
init functional jacobian api
Sep 22, 2021
f47b48f
merge upstream/develop
Sep 22, 2021
16c048a
Merge pull request #2 from veyron95/ops_derivative
JiabinYang Sep 22, 2021
a6a9053
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
JiabinYang Sep 22, 2021
4febae7
Updated trible_test func
veyron95 Sep 22, 2021
be9da74
Updated gradient_checker & test_script
veyron95 Sep 22, 2021
be2b30d
finish test with dtype float32
Sep 23, 2021
36b8c34
add float64 test case
Sep 23, 2021
35b1ce8
polish code
Sep 24, 2021
3a35a00
use atol=1e-5 with dtype float64
Sep 24, 2021
a3ea12e
fix for ci
Sep 24, 2021
8738cf8
set timeout for test_jacobian
Sep 24, 2021
d6e771e
fix dygraph grad to support high differential
JiabinYang Sep 24, 2021
0bd8287
polish API docstring
Sep 26, 2021
83c8395
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
veyron95 Sep 26, 2021
4109fc5
Updated gradient checker and some related files
veyron95 Sep 26, 2021
19e471c
Merge pull request #4 from veyron95/ops_derivative
JiabinYang Sep 26, 2021
1573b2c
Merge branch 'lml/jacobian' of https://github.com/levi131/Paddle into…
JiabinYang Sep 26, 2021
1408ef5
fix double grad strip error for high differential
JiabinYang Sep 26, 2021
ea78b6e
fix double grad strip error for high differential
JiabinYang Sep 26, 2021
2351a99
Add Sigmoid triple grad tests
veyron95 Sep 26, 2021
7a3fbd1
fix dygraph double grad dtype error when calling for high differentia…
JiabinYang Sep 26, 2021
42df611
Merge pull request #8 from veyron95/ops_derivative
JiabinYang Sep 26, 2021
a6dde75
Updated triple grad teses func
veyron95 Sep 27, 2021
848efcf
Use np.random to initialize ddx
veyron95 Sep 27, 2021
04eab89
Updated triple_grad_check func
veyron95 Sep 28, 2021
38ca20a
Merge pull request #9 from veyron95/ops_derivative
JiabinYang Sep 28, 2021
886d9fb
merge develop
JiabinYang Sep 28, 2021
e9f643d
add todo for gradient checker and refine some comments
JiabinYang Sep 28, 2021
2d6370b
remove additional code
JiabinYang Sep 28, 2021
a3b8e4e
add test for infer_var dtype warning
JiabinYang Sep 29, 2021
13af3ed
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
JiabinYang Sep 29, 2021
20ca8e7
add test for warnging in backward.py
JiabinYang Sep 29, 2021
a961e3c
format python code
JiabinYang Oct 11, 2021
ee5489d
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
JiabinYang Oct 11, 2021
a495960
support multi input in triple gradient checker
JiabinYang Oct 12, 2021
ebe8559
Add matmul triple grad kernel
veyron95 Oct 14, 2021
4f31159
Merge branch 'support_derivative' of https://github.com/JiabinYang/Pa…
veyron95 Oct 14, 2021
4d56a30
Updated comments of TODO
veyron95 Oct 14, 2021
15f2a32
Merge develop branch and all conflicts fixed
veyron95 Oct 14, 2021
07d1490
Supported some special tests
veyron95 Oct 14, 2021
d5fdd20
merge develop
JiabinYang Oct 15, 2021
0e44f39
merge jiabin/support_derivative branch
veyron95 Oct 15, 2021
b52794e
Change code-format to follow CI std
veyron95 Oct 18, 2021
4202d96
Updated gradient_checker.py
veyron95 Oct 19, 2021
91149a7
Fix conflicts
veyron95 Oct 19, 2021
e20ef17
Merge develop and fix conflicts
veyron95 Oct 19, 2021
d0741f4
Removed unnecessary printing log
veyron95 Oct 19, 2021
46dbd64
Change code style to follow CI std
veyron95 Oct 20, 2021
e32e10e
Merge remote-tracking branch '3rd_order/ops_derivative' into develop
Oct 20, 2021
46607df
Merge remote-tracking branch 'upstream/develop' into develop
Oct 20, 2021
9da53dd
Merge branch 'develop' of github.com:PaddlePaddle/Paddle into develop
Oct 22, 2021
528ef73
Merge remote-tracking branch 'upstream/develop' into develop
Nov 13, 2021
34116c1
Merge remote-tracking branch 'upstream/develop' into develop
Nov 30, 2021
affdf9f
Merge remote-tracking branch 'upstream/develop' into develop
Dec 8, 2021
6d14deb
Merge remote-tracking branch 'upstream/develop' into develop
Dec 16, 2021
ff8fe05
Merge remote-tracking branch 'upstream/develop' into develop
Dec 27, 2021
3aab3fd
Merge remote-tracking branch 'upstream/develop' into develop
Jan 25, 2022
7d9fe02
Merge remote-tracking branch 'upstream/develop' into develop
Feb 8, 2022
de96a19
Merge remote-tracking branch 'upstream/develop' into develop
Mar 3, 2022
2923e06
merge upstream
Mar 3, 2022
157a098
Merge remote-tracking branch 'upstream/develop' into develop
Mar 3, 2022
d1269ee
Merge remote-tracking branch 'upstream/develop' into develop
Mar 18, 2022
578c013
Merge remote-tracking branch 'upstream/develop' into develop
Mar 28, 2022
568a7cc
add_p
Mar 31, 2022
659f6d0
Merge remote-tracking branch 'upstream/develop' into lml/add_prim_ops
Mar 31, 2022
6754266
rm useless files
Mar 31, 2022
9bb5341
add sub_p mul_p div_p
Mar 31, 2022
94d463b
add sqrt_p and tanh_p
Mar 31, 2022
0a1819c
add reshape_p
Mar 31, 2022
36cc630
add broadcast_p
Mar 31, 2022
8b06ec8
add broadcast_p fill_constant_p matmul_p reduce_p reshape_p transpose_p
Apr 5, 2022
ccbeb71
add split_p and concat_p
Apr 5, 2022
25c2b9a
add gather_p and scatter_add_p
Apr 5, 2022
712984d
add slice_select_p and slice_assign_p
Apr 5, 2022
40e43a2
add multi input check for add_p, sub_p, mul_p, div_p
Apr 6, 2022
a36e487
update concat_p
Apr 6, 2022
7c44dfa
refine gather_p and scatter_add_p
Apr 6, 2022
14b400f
refine slice_assign_p and slice_select_p
Apr 7, 2022
8530425
Merge remote-tracking branch 'upstream/develop' into lml/add_prim_ops
Apr 7, 2022
04e40b8
add 9 test for prim ops
Apr 10, 2022
8afd272
add more test and fix some bug
Apr 10, 2022
86a05d0
add more test
Apr 11, 2022
a6ee225
register proto
Apr 11, 2022
ed6e816
Merge remote-tracking branch 'upstream/develop' into lml/add_prim_ops
Apr 11, 2022
c070a27
add shape valid check for broadcast_p op, and add keepdim attr into r…
Apr 12, 2022
d70a640
support multi input and multi output for split_p and concat_p
Apr 12, 2022
a22b82c
fix slice bug for slice_select_p and slice_assign_p
Apr 12, 2022
675957d
dtype for axis attr should be long int
Apr 12, 2022
5706f3d
update dtype for axis attr int64_t
Apr 12, 2022
e4e91ea
update for iscan CI
Apr 13, 2022
b1f06b4
add more shape and dtype check
Apr 13, 2022
48f3bf1
Merge remote-tracking branch 'upstream/develop' into lml/add_prim_ops
Apr 13, 2022
a8802e7
change IndexTensor into int32 dtype
Apr 13, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ add_subdirectory(reduce_ops)
add_subdirectory(sequence_ops)
add_subdirectory(string)
add_subdirectory(jit)
add_subdirectory(prim_ops)
if(WITH_MKLDNN)
add_subdirectory(mkldnn)
endif()
Expand Down
28 changes: 28 additions & 0 deletions paddle/fluid/operators/prim_ops/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
include(operators)
if(WITH_UNITY_BUILD)
# Load Unity Build rules for operators in paddle/fluid/operators/prim_ops.
include(unity_build_rule.cmake)
endif()
register_operators()

SET(PRIM_OP_SRCS
reshape_p_op.cc
broadcast_p_op.cc
reduce_p_op.cc
transpose_p_op.cc
split_p_op.cc
concat_p_op.cc
slice_select_p_op.cc
slice_assign_p_op.cc
gather_p_op.cc
scatter_add_p_op.cc
add_p_op.cc
sub_p_op.cc
mul_p_op.cc
div_p_op.cc
sqrt_p_op.cc
tanh_p_op.cc
matmul_p_op.cc
fill_constant_p_op.cc)

cc_test(prim_op_test SRCS prim_op_test.cc ${PRIM_OP_SRCS} DEPS op_registry)
116 changes: 116 additions & 0 deletions paddle/fluid/operators/prim_ops/add_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace operators {
class AddPrimOp : public framework::OperatorBase {
public:
AddPrimOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator add_p should not be excuted directly"));
}
};

class AddPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of add_p op.");
AddInput("Y", "(Tensor), The input tensor of add_p op.");
AddOutput("Z", "(Tensor), The output tensor of add_p op.");
AddComment(R"DOC(
Autograd primitive add_p operator.
)DOC");
}
};

class AddPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetInputVarPtrs("Y")[0];
framework::InferShapeVarPtr z_var_ptr = ctx->GetOutputVarPtrs("Z")[0];

framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
framework::VarDesc *y_var = BOOST_GET(framework::VarDesc *, y_var_ptr);
auto x_shape = x_var->GetShape();
auto y_shape = y_var->GetShape();
size_t x_rank = x_shape.size();
size_t y_rank = y_shape.size();
PADDLE_ENFORCE_EQ(x_rank, y_rank,
platform::errors::InvalidArgument(
"The dimensions of two input tensor should be same, "
"but get %d and %d",
x_rank, y_rank));
for (size_t i = 0; i < x_rank; ++i) {
PADDLE_ENFORCE_EQ(
x_shape[i], y_shape[i],
platform::errors::InvalidArgument(
"The shape of two input tensor at dimension %d should be same, "
"but get %d and %d",
i, x_rank, y_rank));
}

BOOST_GET(framework::VarDesc *, z_var_ptr)->SetShape(x_shape);
}
};

class AddPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Input(ctx, "Y")[0];
auto z_name = Output(ctx, "Z")[0];
auto x_type = GetType(ctx, x_name);
auto y_type = GetType(ctx, y_name);
auto x_dtype = GetDataType(ctx, x_name);
auto y_dtype = GetDataType(ctx, y_name);
PADDLE_ENFORCE_EQ(x_type, y_type,
platform::errors::InvalidArgument(
"The type of two input tensor should be same, "
"but get %d and %d",
x_type, y_type));
PADDLE_ENFORCE_EQ(x_dtype, y_dtype,
platform::errors::InvalidArgument(
"The datatype of two input tensor should be same, "
"but get %d and %d",
x_dtype, y_dtype));

SetType(ctx, z_name, x_type);
SetDataType(ctx, z_name, x_dtype);
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(add_p, paddle::operators::AddPrimOp,
paddle::operators::AddPrimOpMaker,
paddle::operators::AddPrimOpShapeInference,
paddle::operators::AddPrimOpVarTypeInference);
110 changes: 110 additions & 0 deletions paddle/fluid/operators/prim_ops/broadcast_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace operators {
class BroadcastPrimOp : public framework::OperatorBase {
public:
BroadcastPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator broadcast_p should not be excuted directly"));
}
};

class BroadcastPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(Tensor), The input tensor of broadcast_p op.");
AddOutput("Y", "(Tensor), The output tensor of broadcast_p op.");
AddAttr<std::vector<int64_t>>(
"shape",
"(std::vector<int64_t>) Target shape of broadcast_p operator.");
AddComment(R"DOC(
Autograd primitive broadcast_p operator.
)DOC");
}
};

static void CheckShapeValid(const std::vector<int64_t> &x_shape,
const std::vector<int64_t> &target_shape) {
size_t x_rank = x_shape.size();
size_t target_rank = target_shape.size();
PADDLE_ENFORCE_GE(target_rank, x_rank,
platform::errors::InvalidArgument(
"The rank of target shape should be greater than or "
"equal to input tensor's dimensions, "
"but received %d and %d",
target_rank, x_rank));
std::vector<int64_t>::const_iterator it = target_shape.begin();
for (size_t i = 0; i < x_rank; i++, it++) {
if (x_shape[i] != 1) {
it = std::find(it, target_shape.end(), x_shape[i]);
}
PADDLE_ENFORCE_EQ(
it != target_shape.end(), true,
platform::errors::InvalidArgument(
"Invalid shape, can not broadcast input tensor into target shape,"
"the first dismatching shape %d is shape of input tensor at "
"dimension %d",
x_shape[i], i));
}
}

class BroadcastPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
framework::InferShapeVarPtr x_var_ptr = ctx->GetInputVarPtrs("X")[0];
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
framework::VarDesc *x_var = BOOST_GET(framework::VarDesc *, x_var_ptr);
auto x_shape = x_var->GetShape();
auto target_shape = ctx->Attrs().Get<std::vector<int64_t>>("shape");
CheckShapeValid(x_shape, target_shape);
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(target_shape);
}
};

class BroadcastPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_name = Input(ctx, "X")[0];
auto y_name = Output(ctx, "Y")[0];
SetType(ctx, y_name, GetType(ctx, x_name));
SetDataType(ctx, y_name, GetDataType(ctx, x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(broadcast_p, paddle::operators::BroadcastPrimOp,
paddle::operators::BroadcastPrimOpMaker,
paddle::operators::BroadcastPrimOpShapeInference,
paddle::operators::BroadcastPrimOpVarTypeInference);
134 changes: 134 additions & 0 deletions paddle/fluid/operators/prim_ops/concat_p_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"

namespace paddle {
namespace framework {
class InferShapeContext;
class VarDesc;
} // namespace framework
} // namespace paddle

namespace paddle {
namespace operators {
class ConcatPrimOp : public framework::OperatorBase {
public:
ConcatPrimOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: framework::OperatorBase(type, inputs, outputs, attrs) {}
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Prim operator concat_p should not be excuted directly"));
}
};

class ConcatPrimOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("XS", "(Tensor), The input tensors of concat_p op.")
.AsDuplicable();
AddOutput("Y", "(Tensor), The output tensor of concat_p op.");
AddAttr<int64_t>("axis", "(int64_t), The axis along which to concat.");
AddComment(R"DOC(
Autograd primitive concat_p operator.
)DOC");
}
};

class ConcatPrimOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
auto x_var_ptrs = ctx->GetInputVarPtrs("XS");
framework::InferShapeVarPtr y_var_ptr = ctx->GetOutputVarPtrs("Y")[0];
auto axis = ctx->Attrs().Get<int64_t>("axis");
int64_t cnt_along_axis = 0;
framework::VarDesc *first_x_var =
BOOST_GET(framework::VarDesc *, x_var_ptrs[0]);
auto first_x_shape = first_x_var->GetShape();
cnt_along_axis += first_x_shape[axis];
size_t first_x_rank = first_x_shape.size();
for (size_t i = 1; i < x_var_ptrs.size(); ++i) {
framework::VarDesc *x_var =
BOOST_GET(framework::VarDesc *, x_var_ptrs[i]);
auto x_shape = x_var->GetShape();
cnt_along_axis += x_shape[axis];
size_t x_rank = x_shape.size();
PADDLE_ENFORCE_EQ(
x_rank, first_x_rank,
platform::errors::InvalidArgument("The dimensions of %d input tensor "
"should be same as the dimensions "
"of 1st input tensor's, "
"but get %d and %d",
i + 1, x_rank, first_x_rank));
for (size_t j = 0; j < x_rank; ++j) {
if (j != size_t(axis)) {
PADDLE_ENFORCE_EQ(x_shape[j], first_x_shape[j],
platform::errors::InvalidArgument(
"The shape of %d input tensor at dimension %d "
"should be same as the 1st input tensor's, "
"but get %d and %d",
i + 1, j, x_shape[j], first_x_shape[j]));
}
}
}

std::vector<int64_t> y_shape(first_x_shape);
y_shape[axis] = cnt_along_axis;
BOOST_GET(framework::VarDesc *, y_var_ptr)->SetShape(y_shape);
}
};

class ConcatPrimOpVarTypeInference
: public framework::StaticGraphVarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
auto x_names = Input(ctx, "XS");
auto y_name = Output(ctx, "Y")[0];
auto first_x_name = x_names[0];
auto first_x_type = GetType(ctx, first_x_name);
auto first_x_dtype = GetDataType(ctx, first_x_name);
for (size_t i = 1; i < x_names.size(); ++i) {
auto x_name = x_names[i];
auto x_type = GetType(ctx, x_name);
auto x_dtype = GetDataType(ctx, x_name);
PADDLE_ENFORCE_EQ(x_type, first_x_type,
platform::errors::InvalidArgument(
"The type of %d input tensor should be same as the "
"first input tensor's, "
"but get %d and %d",
i + 1, x_type, first_x_type));
PADDLE_ENFORCE_EQ(x_dtype, first_x_dtype,
platform::errors::InvalidArgument(
"The datatype of %d input tensor should be same as "
"the first input tensor's, "
"but get %d and %d",
i + 1, x_dtype, first_x_dtype));
}
SetType(ctx, y_name, GetType(ctx, first_x_name));
SetDataType(ctx, y_name, GetDataType(ctx, first_x_name));
}
};

} // namespace operators
} // namespace paddle

REGISTER_OPERATOR(concat_p, paddle::operators::ConcatPrimOp,
paddle::operators::ConcatPrimOpMaker,
paddle::operators::ConcatPrimOpShapeInference,
paddle::operators::ConcatPrimOpVarTypeInference);
Loading