Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize nearest_interp forward #38528

Merged
merged 38 commits into from
Jan 25, 2022
Merged
Show file tree
Hide file tree
Changes from 27 commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
8f532b0
Merge pull request #1 from PaddlePaddle/develop
AshburnLee Sep 8, 2020
5b5804d
Merge pull request #2 from PaddlePaddle/develop
AshburnLee Sep 17, 2020
cee2470
Merge pull request #3 from PaddlePaddle/develop
AshburnLee Sep 30, 2020
5be3a45
Merge pull request #4 from PaddlePaddle/develop
AshburnLee Oct 13, 2020
a1d92b7
Merge pull request #5 from PaddlePaddle/develop
AshburnLee Oct 20, 2020
e674a5d
Merge pull request #6 from PaddlePaddle/develop
AshburnLee Nov 15, 2020
855d00b
Merge pull request #7 from PaddlePaddle/develop
AshburnLee Nov 18, 2020
7cb2c97
Merge pull request #8 from PaddlePaddle/develop
AshburnLee Mar 31, 2021
db9fc91
Merge pull request #9 from PaddlePaddle/develop
AshburnLee Apr 7, 2021
c7b68c8
Merge branch 'develop' of https://github.com/PaddlePaddle/paddle into…
AshburnLee Apr 26, 2021
0fd630e
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Aug 16, 2021
4bbb33b
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Sep 28, 2021
30a1a89
Merge branch 'PaddlePaddle:develop' into develop
AshburnLee Nov 22, 2021
ce3deec
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 21, 2021
8c3620b
init commit
AshburnLee Dec 28, 2021
5719490
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Dec 28, 2021
cb7cc51
remove comments
AshburnLee Dec 28, 2021
26a2aa8
Merge branches 'develop' and 'develop' of https://github.com/PaddlePa…
AshburnLee Dec 28, 2021
76bed9b
remove nchw branch
AshburnLee Dec 28, 2021
b7fd119
optimize code
AshburnLee Dec 28, 2021
3899e0b
apply fast div mod in 1D kernel, rm 3D kernel
AshburnLee Jan 5, 2022
88ff573
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 5, 2022
5844cc5
move init of FastDivMode to CPU
AshburnLee Jan 7, 2022
42c4038
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 7, 2022
cee38bf
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 12, 2022
35086d6
3D kernel for nchw, FastDiv for 1D kernel
AshburnLee Jan 12, 2022
5a79f12
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 12, 2022
5e08a97
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 12, 2022
0fd2b3f
debug done. process boundary
AshburnLee Jan 18, 2022
214b4a0
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 18, 2022
86dd9e1
2^n
AshburnLee Jan 18, 2022
4a07e34
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 18, 2022
b2b85dd
optimize
AshburnLee Jan 19, 2022
45716d7
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 19, 2022
a39400c
optimize
AshburnLee Jan 20, 2022
a0d9431
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 20, 2022
efa4297
change code & optimize code
AshburnLee Jan 21, 2022
14f6927
Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into…
AshburnLee Jan 21, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
116 changes: 89 additions & 27 deletions paddle/fluid/operators/interpolate_v2_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,39 +16,92 @@
#include "paddle/fluid/platform/device/gpu/gpu_device_function.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/fluid/platform/fast_divmod.h"

namespace paddle {
namespace operators {

using framework::Tensor;
using platform::FastDivMod;
using DataLayout = framework::DataLayout;

struct FastDivModForInterpolate {
public:
FastDivMod channels_div;
FastDivMod output_w_div;
// FastDivMod outimg_w_;
// FastDivMod out_size_;
FastDivMod outimgw_mul_channel_div;

explicit HOSTDEVICE FastDivModForInterpolate(const int channels,
const int output_w,
/*const int outimg_w,
const int out_size,*/
const int outimgw_mul_chann)
: channels_div(FastDivMod(channels)),
output_w_div(FastDivMod(output_w)),
// outimg_w_(FastDivMod(outimg_w)),
// out_size_(FastDivMod(out_size)),
outimgw_mul_channel_div(FastDivMod(outimgw_mul_chann)) {}
};

template <typename T>
__global__ void KeNearestNeighborInterpNCHWFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t num_batchs, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t num_channels, const float ratio_h,
const float ratio_w, const bool align_corners) {
int out_img_idx = threadIdx.x + blockIdx.x * blockDim.x;
int out_img_idy = threadIdx.y + blockIdx.y * blockDim.y;
int nc_id = threadIdx.z + blockIdx.z * blockDim.z;
int nc_stride = blockDim.z * gridDim.z;
int nc = num_batchs * num_channels;

// nearest_sampling by multiple read in_addr and write to out_addr
int in_img_idx = (align_corners)
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);
int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
: static_cast<int>(ratio_h * out_img_idy);

int in_index = (nc_id * in_img_h + in_img_idy) * in_img_w + in_img_idx;
int in_index_stride = nc_stride * in_img_h * in_img_w;

int out_index = (nc_id * out_img_h + out_img_idy) * out_img_w + out_img_idx;
int out_index_stride = nc_stride * out_img_h * out_img_w;

while (nc_id < nc) {
out[out_index] = in[in_index];
in_index += in_index_stride;
out_index += out_index_stride;
nc_id += nc_stride;
}
}

