-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【New IR】 backward gradients accumulate test and pulish append_backward_ops func for op_pattern #56265
【New IR】 backward gradients accumulate test and pulish append_backward_ops func for op_pattern #56265
Conversation
…e/Paddle into vjp_for_new_ir
… vjp_for_new_ir
你的PR提交成功,感谢你对开源项目的贡献! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -313,30 +311,57 @@ def append_backward_ops( | |||
v2_g = call_vjp(op3, [v3_g], [v2_stopgradient]) | |||
|
|||
|
|||
special pattern 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
像这种case描述的话,可以整理成 design doc,然后这里援引下就可以,可以参考这种:https://github.com/PaddlePaddle/PaddleSOT/pull/303/files#diff-826111385ea376dc22312877b7c73b22c107d8da6a0069617a396f5ccda02f3b
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
后续统一写样例文档引用
zero_flag = [False] * op.num_results() | ||
for i, value in enumerate(op.results()): | ||
if ( | ||
value not in state.value_to_valuegrad |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
State 除了只是数据封装外,是否也可以承担一些必要的逻辑接口?比如类似这里的 if 判断可以直接调用对应的接口,是个建议,也取决于类似的判断是否在多处出现。
如果State 不便于承担接口,也可以交给某个 Helper类,来与 State 交互
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
待功能完备后进行接口统一;报错提示等完善
return zero_flag, output_grad | ||
|
||
def make_input_stopgradient(combine_op, op): | ||
input_grad_stopgradient_list = [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
变量命令是不是不必以_list
为结尾? 容器类加 s 即可? no_grad_set 应该比较特殊,set 本身也可以理解为 集合的意思
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
下个pr 修改
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
Others
Description
增加梯度聚合测试样例 full -> 2tanh -> add -> mean / full -> mean_grad -> add_grad -> 2tanh_grad -> combine -> add_n
整理反向添加op代码为多个子函数,以支持vector 输入,vector输出op 的 op pattern 添加。
增加相应样例解释,待补齐算子验证
pcard-67164