Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[cherry-pick] fix the cumsum big shape and random result bug #43777

Merged
merged 1 commit into from
Jun 24, 2022
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 15 additions & 13 deletions paddle/phi/kernels/gpu/cumsum_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -152,10 +152,8 @@ __global__ void BlockScanKernel(T* d_out,
} temp_storage;

int bx = blockIdx.x;
int by = blockIdx.y;

BlockPrefixCallbackOp<T> prefix_op(0);
T block_aggregate = static_cast<T>(0);

// Obtain this block's segment of consecutive keys (blocked across threads)
int item_per_block = BLOCK_THREADS * ITEMS_PER_THREAD;
Expand All @@ -168,7 +166,7 @@ __global__ void BlockScanKernel(T* d_out,
valid_item = scan_size;
}

int offset = bx * scan_size + block_offset + by * (inner_size * scan_size);
int offset = block_offset + bx * scan_size;

T thread_keys[ITEMS_PER_THREAD];
BlockLoadT(temp_storage.load)
Expand Down Expand Up @@ -260,8 +258,10 @@ void CumsumKernel(const Context& dev_ctx,
dim3 blocks(32, 8);
dim3 transpose_grids((width + tile_size - 1) / tile_size,
(height + tile_size - 1) / tile_size);
out->Resize(out_dims);
auto* tmp_data = out->data<T>();

DenseTensor tmp_tensor;
tmp_tensor.Resize(out_dims);
auto* tmp_data = dev_ctx.template Alloc<T>(&tmp_tensor);

T* next_in_data = out_data;
T* next_out_data = tmp_data;
Expand All @@ -281,6 +281,8 @@ void CumsumKernel(const Context& dev_ctx,
// Consider the size of shared memory, here block size is 128
dim3 scan_grid(outer_size, inner_size);
dim3 reverse_grid = scan_grid;
int64_t grid_size = outer_size * inner_size;

if (reverse) {
if (transpose) {
reverse_grid.x = scan_grid.y;
Expand All @@ -295,17 +297,17 @@ void CumsumKernel(const Context& dev_ctx,
}
}
if (!transpose && !reverse) {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
BlockScanKernel<T, 128, 4><<<grid_size, 128, 0, dev_ctx.stream()>>>(
out_data, in_data, outer_size, inner_size, scan_size, exclusive);

} else {
BlockScanKernel<T, 128, 4><<<scan_grid, 128, 0, dev_ctx.stream()>>>(
next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
BlockScanKernel<T, 128, 4>
<<<grid_size, 128, 0, dev_ctx.stream()>>>(next_out_data,
next_in_data,
outer_size,
inner_size,
scan_size,
exclusive);
}
swap_ptr(next_in_data, next_out_data);
if (reverse) {
Expand Down