Skip to content

Commit

Permalink
refine test=kunlun
Browse files Browse the repository at this point in the history
  • Loading branch information
wangchaochaohu committed Oct 14, 2020
1 parent 377fadc commit 17225e0
Showing 1 changed file with 4 additions and 11 deletions.
15 changes: 4 additions & 11 deletions paddle/fluid/operators/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,18 @@ namespace operators {
template <typename T = int32_t>
inline std::vector<T> GetDataFromTensor(const framework::Tensor* x) {
std::vector<T> vec_new_data;
auto tmp_place = x->place();
if (x->type() == framework::proto::VarType::INT32) {
auto* data = x->data<int>();
framework::Tensor cpu_attr_tensor;
if (platform::is_gpu_place(tmp_place) ||
platform::is_xpu_place(tmp_place)) {
if (!platform::is_cpu_place(x->place())) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int>();
}

vec_new_data = std::vector<T>(data, data + x->numel());
} else if (x->type() == framework::proto::VarType::INT64) {
auto* data = x->data<int64_t>();
framework::Tensor cpu_attr_tensor;
if (platform::is_cpu_place(tmp_place) ||
platform::is_xpu_place(tmp_place)) {
if (!platform::is_cpu_place(x->place())) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int64_t>();
}
Expand All @@ -58,7 +54,6 @@ inline std::vector<T> GetDataFromTensorList(
std::vector<T> vec_new_data;
for (size_t i = 0; i < list_tensor.size(); ++i) {
auto tensor = list_tensor[i];
auto tmp_place = tensor->place();
PADDLE_ENFORCE_EQ(tensor->dims(), framework::make_ddim({1}),
platform::errors::InvalidArgument(
"The shape of Tensor in list must be [1]. "
Expand All @@ -67,17 +62,15 @@ inline std::vector<T> GetDataFromTensorList(
tensor->dims()));

if (tensor->type() == framework::proto::VarType::INT32) {
if (platform::is_xpu_place(tmp_place) ||
platform::is_gpu_place(tmp_place)) {
if (!platform::is_cpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_data.push_back(static_cast<T>(*temp.data<int>()));
} else {
vec_new_data.push_back(static_cast<T>(*tensor->data<int>()));
}
} else if (tensor->type() == framework::proto::VarType::INT64) {
if (platform::is_xpu_place(tmp_place) ||
platform::is_gpu_place(tmp_place)) {
if (!platform::is_cpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
// NOTE: Converting int64 to int32 may cause data overflow.
Expand Down

1 comment on commit 17225e0

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.