Skip to content

Commit

Permalink
[Paddle Inference] Support GQA Decoder (PaddlePaddle#58472)
Browse files Browse the repository at this point in the history
Support GQA Decoder in masked_multihead_attention.cu
  • Loading branch information
zhoutianzi666 authored Oct 31, 2023
1 parent a2e6d53 commit a37b2b0
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 16 deletions.
14 changes: 13 additions & 1 deletion paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4283,9 +4283,21 @@ void MaskedMultiheadAttentionInferMeta(const MetaTensor& x,
MetaTensor* beam_cache_offset_out) {
int bsz = static_cast<int>(x.dims()[0]);
auto cache_kv_dims = cache_kv.dims();
int num_head = static_cast<int>(cache_kv.dims()[2]);
int k_num_head = static_cast<int>(cache_kv.dims()[2]);
int v_num_head = k_num_head;
int dim_head = static_cast<int>(cache_kv.dims()[4]);
// below's num_head is q's head actually.
int num_head =
x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head;

PADDLE_ENFORCE_EQ(
num_head % k_num_head,
0,
errors::InvalidArgument(
"The num_head of query must be divisible by the num_head of key, but "
"recived num_head of query is %d, and the num_head of key is %d",
num_head,
k_num_head));
PADDLE_ENFORCE_EQ(
cache_kv_dims.size(),
5,
Expand Down
66 changes: 51 additions & 15 deletions paddle/phi/kernels/fusion/gpu/masked_multihead_attention.cu
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,9 @@ struct Masked_multihead_attention_params {
int beam_width;
int cache_batch_size;
int num_head;
// k_num_head and v_num_head must be equal, we unify them.
// kv_num_head = k_num_head && kv_num_head == v_num_head
int kv_num_head;
int timestep; // cache_seq_length
int seq_len;
int max_seq_length;
Expand Down Expand Up @@ -403,6 +406,14 @@ __global__ void masked_multihead_attention_kernel(
const int bbi = bi / params.beam_width;
const int hi = blockIdx.x;
const int bhi = bi * params.num_head + hi;

const int kv_num_head = params.kv_num_head;
const int num_head_per_group = params.num_head / kv_num_head;
// hi means the head index in query processed by this cuda thread.
// kv_bhi means the merged batch and head index in key and value processed by
// this cuda thread.
const int kv_bhi = bi * kv_num_head + hi / num_head_per_group;

const int bbhi = bbi * params.beam_width * params.num_head + hi;
const int ti =
params.cum_offsets ? bi * params.seq_len - params.cum_offsets[bi] : -1;
Expand All @@ -418,8 +429,9 @@ __global__ void masked_multihead_attention_kernel(
? params.timestep
: params.sequence_lengths[bi];

// qkv [B, S=1, 3, num_head, head_dim]
int qkv_base_offset = bi * 3 * params.num_head * Dh + hi * Dh;
// qkv [B, S=1, num_head + 2 * kv_num_head, head_dim]
// this hi means the head index in query!
int qkv_base_offset = bi * (params.num_head + 2 * kv_num_head) * Dh + hi * Dh;

constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T);
static_assert(Dh_MAX % QK_VEC_SIZE == 0, "");
Expand All @@ -444,7 +456,8 @@ __global__ void masked_multihead_attention_kernel(

if (tid < QK_VECS_PER_WARP) {
int qk_offset = qkv_base_offset + tid * QK_VEC_SIZE;
int qk_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
int q_bias_offset = hi * Dh + tid * QK_VEC_SIZE;
int k_bias_offset = hi / num_head_per_group * Dh + tid * QK_VEC_SIZE;

Qk_vec q;
zero(q);
Expand All @@ -461,7 +474,10 @@ __global__ void masked_multihead_attention_kernel(
// ? *reinterpret_cast<const Qk_vec *>(&k_base[qk_offset])
// : k;
if (Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh) {
load_func.template load<Qk_vec>(k, params.num_head * Dh + qk_offset);
load_func.template load<Qk_vec>(k,
params.num_head * Dh + qk_offset -
hi * Dh +
hi / num_head_per_group * Dh);
}

if (params.add_qkv_bias) {
Expand All @@ -472,11 +488,11 @@ __global__ void masked_multihead_attention_kernel(

q_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[qk_bias_offset])
? *reinterpret_cast<const Qk_vec *>(&q_bias_base[q_bias_offset])
: q_bias;
k_bias =
(Dh == Dh_MAX || tid * QK_VEC_SIZE < Dh)
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[qk_bias_offset])
? *reinterpret_cast<const Qk_vec *>(&k_bias_base[k_bias_offset])
: k_bias;

q = add(q, q_bias);
Expand Down Expand Up @@ -582,7 +598,7 @@ __global__ void masked_multihead_attention_kernel(

int co = tid / QK_VECS_IN_16B;
int ci = (tid % QK_VECS_IN_16B) * QK_VEC_SIZE;
int offset = bhi * params.max_seq_length * Dh +
int offset = kv_bhi * params.max_seq_length * Dh +
co * params.max_seq_length * QK_ELTS_IN_16B +
act_time_step * QK_ELTS_IN_16B + ci;
if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) {
Expand Down Expand Up @@ -640,7 +656,7 @@ __global__ void masked_multihead_attention_kernel(
constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY;
constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY;

T *k_cache = &params.cache_kv[bhi * params.max_seq_length * Dh + ki];
T *k_cache = &params.cache_kv[kv_bhi * params.max_seq_length * Dh + ki];
T *k_cache_batch = &params.cache_kv[bbhi * params.max_seq_length * Dh + ki];
int ti_end = div_up(act_time_step, K_PER_WARP) * K_PER_WARP;

Expand Down Expand Up @@ -737,12 +753,20 @@ __global__ void masked_multihead_attention_kernel(
constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE;
using V_vec = typename V_vec_<T, V_VEC_SIZE>::Type;

// now we have got [1, seq] ,distributed in logits_smem.
// next we compute [1, seq] * [seq, head_dim] = [1, head_dim]
// THREADS_PER_VALUE means num of threads per value's head_dim.
// we split the seq dimension for more cuda threads to compute.
// vo means the first seq index processed by this cuda thread in the value.
// vi means the head_dim index processed by this cuda thread in the value.
// so this cuda thread compute [1, k] * [k, vi:vi+V_VEC_SIZE] and k starts
// from vo and increases by a step V_PER_ITER.
int vo = tid / THREADS_PER_VALUE;
int vi = (tid % THREADS_PER_VALUE) * V_VEC_SIZE;

T *v_cache = &params.cache_kv[params.cache_batch_size * params.num_head *
T *v_cache = &params.cache_kv[params.cache_batch_size * kv_num_head *
params.max_seq_length * Dh +
bhi * params.max_seq_length * Dh + vi];
kv_bhi * params.max_seq_length * Dh + vi];
T *v_cache_batch = &params.cache_kv[params.batch_size * params.num_head *
params.max_seq_length * Dh +
bbhi * params.max_seq_length * Dh + vi];
Expand All @@ -755,7 +779,7 @@ __global__ void masked_multihead_attention_kernel(

V_vec_acum out;
zero(out);

// V_PER_ITER is used to strip-mined the seq dimension.
constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE;
if (Dh == Dh_MAX || vi < Dh) {
for (int ti = vo; ti < act_time_step; ti += V_PER_ITER) {
Expand Down Expand Up @@ -783,15 +807,19 @@ __global__ void masked_multihead_attention_kernel(

V_vec v_bias;
zero(v_bias);
// now we process the last v.
if (vo == (act_time_step % V_PER_ITER) && (Dh == Dh_MAX || vi < Dh)) {
// V_vec v = *reinterpret_cast<const V_vec *>(
// &params.qkv[2 * params.num_head * Dh + qkv_base_offset + vi]);
V_vec v;
load_func.template load<V_vec>(
v, 2 * params.num_head * Dh + qkv_base_offset + vi);
load_func.template load<V_vec>(v,
qkv_base_offset + vi - hi * Dh +
params.num_head * Dh + kv_num_head * Dh +
hi / num_head_per_group * Dh);
if (params.add_qkv_bias) {
v_bias = *reinterpret_cast<const V_vec *>(
&params.qkv_bias[2 * params.num_head * Dh + hi * Dh + vi]);
&params
.qkv_bias[(kv_num_head + params.num_head) * Dh + hi * Dh + vi]);
v = add(v, v_bias);
}

Expand All @@ -806,6 +834,7 @@ __global__ void masked_multihead_attention_kernel(

__syncthreads();

// now we do the reduction in the seq dimension to get [1, head_dim].
if (Dh == Dh_MAX || vi < Dh) {
#pragma unroll
for (int active_groups = V_PER_ITER; active_groups >= 2;
Expand All @@ -830,6 +859,7 @@ __global__ void masked_multihead_attention_kernel(
}
}

// write the [1, head_dim] result back to global memory.
if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) {
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
// convert_from_float(*reinterpret_cast<V_vec *>(&params.out[bhi * Dh +
Expand Down Expand Up @@ -1319,12 +1349,17 @@ void DispatchWithDtype(const Context &dev_ctx,
const auto &x_dims = x.dims();
int bsz = x_dims[0];
int cache_bsz = cache_kv.dims()[1];
int num_head = cache_kv.dims()[2];
int max_seq_len = cache_kv.dims()[3];
int dim_head = cache_kv.dims()[4];
int timestep = max_seq_len;
float inv_sqrt_dh = 1. / sqrt(dim_head);

int k_num_head = cache_kv.dims()[2];
int v_num_head = k_num_head;
// this num_head means query's head
int num_head =
x.dims()[x.dims().size() - 1] / dim_head - k_num_head - v_num_head;

Masked_multihead_attention_params<T> params;
bool mask_broadcast_num_heads = true;

Expand Down Expand Up @@ -1385,6 +1420,7 @@ void DispatchWithDtype(const Context &dev_ctx,
params.batch_size = bsz;
params.cache_batch_size = cache_bsz;
params.num_head = num_head;
params.kv_num_head = k_num_head;
params.timestep = timestep;
params.seq_len = seq_len;
params.max_seq_length = max_seq_len;
Expand Down

0 comments on commit a37b2b0

Please sign in to comment.