Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added concat BF16/FP32 BWD OneDNN kernel #35889

Merged
merged 11 commits into from
Oct 5, 2021
18 changes: 15 additions & 3 deletions paddle/fluid/operators/concat_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,9 +169,21 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out")),
ctx.GetPlace());
auto input_data_type = OperatorWithKernel::IndicateVarDataType(
ctx, framework::GradVarName("Out"));

#ifdef PADDLE_WITH_MKLDNN
// extra checking if attr "use_mkldnn" exist is needed because
// test_reverse_op is calling concat_grad kernel without setting
// "use_mkldnn" to any value
if (ctx.HasAttr("use_mkldnn") &&
this->CanMKLDNNBeUsed(ctx, input_data_type)) {
return framework::OpKernelType(input_data_type, ctx.GetPlace(),
framework::DataLayout::kMKLDNN,
framework::LibraryType::kMKLDNN);
}
#endif
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}

framework::OpKernelType GetKernelTypeForVar(
Expand Down
71 changes: 71 additions & 0 deletions paddle/fluid/operators/mkldnn/concat_mkldnn_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ namespace operators {

using framework::DataLayout;
using framework::Tensor;
using framework::LoDTensor;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::concat;
Expand Down Expand Up @@ -149,6 +150,72 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_format(platform::GetMKLDNNFormat(*dst_mem));
}
};

template <typename T>
class ConcatGradMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
const auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& onednn_engine = dev_ctx.GetEngine();

auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();

auto out_var_names = ctx.OutputNames(framework::GradVarName("X"));

const auto x = ctx.MultiInput<LoDTensor>("X");
const auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto dx = ctx.MultiOutput<LoDTensor>(framework::GradVarName("X"));

for (size_t i = 0; i < dx.size(); ++i) {
if (dx[i] != nullptr) {
dx[i]->set_lod(x[i]->lod());
}
}

int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("AxisTensor")) {
auto* axis_tensor = ctx.Input<Tensor>("AxisTensor");
axis = GetDataFromTensor<int>(axis_tensor)[0];
}

auto dout_vec_dims = framework::vectorize(dout->dims());

axis = ComputeAxis(axis, dout_vec_dims.size());

std::vector<int64_t> offset(dout_vec_dims.size(), 0);

mkldnn::memory::data_type dout_type =
framework::ToMKLDNNDataType(dout->type());
platform::ReorderMKLDNNHandler reorder_handler(dout_vec_dims, dout->type(),
dout_type, onednn_engine);
auto reorder_src_memory_p = reorder_handler.AcquireSrcMemory(
dout->format(), platform::to_void_cast(dout->data<T>()));

for (size_t i = 0; i < dx.size(); ++i) {
if (out_var_names[i] != framework::kEmptyVarName &&
dx[i]->numel() != 0UL) {
auto dx_vec_dims = framework::vectorize(dx[i]->dims());
auto slice_mem_p = reorder_handler.AcquireSubmemory(
dx_vec_dims, offset, reorder_src_memory_p);

auto reorder_dst_memory_p = reorder_handler.AcquireDstMemory(
dx[i], dx_vec_dims, dout->format(), ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(reorder_dst_memory_p, slice_mem_p);

reorder_p->execute(astream, *slice_mem_p, *reorder_dst_memory_p);

offset[axis] += dx[i]->dims()[axis];

dx[i]->set_layout(framework::DataLayout::kMKLDNN);
dx[i]->set_format(platform::GetMKLDNNFormat(*reorder_dst_memory_p));
}
}
astream.wait();
}
};

} // namespace operators
} // namespace paddle

Expand All @@ -159,3 +226,7 @@ REGISTER_OP_KERNEL(concat, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConcatMKLDNNOpKernel<paddle::platform::bfloat16>,
ops::ConcatMKLDNNOpKernel<int8_t>,
ops::ConcatMKLDNNOpKernel<uint8_t>);

REGISTER_OP_KERNEL(concat_grad, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConcatGradMKLDNNOpKernel<float>,
ops::ConcatGradMKLDNNOpKernel<paddle::platform::bfloat16>);
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,28 @@ def setUp(self):
'mkldnn_data_type': self.mkldnn_data_type
}

self.sections = [self.x0.shape[self.axis]] * 2
self.sections[1] += self.x1.shape[self.axis]

self.output = np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis).astype(np.uint16)
self.outputs = {'Out': self.output}

def calculate_grads(self):
self.dout = self.outputs['Out']
self.dxs = np.split(self.dout, self.sections, self.axis)

def test_check_output(self):
self.check_output_with_place(core.CPUPlace())

def test_check_grad(self):
self.calculate_grads()
self.check_grad_with_place(
core.CPUPlace(), ["x0", "x1", "x2"],
"Out",
user_defined_grads=[self.dxs[0], self.dxs[1], self.dxs[2]],
user_defined_grad_outputs=[self.dout])

# --------------------test concat bf16 in with axis 0--------------------

def init_test_data(self):
Expand All @@ -61,9 +76,9 @@ def init_axis(self):
self.axis = 0

