Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI] Optimize depthwise conv #71537

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
251 changes: 179 additions & 72 deletions paddle/phi/kernels/gpu/depthwise_conv.h
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class DepthwiseConvFilterGradFunctor {
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
#define SMALL_THRESHOLD 64

template <typename T>
__forceinline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
Expand Down Expand Up @@ -181,39 +182,41 @@ __device__ __inline__ void KernelDepthwiseConvNCHW(

// A Cuda kernel to compute the depthwise convolution forward pass
// in NHWC format.
template <typename T, bool fuse_relu_before_conv>
template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvNHWC(
ARG_DEFINE_KernelDepthwiseConv) {
const int fw_size = c_filter != -1 ? c_filter : filter_width;
const int fh_size = c_filter != -1 ? c_filter : filter_height;
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= (output_channels * batch_size * output_height * output_width))
if (idx >= (output_channels * batch_size * output_height * output_width)) {
return;
}

const int c_out = idx % output_channels;
const int w_out = (idx / output_channels) % output_width;
const int h_out = (idx / output_channels / output_width) % output_height;
const int batch = idx / output_width / output_height / output_channels;
int tmp_1 = idx / output_channels;
const int c_out = idx - tmp_1 * output_channels;
int tmp_2 = tmp_1 / output_width;
const int w_out = tmp_1 - tmp_2 * output_width;
tmp_1 = tmp_2;
tmp_2 = tmp_1 / output_height;
const int h_out = tmp_1 - tmp_2 * output_height;
const int batch = tmp_2;

const int c_in = c_out / filter_multiplier;
T value(0);
const int in_offset =
batch * input_height * input_width * input_channels + c_in;
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + filter_height * dilate_height;
const int w_in_end = w_in_start + filter_width * dilate_width;

const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;
int weight_offset = 0;

#pragma unroll
for (int h_in = h_in_start; h_in < h_in_end; h_in += dilate_height) {
for (int fh = 0, h_in = h_in_start; fh < fh_size;
++fh, h_in += dilate_height) {
#pragma unroll
for (int w_in = w_in_start; w_in < w_in_end; w_in += dilate_width) {
if (h_in >= h_start && h_in < h_end && w_in >= w_start && w_in < w_end) {
int offset = ((batch * input_height + h_in) * input_width + w_in) *
input_channels +
c_in;
for (int fw = 0, w_in = w_in_start; fw < fw_size;
++fw, w_in += dilate_width) {
if (h_in >= 0 && h_in < input_height && w_in >= 0 && w_in < input_width) {
int offset = in_offset + (h_in * input_width + w_in) * input_channels;
T in_data = input_data[offset];
const T* weight = filter_data + weight_offset * output_channels + c_out;
if (fuse_relu_before_conv) {
Expand All @@ -226,10 +229,7 @@ __device__ __inline__ void KernelDepthwiseConvNHWC(
weight_offset++;
}
}
int index = batch * output_channels * output_height * output_width +
h_out * output_width * output_channels + w_out * output_channels +
c_out;
output_data[index] = value;
output_data[idx] = value;
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
Expand All @@ -251,17 +251,10 @@ __device__ __inline__ void KernelDepthwiseConvCFilterNCHW(
T value(0);
const int h_in_start = -padding_height + h_out * stride_height;
const int w_in_start = -padding_width + w_out * stride_width;
const int h_in_end = h_in_start + c_filter * dilate_height;
const int w_in_end = w_in_start + c_filter * dilate_width;

int in_offset =
((batch * input_channels + c_in) * input_height) * input_width;

const int h_end = h_in_end < input_height ? h_in_end : input_height;
const int w_end = w_in_end < input_width ? w_in_end : input_width;
const int h_start = h_in_start > 0 ? h_in_start : 0;
const int w_start = w_in_start > 0 ? w_in_start : 0;

for (int h_in = h_in_start, h_f = 0; h_f < c_filter;
h_in += dilate_height, h_f++) {
for (int w_in = w_in_start, w_f = 0; w_f < c_filter;
Expand Down Expand Up @@ -380,25 +373,26 @@ __global__ void KernelDepthwiseConvSp(ARG_DEFINE_KernelDepthwiseConv) {
dilate_width,
output_data);
} else {
KernelDepthwiseConvNHWC<T, fuse_relu_before_conv>(input_data,
filter_data,
batch_size,
output_channels,
output_height,
output_width,
input_channels,
input_height,
input_width,
final_filter_multiplier,
filter_height,
filter_width,
h_stride,
w_stride,
padding_height,
padding_width,
dilate_height,
dilate_width,
output_data);
KernelDepthwiseConvNHWC<T, c_filter, fuse_relu_before_conv>(
input_data,
filter_data,
batch_size,
output_channels,
output_height,
output_width,
input_channels,
input_height,
input_width,
final_filter_multiplier,
filter_height,
filter_width,
h_stride,
w_stride,
padding_height,
padding_width,
dilate_height,
dilate_width,
output_data);
}
} else {
if (data_layout != DataLayout::kNHWC) {
Expand Down Expand Up @@ -1020,7 +1014,43 @@ __device__ __forceinline__ void NoReturnAtomicAdd(T* tensor,

template <typename T,
typename index_t,
typename std::enable_if_t<!std::is_same_v<phi::dtype::float16, T>>* =
typename std::enable_if_t<std::is_same_v<phi::dtype::bfloat16, T>>* =
nullptr>
__device__ __forceinline__ void NoReturnAtomicAdd(T* tensor,
index_t index,
const index_t numel,
T value) {
#if (defined(PADDLE_WITH_HIP) || \
(defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)))
phi::CudaAtomicAdd(tensor + index, value);
#else
// Check if 32 bit aligned
__nv_bfloat16* target_addr = reinterpret_cast<__nv_bfloat16*>(tensor + index);
bool low_byte =
(reinterpret_cast<std::uintptr_t>(target_addr) % sizeof(__nv_bfloat16) ==
0);

if (low_byte && index < (numel - 1)) {
__nv_bfloat162 value2;
value2.x = value.to_nv_bfloat16();
value2.y = __int2bfloat16_rz(0);
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr), value2);
} else if (!low_byte && index > 0) {
__nv_bfloat162 value2;
value2.x = __int2bfloat16_rz(0);
value2.y = value.to_nv_bfloat16();
atomicAdd(reinterpret_cast<__nv_bfloat162*>(target_addr - 1), value2);
} else {
atomicAdd(reinterpret_cast<__nv_bfloat16*>(tensor) + index,
value.to_nv_bfloat16());
}
#endif
}

template <typename T,
typename index_t,
typename std::enable_if_t<!std::is_same_v<phi::dtype::float16, T> &&
!std::is_same_v<phi::dtype::bfloat16, T>>* =
nullptr>
__device__ __forceinline__ void NoReturnAtomicAdd(T* tensor,
index_t index,
Expand Down Expand Up @@ -1109,6 +1139,68 @@ __device__ __inline__ void KernelDepthwiseConvFilterGradCFilterNHWC(
}
}

template <typename T, int c_filter, bool fuse_relu_before_conv>
__device__ __inline__ void KernelDepthwiseConvFilterGradCFilterSmallChannelNHWC(
const T* output_grad_data,
const T* input_data,
const int num,
const int output_channels,
const int output_height,
const int output_width,
const int input_channels,
const int input_height,
const int input_width,
const int filter_multiplier,
const int filter_height,
const int filter_width,
const int stride_height,
const int stride_width,
const int padding_height,
const int padding_width,
const int dilate_height,
const int dilate_width,
T* filter_grad_data) {
const int bid = blockIdx.y;
const int idx = blockIdx.x * blockDim.x + threadIdx.x;
const int numel = output_channels * c_filter * c_filter;
if (idx >= numel) {
return;
}
const int tmp = idx / output_channels;
const int kernel_id = idx - tmp * output_channels;
const int kernel_ih = tmp / c_filter;
const int kernel_iw = tmp - kernel_ih * c_filter;

const int h_offset = kernel_ih * dilate_height - padding_height;
const int w_offset = kernel_iw * dilate_width - padding_width;

T s(0);
for (int og_h = 0; og_h < output_height; ++og_h) {
for (int og_w = 0; og_w < output_width; ++og_w) {
int image_hk = og_h * stride_height + h_offset;
int image_wk = og_w * stride_width + w_offset;
if (image_hk >= 0 && image_hk < input_height && image_wk >= 0 &&
image_wk < input_width) {
int input_id =
((bid * input_height + image_hk) * input_width + image_wk) *
input_channels +
kernel_id / filter_multiplier;
int output_id = ((bid * output_height + og_h) * output_width + og_w) *
output_channels +
kernel_id;
if (fuse_relu_before_conv) {
s += output_grad_data[output_id] *
static_cast<T>(
max(0.0f, static_cast<double>(input_data[input_id])));
} else {
s += output_grad_data[output_id] * input_data[input_id];
}
}
}
}
NoReturnAtomicAdd(filter_grad_data, idx, numel, s);
}

template <typename T,
int c_filter_multiplier,
int c_stride,
Expand Down Expand Up @@ -1209,28 +1301,35 @@ __global__ void KernelDepthwiseConvFilterGradSp(const T* output_grad_data,
dilate_width,
filter_grad_data);
} else {
KernelDepthwiseConvFilterGradCFilterNHWC<T,
c_filter,
fuse_relu_before_conv>(
output_grad_data,
input_data,
num,
output_channels,
output_height,
output_width,
input_channels,
input_height,
input_width,
final_filter_multiplier,
filter_height,
filter_width,
h_stride,
w_stride,
padding_height,
padding_width,
dilate_height,
dilate_width,
filter_grad_data);
auto kernel =
KernelDepthwiseConvFilterGradCFilterNHWC<T,
c_filter,
fuse_relu_before_conv>;
if (output_channels < SMALL_THRESHOLD) {
kernel = KernelDepthwiseConvFilterGradCFilterSmallChannelNHWC<
T,
c_filter,
fuse_relu_before_conv>;
}
kernel(output_grad_data,
input_data,
num,
output_channels,
output_height,
output_width,
input_channels,
input_height,
input_width,
final_filter_multiplier,
filter_height,
filter_width,
h_stride,
w_stride,
padding_height,
padding_width,
dilate_height,
dilate_width,
filter_grad_data);
}
}
}
Expand Down Expand Up @@ -1637,13 +1736,21 @@ class DepthwiseConvFilterGradFunctor<phi::GPUContext,
std::min(block_size, batch_size * output_height * output_width));
}
} else {
// Large block size may cause atomic dependence, reduce block size here.
block_size = 256;
blocks = std::min(
std::max(block_size / output_channels, 1),
((output_width + dilate_width - 1) / dilate_width) * dilate_width);
grid = dim3((output_height + dilate_height - 1) / dilate_height,
dilate_height,
batch_size);
threads = dim3(std::min(output_channels, block_size), blocks, 1);

if (output_channels < SMALL_THRESHOLD) {
const int hwc_size = ksize_height * ksize_width * output_channels;
grid = dim3((hwc_size + block_size - 1) / block_size, batch_size, 1);
threads = dim3(std::min(hwc_size, block_size));
}
}
int filter_multiplier = output_channels / input_channels;

Expand Down
Loading