Skip to content

Commit

Permalink
elementwise op support fp16 (#45496)
Browse files Browse the repository at this point in the history
  • Loading branch information
xiaohemaikoo authored Sep 6, 2022
1 parent 72b5b5b commit f6d9ec2
Show file tree
Hide file tree
Showing 11 changed files with 233 additions and 19 deletions.
9 changes: 6 additions & 3 deletions paddle/phi/kernels/elementwise_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,8 @@ PD_REGISTER_KERNEL(remainder,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
floor_divide, KPS, ALL_LAYOUT, phi::FloorDivideKernel, int, int64_t) {}
PD_REGISTER_KERNEL(elementwise_heaviside,
Expand All @@ -247,15 +248,17 @@ PD_REGISTER_KERNEL(elementwise_heaviside,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(elementwise_pow,
KPS,
ALL_LAYOUT,
phi::ElementwisePowKernel,
float,
double,
int,
int64_t) {}
int64_t,
phi::dtype::float16) {}

#endif

Expand Down
26 changes: 25 additions & 1 deletion paddle/phi/kernels/funcs/elementwise_functor.h
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,19 @@ struct RemainderFunctor<
}
};

template <>
struct RemainderFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float b_float = static_cast<float>(b);
float res = fmod(static_cast<float>(a), b_float);
// Accoding to #PR26732: in dividen % divsor
// remainder shall have the same sign as divsor.
if ((res != 0.0f) && ((res < 0.0f) != (b_float < 0.0f))) res += b_float;
return static_cast<dtype::float16>(res);
}
};

template <typename T, typename Enable = void>
struct InverseRemainderFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
Expand All @@ -547,7 +560,7 @@ struct InverseRemainderFunctor<
template <typename T>
struct ElementwiseHeavisideFunctor {
inline HOSTDEVICE T operator()(const T a, const T b) const {
return a == static_cast<T>(0) ? b : static_cast<T>(a > 0);
return a == static_cast<T>(0) ? b : static_cast<T>(a > static_cast<T>(0));
}
};

Expand Down Expand Up @@ -592,5 +605,16 @@ struct ElementwisePowFunctor {
return std::pow(a, b);
}
};

