Skip to content

Commit

Permalink
Update attn_gemm.h
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored Sep 21, 2023
1 parent 8abb807 commit 841e553
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions paddle/fluid/operators/fused/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,10 +255,10 @@ class AttnMatMul {
gpuStream_t stream = dev_ctx_.stream();
if (support_case_1 || support_case_2) {
phi::SumKernel<T>(
dev_ctx, *d_output, {0, 1}, d_output->dtype(), false, d_bias);
dev_ctx_, *d_output, {0, 1}, d_output->dtype(), false, d_bias);
} else if (support_case_3 || support_case_4) {
phi::SumKernel<T>(
dev_ctx, *d_output, {0, 2}, d_output->dtype(), false, d_bias);
dev_ctx_, *d_output, {0, 2}, d_output->dtype(), false, d_bias);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
Expand Down

0 comments on commit 841e553

Please sign in to comment.