diff --git a/lite/api/mobilenetv2_test.cc b/lite/api/mobilenetv2_test.cc index b523d5951b3..754994097c6 100644 --- a/lite/api/mobilenetv2_test.cc +++ b/lite/api/mobilenetv2_test.cc @@ -95,17 +95,28 @@ void TestModel(const std::vector& valid_places, } auto first_target = valid_places[0].target; + float relative_err_max = 0.f; if (first_target == TARGET(kOpenCL) || first_target == TARGET(kNPU)) { ASSERT_EQ(out->dims().production(), 1000); double eps = first_target == TARGET(kOpenCL) ? 0.25 : 0.1; for (int i = 0; i < ref.size(); ++i) { for (int j = 0; j < ref[i].size(); ++j) { - auto result = pdata[j * step + (out->dims()[1] * i)]; - auto diff = std::fabs((result - ref[i][j]) / ref[i][j]); - VLOG(3) << diff; - EXPECT_LT(diff, eps); + auto idx = j * step + (out->dims()[1] * i); + auto result = pdata[idx]; + auto relative_err = std::fabs((result - ref[i][j]) / ref[i][j]); + VLOG(3) << lite::string_format( + "relative_err[%d]: %f \tresult: %f \tref: %f", + idx, + relative_err, + result, + ref[i][j]); + if (relative_err > relative_err_max) { + relative_err_max = relative_err; + } } } + VLOG(3) << lite::string_format("max relative err: %f", relative_err_max); + EXPECT_LT(relative_err_max, eps); } else { ASSERT_EQ(out->dims().size(), 2); ASSERT_EQ(out->dims()[0], 1); diff --git a/lite/backends/opencl/cl_kernel/image/argmax_kernel.cl b/lite/backends/opencl/cl_kernel/image/argmax_kernel.cl new file mode 100644 index 00000000000..5796dc66279 --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/argmax_kernel.cl @@ -0,0 +1,169 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void argmax_n(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 cur_idx = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_idx = (CL_DTYPE4)(DATAINIT); + + FLAG_TYPE4 flag_v = (FLAG_TYPE4)(0); + + for (unsigned short i = 0; i < in_nchw.x; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw, in_nchw.z * i + bh)); + + cur_idx = (CL_DTYPE4)(i); + flag_v = isgreaterequal(cur_data, max_data); + max_data = select(max_data, cur_data, flag_v); + max_idx = select(max_idx, cur_idx, flag_v); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_idx); +} + +__kernel void argmax_c(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 cur_idx = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_idx = (CL_DTYPE4)(DATAINIT); + + FLAG_TYPE4 flag_v = (FLAG_TYPE4)(0); + + for (unsigned short i = 0; i < c4_n; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(in_nchw.w * i + cw, bh)); + cur_idx = (CL_DTYPE4)(i << 2, (i << 2) + 1, (i << 2) + 2, (i << 2) + 3); + flag_v = isgreaterequal(cur_data, max_data); + max_data = select(max_data, cur_data, flag_v); + max_idx = select(max_idx, cur_idx, flag_v); + } + + if (c4_r != 0) { + cur_data = + READ_IMG_TYPE(CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw4 + cw, bh)); + } + + if (c4_r >= 1) { + cur_idx.x = (c4_n << 2); + flag_v.x = isgreaterequal(cur_data.x, max_data.x); + max_data.x = select(max_data.x, cur_data.x, flag_v.x); + max_idx.x = select(max_idx.x, cur_idx.x, flag_v.x); + } + if (c4_r >= 2) { + cur_idx.y = (c4_n << 2) + 1; + flag_v.y = isgreaterequal(cur_data.y, max_data.y); + max_data.y = select(max_data.y, cur_data.y, flag_v.y); + max_idx.y = select(max_idx.y, cur_idx.y, flag_v.y); + } + if (c4_r == 3) { + cur_idx.z = (c4_n << 2) + 2; + flag_v.z = isgreaterequal(cur_data.z, max_data.z); + max_data.z = select(max_data.z, cur_data.z, flag_v.z); + max_idx.z = select(max_idx.z, cur_idx.z, flag_v.z); + } + + if (max_data.y > max_data.x) { + max_data.x = max_data.y; + max_idx.x = max_idx.y; + } + + if (max_data.z > max_data.x) { + max_data.x = max_data.z; + max_idx.x = max_idx.z; + } + + if (max_data.w > max_data.x) { + max_data.x = max_data.w; + max_idx.x = max_idx.w; + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_idx); +} + +__kernel void argmax_h(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 cur_idx = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_idx = (CL_DTYPE4)(DATAINIT); + + FLAG_TYPE4 flag_v = (FLAG_TYPE4)(0); + + for (unsigned short i = 0; i < in_nchw.z; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw, in_nchw.z * bh + i)); + + cur_idx = (CL_DTYPE4)(i); + flag_v = isgreaterequal(cur_data, max_data); + max_data = select(max_data, cur_data, flag_v); + max_idx = select(max_idx, cur_idx, flag_v); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_idx); +} + +__kernel void argmax_w(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 cur_idx = (CL_DTYPE4)(DATAINIT); + CL_DTYPE4 max_idx = (CL_DTYPE4)(DATAINIT); + + FLAG_TYPE4 flag_v = (FLAG_TYPE4)(0); + + for (unsigned short i = 0; i < in_nchw.w; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw * in_nchw.w + i, bh)); + + cur_idx = (CL_DTYPE4)(i); + flag_v = isgreaterequal(cur_data, max_data); + max_data = select(max_data, cur_data, flag_v); + max_idx = select(max_idx, cur_idx, flag_v); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_idx); +} diff --git a/lite/backends/opencl/cl_kernel/image/max_kernel.cl b/lite/backends/opencl/cl_kernel/image/max_kernel.cl new file mode 100644 index 00000000000..1668e2d85cc --- /dev/null +++ b/lite/backends/opencl/cl_kernel/image/max_kernel.cl @@ -0,0 +1,187 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +__kernel void max_n(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4, + __private const int axis_n) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data; + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + + for (unsigned short i = 0; i < in_nchw.x; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw, in_nchw.z * i + bh)); + max_data = fmax(max_data, cur_data); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_data); +} + +__kernel void max_c(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4, + __private const int axis_n) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data; + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + + for (unsigned short i = 0; i < c4_n; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(in_nchw.w * i + cw, bh)); + max_data = fmax(max_data, cur_data); + } + + if (c4_r != 0) { + cur_data = + READ_IMG_TYPE(CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw4 + cw, bh)); + } + + if (c4_r >= 1) { + max_data.x = fmax(max_data.x, cur_data.x); + } + if (c4_r >= 2) { + max_data.y = fmax(max_data.y, cur_data.y); + } + if (c4_r == 3) { + max_data.z = fmax(max_data.z, cur_data.z); + } + + max_data.x = fmax(max_data.x, max_data.y); + max_data.x = fmax(max_data.x, max_data.z); + max_data.x = fmax(max_data.x, max_data.w); + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_data); +} + +__kernel void max_h(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4, + __private const int axis_n) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data; + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + + for (unsigned short i = 0; i < in_nchw.z; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw, in_nchw.z * bh + i)); + + max_data = fmax(max_data, cur_data); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_data); +} + +__kernel void max_w(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4, + __private const int axis_n) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + CL_DTYPE4 cur_data; + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + + for (unsigned short i = 0; i < in_nchw.w; i++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw * in_nchw.w + i, bh)); + + max_data = fmax(max_data, cur_data); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_data); +} + +__kernel void max_multi_axis(__read_only image2d_t input, + __write_only image2d_t output, + __private const int4 in_nchw, + __private const int c4_n, + __private const int c4_r, + __private const int cw4, + __private const int axis_n, + __private const int4 axis_nhwc) { + const int cw = get_global_id(0); + const int bh = get_global_id(1); + + int n_reduce_len = select(1, in_nchw.x, axis_nhwc.x); + int h_reduce_len = select(1, in_nchw.z, axis_nhwc.y); + int w_reduce_len = select(1, in_nchw.w, axis_nhwc.z); + + CL_DTYPE4 cur_data; + CL_DTYPE4 max_data = (CL_DTYPE4)(DATAINIT); + + for (unsigned short n = 0; n < n_reduce_len; n++) { + for (unsigned short h = 0; h < h_reduce_len; h++) { + int img_h_idx = in_nchw.z * n + h + bh * h_reduce_len; + for (unsigned short w = 0; w < w_reduce_len; w++) { + for (unsigned short c_4 = 0; c_4 < select(1, c4_n, axis_nhwc.w); + c_4++) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, + input, + SAMPLER, + (int2)(in_nchw.w * c_4 + w + cw * w_reduce_len, img_h_idx)); + max_data = fmax(max_data, cur_data); + } + + if (axis_nhwc.w) { + if (c4_r == 1) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw4 + w + cw, img_h_idx)); + max_data.x = fmax(max_data.x, cur_data.x); + } else if (c4_r == 2) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw4 + w + cw, img_h_idx)); + max_data.x = fmax(max_data.x, cur_data.x); + max_data.y = fmax(max_data.y, cur_data.y); + } else if (c4_r == 3) { + cur_data = READ_IMG_TYPE( + CL_DTYPE_CHAR, input, SAMPLER, (int2)(cw4 + w + cw, img_h_idx)); + max_data.x = fmax(max_data.x, cur_data.x); + max_data.y = fmax(max_data.y, cur_data.y); + max_data.z = fmax(max_data.z, cur_data.z); + } + } + } + } + } + + if (axis_nhwc.w) { + max_data.x = fmax(max_data.x, max_data.y); + max_data.x = fmax(max_data.x, max_data.z); + max_data.x = fmax(max_data.x, max_data.w); + } + + WRITE_IMG_TYPE(CL_DTYPE_CHAR, output, (int2)(cw, bh), max_data); +} diff --git a/lite/kernels/opencl/CMakeLists.txt b/lite/kernels/opencl/CMakeLists.txt index 072be176d04..157daedfe97 100644 --- a/lite/kernels/opencl/CMakeLists.txt +++ b/lite/kernels/opencl/CMakeLists.txt @@ -48,6 +48,8 @@ add_kernel(clip_opencl_image OPENCL basic SRCS clip_image_compute.cc) add_kernel(softmax_opencl_image OPENCL basic SRCS softmax_image_compute.cc) add_kernel(greater_than_opencl_image OPENCL basic SRCS greater_than_image_compute.cc) add_kernel(fc_opencl_image OPENCL basic SRCS fc_image_compute.cc) +add_kernel(argmax_opencl_image OPENCL basic SRCS argmax_image_compute.cc) +add_kernel(max_opencl_image OPENCL basic SRCS max_image_compute.cc) # extra # wait to add ... @@ -136,6 +138,12 @@ lite_cc_test(test_fc_image_opencl SRCS fc_image_compute_test.cc lite_cc_test(test_softmax_image_opencl SRCS softmax_image_compute_test.cc DEPS kernels core) +lite_cc_test(test_argmax_image_opencl SRCS argmax_image_compute_test.cc + DEPS kernels core) + + lite_cc_test(test_max_image_opencl SRCS max_image_compute_test.cc + DEPS kernels core) + ###################### # buffer kernel # ###################### diff --git a/lite/kernels/opencl/argmax_image_compute.cc b/lite/kernels/opencl/argmax_image_compute.cc new file mode 100644 index 00000000000..ee4ecd5fa2e --- /dev/null +++ b/lite/kernels/opencl/argmax_image_compute.cc @@ -0,0 +1,188 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class ArgmaxComputeImage2D : public KernelLite { + public: + using param_t = operators::ArgmaxParam; + + std::string doc() const override { return "Argmax using cl::Image2D, kFP16"; } + + void PrepareForRun() override { + auto& context = ctx_->As(); + argmax_param_ = param_.get_mutable(); + auto& x_dims = argmax_param_->X->dims(); + + // padding to 4-dims + in_nchw_ = x_dims.Vectorize(); + while (in_nchw_.size() < 4) { + in_nchw_.insert(in_nchw_.cbegin(), 1); + } + + int padding_axis = argmax_param_->Axis + (4 - x_dims.size()); + switch (padding_axis) { + case 0: + kernel_func_name_ = "argmax_n"; + break; + case 1: + kernel_func_name_ = "argmax_c"; + break; + case 2: + kernel_func_name_ = "argmax_h"; + break; + case 3: + kernel_func_name_ = "argmax_w"; + break; + default: + LOG(FATAL) << "invalid axis: " << argmax_param_->Axis; + } + + create_build_options(); + + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + context.cl_context()->AddKernel(kernel_func_name_, + "image/argmax_kernel.cl", + build_options_, + time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + argmax_param_ = param_.get_mutable(); + auto& x_dims = argmax_param_->X->dims(); + + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // compute global work size + // padding out_dims to 4-dims + out_nchw_ = in_nchw_; + out_nchw_[argmax_param_->Axis + (4 - x_dims.size())] = 1; + + int hb = out_nchw_[0] * out_nchw_[2]; + int cw = + out_nchw_[3] * + maptofactor(out_nchw_[1], 4); // return (i + factor - 1) / factor; + gws_ = cl::NDRange{static_cast(cw), + static_cast(hb), + static_cast(1)}; + } + } + + void Run() override { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + const auto* x_img = GET_DATA_GPU(argmax_param_->X); + auto out_image_shape = InitImageDimInfoWith(DDim(out_nchw_)); + auto* out_img = MUTABLE_DATA_GPU(argmax_param_->Out, + out_image_shape["width"], + out_image_shape["height"], + nullptr); + int c4_n = in_nchw_[1] / 4; + int c4_r = in_nchw_[1] % 4; + int cw4 = in_nchw_[3] * c4_n; + + int in_dims[] = {static_cast(in_nchw_[0]), + static_cast(in_nchw_[1]), + static_cast(in_nchw_[2]), + static_cast(in_nchw_[3])}; + + cl_int status; + int arg_idx = 0; + status = kernel_.setArg(arg_idx++, *x_img); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, *out_img); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, in_dims); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, c4_n); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, c4_r); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, cw4); + CL_CHECK_FATAL(status); + + status = EnqueueNDRangeKernel( + context, kernel_, cl::NullRange, gws_, cl::NullRange, nullptr, event_); + CL_CHECK_FATAL(status); + } + + void create_build_options() { + const bool fp16_support = + CLRuntime::Global()->get_precision() == lite_api::CL_PRECISION_FP16; + std::string init_max = + fp16_support ? " -DDATAINIT=-HALF_MAX " : " -DDATAINIT=-FLT_MAX "; + std::string flag_type = + fp16_support ? " -DFLAG_TYPE4=short4 " : " -DFLAG_TYPE4=int4 "; + + build_options_ = init_max + flag_type; + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->global_work_size = ch->NDRangeToStr(gws_); + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + + private: + param_t* argmax_param_{nullptr}; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + std::vector in_nchw_{}; + std::vector out_nchw_{}; + std::string kernel_func_name_{}; + std::string build_options_{}; + std::string time_stamp_{GetTimeStamp()}; + cl::Kernel kernel_; + cl::NDRange gws_; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(arg_max, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::ArgmaxComputeImage2D, + def) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/argmax_image_compute_test.cc b/lite/kernels/opencl/argmax_image_compute_test.cc new file mode 100644 index 00000000000..a44dab4acfa --- /dev/null +++ b/lite/kernels/opencl/argmax_image_compute_test.cc @@ -0,0 +1,237 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_image_converter.h" +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" +#include "lite/tests/utils/fill_data.h" + +#define FP32_ABS_DIFF (1e-7) +#define FP32_RELATIVE_DIFF (1e-6) +#define FP16_ABS_DIFF (1e-3) +#define FP16_RELATIVE_DIFF (1e-3) + +namespace paddle { +namespace lite { + +template +void argmax_baseline(const indtype* x_data, + outdtype* out_data, + const DDim input_dims, + const DDim output_dims, + int axis) { + const int size = input_dims[axis]; + const int in_channel = input_dims.count(axis, input_dims.size()); + const int out_channel = output_dims.count(axis, output_dims.size()); + const int in_stride = input_dims.count(axis + 1, input_dims.size()); + const int out_stride = input_dims.count(0, axis); + + for (int n = 0; n < out_stride; n++) { + for (int k = 0; k < in_stride; k++) { + const indtype* in_ptr = x_data + n * in_channel + k; + std::vector> vec; + vec.resize(size); + for (int i = 0; i < size; i++) { + vec[i] = std::make_pair(in_ptr[i * in_stride], i); + } + // sort + std::partial_sort(vec.begin(), + vec.begin() + 1, + vec.end(), + std::greater>()); + + // out + auto* out_ptr = out_data + n * out_channel + k; + *out_ptr = vec[0].second; + } + } +} + +void test(const lite_api::CLPrecisionType p, + bool keepdims, + const int axis, + DDim x_dim) { + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + CLRuntime::Global()->set_precision(p); + const bool fp16_flag = (p == lite_api::CLPrecisionType::CL_PRECISION_FP16); + LOG(INFO) << "\n\t[ START ] Test Precision=" + << lite_api::CLPrecisionTypeToStr(p) << " axis=" << axis; + + auto kernels = KernelRegistry::Global().Create( + "arg_max", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + + lite::Tensor x, out; + operators::ArgmaxParam param; + param.X = &x; + param.Out = &out; + param.Axis = axis; + + kernel->SetParam(param); + kernel->SetContext(std::move(context)); + + std::vector output_shape; + for (size_t i = 0; i < x_dim.size(); i++) { + output_shape.push_back(x_dim[i]); + } + output_shape[axis] = 1L; + DDim out_dim(output_shape); + + x.Resize(x_dim); + out.Resize(out_dim); + + std::vector x_cpu(x_dim.production()); + std::vector out_from_cpu(out_dim.production()); + std::vector out_from_gpu(out_dim.production()); + fill_data_rand(x_cpu.data(), -100.f, 100.f, x_dim.production()); + + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(x_dim); + DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dim); + VLOG(4) << "x_image_shape = " << x_image_shape[0] << " " << x_image_shape[1]; + VLOG(4) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + + const size_t dtype_size = fp16_flag ? sizeof(half_t) : sizeof(float); + std::vector x_image_data(x_image_shape.production() * 4 * dtype_size); + default_converter->NCHWToImage(x_cpu.data(), x_image_data.data(), x_dim); + MUTABLE_DATA_GPU(&x, x_image_shape[0], x_image_shape[1], x_image_data.data()); + auto* out_image = + MUTABLE_DATA_GPU(&out, out_image_shape[0], out_image_shape[1], nullptr); + + // run opencl kernel + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + std::vector out_image_data(out_image_shape.production() * 4 * + dtype_size); // 4 : RGBA + TargetWrapperCL::ImgcpySync(out_image_data.data(), + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + default_converter->ImageToNCHW( + out_image_data.data(), out_from_gpu.data(), out_image_shape, out_dim); + + // run cpu ref + if (fp16_flag) { + argmax_baseline( + x_cpu.data(), out_from_cpu.data(), x_dim, out_dim, axis); + } else { + argmax_baseline( + x_cpu.data(), out_from_cpu.data(), x_dim, out_dim, axis); + } + + VLOG(4) << "output_data vs output_ref_data"; + auto relative_diff_thres = + fp16_flag ? FP16_RELATIVE_DIFF : FP32_RELATIVE_DIFF; + auto abs_diff_thres = fp16_flag ? FP16_ABS_DIFF : FP32_ABS_DIFF; + uint32_t diff_cnt = 0; + for (int i = 0; i < out_dim.production(); i++) { + auto relative_diff = + COMPUTE_RELATIVE_DIFF(out_from_gpu[i], out_from_cpu[i]); + auto abs_diff = COMPUTE_ABS_DIFF(out_from_gpu[i], out_from_cpu[i]); + EXPECT_FALSE(relative_diff > relative_diff_thres && + abs_diff > abs_diff_thres); + if (relative_diff > relative_diff_thres && abs_diff > abs_diff_thres) { + LOG(WARNING) << lite_api::CLPrecisionTypeToStr(p) << " err idx: " << i + << " abs_diff: " << abs_diff + << "\t relative_diff: " << relative_diff + << "\t out_ins: " << out_from_gpu[i] + << "\t out_ref: " << out_from_cpu[i]; + diff_cnt++; + } + } + if (diff_cnt != 0) { + LOG(FATAL) << "Err num " << diff_cnt << "/" << out_dim.production(); + } + + LOG(INFO) << "\n\t[ PASSED ] " + << " Test Precision=" << lite_api::CLPrecisionTypeToStr(p) + << " x_dim=" << x_dim << " axis=" << axis; +} + +void test_argmax_opencl_4d() { + for (bool keepdims : {true}) { + for (int axis : {0, 1, 2, 3}) { + for (int n : {2, 3}) { + for (int c : {5, 6}) { + for (int h : {2, 3, 4, 5, 6}) { + for (int w : {2, 3, 4, 5, 6}) { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32}) { + auto x_dims = DDim(std::vector({n, c, h, w})); + test(precision_type, keepdims, axis, x_dims); + } + } + } + } + } + } + } +} + +void test_argmax_opencl_3d() { + for (bool keepdims : {true}) { + for (int axis : {0, 1, 2}) { + for (int c : {4, 4}) { + for (int h : {2, 10}) { + for (int w : {2, 17}) { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32}) { + auto x_dims = DDim(std::vector({c, h, w})); + test(precision_type, keepdims, axis, x_dims); + } + } + } + } + } + } +} + +void test_argmax_opencl_2d() { + for (bool keepdims : {true}) { + for (int axis : {0, 1}) { + for (int h : {2, 10}) { + for (int w : {2, 17}) { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32}) { + auto x_dims = DDim(std::vector({h, w})); + test(precision_type, keepdims, axis, x_dims); + } + } + } + } + } +} + +TEST(argmax, compute_basic) { + test_argmax_opencl_4d(); + test_argmax_opencl_3d(); + test_argmax_opencl_2d(); +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(arg_max, kOpenCL, kFP16, kImageDefault, def); diff --git a/lite/kernels/opencl/max_image_compute.cc b/lite/kernels/opencl/max_image_compute.cc new file mode 100644 index 00000000000..6c025ccfcd0 --- /dev/null +++ b/lite/kernels/opencl/max_image_compute.cc @@ -0,0 +1,280 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include "lite/backends/opencl/cl_half.h" +#include "lite/backends/opencl/cl_include.h" +#include "lite/core/kernel.h" +#include "lite/core/op_registry.h" +#include "lite/kernels/opencl/image_helper.h" +#include "lite/operators/op_params.h" +#include "lite/utils/replace_stl/stream.h" +#include "lite/utils/string.h" +#ifdef LITE_WITH_PROFILE +#include "lite/core/profile/profiler.h" +#endif +#include "lite/backends/opencl/cl_utility.h" + +namespace paddle { +namespace lite { +namespace kernels { +namespace opencl { + +class MaxComputeImage2D : public KernelLite { + public: + using param_t = operators::ReduceParam; + + std::string doc() const override { return "Max using cl::Image2D, kFP16"; } + + void PrepareForRun() override { + auto& context = ctx_->As(); + max_param_ = param_.get_mutable(); + auto& x_dims = max_param_->X->dims(); + auto& dim = max_param_->dim; + + // padding to 4-dims + in_nchw_ = x_dims.Vectorize(); + while (in_nchw_.size() < 4) { + in_nchw_.insert(in_nchw_.cbegin(), 1); + } + + // format axis + int offset = 4 - x_dims.size(); + for (auto i = 0; i < dim.size(); i++) { + axis_.push_back(dim[i] >= 0 ? dim[i] + offset + : dim[i] + x_dims.size() + offset); + } + + if (dim.size() == 1) { + switch (axis_[0]) { + case 0: + kernel_func_name_ = "max_n"; + break; + case 1: + kernel_func_name_ = "max_c"; + break; + case 2: + kernel_func_name_ = "max_h"; + break; + case 3: + kernel_func_name_ = "max_w"; + break; + default: + LOG(FATAL) << "invalid dim: " << dim[0]; + } + } else { + kernel_func_name_ = "max_multi_axis"; + } + + create_build_options(); + + VLOG(1) << "kernel_func_name_:" << kernel_func_name_; + context.cl_context()->AddKernel( + kernel_func_name_, "image/max_kernel.cl", build_options_, time_stamp_); + + STL::stringstream kernel_key; + kernel_key << kernel_func_name_ << build_options_ << time_stamp_; + kernel_ = context.cl_context()->GetKernel(kernel_key.str()); + } + + void ReInitWhenNeeded() override { + max_param_ = param_.get_mutable(); + auto& x_dims = max_param_->X->dims(); + + if ((!first_epoch_for_reinit_ && x_dims != last_x_dims_) || + first_epoch_for_reinit_) { + last_x_dims_ = x_dims; + first_epoch_for_reinit_ = false; + + // compute global work size + // padding out_dims to 4-dims + out_nchw_ = in_nchw_; + for (auto k = 0; k < axis_.size(); k++) { + out_nchw_[axis_[k]] = 1; + } + + int hb = out_nchw_[0] * out_nchw_[2]; + int cw = + out_nchw_[3] * + maptofactor(out_nchw_[1], 4); // return (i + factor - 1) / factor; + gws_ = cl::NDRange{static_cast(cw), + static_cast(hb), + static_cast(1)}; + } + } + + void Run() override { + auto& context = ctx_->As(); + CHECK(context.cl_context() != nullptr); + + const auto* x_img = GET_DATA_GPU(max_param_->X); + auto out_image_shape = InitImageDimInfoWith(DDim(out_nchw_)); + auto* out_img = MUTABLE_DATA_GPU(max_param_->Out, + out_image_shape["width"], + out_image_shape["height"], + nullptr); + int c4_n = in_nchw_[1] / 4; + int c4_r = in_nchw_[1] % 4; + int cw4 = in_nchw_[3] * c4_n; + + int axis_n = 0; + int axis_nhwc[] = {0, 0, 0, 0}; + auto dimsize = max_param_->dim.size(); + + if (dimsize == 0) { + axis_n = std::accumulate( + in_nchw_.cbegin(), in_nchw_.cend(), 0, std::multiplies()); + axis_nhwc[0] = 1; + axis_nhwc[1] = 1; + axis_nhwc[2] = 1; + axis_nhwc[3] = 1; + } else if (dimsize == 1) { + axis_n = in_nchw_[axis_[0]]; + } else { + // multi axies + axis_n = 1; + for (auto i = 0; i < max_param_->dim.size(); i++) { + int axis = axis_[i]; + switch (axis) { + case 0: // n + if (!axis_nhwc[0]) { + axis_n *= in_nchw_[axis]; + axis_nhwc[0] = 1; + } + break; + case 1: // c + if (!axis_nhwc[3]) { + axis_n *= in_nchw_[axis]; + axis_nhwc[3] = 1; + } + break; + case 2: // h + if (!axis_nhwc[1]) { + axis_n *= in_nchw_[axis]; + axis_nhwc[1] = 1; + } + break; + case 3: // w + if (!axis_nhwc[2]) { + axis_n *= in_nchw_[axis]; + axis_nhwc[2] = 1; + } + break; + default: + LOG(FATAL) << "invalid axis: " << axis; + } + } + } + + int in_dims[] = {static_cast(in_nchw_[0]), + static_cast(in_nchw_[1]), + static_cast(in_nchw_[2]), + static_cast(in_nchw_[3])}; + + cl_int status; + int arg_idx = 0; + status = kernel_.setArg(arg_idx++, *x_img); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, *out_img); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, in_dims); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, c4_n); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, c4_r); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, cw4); + CL_CHECK_FATAL(status); + status = kernel_.setArg(arg_idx++, axis_n); + CL_CHECK_FATAL(status); + if (dimsize != 1) { + status = kernel_.setArg(arg_idx++, axis_nhwc); + CL_CHECK_FATAL(status); + } + + status = EnqueueNDRangeKernel( + context, kernel_, cl::NullRange, gws_, cl::NullRange, nullptr, event_); + CL_CHECK_FATAL(status); + } + + void create_build_options() { + std::string init_fp32 = " -DDATAINIT=-FLT_MAX "; + std::string init_fp16 = " -DDATAINIT=-HALF_MAX "; + build_options_ = + (CLRuntime::Global()->get_precision() == lite_api::CL_PRECISION_FP16) + ? init_fp16 + : init_fp32; + } + +#ifdef LITE_WITH_PROFILE + void SetProfileRuntimeKernelInfo(paddle::lite::profile::OpCharacter* ch) { + ch->kernel_func_name = kernel_func_name_; + ch->global_work_size = ch->NDRangeToStr(gws_); + ch->cl_event = + event_; // `event_` defined in `kernel.h`, valid after kernel::Run + } +#endif + + private: + param_t* max_param_{nullptr}; + bool first_epoch_for_reinit_{true}; + DDim last_x_dims_; + std::vector in_nchw_{}; + std::vector out_nchw_{}; + std::vector axis_{}; + std::string kernel_func_name_{}; + std::string build_options_{}; + std::string time_stamp_{GetTimeStamp()}; + cl::Kernel kernel_; + cl::NDRange gws_; +}; + +} // namespace opencl +} // namespace kernels +} // namespace lite +} // namespace paddle + +REGISTER_LITE_KERNEL(reduce_max, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::MaxComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); + +REGISTER_LITE_KERNEL(max, + kOpenCL, + kFP16, + kImageDefault, + paddle::lite::kernels::opencl::MaxComputeImage2D, + image2d) + .BindInput("X", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .BindOutput("Out", + {LiteType::GetTensorTy(TARGET(kOpenCL), + PRECISION(kFP16), + DATALAYOUT(kImageDefault))}) + .Finalize(); diff --git a/lite/kernels/opencl/max_image_compute_test.cc b/lite/kernels/opencl/max_image_compute_test.cc new file mode 100644 index 00000000000..f63943a68f0 --- /dev/null +++ b/lite/kernels/opencl/max_image_compute_test.cc @@ -0,0 +1,314 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include +#include +#include "lite/backends/opencl/target_wrapper.h" +#include "lite/core/op_registry.h" +#include "lite/core/tensor.h" +#include "lite/kernels/opencl/test_helper.h" +#include "lite/tests/utils/fill_data.h" + +#define FP16_MAX_DIFF (5e-1) +#define FP32_ABS_DIFF (1e-7) +#define FP32_RELATIVE_DIFF (1e-6) +#define FP16_ABS_DIFF (1e-3) +#define FP16_RELATIVE_DIFF (1e-3) +namespace paddle { +namespace lite { + +template +void max_baseline_dim_single(const float* x_data, + float* out_data, + const DDim input_dims, + const DDim output_dims, + int axis) { + const int size = input_dims[axis]; + const int in_channel = input_dims.count(axis, input_dims.size()); + const int out_channel = output_dims.count(axis, output_dims.size()); + const int in_stride = input_dims.count(axis + 1, input_dims.size()); + const int out_stride = input_dims.count(0, axis); + + for (int n = 0; n < out_stride; n++) { + for (int k = 0; k < in_stride; k++) { + const indtype* in_ptr = x_data + n * in_channel + k; + std::vector vec; + vec.resize(size); + for (int i = 0; i < size; i++) { + vec[i] = in_ptr[i * in_stride]; + } + // sort + std::partial_sort( + vec.begin(), vec.begin() + 1, vec.end(), std::greater()); + + // out + auto* out_ptr = out_data + n * out_channel + k; + *out_ptr = vec[0]; + } + } +} + +template +void max_baseline(const float* x_data, + float* out_data, + const DDim input_dims, + const DDim output_dims, + std::vector dim) { + lite::Tensor tin_tmp; + lite::Tensor tout_tmp; + tin_tmp.Resize(input_dims); + tout_tmp.Resize(input_dims); + float* tmp_in = tin_tmp.mutable_data(); + float* tmp_out = tout_tmp.mutable_data(); + DDim in_dim = input_dims; + DDim out_dim = input_dims; + std::vector real_dim = dim; + if (dim.size() == 0) { + real_dim.resize(input_dims.size()); + for (int i = 0; i < real_dim.size(); ++i) { + real_dim[i] = i; + } + } + for (size_t i = 0; i < real_dim.size(); i++) { + const float* input_data = (i == 0) ? x_data : tmp_in; + float* output_data = (i == real_dim.size() - 1) ? out_data : tmp_out; + out_dim[real_dim[i]] = 1; + max_baseline_dim_single( + input_data, output_data, in_dim, out_dim, real_dim[i]); + std::swap(tmp_in, tmp_out); + in_dim = out_dim; + } +} + +void max_test(const lite_api::CLPrecisionType p, + std::vector dim, + bool keepdims, + DDim x_dims) { + std::unique_ptr context(new KernelContext); + context->As().InitOnce(); + CLRuntime::Global()->set_precision(p); + const bool fp16_flag = (p == lite_api::CLPrecisionType::CL_PRECISION_FP16); + LOG(INFO) << "\n\t[ START ] Test Precision=" + << lite_api::CLPrecisionTypeToStr(p); + + auto kernels = KernelRegistry::Global().Create( + "max", TARGET(kOpenCL), PRECISION(kFP16), DATALAYOUT(kImageDefault)); + ASSERT_FALSE(kernels.empty()); + auto kernel = std::move(kernels.front()); + + bool reduce_all = false; + auto x_rank = x_dims.size(); + lite::Tensor x, out; + x.Resize(x_dims); + if (!dim.empty()) { + for (size_t i = 0; i < dim.size(); i++) { + if (dim[i] < 0) { + dim[i] += x_rank; + } + } + } + + std::stable_sort(dim.begin(), dim.end()); + if (dim.size() == 0) { + reduce_all = true; + } + + std::vector out_dims_shape; + // DDim out_dims; + if (reduce_all) { + if (keepdims) { + for (size_t i = 0; i < x_dims.size(); i++) { + out_dims_shape.push_back(1); + } + } else { + out_dims_shape.push_back(1); + } + } else { + for (size_t i = 0; i < x_dims.size(); i++) { + out_dims_shape.push_back(x_dims[i]); + } + if (keepdims) { + for (size_t i = 0; i < dim.size(); ++i) { + out_dims_shape[dim[i]] = 1L; + } + } else { + int64_t kDelFlag = -2; + for (size_t i = 0; i < dim.size(); ++i) { + out_dims_shape[dim[i]] = kDelFlag; + } + out_dims_shape.erase( + remove(out_dims_shape.begin(), out_dims_shape.end(), kDelFlag), + out_dims_shape.end()); + } + if (!keepdims && out_dims_shape.empty()) { + out_dims_shape.push_back(1); + } + out.Resize(DDim(out_dims_shape)); + } + + DDim out_dims = DDim(out_dims_shape); + operators::ReduceParam param; + param.X = &x; + param.Out = &out; + param.dim = dim; + param.keep_dim = keepdims; + param.reduce_all = reduce_all; + + kernel->SetParam(param); + kernel->SetContext(std::move(context)); + + std::vector x_cpu(x_dims.production()); + std::vector out_from_cpu(out_dims.production()); + std::vector out_from_gpu(out_dims.production()); + fill_data_rand(x_cpu.data(), -1.f, 1.f, x_dims.production()); + + CLImageConverterDefault* default_converter = new CLImageConverterDefault(); + DDim x_image_shape = default_converter->InitImageDimInfoWith(x_dims); + DDim out_image_shape = default_converter->InitImageDimInfoWith(out_dims); + VLOG(4) << "x_image_shape = " << x_image_shape[0] << " " << x_image_shape[1]; + VLOG(4) << "out_image_shape = " << out_image_shape[0] << " " + << out_image_shape[1]; + + const size_t dtype_size = fp16_flag ? sizeof(half_t) : sizeof(float); + std::vector x_image_data(x_image_shape.production() * 4 * dtype_size); + default_converter->NCHWToImage(x_cpu.data(), x_image_data.data(), x_dims); + MUTABLE_DATA_GPU(&x, x_image_shape[0], x_image_shape[1], x_image_data.data()); + auto* out_image = + MUTABLE_DATA_GPU(&out, out_image_shape[0], out_image_shape[1], nullptr); + + // run opencl kernel + kernel->Launch(); + CLRuntime::Global()->command_queue().finish(); + + const size_t cl_image2d_row_pitch{0}; + const size_t cl_image2d_slice_pitch{0}; + std::vector out_image_data(out_image_shape.production() * 4 * + dtype_size); // 4 : RGBA + TargetWrapperCL::ImgcpySync(out_image_data.data(), + out_image, + out_image_shape[0], + out_image_shape[1], + cl_image2d_row_pitch, + cl_image2d_slice_pitch, + IoDirection::DtoH); + default_converter->ImageToNCHW( + out_image_data.data(), out_from_gpu.data(), out_image_shape, out_dims); + + // run cpu ref + max_baseline(x_cpu.data(), out_from_cpu.data(), x_dims, out_dims, dim); + + VLOG(4) << "output_data vs output_ref_data"; + auto relative_diff_thres = + fp16_flag ? FP16_RELATIVE_DIFF : FP32_RELATIVE_DIFF; + auto abs_diff_thres = fp16_flag ? FP16_ABS_DIFF : FP32_ABS_DIFF; + uint32_t diff_cnt = 0; + for (int i = 0; i < out_dims.production(); i++) { + auto relative_diff = + COMPUTE_RELATIVE_DIFF(out_from_gpu[i], out_from_cpu[i]); + auto abs_diff = COMPUTE_ABS_DIFF(out_from_gpu[i], out_from_cpu[i]); + EXPECT_FALSE(relative_diff > relative_diff_thres && + abs_diff > abs_diff_thres); + if (relative_diff > relative_diff_thres && abs_diff > abs_diff_thres) { + LOG(WARNING) << lite_api::CLPrecisionTypeToStr(p) << " err idx: " << i + << " abs_diff: " << abs_diff + << "\t relative_diff: " << relative_diff + << "\t out_ins: " << out_from_gpu[i] + << "\t out_ref: " << out_from_cpu[i]; + diff_cnt++; + } + } + if (diff_cnt != 0) { + LOG(FATAL) << "Err num " << diff_cnt << "/" << out_dims.production(); + } + + LOG(INFO) << "\n\t[ PASSED ] " + << " Test Precision=" << lite_api::CLPrecisionTypeToStr(p) + << " x_dim=" << x_dims; +} + +void test_max_opencl_4d() { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32, + lite_api::CLPrecisionType::CL_PRECISION_FP16}) { + std::vector> reduce_dim{{}, + {0}, + {1}, + {2}, + {3}, + {0, 1}, + {1, 2}, + {2, 3}, + {-2, -1}, + {0, 1, 2}, + {1, 2, 3}}; + for (int n : {1, 3}) { + for (int c : {1, 3}) { + for (int h : {1, 3}) { + for (int w : {1, 3}) { + for (auto dim : reduce_dim) { + auto x_dims = DDim(std::vector({n, c, h, w})); + max_test(precision_type, dim, true, x_dims); + } + } + } + } + } + } +} + +void test_max_opencl_3d() { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32, + lite_api::CLPrecisionType::CL_PRECISION_FP16}) { + std::vector> reduce_dim{{}, {0}, {1}, {2}, {0, 1}, {1, 2}}; + for (int c : {1, 3}) { + for (int h : {1, 3}) { + for (int w : {1, 3}) { + for (auto dim : reduce_dim) { + auto x_dims = DDim(std::vector({c, h, w})); + max_test(precision_type, dim, true, x_dims); + } + } + } + } + } +} + +void test_max_opencl_2d() { + for (const auto precision_type : + {lite_api::CLPrecisionType::CL_PRECISION_FP32, + lite_api::CLPrecisionType::CL_PRECISION_FP16}) { + std::vector> reduce_dim{{}, {0}, {1}}; + for (int h : {2, 3}) { + for (int w : {2, 3}) { + for (auto dim : reduce_dim) { + auto x_dims = DDim(std::vector({h, w})); + max_test(precision_type, dim, true, x_dims); + } + } + } + } +} + +TEST(max, compute_basic) { + test_max_opencl_4d(); + test_max_opencl_3d(); + test_max_opencl_2d(); +} + +} // namespace lite +} // namespace paddle + +USE_LITE_KERNEL(reduce_max, kOpenCL, kFP16, kImageDefault, image2d); diff --git a/tools/ci_tools/ci_android_opencl_unit_test.sh b/tools/ci_tools/ci_android_opencl_unit_test.sh index fe6de2c57b5..8518cef90c6 100755 --- a/tools/ci_tools/ci_android_opencl_unit_test.sh +++ b/tools/ci_tools/ci_android_opencl_unit_test.sh @@ -8,7 +8,7 @@ WORKSPACE=${SHELL_FOLDER%tools/ci_tools*} TESTS_FILE="./lite_tests.txt" NUM_PROC=4 -skip_list=("test_model_parser" "test_mobilenetv2" \ +skip_list=("test_model_parser" \ "test_resnet50" "test_inceptionv4" "test_light_api" "test_apis" \ "test_paddle_api" "test_cxx_api" "test_gen_code" \ "test_mobilenetv1_int8" "test_subgraph_pass" \