From a37b2b059384f43027931df1fea67229c01a2480 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=91=A8=E5=91=A8=E5=91=A8?= <39978853+zhoutianzi666@users.noreply.github.com> Date: Tue, 31 Oct 2023 09:56:47 +0800 Subject: [PATCH] [Paddle Inference] Support GQA Decoder (#58472) Support GQA Decoder in masked_multihead_attention.cu --- paddle/phi/infermeta/multiary.cc | 14 +++- .../fusion/gpu/masked_multihead_attention.cu | 66 ++++++++++++++----- 2 files changed, 64 insertions(+), 16 deletions(-) diff --git a/paddle/phi/infermeta/multiary.cc b/paddle/phi/infermeta/multiary.cc index cece7dd8807933..64cf9b010ae07e 100644 --- a/paddle/phi/infermeta/multiary.cc +++ b/paddle/phi/infermeta/multiary.cc @@ -4283,9 +4283,21 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x, MetaTensor* beam_cache_offset_out) { int bsz = static_cast(x.dims()[0]); auto cache_kv_dims = cache_kv.dims(); - int num_head = static_cast(cache_kv.dims()[2]); + int k_num_head = static_cast(cache_kv.dims()[2]); + int v_num_head = k_num_head; int dim_head = static_cast(cache_kv.dims()[4]); + // below's num_head is q's head actually. + int num_head = + x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head; + PADDLE_ENFORCE_EQ( + num_head % k_num_head, + 0, + errors::InvalidArgument( + "The num_head of query must be divisible by the num_head of key, but " + "recived num_head of query is %d, and the num_head of key is %d", + num_head, + k_num_head)); PADDLE_ENFORCE_EQ( cache_kv_dims.size(), 5, diff --git a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu index 47ceb7ba1fdbce..0d65c4436b23dc 100644 --- a/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu +++ b/paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu @@ -92,6 +92,9 @@ struct Masked_multihead_attention_params { int beam_width; int cache_batch_size; int num_head; + // k_num_head and v_num_head must be equal, we unify them. + // kv_num_head = k_num_head && kv_num_head == v_num_head + int kv_num_head; int timestep; // cache_seq_length int seq_len; int max_seq_length; @@ -403,6 +406,14 @@ __global__ void masked_multihead_attention_kernel( const int bbi = bi / params.beam_width; const int hi = blockIdx.x; const int bhi = bi * params.num_head + hi; + + const int kv_num_head = params.kv_num_head; + const int num_head_per_group = params.num_head / kv_num_head; + // hi means the head index in query processed by this cuda thread. + // kv_bhi means the merged batch and head index in key and value processed by + // this cuda thread. + const int kv_bhi = bi * kv_num_head + hi / num_head_per_group; + const int bbhi = bbi * params.beam_width * params.num_head + hi; const int ti = params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1; @@ -418,8 +429,9 @@ __global__ void masked_multihead_attention_kernel( ? params.timestep : params.sequence_lengths[bi]; - // qkv [B, S=1, 3, num_head, head_dim] - int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh; + // qkv [B, S=1, num_head + 2 * kv_num_head, head_dim] + // this hi means the head index in query! + int qkv_base_offset = bi * (params.num_head + 2 * kv_num_head) * Dh + hi * Dh; constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); @@ -444,7 +456,8 @@ __global__ void masked_multihead_attention_kernel( if (tid < QK_VECS_PER_WARP) { int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE; - int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + int q_bias_offset = hi * Dh + tid * QK_VEC_SIZE; + int k_bias_offset = hi / num_head_per_group * Dh + tid * QK_VEC_SIZE; Qk_vec q; zero(q); @@ -461,7 +474,10 @@ __global__ void masked_multihead_attention_kernel( // ? *reinterpret_cast(&k_base[qk_offset]) // : k; if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) { - load_func.template load(k, params.num_head * Dh + qk_offset); + load_func.template load(k, + params.num_head * Dh + qk_offset - + hi * Dh + + hi / num_head_per_group * Dh); } if (params.add_qkv_bias) { @@ -472,11 +488,11 @@ __global__ void masked_multihead_attention_kernel( q_bias = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&q_bias_base[qk_bias_offset]) + ? *reinterpret_cast(&q_bias_base[q_bias_offset]) : q_bias; k_bias = (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) - ? *reinterpret_cast(&k_bias_base[qk_bias_offset]) + ? *reinterpret_cast(&k_bias_base[k_bias_offset]) : k_bias; q = add(q, q_bias); @@ -582,7 +598,7 @@ __global__ void masked_multihead_attention_kernel( int co = tid / QK_VECS_IN_16B; int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE; - int offset = bhi * params.max_seq_length * Dh + + int offset = kv_bhi * params.max_seq_length * Dh + co * params.max_seq_length * QK_ELTS_IN_16B + act_time_step * QK_ELTS_IN_16B + ci; if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { @@ -640,7 +656,7 @@ __global__ void masked_multihead_attention_kernel( constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; - T *k_cache = ¶ms.cache_kv[bhi * params.max_seq_length * Dh + ki]; + T *k_cache = ¶ms.cache_kv[kv_bhi * params.max_seq_length * Dh + ki]; T *k_cache_batch = ¶ms.cache_kv[bbhi * params.max_seq_length * Dh + ki]; int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP; @@ -737,12 +753,20 @@ __global__ void masked_multihead_attention_kernel( constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; using V_vec = typename V_vec_::Type; + // now we have got [1, seq] ,distributed in logits_smem. + // next we compute [1, seq] * [seq, head_dim] = [1, head_dim] + // THREADS_PER_VALUE means num of threads per value's head_dim. + // we split the seq dimension for more cuda threads to compute. + // vo means the first seq index processed by this cuda thread in the value. + // vi means the head_dim index processed by this cuda thread in the value. + // so this cuda thread compute [1, k] * [k, vi:vi+V_VEC_SIZE] and k starts + // from vo and increases by a step V_PER_ITER. int vo = tid / THREADS_PER_VALUE; int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE; - T *v_cache = ¶ms.cache_kv[params.cache_batch_size * params.num_head * + T *v_cache = ¶ms.cache_kv[params.cache_batch_size * kv_num_head * params.max_seq_length * Dh + - bhi * params.max_seq_length * Dh + vi]; + kv_bhi * params.max_seq_length * Dh + vi]; T *v_cache_batch = ¶ms.cache_kv[params.batch_size * params.num_head * params.max_seq_length * Dh + bbhi * params.max_seq_length * Dh + vi]; @@ -755,7 +779,7 @@ __global__ void masked_multihead_attention_kernel( V_vec_acum out; zero(out); - + // V_PER_ITER is used to strip-mined the seq dimension. constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; if (Dh == Dh_MAX || vi < Dh) { for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) { @@ -783,15 +807,19 @@ __global__ void masked_multihead_attention_kernel( V_vec v_bias; zero(v_bias); + // now we process the last v. if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) { // V_vec v = *reinterpret_cast( // ¶ms.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]); V_vec v; - load_func.template load( - v, 2 * params.num_head * Dh + qkv_base_offset + vi); + load_func.template load(v, + qkv_base_offset + vi - hi * Dh + + params.num_head * Dh + kv_num_head * Dh + + hi / num_head_per_group * Dh); if (params.add_qkv_bias) { v_bias = *reinterpret_cast( - ¶ms.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]); + ¶ms + .qkv_bias[(kv_num_head + params.num_head) * Dh + hi * Dh + vi]); v = add(v, v_bias); } @@ -806,6 +834,7 @@ __global__ void masked_multihead_attention_kernel( __syncthreads(); + // now we do the reduction in the seq dimension to get [1, head_dim]. if (Dh == Dh_MAX || vi < Dh) { #pragma unroll for (int active_groups = V_PER_ITER; active_groups >= 2; @@ -830,6 +859,7 @@ __global__ void masked_multihead_attention_kernel( } } + // write the [1, head_dim] result back to global memory. if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { #ifdef MMHA_USE_FP32_ACUM_FOR_OUT // convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + @@ -1319,12 +1349,17 @@ void DispatchWithDtype(const Context &dev_ctx, const auto &x_dims = x.dims(); int bsz = x_dims[0]; int cache_bsz = cache_kv.dims()[1]; - int num_head = cache_kv.dims()[2]; int max_seq_len = cache_kv.dims()[3]; int dim_head = cache_kv.dims()[4]; int timestep = max_seq_len; float inv_sqrt_dh = 1. / sqrt(dim_head); + int k_num_head = cache_kv.dims()[2]; + int v_num_head = k_num_head; + // this num_head means query's head + int num_head = + x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head; + Masked_multihead_attention_params params; bool mask_broadcast_num_heads = true; @@ -1385,6 +1420,7 @@ void DispatchWithDtype(const Context &dev_ctx, params.batch_size = bsz; params.cache_batch_size = cache_bsz; params.num_head = num_head; + params.kv_num_head = k_num_head; params.timestep = timestep; params.seq_len = seq_len; params.max_seq_length = max_seq_len;