Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate cudnn flash attention and add IR pass to fuse dot product attention #58680

Merged
merged 1 commit into from
Dec 12, 2023

Conversation

Wong4j
Copy link
Collaborator

@Wong4j Wong4j commented Nov 3, 2023

PR types

New features

PR changes

OPs

Description

CUDNN v8 provide fused (flash) attention kernel to accelerate TransformerLayer(refer to cudnn-frontend/fused_mha_sample). To be exact, it only fuse the calculation of scaled dot product attention, which does not include qkv projection and output projection, but only includes scale, bmm1, mask, softmax, dropout, and bmm2. So, let's call it fused_dot_product_attention instead of fused_multi_head_attention.

  1. Added the following OPs to compute scaled dot product attention (forward + backward):
    • fused_dot_product_attention/fused_dot_product_attention_grad
  2. Added a new IR pass to detect above scaled dot product attention pattern.
  3. Added a class member fuse_dot_product_attention to BuildStrategy to enable fuse_dot_product_attention_pass.

TransformerLayer perf test

GPU: A100
config: batch size=32, hidden size=1024, seq len=1024

original (time per step) fuse dot product attention (time per step) speedup
576 ms 406 ms 1.40x

Build strategy usage sample (CUDNN >= 8.9.6 is required):

class TransformerLayer(paddle.nn.Layer):
    def __init__(self, hidden, num_heads):
        super().__init__()
        self.encoder_layer = paddle.nn.TransformerEncoderLayer(hidden, num_heads, 4*hidden, dropout=0.1)
        self.encoder = paddle.nn.TransformerEncoder(self.encoder_layer, 2)

    def forward(self, q, mask):
        out = self.encoder(q, mask)
        loss = paddle.mean(out)
        return loss

paddle.enable_static()
main_prog = paddle.static.Program()
startup_prog = paddle.static.Program()

batch_size = 32
head_size = 64
num_head = 12
hidden = head_size * num_heads
seq_len = 128

with paddle.static.program_guard(main_prog, startup_prog):
    enc_input = paddle.static.data(name='enc_input', shape=[-1, -1, hidden], dtype="float32")
    attn_mask = paddle.static.data(name='mask', shape=[-1, 1, seq_len, seq_len], dtype="int32")
    model = TransformerLayer(hidden, num_heads)
    loss = model(enc_input, attn_mask)
    opt = paddle.optimizer.SGD(learning_rate=0.1)
    amp_list = paddle.static.amp.CustomOpLists(custom_white_list=['softmax'])
    opt = paddle.static.amp.decorate(
        optimizer=opt,
        amp_lists=amp_list,
        init_loss_scaling=128.0,
        use_dynamic_loss_scaling=True)
    opt.minimize(loss)


place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup_prog)

build_strategy = paddle.static.BuildStrategy()
build_strategy.fuse_dot_product_attention = True
main_prog = paddle.static.CompiledProgram(main_prog)
main_prog = main_prog.with_data_parallel(loss_name=loss.name,
                                         build_strategy=build_strategy,
                                         places=place)
exe.run(main_prog, feed=feed)
# 2 subgraphs be fused to fused_dot_product_attention op
# 2 subgraphs be fused to fused_dot_product_attention_grad op

Copy link

paddle-bot bot commented Nov 3, 2023

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@Wong4j Wong4j added the NVIDIA label Nov 3, 2023
@onecatcn onecatcn assigned heavengate and zyfncg and unassigned heavengate Nov 6, 2023
@Wong4j Wong4j closed this Nov 8, 2023
@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch from 2fb1b23 to e5ef5d3 Compare November 8, 2023 14:39
@Wong4j Wong4j reopened this Nov 8, 2023
@Wong4j
Copy link
Collaborator Author

Wong4j commented Nov 8, 2023

IR pass

IR pass部分的细节较多,这里做一些解释以便review

设计思路:

QKV输入:fused_dot_product_attention 不计算qkv projection和out projection,只计算中间的scale, bmm1, mask, softmax, dropout, bmm2等。输入是q, k, v的projection的结果(已经reshape成[b, s, h, d])和mask。

