Skip to content

Commit

Permalink
remove bfloat16 when defined paddle_with_hip
Browse files Browse the repository at this point in the history
  • Loading branch information
MingMingShangTian committed Nov 29, 2021
1 parent 0c3a892 commit 77a8e8b
Showing 1 changed file with 26 additions and 17 deletions.
43 changes: 26 additions & 17 deletions paddle/pten/kernels/cuda/manipulation.cu
Original file line number Diff line number Diff line change
Expand Up @@ -157,23 +157,32 @@ PT_REGISTER_KERNEL("flatten_contiguous_range.mid",
int8_t,
int,
int64_t) {}
PT_REGISTER_KERNEL("cast",
CUDA,
ANY,
pten::Cast,
float,
double,
int,
int64_t,
int16_t,
bool,
uint8_t,
paddle::platform::float16,
paddle::platform::bfloat16,
paddle::platform::complex<float>,
paddle::platform::complex<double>) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}

#define PTEN_REGISTER_CAST_CUDA_BASE_TYPE(op_name, ...) \
PT_REGISTER_KERNEL("cast", \
CUDA, \
ANY, \
pten::Cast, \
float, \
double, \
int, \
int64_t, \
int16_t, \
bool, \
uint8_t, \
paddle::platform::float16, \
paddle::platform::complex<float>, \
paddle::platform::complex<double>, \
##__VA_ARGS__) { \
kernel->OutputAt(0).SetDataType( \
paddle::experimental::DataType::UNDEFINED); \
}

#if !defined(PADDLE_WITH_HIP)
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast, paddle::platform::bfloat16)
#else
PTEN_REGISTER_CAST_CUDA_BASE_TYPE(cast)
#endif

PT_REGISTER_KERNEL_WITH_NO_TYPE("reshape2",
CUDA,
Expand Down

1 comment on commit 77a8e8b

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.