Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add yaml for flatten_contiguous_range OP #41345

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/phi/kernels/flatten_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ namespace phi {

template <typename T, typename Context>
void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
auto xshape_dims = xshape.dims();
dev_ctx.Alloc(x_grad, out_grad.dtype());
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/flatten_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ namespace phi {

template <typename T, typename Context>
void FlattenGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const DenseTensor& xshape,
const DenseTensor& out_grad,
DenseTensor* x_grad);

} // namespace phi
2 changes: 1 addition & 1 deletion paddle/phi/ops/compat/flatten_sig.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ KernelSignature FlattenOpArgumentMapping(const ArgumentMappingContext& ctx) {
KernelSignature FlattenGradOpArgumentMapping(
const ArgumentMappingContext& ctx) {
return KernelSignature(
"flatten_grad", {GradVarName("Out"), "XShape"}, {}, {GradVarName("X")});
"flatten_grad", {"XShape", GradVarName("Out")}, {}, {GradVarName("X")});
}

} // namespace phi
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/tests/api/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ cc_test(test_dot_api SRCS test_dot_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_matmul_api SRCS test_matmul_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_empty_api SRCS test_empty_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_fill_api SRCS test_fill_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_flatten_api SRCS test_flatten_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_elementwise_api SRCS test_elementwise_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_cast_api SRCS test_cast_api.cc DEPS ${COMMON_API_TEST_DEPS})
cc_test(test_reshape_api SRCS test_reshape_api.cc DEPS ${COMMON_API_TEST_DEPS})
Expand Down
75 changes: 0 additions & 75 deletions paddle/phi/tests/api/test_flatten_api.cc

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

class TestFlattenOp(OpTest):
def setUp(self):
self.python_api = paddle.flatten
self.python_out_sig = ["Out"]
self.op_type = "flatten_contiguous_range"
self.start_axis = 0
self.stop_axis = -1
Expand All @@ -35,10 +37,10 @@ def setUp(self):
}

def test_check_output(self):
self.check_output(no_check_set=["XShape"])
self.check_output(no_check_set=["XShape"], check_eager=True)

def test_check_grad(self):
self.check_grad(["X"], "Out")
self.check_grad(["X"], "Out", check_eager=True)

def init_test_case(self):
self.in_shape = (3, 2, 5, 4)
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,7 +697,11 @@ def flatten(x, start_axis=0, stop_axis=-1, name=None):
if start_axis > stop_axis:
raise ValueError("The stop_axis should be larger than stat_axis")

if paddle.in_dynamic_mode():
if in_dygraph_mode():
dy_out, _ = _C_ops.final_state_flatten(x, start_axis, stop_axis)
return dy_out

if _in_legacy_dygraph():
dy_out, _ = _C_ops.flatten_contiguous_range(x, 'start_axis', start_axis,
'stop_axis', stop_axis)
return dy_out
Expand Down
10 changes: 7 additions & 3 deletions python/paddle/utils/code_gen/api.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -547,11 +547,15 @@

- api : flatten
args : (Tensor x, int start_axis, int stop_axis)
output : Tensor
output : Tensor(out), Tensor(xshape)
infer_meta :
func : FlattenInferMeta
func : FlattenWithXShapeInferMeta
kernel :
func : flatten
func : flatten_with_xshape
backend : x
inplace : (x -> out)
view : (x -> out)
backward : flatten_grad

# flip
- api : flip
Expand Down
13 changes: 13 additions & 0 deletions python/paddle/utils/code_gen/backward.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,19 @@
kernel :
func : expm1_grad

- backward_api : flatten_grad
forward : flatten(Tensor x, int start_axis, int stop_axis) -> Tensor(out), Tensor(xshape)
args : (Tensor xshape, Tensor out_grad)
output : Tensor(x_grad)
infer_meta :
func : KernelWithXShapeInferMeta
param : [xshape]
kernel :
func : flatten_grad
data_type: out_grad
backend: out_grad
layout: out_grad

- backward_api : floor_grad
forward : floor(Tensor x) -> Tensor(out)
args : (Tensor out_grad)
Expand Down
2 changes: 1 addition & 1 deletion tools/infrt/skipped_phi_api.json
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"phi_apis":["conj", "nll_loss"],
"phi_apis":["conj", "nll_loss", "flatten"],
"phi_kernels":["equal_all"]
}