Skip to content

Commit

Permalink
resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
MingMingShangTian committed Oct 14, 2020
2 parents 2119342 + ae6ad23 commit 6e50e93
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 18 deletions.
12 changes: 11 additions & 1 deletion paddle/fluid/operators/fill_constant_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,9 @@ class FillConstantKernel : public framework::OpKernel<T> {
value_tensor->numel()));
const T *tensor_data = value_tensor->data<T>();
framework::Tensor cpu_tensor;
if (platform::is_gpu_place(value_tensor->place())) {
auto tmp_place = value_tensor->place();
if (platform::is_gpu_place(tmp_place) ||
platform::is_xpu_place(tmp_place)) {
TensorCopySync(*value_tensor, platform::CPUPlace(), &cpu_tensor);
tensor_data = cpu_tensor.data<T>();
}
Expand Down Expand Up @@ -102,6 +104,14 @@ class FillConstantKernel : public framework::OpKernel<T> {
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#endif
#ifdef PADDLE_WITH_XPU
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::XPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#endif
}
};
Expand Down
23 changes: 23 additions & 0 deletions paddle/fluid/operators/fill_constant_op_xpu.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fill_constant_op.h"

namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<int64_t>,
ops::FillConstantKernel<double>,
ops::FillConstantKernel<bool>,
ops::FillConstantKernel<int>);
#endif
10 changes: 10 additions & 0 deletions paddle/fluid/operators/math/math_function.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include <cblas.h>
#endif

#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
Expand All @@ -44,6 +45,15 @@ template struct SetConstant<platform::CPUDeviceContext, int64_t>;
template struct SetConstant<platform::CPUDeviceContext, bool>;
template struct SetConstant<platform::CPUDeviceContext, uint8_t>;

#ifdef PADDLE_WITH_XPU
template struct SetConstant<platform::XPUDeviceContext, platform::float16>;
template struct SetConstant<platform::XPUDeviceContext, float>;
template struct SetConstant<platform::XPUDeviceContext, double>;
template struct SetConstant<platform::XPUDeviceContext, int>;
template struct SetConstant<platform::XPUDeviceContext, int64_t>;
template struct SetConstant<platform::XPUDeviceContext, bool>;
#endif

#define DEFINE_CPU_TRANS(RANK) \
template struct Transpose<platform::CPUDeviceContext, platform::float16, \
RANK>; \
Expand Down
28 changes: 28 additions & 0 deletions paddle/fluid/operators/math/math_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ limitations under the License. */

#pragma once
#include <cmath>
#include <memory>
#include <vector>

#include "paddle/fluid/framework/eigen.h"
Expand Down Expand Up @@ -84,6 +85,33 @@ struct RowwiseMean {
framework::Tensor* vec);
};

#ifdef PADDLE_WITH_XPU
template <typename U>
struct TensorSetConstantXPU {
TensorSetConstantXPU(framework::Tensor* tensor, U value)
: tensor_(tensor), value_(value) {}
template <typename T>
void apply() const {
int dev_id = -1;
xpu_current_device(&dev_id);
if (dev_id >= 64) {
// if dev_id >= 64, the device is a simulator device, -64 to get real
// dev_id
dev_id -= 64;
}
auto xpu = platform::XPUPlace(dev_id);
auto* begin = tensor_->mutable_data<T>(xpu);
int numel = tensor_->numel();
std::unique_ptr<T[]> data_cpu(new T[numel]);
std::fill(data_cpu.get(), data_cpu.get() + numel, static_cast<T>(value_));
memory::Copy(xpu, begin, platform::CPUPlace(),
static_cast<void*>(data_cpu.get()), numel * sizeof(T));
}
framework::Tensor* tensor_;
U value_;
};
#endif

} // namespace math
} // namespace operators
} // namespace paddle
15 changes: 13 additions & 2 deletions paddle/fluid/operators/math/math_function_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <memory>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/operators/math/math_function.h"
Expand All @@ -27,8 +28,18 @@ template <typename DeviceContext, typename T>
void SetConstant<DeviceContext, T>::operator()(const DeviceContext& context,
framework::Tensor* tensor,
T num) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
bool xpu_place = false;
#ifdef PADDLE_WITH_XPU
if (context.GetPlace() == platform::XPUPlace()) {
xpu_place = true;
framework::VisitDataType(tensor->type(),
TensorSetConstantXPU<T>(tensor, num));
}
#endif
if (!xpu_place) {
auto t = framework::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(num));
}
}