def init_shape(self):
self.x0_shape = [2, 2, 1, 2]
self.x1_shape = [1, 2, 1, 2]
self.x2_shape = [3, 2, 1, 2]
self.x0_shape = [6, 2, 4, 3]
self.x1_shape = [7, 2, 4, 3]
self.x2_shape = [8, 2, 4, 3]


# --------------------test concat bf16 in with axis 1--------------------
Expand All @@ -74,9 +89,9 @@ def init_axis(self):
self.axis = 1

def init_shape(self):
self.x0_shape = [1, 1, 5, 5]
self.x1_shape = [1, 2, 5, 5]
self.x2_shape = [1, 3, 5, 5]
self.x0_shape = [1, 4, 5, 5]
self.x1_shape = [1, 8, 5, 5]
self.x2_shape = [1, 6, 5, 5]


# --------------------test concat bf16 in with axis 2--------------------
Expand Down
114 changes: 63 additions & 51 deletions python/paddle/fluid/tests/unittests/mkldnn/test_concat_mkldnn_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,78 +15,90 @@
from __future__ import print_function

import unittest
from paddle.fluid.tests.unittests.test_concat_op import TestConcatOp, TestConcatOp2, TestConcatOp3, TestConcatOp4
import numpy as np
import struct


class TestMKLDNNConcatOp(TestConcatOp):
def setUp(self):
super(TestMKLDNNConcatOp, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))

def test_check_grad(self):
pass

def init_kernel_type(self):
self.use_mkldnn = True
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle import enable_static


class TestMKLDNNConcatOp2(TestConcatOp2):
class TestConcatAxis0OneDNNOp(OpTest):
def setUp(self):
super(TestMKLDNNConcatOp2, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
self.op_type = "concat"
self.mkldnn_data_type = "float32"
self.init_axis()
self.init_shape()
self.init_test_data()
self.configure_datatype()
self.inputs = {'X': [('x0', self.x0), ('x1', self.x1), ('x2', self.x2)]}
self.attrs = {
'axis': self.axis,
'use_mkldnn': True,
'mkldnn_data_type': self.mkldnn_data_type
}

self.output = np.concatenate(
(self.x0, self.x1, self.x2), axis=self.axis).astype(self.dtype)

self.outputs = {'Out': self.output}

def configure_datatype(self):
self.mkldnn_data_type = "float32"
self.dtype = np.float32

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
self.check_output_with_place(core.CPUPlace())

def test_check_grad(self):
pass
self.check_grad(['x0'], 'Out')
self.check_grad(['x1'], 'Out')
self.check_grad(['x2'], 'Out')

def init_kernel_type(self):
self.use_mkldnn = True
def init_test_data(self):
self.x0 = np.random.random(self.x0_shape).astype(np.float32)
self.x1 = np.random.random(self.x1_shape).astype(np.float32)
self.x2 = np.random.random(self.x2_shape).astype(np.float32)

def init_axis(self):
self.axis = 0

class TestMKLDNNConcatOp3(TestConcatOp3):
def setUp(self):
super(TestMKLDNNConcatOp3, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
def init_shape(self):
self.x0_shape = [2, 2, 1, 50]
self.x1_shape = [1, 2, 1, 50]
self.x2_shape = [3, 2, 1, 50]

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))

def test_check_grad(self):
pass
class TestConcatAxis1OneDNNOp(TestConcatAxis0OneDNNOp):
def init_axis(self):
self.axis = 1

def init_kernel_type(self):
self.use_mkldnn = True
def init_shape(self):
self.x0_shape = [1, 1, 5, 50]
self.x1_shape = [1, 2, 5, 50]
self.x2_shape = [1, 3, 5, 50]


class TestMKLDNNConcatOp4(TestConcatOp4):
def setUp(self):
super(TestMKLDNNConcatOp4, self).setUp()
self.attrs["use_mkldnn"] = True
self._cpu_only = True
class TestConcatAxis2OneDNNOp(TestConcatAxis0OneDNNOp):
def init_axis(self):
self.axis = 2

def test_check_output(self):
# TODO(wangzhongpu): support mkldnn op in dygraph mode
self.check_output(check_dygraph=(self.attrs["use_mkldnn"] == False))
def init_shape(self):
self.x0_shape = [2, 3, 4, 50]
self.x1_shape = [2, 3, 5, 50]
self.x2_shape = [2, 3, 6, 50]

def test_check_grad(self):
pass

def init_kernel_type(self):
self.use_mkldnn = True
class TestConcatAxis3OneDNNOp(TestConcatAxis0OneDNNOp):
def init_axis(self):
self.axis = 3

def init_shape(self):
self.x0_shape = [5, 3, 5, 5]
self.x1_shape = [5, 3, 5, 6]
self.x2_shape = [5, 3, 5, 7]


if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
2 changes: 1 addition & 1 deletion python/paddle/fluid/tests/unittests/test_concat_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import unittest
import numpy as np
from op_test import OpTest, skip_check_grad_ci
from paddle.fluid.tests.unittests.op_test import OpTest, skip_check_grad_ci
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard, core
import paddle
Expand Down