-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support MaskedSelectGrad op with Kernel Primitive API #40617
Support MaskedSelectGrad op with Kernel Primitive API #40617
Conversation
Thanks for your contribution! |
71a52f0
to
c8967e6
Compare
typename Functor, | ||
int VecSize, | ||
int IsBoundary, | ||
int IsMaskData> | ||
int MaskData> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
建议在下个PR注释里说明下 maskdata =0,1,2 对应的情况
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的
@@ -123,6 +123,15 @@ __device__ __forceinline__ void WriteData(T* dst, | |||
dst[i] = src[i]; | |||
} | |||
} | |||
|
|||
template <typename T> |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个和之前的 readdata 可以复用吗 ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不可以 这个是线程级别的API
#include "paddle/phi/kernels/masked_select_grad_kernel.h" | ||
|
||
#include "paddle/phi/backends/cpu/cpu_context.h" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个头文件应该没用到,可以删掉试试
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的下个PR再修改
|
||
SelectGradWithPrefixMask<T><<<grid, threads, 0, stream>>>( | ||
mask_prefix_sum_data, mask_data, input_data, out_data, mask_size); | ||
auto* out_data = x_grad->mutable_data<T>(dev_ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kernel里分配内存调用新接口:dev_ctx.template Alloc(x_grad)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
好的 下个PR再修改
PR types
Others
PR changes
Others
Describe
Support MaskedSelectGrad op with Kernel Primitive API