我们用IR pass来对subgraph进行遍历和替换,将scaled dot product attention的所有op和中间变量都匹配出来,将其替换成fused op,如下图,左边是原始的subgraph,紫色的部分会被替换成右图中的绿色的fused op。
image

Note:

  • 由于cuDNN flash attention并不能直接输入mask,而是输入actual_seqlen,所以我额外加了一个cuda kernel来将mask转换成actual_seqlen。
  • cuDNN flash attention 支持causal mask,并且我写的fused_dot_product_attention op也支持,只要设置is_causal_masking=True。但这个IR pass并不支持causal mask,因为这个IR pass是匹配的nn.MultiHeadAttention中的subgraph,而nn.MultiHeadAttention中没有causal相关参数。虽然理论上IR pass中可以根据mask的排布来确定是否是causal mask,但这并不是一个好的方法。所以目前的设计只能用来处理普通的padding mask。

实现细节:

fuse_dot_product_attention_pass.h:
fuse_dot_product_attention_pass.cc

  1. 3个OpCache
    1. MaskCache: 用来存mask节点),比如BERT large有24个layer,对同一个输入batch来说,计算24次mha的fwd/bwd都会使用同一组mask,所以这个mask需要存起来,作为替换subgraph之后的fused op的fwd/bwd的输入。
    2. OutputCache:cudnn mha中bwd的计算需要fwd的变量,包含SoftmaxStats,fwd output和rng state。所以fwd的时候要存储在cache中,然后bwd再取出来用。
    3. QKVCache:同理,fwd计算时将qkv存储在cache中,在对应的bwd计算中取出来,用来计算dq, dk, dv。
  2. nodes_to_remove
    如上所述,以BERT-large为例,需要将24个fwd/bwd中所有的subgraph都替换成fused op。由于某些node的指针存储在cache中,在bwd时要再次使用,所以不能每次替换一个subgraph后就立即删除掉被替换的节点,而是替换完所有的subgraph后再删除。

@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch 3 times, most recently from 1b3743b to 2595619 Compare November 13, 2023 13:03
@zhiqiu
Copy link
Contributor

zhiqiu commented Nov 14, 2023

We are now developing a new ir, called PIR. Would you mind transfer the fusion pass to PIR? You can refer to fused_gemm_epilogue_pass.

Copy link

paddle-ci-bot bot commented Nov 23, 2023

Sorry to inform you that 2595619's CIs have passed for more than 7 days. To prevent PR conflicts, you need to re-run all CIs manually.

@Wong4j
Copy link
Collaborator Author

Wong4j commented Nov 26, 2023

We are now developing a new ir, called PIR. Would you mind transfer the fusion pass to PIR? You can refer to fused_gemm_epilogue_pass.

@zhiqiu 请教一下,我参考其他的PR,在PIR下面添加了fusion pass,但单测却没有成功将subgraph替换成fused op,也没有任何错误信息。请问有没有什么debug技巧呢?老的IR pass可以打开GLOG_v来看subgraph匹配每一步的细节,但是新的IR我不知道如何debug。

@zyfncg
Copy link
Contributor

zyfncg commented Nov 27, 2023

We are now developing a new ir, called PIR. Would you mind transfer the fusion pass to PIR? You can refer to fused_gemm_epilogue_pass.

@zhiqiu 请教一下,我参考其他的PR,在PIR下面添加了fusion pass,但单测却没有成功将subgraph替换成fused op,也没有任何错误信息。请问有没有什么debug技巧呢?老的IR pass可以打开GLOG_v来看subgraph匹配每一步的细节,但是新的IR我不知道如何debug。

新fusion pass的匹配逻辑在paddle/fluid/pir/drr/drr_rewrite_pattern.cc中,使用GLOG_v可以打印出一些信息,如果不够可以手动添加GLOG信息用于定位

@Wong4j
Copy link
Collaborator Author

Wong4j commented Nov 30, 2023

blocked by Issue#59467

