-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Integrate cudnn flash attention and add IR pass to fuse dot product attention #58680
Integrate cudnn flash attention and add IR pass to fuse dot product attention #58680
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
2fb1b23
to
e5ef5d3
Compare
1b3743b
to
2595619
Compare
We are now developing a new ir, called PIR. Would you mind transfer the fusion pass to PIR? You can refer to fused_gemm_epilogue_pass. |
Sorry to inform you that 2595619's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually. |
@zhiqiu 请教一下,我参考其他的PR,在PIR下面添加了fusion pass,但单测却没有成功将subgraph替换成fused op,也没有任何错误信息。请问有没有什么debug技巧呢?老的IR pass可以打开GLOG_v来看subgraph匹配每一步的细节,但是新的IR我不知道如何debug。 |
新fusion pass的匹配逻辑在 |
blocked by Issue#59467 |
bf16deb
to
92c5fdf
Compare
@zyfncg 我已经在PIR下添加了这个pass和单测,但是似乎新的IR pass需要用PassManager run来apply,这样的话能放到build strategy里面来开启吗? |
@zyfncg 所有的CI都过了,剩下的 |
paddle/phi/infermeta/fusion.cc
Outdated
@@ -3245,4 +3245,48 @@ void SkipLayerNormInferMeta(const MetaTensor& x, | |||
out->set_dtype(x.dtype()); | |||
} | |||
|
|||
void FusedDotProductAttentionInferMeta(const MetaTensor& q, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
InferMeta 函数按照字母序放置
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
92c5fdf
to
c500ccf
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
4c1bd02
to
02da812
Compare
|
from paddle.framework import LayerHelper, in_dynamic_mode | ||
|
||
|
||
def fused_dot_product_attention( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
from user's perspective, is this another implementation of paddle.nn.functional.scaled_dot_product_attention ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是的,从用户角度看,两者的接口和功能基本差不多。区别是这个底层用的是cudnn版本的flash attention,在Ampere和Hopper GPU上的适配和性能会更好。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
是不是我们只需要暴露 paddle.nn.functional.scaled_dot_product_attention
, 这一个API给用户? 这个 API 的命名、参数、语义,我记得内部是经过讨论之后确定的。
如果技术实现上导致无法只有一个API的话, 那这个新的API,我建议其 signature 尽可能跟已有的API对齐, 以方便用户理解和使用。
any thoughts? @jeff41404 @sneaxiy @zhiqiu @liuzhenhai93
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
API有一些区别,主要是1.输入多了一个float类型的scaling factor 2. mask需要是bool或int。其他的参数都是一致的。
另外就是cudnn的flash attention仅支持Ampere和Hopper GPU,这可能和paddle的flash attention不一样。
所以看起来无法直接统一到nn.functional.scaled_dot_product_attention
的这个API中,有什么建议吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个PR中主要还是增加了这个c++ op和ir pass,增加nn.functional.fused_dot_product_attention
这个API并不是必要的,不加也不影响其他的改动。所以如果暴露这个API需要更多的讨论,我想可以另外开一个PR来讨论。
@jzhang533 是否应该在这个PR中先移除这个修改?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
我觉得不需要移除的,现在这个PR里的方案,这个 API 是放在 paddle.incubate 下的,以后正式化到paddle.nn.functional 下时,是可以再调整的。
我只是想在这里开始关于这个问题的讨论,等一下我会给一个approve,让这个PR可以通过CI。
我们可以继续讨论这个问题~
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jzhang533 的评论很好
API有一些区别,主要是1.输入多了一个float类型的scaling factor 2. mask需要是bool或int。其他的参数都是一致的。
如果只是这2点区别,我认为不足以独立成为1个API,可能解决方案比如在scaled_dot_product_attention
加1个scaling factor,且有默认值;加1个 use_cudnn = False 参数,如果设置为True,再对环境和mask进行检查,不符合要求告知用户
所以我的建议是:如果这个API还未确定参数和用法,可以先将代码整体移动到incubate目录中,文档中加上warning,说明以后可能会修改;如果确定了,最好都统一到scaled_dot_product_attention
中
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
同意@jeff41404 的建议,代码已经在incubate目录中,我加上了Warning。之后若需要统一到scaled_dot_product_attention
,我再提新的PR。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
“可能解决方案比如在scaled_dot_product_attention加1个scaling factor,且有默认值;”
这个可以考虑
“加1个 use_cudnn = False 参数,如果设置为True”
这个不太建议,一般API和OP不包含设备相关的参数。
91d70ae
to
1fa4ea6
Compare
1fa4ea6
to
bf802e0
Compare
2023-12-08 13:12:17 0. You must have raindrops2sea or XiaoguangHu01 approval for change 20+ files or add than 1000+ lines of content. |
Could you please take a look? @XiaoguangHu01 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
New features
PR changes
OPs
Description
CUDNN v8 provide fused (flash) attention kernel to accelerate TransformerLayer(refer to cudnn-frontend/fused_mha_sample). To be exact, it only fuse the calculation of scaled dot product attention, which does not include qkv projection and output projection, but only includes scale, bmm1, mask, softmax, dropout, and bmm2. So, let's call it fused_dot_product_attention instead of fused_multi_head_attention.
TransformerLayer perf test
GPU: A100
config: batch size=32, hidden size=1024, seq len=1024
Build strategy usage sample (CUDNN >= 8.9.6 is required):