Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Paddle-Inference]support GQA in variable_length_memory_efficient_attention #58836

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2930,11 +2930,19 @@ void VariableLengthMemoryEfficientAttentionInferMeta(
phi::errors::InvalidArgument(
"The batch size of Query, Key, Value should be equal."));

PADDLE_ENFORCE_EQ((key_num_head == value_num_head),
true,
phi::errors::InvalidArgument(
"The head number of Key, Value should be equal."));

PADDLE_ENFORCE_EQ(
((query_num_head == key_num_head) && (key_num_head == value_num_head)),
true,
phi::errors::InvalidArgument(
"The head number of Query, Key, Value should be equal."));
query_num_head % key_num_head,
0,
errors::InvalidArgument(
"The num_head of query must be divisible by the num_head of key, but "
"recived num_head of query is %d, and the num_head of key is %d",
query_num_head,
key_num_head));

PADDLE_ENFORCE_EQ(query_head_size == key_head_size,
true,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ struct FMHAGrouped {
int problem_count;
int threadblock_count;
int num_heads;
int kv_num_heads;

ElementQ *ptr_Q;
ElementK *ptr_K;
Expand Down Expand Up @@ -205,6 +206,7 @@ struct FMHAGrouped {
: problem_count(0),
threadblock_count(0),
num_heads(0),
kv_num_heads(0),
ptr_Q(nullptr),
ptr_K(nullptr),
ptr_P(nullptr),
Expand Down Expand Up @@ -234,6 +236,7 @@ struct FMHAGrouped {
int problem_count,
int threadblock_count,
int num_heads,
int kv_num_heads,
ElementQ *ptr_Q,
ElementK *ptr_K,
ElementM *ptr_M,
Expand All @@ -259,6 +262,7 @@ struct FMHAGrouped {
problem_count(problem_count),
threadblock_count(threadblock_count),
num_heads(num_heads),
kv_num_heads(kv_num_heads),
ptr_Q(ptr_Q),
ptr_K(ptr_K),
ptr_M(ptr_M),
Expand Down Expand Up @@ -307,6 +311,7 @@ struct FMHAGrouped {
typename ProblemVisitor::Params problem_visitor;
int threadblock_count;
int num_heads;
int kv_num_heads;

ElementQ *ptr_Q;
ElementK *ptr_K;
Expand Down Expand Up @@ -369,6 +374,7 @@ struct FMHAGrouped {
tile_count),
threadblock_count(args.threadblock_count),
num_heads(args.num_heads),
kv_num_heads(args.kv_num_heads),
ptr_Q(args.ptr_Q),
ptr_K(args.ptr_K),
ptr_P(args.ptr_P),
Expand Down Expand Up @@ -403,6 +409,7 @@ struct FMHAGrouped {
tile_count);
threadblock_count = args.threadblock_count;
num_heads = args.num_heads;
kv_num_heads = args.kv_num_heads;
ptr_Q = args.ptr_Q;
ptr_K = args.ptr_K;
ptr_P = args.ptr_P;
Expand Down Expand Up @@ -580,6 +587,8 @@ struct FMHAGrouped {

const int32_t problem_idx = problem_visitor.problem_index();
const int32_t batch_idx = problem_idx / params.num_heads;
// how many query head share a kv head?
const int32_t qhead_per_kv_head = params.num_heads / params.kv_num_heads;

if (thread_id() < kQueriesPerBlock) {
s_prime[thread_id()] = ElementAccumulator(0);
Expand Down Expand Up @@ -639,7 +648,8 @@ struct FMHAGrouped {
auto prologueV = [&](int blockN) {
typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)},
params.ptr_V + problem_idx * params.kElementV +
params.ptr_V +
(problem_idx / qhead_per_kv_head) * params.kElementV +
iter_key_start * params.ldv,
{problem_size_1_k, problem_size_1_n},
thread_id(),
Expand Down Expand Up @@ -679,7 +689,8 @@ struct FMHAGrouped {
typename MM0::IteratorB iterator_B(
typename MM0::IteratorB::Params(
typename MM0::MmaCore::LayoutB(params.ldk)),
params.ptr_K + problem_idx * params.kElementK +
params.ptr_K +
(problem_idx / qhead_per_kv_head) * params.kElementK +
iter_key_start * params.ldk,
{problem_size_0_k, problem_size_0_n},
thread_id(),
Expand Down Expand Up @@ -834,7 +845,8 @@ struct FMHAGrouped {

typename MM1::Mma::IteratorB iterator_V(
typename MM1::IteratorB::Params{MM1::LayoutB(params.ldv)},
params.ptr_V + problem_idx * params.kElementV +
params.ptr_V +
(problem_idx / qhead_per_kv_head) * params.kElementV +
iter_key_start * params.ldv,
{problem_size_1_k, problem_size_1_n},
thread_id(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ def parse_args():
problem_count,
threadblock_count,
params.num_heads,
params.kv_num_heads,
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.query_ptr)),
const_cast<scalar_t*>(reinterpret_cast<const scalar_t*>(params.key_ptr)),
params.mask_ptr
Expand Down Expand Up @@ -465,6 +466,7 @@ def write_main_header():
// Dimensions/strides
int32_t num_batches;
int32_t num_heads;
int32_t kv_num_heads;
int32_t query_seq_len;
int32_t key_value_seq_len;
int32_t head_size;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ void MultiHeadAttentionVariableForwardKernel(

params.num_batches = query.dims()[0];
params.num_heads = query.dims()[1];
params.kv_num_heads = key.dims()[1];
params.query_seq_len = query.dims()[2];
params.head_size = query.dims()[3];
params.key_value_seq_len = key.dims()[2];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,20 @@ def create_attn_mask(


def naive_attention_impl(query, key, value, mask, scale):
batch = query.shape[0]
heads = query.shape[1]
seq_len = query.shape[2]
head_dim = query.shape[3]
kv_head = key.shape[1]

key = key.reshape([batch, kv_head, 1, seq_len, head_dim])
key = paddle.tile(key, [1, 1, heads // kv_head, 1, 1])
key = key.reshape([batch, heads, seq_len, head_dim])

value = value.reshape([batch, kv_head, 1, seq_len, head_dim])
value = paddle.tile(value, [1, 1, heads // kv_head, 1, 1])
value = value.reshape([batch, heads, seq_len, head_dim])

qk_res = paddle.matmul(query, key, transpose_y=True)
attention = qk_res * scale
attention = attention + mask
Expand All @@ -78,6 +92,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 1
self.num_head = 8
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 16
self.seq_lens = paddle.to_tensor(
Expand All @@ -93,6 +108,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float32'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -111,11 +132,11 @@ def test_all(self):
q = paddle.to_tensor(
query, place=self.place, dtype=self.dtype, stop_gradient=False
)
key = np.random.random(self.shape)
key = np.random.random(self.shape_kv)
k = paddle.to_tensor(
key, place=self.place, dtype=self.dtype, stop_gradient=False
)
value = np.random.random(self.shape)
value = np.random.random(self.shape_kv)
v = paddle.to_tensor(
value, place=self.place, dtype=self.dtype, stop_gradient=False
)
Expand Down Expand Up @@ -147,6 +168,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
Expand All @@ -162,6 +184,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -180,6 +208,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 2
self.num_head = 8
self.kv_num_head = 2
self.seq_len = 32
self.dim_head = 128
self.seq_lens = paddle.to_tensor(
Expand All @@ -195,6 +224,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'bfloat16'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -217,6 +252,7 @@ def setUp(self):
self.place = paddle.CUDAPlace(0)
self.batch_size = 3
self.num_head = 16
self.kv_num_head = 2
self.seq_len = 64
self.dim_head = 32
self.seq_lens = paddle.to_tensor(
Expand All @@ -232,6 +268,12 @@ def setUp(self):
self.seq_len,
self.dim_head,
)
self.shape_kv = (
self.batch_size,
self.kv_num_head,
self.seq_len,
self.dim_head,
)
self.dtype = 'float16'
self.attention_mask = create_attn_mask(
self.dtype,
Expand All @@ -242,8 +284,8 @@ def setUp(self):
* self.batch_size,
).numpy()
self.q = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape).astype(self.dtype)
self.v = np.random.random(self.shape).astype(self.dtype)
self.k = np.random.random(self.shape_kv).astype(self.dtype)
self.v = np.random.random(self.shape_kv).astype(self.dtype)
self.scale = 1.0 / np.sqrt(self.shape[-1])

self.ref_out = naive_attention_impl(
Expand All @@ -261,10 +303,10 @@ def test_all(self):
name="query", shape=self.shape, dtype=self.dtype
)
k = paddle.static.data(
name="key", shape=self.shape, dtype=self.dtype
name="key", shape=self.shape_kv, dtype=self.dtype
)
v = paddle.static.data(
name="value", shape=self.shape, dtype=self.dtype
name="value", shape=self.shape_kv, dtype=self.dtype
)
mask = paddle.static.data(
name="mask",
Expand Down