template <>
struct ElementwisePowFunctor<dtype::float16> {
inline HOSTDEVICE dtype::float16 operator()(const dtype::float16 a,
const dtype::float16 b) const {
float f_a = static_cast<float>(a);
float f_b = static_cast<float>(b);
return static_cast<dtype::float16>(std::pow(f_a, f_b));
}
};

} // namespace funcs
} // namespace phi
5 changes: 5 additions & 0 deletions paddle/phi/kernels/gpu/elementwise_grad_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ void MaximumGradKernel(const Context& dev_ctx,
DenseTensor* dx,
DenseTensor* dy) {
const auto place = dev_ctx.GetPlace();

if (dx != nullptr && dy != nullptr) {
std::vector<const DenseTensor*> ins = {&x, &y, &dout};
GetGradXAndYOut<ElementwiseType::kTernary, T>(
Expand Down Expand Up @@ -96,6 +97,7 @@ PD_REGISTER_KERNEL(fmax_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}

PD_REGISTER_KERNEL(fmin_grad,
Expand All @@ -105,6 +107,7 @@ PD_REGISTER_KERNEL(fmin_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}

PD_REGISTER_KERNEL(maximum_grad,
Expand Down Expand Up @@ -136,6 +139,7 @@ PD_REGISTER_KERNEL(elementwise_heaviside_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}

PD_REGISTER_KERNEL(elementwise_pow_grad,
Expand All @@ -145,4 +149,5 @@ PD_REGISTER_KERNEL(elementwise_pow_grad,
float,
double,
int,
phi::dtype::float16,
int64_t) {}
30 changes: 30 additions & 0 deletions paddle/phi/kernels/impl/elementwise_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once

#include "paddle/phi/common/complex.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/broadcast_function.h"
Expand Down Expand Up @@ -753,6 +754,20 @@ struct PowGradDX {
}
};

template <>
struct PowGradDX<dtype::float16> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float result = tmp_dout * tmp_y * std::pow(tmp_x, tmp_y - 1.0f);
return static_cast<dtype::float16>(result);
}
};

template <typename T, typename Enable = void>
struct PowGradDY {
HOSTDEVICE T operator()(T x, T y, T out, T dout) const {
Expand All @@ -766,6 +781,21 @@ struct PowGradDY {
}
};

template <>
struct PowGradDY<dtype::float16, void> {
HOSTDEVICE dtype::float16 operator()(dtype::float16 x,
dtype::float16 y,
dtype::float16 out,
dtype::float16 dout) const {
float tmp_y = static_cast<float>(y);
float tmp_dout = static_cast<float>(dout);
float tmp_x = static_cast<float>(x);
float tmp_pow = std::pow(tmp_x, tmp_y);
float result = tmp_pow * tmp_dout * std::log(tmp_x);
return static_cast<dtype::float16>(result);
}
};

template <typename T, typename Context>
void ElementwisePowGradKernel(const Context& dev_ctx,
const DenseTensor& x,
Expand Down
26 changes: 22 additions & 4 deletions paddle/phi/kernels/kps/elementwise_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ void MaximumKernel(const Context& dev_ctx,
int axis = -1;
MaximumRawKernel<T>(dev_ctx, x, y, axis, out);
}

// Create the definition of Minimum
DEFINE_CUDA_ELEMENTWISE_OP(Minimum)
template <typename T, typename Context>
Expand Down Expand Up @@ -92,11 +93,25 @@ using bfloat16 = phi::dtype::bfloat16;
using complex64 = ::phi::dtype::complex<float>;
using complex128 = ::phi::dtype::complex<double>;

PD_REGISTER_KERNEL(
fmax, KPS, ALL_LAYOUT, phi::FMaxKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(fmax,
KPS,
ALL_LAYOUT,
phi::FMaxKernel,
float,
double,
int,
float16,
int64_t) {}

PD_REGISTER_KERNEL(
fmin, KPS, ALL_LAYOUT, phi::FMinKernel, float, double, int, int64_t) {}
PD_REGISTER_KERNEL(fmin,
KPS,
ALL_LAYOUT,
phi::FMinKernel,
float,
double,
int,
float16,
int64_t) {}

PD_REGISTER_KERNEL(maximum_raw,
KPS,
Expand Down Expand Up @@ -125,6 +140,7 @@ PD_REGISTER_KERNEL(remainder_raw,
float,
double,
int,
float16,
int64_t) {}
PD_REGISTER_KERNEL(floor_divide_raw,
KPS,
Expand All @@ -139,6 +155,7 @@ PD_REGISTER_KERNEL(elementwise_heaviside_raw,
float,
double,
int,
float16,
int64_t) {}
PD_REGISTER_KERNEL(elementwise_pow_raw,
KPS,
Expand All @@ -147,5 +164,6 @@ PD_REGISTER_KERNEL(elementwise_pow_raw,
float,
double,
int,
float16,
int64_t) {}
#endif
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,13 @@
import paddle


def Heaviside_grad(x, y, dout):
tmp = np.zeros(x.shape).astype("float16")
dx = np.multiply(tmp, dout)
dy = np.multiply(np.equal(x, 0), dout).astype("float16")
return dx, dy


class TestElementwiseOp(OpTest):

def setUp(self):
Expand Down Expand Up @@ -152,6 +159,30 @@ def setUp(self):
self.dtype = "int32"


class TestHeavisideAPI_float16(OpTest):

def setUp(self):
self.dtype = np.float16
self.op_type = "elementwise_heaviside"
self.python_api = paddle.heaviside
self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float16"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float16")
}
self.outputs = {'Out': np.heaviside(self.inputs['X'], self.inputs['Y'])}

def test_check_output(self):
self.check_output()

def test_check_grad(self):
self.check_grad(['X', 'Y'],
'Out',
user_defined_grads=Heaviside_grad(
self.inputs['X'], self.inputs['Y'],
1 / self.inputs['X'].size),
check_eager=True)


class TestHeavisideError(unittest.TestCase):

def test_input(self):
Expand Down
17 changes: 17 additions & 0 deletions python/paddle/fluid/tests/unittests/test_elementwise_mod_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ def test_check_output(self):
self.check_output(check_eager=False)


class TestElementwiseModOpFp16(TestElementwiseModOp):

def init_dtype(self):
self.dtype = np.float16

def init_input_output(self):
self.x = np.random.uniform(-1000, 1000, [10, 10]).astype(self.dtype)
self.y = np.random.uniform(-100, 100, [10, 10]).astype(self.dtype)
self.out = np.mod(self.x, self.y)

def test_check_output(self):
if self.attrs['axis'] == -1:
self.check_output(check_eager=True)
else:
self.check_output(check_eager=False)


class TestElementwiseModOpDouble(TestElementwiseModOpFloat):

def init_dtype(self):
Expand Down
34 changes: 32 additions & 2 deletions python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,12 @@
import paddle


def pow_grad(x, y, dout):
dx = dout * y * np.power(x, (y - 1))
dy = dout * np.log(x) * np.power(x, y)
return dx, dy


class TestElementwisePowOp(OpTest):

def setUp(self):
Expand Down Expand Up @@ -194,7 +200,6 @@ def setUp(self):
# dy = dout * log(x) * pow(x, y)
self.grad_y = (self.grad_res * np.log(self.x) *
(self.x**self.y)).astype("int")
print(self.grad_res, self.grad_x, self.grad_y)

def test_grad(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": True})
Expand All @@ -205,7 +210,6 @@ def test_grad(self):
with fluid.dygraph.guard(place):
x = fluid.dygraph.to_variable(self.x, zero_copy=False)
y = fluid.dygraph.to_variable(self.y, zero_copy=False)
print(x, y)
x.stop_gradient = False
y.stop_gradient = False
res = x**y
Expand All @@ -216,5 +220,31 @@ def test_grad(self):
fluid.set_flags({"FLAGS_retain_grad_for_all_tensor": False})


class TestElementwisePowOpFP16(OpTest):

def setUp(self):
self.op_type = "elementwise_pow"
self.python_api = paddle.pow
self.inputs = {
'X': np.random.uniform(1, 2, [20, 5]).astype("float16"),
'Y': np.random.uniform(1, 2, [20, 5]).astype("float16")
}
self.outputs = {'Out': np.power(self.inputs['X'], self.inputs['Y'])}

def test_check_output(self):
if hasattr(self, 'attrs'):
self.check_output(check_eager=False)
else:
self.check_output(check_eager=True)

def test_check_grad(self):
self.check_grad(['X', 'Y'],
'Out',
user_defined_grads=pow_grad(self.inputs['X'],
self.inputs['Y'],
1 / self.inputs['X'].size),
check_eager=True)


if __name__ == '__main__':
unittest.main()
30 changes: 30 additions & 0 deletions python/paddle/fluid/tests/unittests/test_fmax_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,3 +209,33 @@ def test_check_grad_ingore_y(self):
max_relative_error=0.005,
no_grad_set=set('Y'),
check_eager=True)


class TestElementwiseFmax3Op(OpTest):
"""TestElementwiseFmax3Op"""

def setUp(self):
"""setUp"""
self.op_type = "elementwise_fmax"
self.python_api = paddle.fmax
# If x and y have the same value, the max() is not differentiable.
# So we generate test data by the following method
# to avoid them being too close to each other.
x = np.random.uniform(0.1, 1, [13, 17]).astype("float16")
sgn = np.random.choice([-1, 1], [13, 17]).astype("float16")
y = x + sgn * np.random.uniform(0.1, 1, [13, 17]).astype("float16")

self.inputs = {'X': x, 'Y': y}
self.outputs = {'Out': np.fmax(self.inputs['X'], self.inputs['Y'])}

def test_check_output(self):
"""test_check_output"""
self.check_output(check_eager=True)

def test_check_grad_normal(self):
"""test_check_grad_normal"""
self.check_grad(['X', 'Y'], 'Out', check_eager=True)


if __name__ == "__main__":
unittest.main()
Loading

0 comments on commit f6d9ec2

Please sign in to comment.