template <typename DeviceContext, typename T, int Rank>
Expand Down
8 changes: 4 additions & 4 deletions paddle/fluid/operators/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ inline std::vector<T> GetDataFromTensor(const framework::Tensor* x) {
if (x->type() == framework::proto::VarType::INT32) {
auto* data = x->data<int>();
framework::Tensor cpu_attr_tensor;
if (platform::is_gpu_place(x->place())) {
if (!platform::is_cpu_place(x->place())) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int>();
}
vec_new_data = std::vector<T>(data, data + x->numel());
} else if (x->type() == framework::proto::VarType::INT64) {
auto* data = x->data<int64_t>();
framework::Tensor cpu_attr_tensor;
if (platform::is_gpu_place(x->place())) {
if (!platform::is_cpu_place(x->place())) {
TensorCopySync(*x, platform::CPUPlace(), &cpu_attr_tensor);
data = cpu_attr_tensor.data<int64_t>();
}
Expand Down Expand Up @@ -62,15 +62,15 @@ inline std::vector<T> GetDataFromTensorList(
tensor->dims()));

if (tensor->type() == framework::proto::VarType::INT32) {
if (platform::is_gpu_place(tensor->place())) {
if (!platform::is_cpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
vec_new_data.push_back(static_cast<T>(*temp.data<int>()));
} else {
vec_new_data.push_back(static_cast<T>(*tensor->data<int>()));
}
} else if (tensor->type() == framework::proto::VarType::INT64) {
if (platform::is_gpu_place(tensor->place())) {
if (!platform::is_cpu_place(tensor->place())) {
framework::Tensor temp;
TensorCopySync(*tensor, platform::CPUPlace(), &temp);
// NOTE: Converting int64 to int32 may cause data overflow.
Expand Down
7 changes: 4 additions & 3 deletions python/paddle/fluid/layers/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7333,10 +7333,10 @@ def image_resize(input,
.. code-block:: python

#declarative mode
import paddle
import paddle.fluid as fluid
import numpy as np
import paddle
paddle.enable_static()
paddle.enable_static()
input = fluid.data(name="input", shape=[None,3,6,10])

#1
Expand Down Expand Up @@ -7951,8 +7951,8 @@ def resize_trilinear(input,

#declarative mode
import paddle.fluid as fluid
import numpy as np
import paddle
import numpy as np
paddle.enable_static()
input = fluid.data(name="input", shape=[None,3,6,8,10])

Expand Down Expand Up @@ -8110,6 +8110,7 @@ def resize_nearest(input,
import numpy as np
import paddle
paddle.enable_static()

input = fluid.data(name="input", shape=[None,3,6,10])

#1
Expand Down
6 changes: 0 additions & 6 deletions python/paddle/fluid/tests/unittests/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -336,11 +336,6 @@ list(REMOVE_ITEM TEST_OPS test_conv3d_transpose_op)
list(REMOVE_ITEM TEST_OPS test_imperative_data_loader_exception)
list(REMOVE_ITEM TEST_OPS test_sampling_id_op)
list(REMOVE_ITEM TEST_OPS test_parallel_executor_fetch_isolated_var)
if (APPLE)
list(REMOVE_ITEM TEST_OPS test_imperative_framework)
list(REMOVE_ITEM TEST_OPS test_learning_rate_scheduler)
list(REMOVE_ITEM TEST_OPS test_var_base)
endif()

if (APPLE OR WIN32)
list(REMOVE_ITEM TEST_OPS test_dataset)
Expand Down Expand Up @@ -606,4 +601,3 @@ if(NOT WIN32 AND NOT APPLE)
set_tests_properties(test_pool3d_op PROPERTIES TIMEOUT 150)
set_tests_properties(test_regularizer PROPERTIES TIMEOUT 150)
endif()

Loading

0 comments on commit 6e50e93

Please sign in to comment.