Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the cumsum bug for large size #43722

Merged
merged 1 commit into from
Jun 22, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 7 additions & 12 deletions paddle/phi/kernels/gpu/cum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage;

int bx = blockIdx.x;
int by = blockIdx.y;

BlockPrefixCallbackOp<T, Op> prefix_op(Identity<T, Op>::value, op);
T block_aggregate = static_cast<T>(0);

// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
Expand All @@ -192,7 +190,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size;
}

int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
int offset = block_offset + bx * scan_size;

T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
Expand Down Expand Up @@ -271,7 +269,6 @@ void ScanKernel(const Context& dev_ctx,
return;
}


size_t height = 1;
size_t width = 1;
for (size_t i = 0; i <= axis; i++) {
Expand Down Expand Up @@ -308,6 +305,7 @@ void ScanKernel(const Context& dev_ctx,
int outer_size = height / scan_size;
int inner_size = width;
// Consider the size of shared memory, here block size is 128

dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
if (reverse) {
Expand All @@ -323,13 +321,14 @@ void ScanKernel(const Context& dev_ctx,
in_data, out_data, scan_size, outer_size, inner_size);
}
}
int64_t grid_size = outer_size * inner_size;
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4, Op><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
BlockScanKernel<T, 128, 4, Op><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive, op);

} else {
BlockScanKernel<T, 128, 4, Op>
<<<scan_grid, 128, 0, dev_ctx.stream()>>>(next_out_data,
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data,
outer_size,
inner_size,
Expand Down Expand Up @@ -391,9 +390,5 @@ PD_REGISTER_KERNEL(cumsum,
int,
int64_t) {}

PD_REGISTER_KERNEL(logcumsumexp,
GPU,
ALL_LAYOUT,
phi::LogcumsumexpKernel,
float,
double) {}
PD_REGISTER_KERNEL(
logcumsumexp, GPU, ALL_LAYOUT, phi::LogcumsumexpKernel, float, double) {}