@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch 2 times, most recently from bf16deb to 92c5fdf Compare December 4, 2023 07:31
@Wong4j
Copy link
Collaborator Author

Wong4j commented Dec 4, 2023

We are now developing a new ir, called PIR. Would you mind transfer the fusion pass to PIR? You can refer to fused_gemm_epilogue_pass.

@zyfncg 我已经在PIR下添加了这个pass和单测,但是似乎新的IR pass需要用PassManager run来apply,这样的话能放到build strategy里面来开启吗?
因为我设计的其中一个功能就是想要仅添加一行代码build_strategy.fuse_dot_product_attention=True,来自动fuse模型中的attention。

@Wong4j
Copy link
Collaborator Author

Wong4j commented Dec 5, 2023

@zyfncg 这个PR目前包含新旧两套IR的实现,因为我们JoC BERT的代码目前还是用旧的静态图写法,依赖旧的IR。所以我觉得这个PR可以先审核完合入,我后续研究一下engine,更新我们的BERT代码,然后再提新的PR完全迁移到PIR。

@Wong4j
Copy link
Collaborator Author

Wong4j commented Dec 5, 2023

@zyfncg 所有的CI都过了,剩下的PR-CI-APPROVALPR-CI-Static-Check需要额外的approve,PR-CI-Coverage中的错误跟我的修改无关,是其他的test timeout了。

@onecatcn onecatcn requested a review from jeff41404 December 5, 2023 01:54
@@ -3245,4 +3245,48 @@ void SkipLayerNormInferMeta(const MetaTensor& x,
out->set_dtype(x.dtype());
}

void FusedDotProductAttentionInferMeta(const MetaTensor& q,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

InferMeta 函数按照字母序放置

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch from 92c5fdf to c500ccf Compare December 5, 2023 07:31
zyfncg
zyfncg previously approved these changes Dec 6, 2023
Copy link
Contributor

@zyfncg zyfncg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch from 4c1bd02 to 02da812 Compare December 6, 2023 08:35
@onecatcn onecatcn requested a review from zyfncg December 7, 2023 01:37
@Wong4j
Copy link
Collaborator Author

Wong4j commented Dec 7, 2023

  • @onecatcn PR-CI-APPROVAL和PR-CI-APPROVAL需要以下approve:
  1. one RD (XiaoguangHu01, jeff41404, lanxianghit or qingqing01) approval for API change.
    and one TPM approval for API change: jzhang533/ZhangJun, sunzhongkai588/SunZhongKai, Ligoml/LiMengLiu for general APIs.
  2. one TPM approval for API documents change: jzhang533/ZhangJun, sunzhongkai588/SunZhongKai, Ligoml/LiMengLiu for general API docs.
  3. print or std::cout is not recommended for direct use, please use logging or VLOG. If it is necessary to use, please contact tianshuo78520a (Recommend) or zhangbo9674 review and approve.
  4. must have raindrops2sea or XiaoguangHu01 approval for change 20+ files or add than 1000+ lines of content.
  5. one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93,Aurelius84) approval for the usage of const_cast.
  6. Unittest is not allowed to be disabled. You must have one RD (kolinwei(Recommend), wanghuancoder, luotao1, QingshuChen, qili93 or ZzSean or Aurelius84) approval for the usage of @unittest.skip or @unittest.skipIf.
  • PR-CI-Coverage:我新加的单测是在A100+cudnn 8.9.6的机器上跑的,在PR-CI-GpuPS中都pass了,都能覆盖到Coverage中没有跑到的code,所以这种情况应该可以豁免?@tianshuo78520a

zyfncg
zyfncg previously approved these changes Dec 7, 2023
from paddle.framework import LayerHelper, in_dynamic_mode


