Skip to content

Commit

Permalink
[Phi] move flip op to phi kernel (#39822)
Browse files Browse the repository at this point in the history
  • Loading branch information
m3ngyang authored Feb 23, 2022
1 parent 64ed92b commit ad294a8
Show file tree
Hide file tree
Showing 6 changed files with 250 additions and 222 deletions.
13 changes: 3 additions & 10 deletions paddle/fluid/operators/flip_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#include "paddle/fluid/operators/flip_op.h"
#include <string>
#include <unordered_map>
#include <vector>

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/complex.h"

namespace paddle {
namespace operators {
Expand All @@ -29,6 +29,7 @@ class FlipOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;

// TODO move to phi kernel
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE_EQ(
ctx->HasInput("X"), true,
Expand Down Expand Up @@ -150,14 +151,6 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR(flip, ops::FlipOp, ops::FlipOpMaker, ops::FlipOpInferVarType,
ops::FlipOpGradMaker<paddle::framework::OpDesc>,
ops::FlipOpGradMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(
flip, ops::FlipKernel<paddle::platform::CPUDeviceContext, float>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, double>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int32_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, int64_t>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, bool>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<float>>,
ops::FlipKernel<paddle::platform::CPUDeviceContext, plat::complex<double>>);

/* ========================== register checkpoint ===========================*/
REGISTER_OP_VERSION(flip)
Expand Down
129 changes: 0 additions & 129 deletions paddle/fluid/operators/flip_op.cu

This file was deleted.

83 changes: 0 additions & 83 deletions paddle/fluid/operators/flip_op.h

This file was deleted.

77 changes: 77 additions & 0 deletions paddle/phi/kernels/cpu/flip_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/flip_kernel.h"

#include <bitset>

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"

namespace phi {

constexpr size_t dim_bitset_size = 64;

template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
auto x_dims = x.dims();
const int total_dims = x_dims.size();
std::bitset<dim_bitset_size> dim_bitset;
for (size_t i = 0; i < axis.size(); ++i) {
int dim = axis[i];
if (axis[i] < 0) {
dim += total_dims;
}
dim_bitset[dim] = true;
}
auto x_strides = phi::stride(x_dims);
auto numel = x.numel();
const T* x_data = x.data<T>();
T* out_data = dev_ctx.template Alloc<T>(out);
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for
#endif
for (int64_t i = 0; i < numel; ++i) {
int64_t cur_indices = i;
int64_t rem = 0;
int64_t dst_offset = 0;

for (int d = 0; d < total_dims; ++d) {
int64_t temp = cur_indices;
cur_indices = cur_indices / x_strides[d];
rem = temp - cur_indices * x_strides[d];
dst_offset += dim_bitset[d] ? (x_dims[d] - 1 - cur_indices) * x_strides[d]
: cur_indices * x_strides[d];
cur_indices = rem;
}
out_data[i] = x_data[dst_offset];
}
}

} // namespace phi

PD_REGISTER_KERNEL(flip,
CPU,
ALL_LAYOUT,
phi::FlipKernel,
float,
double,
int32_t,
int64_t,
bool,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
29 changes: 29 additions & 0 deletions paddle/phi/kernels/flip_kernel.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

#include "paddle/phi/core/dense_tensor.h"

namespace phi {

template <typename T, typename Context>
void FlipKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out);

} // namespace phi
Loading

0 comments on commit ad294a8

Please sign in to comment.