diff --git a/paddle/fluid/operators/elementwise/elementwise_op_function.h b/paddle/fluid/operators/elementwise/elementwise_op_function.h index c69acb89750c93..4894dff4b971ca 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_function.h +++ b/paddle/fluid/operators/elementwise/elementwise_op_function.h @@ -41,9 +41,9 @@ limitations under the License. */ #include #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/gpu/elementwise_grad.h" #endif diff --git a/paddle/fluid/operators/fused/attn_bias_add.cu.h b/paddle/fluid/operators/fused/attn_bias_add.cu.h index 53001b24930847..8ea1e11cd29f41 100644 --- a/paddle/fluid/operators/fused/attn_bias_add.cu.h +++ b/paddle/fluid/operators/fused/attn_bias_add.cu.h @@ -33,8 +33,8 @@ namespace cub = hipcub; #include "paddle/fluid/operators/elementwise/elementwise_functor.h" #include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h" #include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h" -#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h" #include "paddle/phi/kernels/funcs/fast_divmod.h" +#include "paddle/phi/kernels/funcs/reduce_function.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h b/paddle/fluid/operators/reduce_ops/reduce_op.cu.h deleted file mode 100644 index 21646d08db3962..00000000000000 --- a/paddle/fluid/operators/reduce_ops/reduce_op.cu.h +++ /dev/null @@ -1,53 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#pragma once - -#include -#include -#include -#include -#include - -#include "paddle/fluid/framework/tensor.h" -#include "paddle/phi/core/dense_tensor.h" -#include "paddle/phi/kernels/funcs/reduce_function.h" -namespace paddle { -namespace operators { - -template - class ReduceBaseOp, - typename TransformOp> -void TensorReduceImpl(const phi::GPUContext& dev_ctx, - const phi::DenseTensor& x, - phi::DenseTensor* y, - const TransformOp& transform, - const std::vector& origin_reduce_dims, - gpuStream_t stream, - bool is_mean = false) { - y->mutable_data(x.place()); - - phi::funcs::ReduceKernel( - static_cast(dev_ctx), - x, - y, - transform, - origin_reduce_dims, - is_mean); -} - -} // namespace operators -} // namespace paddle diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 16e46fc201ea8c..a755cc83bba658 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -925,6 +925,7 @@ static void LaunchReduceKernel(const Tx* x_data, } #if !defined(PADDLE_WITH_XPU_KP) + template @@ -983,7 +984,6 @@ CubTensorReduceImpl(const Tx* x_data, PADDLE_THROW(phi::errors::InvalidArgument( "Tx should not be float16 when using cub::DeviceReduce::Reduce().")); } - template @@ -1002,17 +1002,53 @@ CubTensorReduceImpl(const Tx* x_data, } #endif // PADDLE_WITH_XPU_KP +template + class ReduceOp, + typename TransformOp, + bool IsMean = false> +struct CubTensorReduce { + static void apply(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const KPDevice& dev_ctx, + KPStream stream) { + CubTensorReduceImpl( + x_data, y_data, transform, reduce_num, dev_ctx, stream); + } +}; + template class ReduceOp, typename TransformOp> +struct CubTensorReduce { + static void apply(const Tx* x_data, + Ty* y_data, + const TransformOp& transform, + int reduce_num, + const KPDevice& dev_ctx, + KPStream stream) { + using Div = kps::DivideFunctor; + CubTensorReduceImpl( + x_data, y_data, Div(reduce_num), reduce_num, dev_ctx, stream); + } +}; + +template + class ReduceOp, + typename TransformOp, + bool IsMean = false> void ReduceKernel(const KPDevice& dev_ctx, const phi::DenseTensor& x, phi::DenseTensor* y, const TransformOp& transform, - const std::vector& origin_reduce_dims, - bool is_mean = false) { + const std::vector& origin_reduce_dims) { PADDLE_ENFORCE_GT( x.numel(), 0, @@ -1061,18 +1097,8 @@ void ReduceKernel(const KPDevice& dev_ctx, bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16; #ifndef PADDLE_WITH_XPU_KP if (use_cub_reduce) { - if (is_mean) { - using Div = kps::DivideFunctor; - CubTensorReduceImpl(x_data, - y_data, - Div(config.reduce_num), - config.reduce_num, - dev_ctx, - stream); - } else { - CubTensorReduceImpl( - x_data, y_data, transform, config.reduce_num, dev_ctx, stream); - } + CubTensorReduce::apply( + x_data, y_data, transform, config.reduce_num, dev_ctx, stream); return; } #endif @@ -1115,7 +1141,7 @@ void ReduceKernel(const KPDevice& dev_ctx, config.blocking_size, dim, config.reduce_num, - is_mean && (!config.should_reduce_again), + IsMean && (!config.should_reduce_again), config.tmp_data, config.should_reduce_again); @@ -1149,7 +1175,7 @@ void ReduceKernel(const KPDevice& dev_ctx, config.grid.y, dim2, config.reduce_num, - is_mean, + IsMean, config.tmp_data, false); } @@ -1167,29 +1193,28 @@ void ReduceKernel(const KPDevice& dev_ctx, reducer.initial(), stream, config, - is_mean); + IsMean); } template class ReduceOp, - typename TransformOp> + typename TransformOp, + bool IsMean = false> void TensorReduceImpl(const phi::GPUContext& dev_ctx, const phi::DenseTensor& x, phi::DenseTensor* y, const TransformOp& transform, const std::vector& origin_reduce_dims, - gpuStream_t stream, - bool is_mean = false) { + gpuStream_t stream) { dev_ctx.template Alloc(y); - ReduceKernel( + ReduceKernel( static_cast(dev_ctx), x, y, transform, - origin_reduce_dims, - is_mean); + origin_reduce_dims); } #endif diff --git a/paddle/phi/kernels/fusion/gpu/attn_gemm.h b/paddle/phi/kernels/fusion/gpu/attn_gemm.h index 8b83ddab93b9b1..27d972bc1d7407 100644 --- a/paddle/phi/kernels/fusion/gpu/attn_gemm.h +++ b/paddle/phi/kernels/fusion/gpu/attn_gemm.h @@ -28,6 +28,7 @@ #include "paddle/phi/kernels/funcs/reduce_function.h" #include "paddle/phi/kernels/primitive/functor_primitives.h" #include "paddle/phi/kernels/primitive/kernel_primitives.h" +#include "paddle/phi/kernels/reduce_sum_kernel.h" namespace phi { namespace fusion { @@ -259,23 +260,12 @@ class AttnMatMul { gpuStream_t stream = dev_ctx_.stream(); if (support_case_1 || support_case_2) { - phi::funcs:: - TensorReduceImpl>( - dev_ctx_, - *d_output, - d_bias, - kps::IdentityFunctor(), - {0, 1}, - stream); + phi::SumKernel( + dev_ctx_, *d_output, {0, 1}, d_output->dtype(), false, d_bias); + } else if (support_case_3 || support_case_4) { - phi::funcs:: - TensorReduceImpl>( - dev_ctx_, - *d_output, - d_bias, - kps::IdentityFunctor(), - {0, 1, 2}, - stream); + phi::SumKernel( + dev_ctx_, *d_output, {0, 1, 2}, d_output->dtype(), false, d_bias); } else { PADDLE_THROW(phi::errors::InvalidArgument( "Only support reduce when the input dims are [0,1,2,3,4] and " diff --git a/paddle/phi/kernels/gpu/mean_all_kernel.cu b/paddle/phi/kernels/gpu/mean_all_kernel.cu index 4f85a89047aed1..82405d964737c7 100644 --- a/paddle/phi/kernels/gpu/mean_all_kernel.cu +++ b/paddle/phi/kernels/gpu/mean_all_kernel.cu @@ -43,13 +43,12 @@ void MeanAllKernel(const Context& dev_ctx, for (decltype(rank) i = 0; i < rank; ++i) { reduce_dims.push_back(i); } - funcs::ReduceKernel>( - dev_ctx, - x, - out, - kps::IdentityFunctor(), - reduce_dims, - /*is_mean=*/true); + funcs::ReduceKernel, + /*is_mean*/ true>( + dev_ctx, x, out, kps::IdentityFunctor(), reduce_dims); } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index cc3cad38f46fbd..79c7381edab192 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -27,15 +27,15 @@ template class ReduceOp, template - class TransformOp> + class TransformOp, + bool IsMean = false> void Reduce(const KPDevice& dev_ctx, const DenseTensor& x, bool reduce_all, const std::vector& dims, bool keep_dim, DataType out_dtype, - DenseTensor* out, - bool is_mean = false) { + DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); std::vector reduce_dims = phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all); @@ -59,33 +59,23 @@ void Reduce(const KPDevice& dev_ctx, phi::funcs::ReduceKernel>( + TransformOp, + IsMean>( dev_ctx, tmp_tensor, out, TransformOp(reduce_num), - reduce_dims, - is_mean); + reduce_dims); })); } else { using MPType = typename phi::dtype::MPTypeTrait::Type; - phi::funcs::ReduceKernel>( - dev_ctx, - x, - out, - TransformOp(reduce_num), - reduce_dims, - is_mean); + phi::funcs::ReduceKernel, IsMean>( + dev_ctx, x, out, TransformOp(reduce_num), reduce_dims); } #else using MPType = typename phi::dtype::MPTypeTrait::Type; - phi::funcs::ReduceKernel>( - dev_ctx, - x, - out, - TransformOp(reduce_num), - reduce_dims, - is_mean); + phi::funcs::ReduceKernel, IsMean>( + dev_ctx, x, out, TransformOp(reduce_num), reduce_dims); #endif } } // namespace phi diff --git a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h index 8964c2547886b8..b04267030b2846 100644 --- a/paddle/phi/kernels/gpu/reduce_amin_amax_common.h +++ b/paddle/phi/kernels/gpu/reduce_amin_amax_common.h @@ -90,7 +90,6 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx, equal_out_tensor.dtype(), false, equal_count); - // 3. dx = dout * 1 phi::MultiplyKernel( dev_ctx, new_dout, equal_out_tensor, &equal_out_tensor); diff --git a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu index 38f84bd5d0d9d5..7f8e985695818b 100644 --- a/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu +++ b/paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu @@ -30,7 +30,7 @@ void SquaredL2NormKernel(const Context& dev_ctx, origin_reduce_dims.push_back(i); } phi::funcs::ReduceKernel>( - dev_ctx, x, out, kps::SquareFunctor(), origin_reduce_dims, false); + dev_ctx, x, out, kps::SquareFunctor(), origin_reduce_dims); } } // namespace phi diff --git a/paddle/phi/kernels/kps/reduce_kernel.cu b/paddle/phi/kernels/kps/reduce_kernel.cu index d5d0fd9de3b282..3440b53c68b488 100644 --- a/paddle/phi/kernels/kps/reduce_kernel.cu +++ b/paddle/phi/kernels/kps/reduce_kernel.cu @@ -116,8 +116,8 @@ void MeanRawKernel(const Context& dev_ctx, DenseTensor* out) { reduce_all = recompute_reduce_all(x, dims, reduce_all); auto out_dtype = x.dtype(); - phi::Reduce( - dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true); + phi::Reduce( + dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out); } template