def fused_dot_product_attention(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

from user's perspective, is this another implementation of paddle.nn.functional.scaled_dot_product_attention ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的,从用户角度看,两者的接口和功能基本差不多。区别是这个底层用的是cudnn版本的flash attention,在Ampere和Hopper GPU上的适配和性能会更好。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是我们只需要暴露 paddle.nn.functional.scaled_dot_product_attention, 这一个API给用户? 这个 API 的命名、参数、语义,我记得内部是经过讨论之后确定的。

如果技术实现上导致无法只有一个API的话, 那这个新的API,我建议其 signature 尽可能跟已有的API对齐, 以方便用户理解和使用。

any thoughts? @jeff41404 @sneaxiy @zhiqiu @liuzhenhai93

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

API有一些区别,主要是1.输入多了一个float类型的scaling factor 2. mask需要是bool或int。其他的参数都是一致的。
另外就是cudnn的flash attention仅支持Ampere和Hopper GPU,这可能和paddle的flash attention不一样。
所以看起来无法直接统一到nn.functional.scaled_dot_product_attention的这个API中,有什么建议吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个PR中主要还是增加了这个c++ op和ir pass,增加nn.functional.fused_dot_product_attention这个API并不是必要的,不加也不影响其他的改动。所以如果暴露这个API需要更多的讨论,我想可以另外开一个PR来讨论。
@jzhang533 是否应该在这个PR中先移除这个修改?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我觉得不需要移除的,现在这个PR里的方案,这个 API 是放在 paddle.incubate 下的,以后正式化到paddle.nn.functional 下时,是可以再调整的。

我只是想在这里开始关于这个问题的讨论,等一下我会给一个approve,让这个PR可以通过CI。

我们可以继续讨论这个问题~

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jzhang533 的评论很好

API有一些区别,主要是1.输入多了一个float类型的scaling factor 2. mask需要是bool或int。其他的参数都是一致的。

如果只是这2点区别,我认为不足以独立成为1个API,可能解决方案比如在scaled_dot_product_attention加1个scaling factor,且有默认值;加1个 use_cudnn = False 参数,如果设置为True,再对环境和mask进行检查,不符合要求告知用户
所以我的建议是:如果这个API还未确定参数和用法,可以先将代码整体移动到incubate目录中,文档中加上warning,说明以后可能会修改;如果确定了,最好都统一到scaled_dot_product_attention

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同意@jeff41404 的建议,代码已经在incubate目录中,我加上了Warning。之后若需要统一到scaled_dot_product_attention,我再提新的PR。

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

“可能解决方案比如在scaled_dot_product_attention加1个scaling factor,且有默认值;”
这个可以考虑

“加1个 use_cudnn = False 参数,如果设置为True”
这个不太建议,一般API和OP不包含设备相关的参数。

jzhang533
jzhang533 previously approved these changes Dec 7, 2023
@Wong4j Wong4j dismissed stale reviews from jzhang533 and zyfncg via 91d70ae December 7, 2023 09:27
@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch from 91d70ae to 1fa4ea6 Compare December 8, 2023 03:03
@Wong4j Wong4j force-pushed the jaywan/cudnn_fa_integration branch from 1fa4ea6 to bf802e0 Compare December 8, 2023 05:08
@onecatcn
Copy link
Contributor

onecatcn commented Dec 9, 2023

2023-12-08 13:12:17 0. You must have raindrops2sea or XiaoguangHu01 approval for change 20+ files or add than 1000+ lines of content.
2023-12-08 13:12:17 1. You must have one RD (XiaoguangHu01,chenwhql,zhiqiu,Xreki,luotao1,qili93,Aurelius84) approval for the usage of const_cast.
2023-12-08 13:12:17 2. Unittest is not allowed to be disabled.
2023-12-08 13:12:17 You must have one RD (kolinwei(Recommend), wanghuancoder, luotao1, QingshuChen, qili93 or ZzSean or Aurelius84) approval for the usage of @unittest.skip or @unittest.skipIf.

@onecatcn onecatcn requested review from chenwhql and wanghuancoder and removed request for tianshuo78520a December 9, 2023 07:04
@Wong4j
Copy link
Collaborator Author

Wong4j commented Dec 12, 2023

Could you please take a look? @XiaoguangHu01

Copy link
Contributor

@XiaoguangHu01 XiaoguangHu01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@zyfncg zyfncg merged commit fd3e9ef into PaddlePaddle:develop Dec 12, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

10 participants