template <typename T>
__global__ void KeNearestNeighborInterpFw(
const T* in, const size_t in_img_h, const size_t in_img_w,
const size_t input_h, const size_t input_w, T* out, const size_t out_img_h,
const size_t out_img_w, const size_t output_h, const size_t output_w,
const size_t num_channels, const float ratio_h, const float ratio_w,
const bool align_corners, const DataLayout data_layout) {
const bool align_corners, FastDivModForInterpolate divmods) {
int nthreads = output_h * output_w;
int tid = blockIdx.x * blockDim.x + threadIdx.x;
int stride = blockDim.x * gridDim.x;
int in_img_size = in_img_h * in_img_w;
int out_img_size = out_img_h * out_img_w;

for (; tid < nthreads; tid += stride) {
int out_id_h = tid / output_w;
int out_id_w = tid % output_w;
int in_img_size = input_w / num_channels;
int out_img_size = output_w / num_channels;
auto out_id_divmod = divmods.output_w_div.Divmod(tid);
int out_id_h = out_id_divmod.val[0];
int out_id_w = out_id_divmod.val[1];

int channel_id, out_img_idy, out_img_idx;
if (data_layout == DataLayout::kNCHW) {
channel_id = out_id_w / out_img_size;
out_img_idy = (out_id_w % out_img_size) / out_img_w;
out_img_idx = tid % out_img_w;
} else {
out_img_idy = out_id_w / (out_img_w * num_channels);
out_img_idx = out_id_w % (out_img_w * num_channels) / num_channels;
channel_id = tid % num_channels;
}
int channel_id = divmods.channels_div.Divmod(tid).val[1];
auto outimg_id_divmod = divmods.outimgw_mul_channel_div.Divmod(out_id_w);
int out_img_idy = outimg_id_divmod.val[0];
int out_img_idx =
divmods.channels_div.Divmod(outimg_id_divmod.val[1]).val[0];

int in_img_idy = (align_corners)
? static_cast<int>(ratio_h * out_img_idy + 0.5)
Expand All @@ -57,13 +110,8 @@ __global__ void KeNearestNeighborInterpFw(
? static_cast<int>(ratio_w * out_img_idx + 0.5)
: static_cast<int>(ratio_w * out_img_idx);

if (data_layout == DataLayout::kNCHW) {
out[tid] = in[out_id_h * input_w + channel_id * in_img_size +
in_img_idy * in_img_w + in_img_idx];
} else {
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
out[tid] = in[out_id_h * input_w + in_img_idy * in_img_w * num_channels +
in_img_idx * num_channels + channel_id];
}
}

Expand Down Expand Up @@ -1180,11 +1228,25 @@ static void Interpolate2DCUDAFwd(const framework::ExecutionContext& ctx,
platform::GetGpuLaunchConfig1D(ctx.cuda_device_context(), pixelNum);

if ("nearest" == interp_method) {
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, data_layout);
if (data_layout == DataLayout::kNCHW) {
// get launch 3D config
platform::GpuLaunchConfig config_3d = platform::GetCpuLaunchConfig3D(
ctx.cuda_device_context(), n * c, out_h, out_w);
KeNearestNeighborInterpNCHWFw<
T><<<config_3d.block_per_grid, config_3d.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, output_data, out_h, out_w, c, ratio_h,
ratio_w, align_corners);
} else {
int64_t cw = c * out_w;
auto interp_divmods =
FastDivModForInterpolate(c, out_chw, /*out_w, out_hw, */ cw);
KeNearestNeighborInterpFw<
T><<<config.block_per_grid, config.thread_per_block, 0,
ctx.cuda_device_context().stream()>>>(
input_data, in_h, in_w, n, in_chw, output_data, out_h, out_w, n,
out_chw, c, ratio_h, ratio_w, align_corners, interp_divmods);
}
} else if ("bilinear" == interp_method) {
dim3 thread_num = config.thread_per_block;
#ifdef WITH_NV_JETSON
Expand Down
24 changes: 23 additions & 1 deletion paddle/fluid/platform/device/gpu/gpu_launch_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,29 @@ inline GpuLaunchConfig GetGpuLaunchConfig2D(
return config;
}

// TODO(wangchaochaohu): 3D will add later
inline GpuLaunchConfig GetCpuLaunchConfig3D(
const platform::CUDADeviceContext& context, int num_img, int height,
int width) {
const int kThreadsPerBlock = 256;
int max_threads_per_block = context.GetMaxThreadsPerBlock(); // 1024
int max_threads = std::min(kThreadsPerBlock, max_threads_per_block);

int block_x = std::min(width, max_threads);
int block_y = std::min(height, max_threads / block_x);
int block_z = std::min(num_img, max_threads / block_x / block_y);

dim3 max_grid_dim = context.GetCUDAMaxGridDimSize();
int grid_x = std::min<int>(max_grid_dim.x, DivUp(width, block_x));
int grid_y = std::min<int>(max_grid_dim.y, DivUp(height, block_y));
int grid_z = std::min<int>(max_grid_dim.z, DivUp(num_img, block_z));

const int capability = context.GetComputeCapability();
GpuLaunchConfig config;
config.compute_capability = capability;
config.thread_per_block = dim3(block_x, block_y, block_z);
config.block_per_grid = dim3(grid_x, grid_y, grid_z);
return config;
}

} // namespace platform
} // namespace paddle
Expand Down