-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Paddle Inference] Add bias input of mmha and simplify mmha. #56411
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
@@ -3987,6 +3987,7 @@ void WeightOnlyMatmulInferMeta(const MetaTensor& x, | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
量化的完整需求是
输入为int32/float16/32
输出为int8/float/16/32
我看这里考虑了输出是int8的情况 但是没考虑输入是int32的情况是么
#include "paddle/phi/core/kernel_registry.h" | ||
#include "paddle/phi/kernels/funcs/aligned_vector.h" | ||
#include "paddle/phi/kernels/fusion/gpu/mmha_util.cu.h" | ||
|
||
namespace phi { | ||
namespace fusion { | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为啥要把头文件干掉呢
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要添加多余的头文件,防止被其他调用
@@ -43,6 +45,7 @@ def masked_multihead_attention( | |||
Args: | |||
x (Tensor): The input tensor could be 2-D tensor. Its shape is [batch_size, 3 * num_head * head_dim]. | |||
cache_kvs (list(Tensor)|tuple(Tensor)): The cache structure tensors for the generation model. Its shape is [2, batch_size, num_head, max_seq_len, head_dim]. | |||
bias (Tensor, optional): The bias tensor. Its shape is [3, num_head, head_dim]. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也需要加一下compte_dtype的参数说明~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done~
@@ -77,7 +80,7 @@ def setUp(self): | |||
self.seq_len = 1 | |||
self.rotary_emb_dims = 0 | |||
self.use_neox_rotary_style = False | |||
|
|||
self.compute_dtype = "default" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测要不加一下输入为Int32的情形
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
单测有int32的情况
@@ -53,6 +56,7 @@ def masked_multihead_attention( | |||
seq_len (int, optional): The seq_len, used to get input length. Default 1. | |||
rotary_emb_dims (int, optional): The rotary_emb_dims. Default 1. | |||
use_neox_rotary_style (bool, optional): A flag indicating whether neox_rotary_style is needed or not. Default False. | |||
compute_dtype (string): A compute dtype, used to represent the input data type. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
compute_dtype的为啥不能根据输入tensor的类型判断呢?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ptq 情况下,输入 x 有可能是int32,如果根据 cache_kv dtype 判断,后续 cache_kv 量化支持还需要修改。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for API change
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM for new args
…addle#56411) * add_bias_and_simplify_mmha
PR types
Others
PR changes
Others
Description
Add bias input of mmha and simplify mmha.
关联pr #55344
Pcard-71502