diff --git a/paddle/fluid/operators/lstsq_op.cc b/paddle/fluid/operators/lstsq_op.cc new file mode 100644 index 00000000000000..65fe99e2ead2eb --- /dev/null +++ b/paddle/fluid/operators/lstsq_op.cc @@ -0,0 +1,142 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/lstsq_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class LstsqOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp"); + OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp"); + + OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp"); + OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp"); + OP_INOUT_CHECK(ctx->HasOutput("SingularValues"), "Output", "SingularValues", + "LstsqOp"); + + auto x_dims = ctx->GetInputDim("X"); + auto y_dims = ctx->GetInputDim("Y"); + int x_rank = x_dims.size(); + int y_rank = y_dims.size(); + + PADDLE_ENFORCE_GE(x_rank, 2, + platform::errors::InvalidArgument( + "Expects input tensor x to be not less than " + "2 dimentions, but got dimention %d", + x_rank)); + PADDLE_ENFORCE_GE(y_rank, 2, + platform::errors::InvalidArgument( + "Expects input tensor y to be not less than " + "2 dimentions, but got dimention %d", + y_rank)); + + PADDLE_ENFORCE_EQ( + x_rank, y_rank, + platform::errors::InvalidArgument( + "Expects input tensor x and y to have the same dimension " + "but got x's dimention [%d] and y's dimention [%d]", + x_rank, y_rank)); + + std::vector batch_dims_vec{}; + for (int i = 0; i < x_rank - 2; ++i) { + PADDLE_ENFORCE_EQ( + x_dims[i], y_dims[i], + platform::errors::InvalidArgument( + "Expects input tensor x and y to have the same batch " + "dimension, but got x's batch dimention [%d] and " + "y's batch dimention [%d] in %d-th dim", + x_dims[i], y_dims[i], i)); + batch_dims_vec.emplace_back(x_dims[i]); + } + + PADDLE_ENFORCE_EQ( + x_dims[x_rank - 2], y_dims[y_rank - 2], + platform::errors::InvalidArgument( + "Expects input tensor x and y to have the same row dimension " + "of the inner-most 2-dims matrix, " + "but got x's row dimention [%d] and y's row dimention [%d]", + x_dims[x_rank - 2], y_dims[y_rank - 2])); + + ctx->SetOutputDim("Rank", framework::make_ddim(batch_dims_vec)); + + batch_dims_vec.emplace_back( + std::min(x_dims[x_rank - 2], x_dims[x_rank - 1])); + ctx->SetOutputDim("SingularValues", framework::make_ddim(batch_dims_vec)); + + batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1]; + batch_dims_vec.emplace_back(y_dims[x_rank - 1]); + ctx->SetOutputDim("Solution", framework::make_ddim(batch_dims_vec)); + } + + protected: + // The output of lstsq is always complex-valued even for real-valued inputs + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X"); + if (dtype != framework::proto::VarType::FP32 && + dtype != framework::proto::VarType::FP64) { + PADDLE_THROW(platform::errors::InvalidArgument( + "unsupported data type: %s!", dtype)); + } + return framework::OpKernelType(dtype, ctx.GetPlace()); + } +}; + +class LstsqOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", + "(Tensor), A real-valued tensor with shape (*, m, n). " + "The accepted datatype is one of float32, float64"); + AddInput("Y", + "(Tensor), A real-valued tensor with shape (*, m, k). " + "The accepted datatype is one of float32, float64"); + AddAttr( + "rcond", + "(float, default 0.0), A float value used to determine the effective " + "rank of A.") + .SetDefault(0.0f); + AddAttr("driver", + "(string, default \"gels\"). " + "name of the LAPACK method to be used.") + .SetDefault("gels"); + AddOutput("Solution", + "(Tensor), The output Solution tensor with shape (*, n, k)."); + AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*)."); + AddOutput( + "SingularValues", + "(Tensor), The output SingularValues tensor with shape (*, min(m,n))."); + AddComment(R"DOC( + Lstsq Operator. +This API processes Lstsq functor for general matrices. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker) + +REGISTER_OP_CPU_KERNEL( + lstsq, ops::LstsqCPUKernel, + ops::LstsqCPUKernel); \ No newline at end of file diff --git a/paddle/fluid/operators/lstsq_op.h b/paddle/fluid/operators/lstsq_op.h new file mode 100644 index 00000000000000..b9c5c87a6a376a --- /dev/null +++ b/paddle/fluid/operators/lstsq_op.h @@ -0,0 +1,229 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include +#include "paddle/fluid/operators/eig_op.h" +#include "paddle/fluid/operators/math/complex_functors.h" +#include "paddle/fluid/operators/math/eigen_values_vectors.h" +#include "paddle/fluid/operators/math/lapack_function.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/math/matrix_solve.h" +#include "paddle/fluid/operators/svd_helper.h" +#include "paddle/fluid/operators/transpose_op.h" +#include "paddle/fluid/operators/triangular_solve_op.h" +#include "paddle/fluid/platform/for_range.h" + +#define EPSILON 1e-6 + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; +enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss }; + +using DDim = framework::DDim; +static DDim UDDim(const DDim& x_dim) { + auto x_vec = vectorize(x_dim); + return framework::make_ddim(x_vec); +} + +template +class LstsqCPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + using ValueType = math::Real; + + const Tensor& x = *context.Input("X"); + const Tensor& y = *context.Input("Y"); + auto rcond = context.Attr("rcond"); + auto driver_string = context.Attr("driver"); + + static auto driver_type = std::unordered_map( + {{"gels", LapackDriverType::Gels}, + {"gelsy", LapackDriverType::Gelsy}, + {"gelsd", LapackDriverType::Gelsd}, + {"gelss", LapackDriverType::Gelss}}); + auto driver = driver_type[driver_string]; + + auto solution = context.Output("Solution"); + auto* rank = context.Output("Rank"); + auto* singular_values = context.Output("SingularValues"); + + auto dito = + math::DeviceIndependenceTensorOperations(context); + + auto x_dims = x.dims(); + auto y_dims = y.dims(); + int dim_size = x_dims.size(); + int x_stride = MatrixStride(x); + int y_stride = MatrixStride(y); + int batch_count = BatchCount(x); + auto ori_solution_dim = solution->dims(); + int ori_solu_stride = MatrixStride(*solution); + + // lapack is a column-major storge, transpose make the input to + // have a continuous memory layout + int info = 0; + int m = x_dims[dim_size - 2]; + int n = x_dims[dim_size - 1]; + int nrhs = y_dims[dim_size - 1]; + int lda = std::max(m, 1); + int ldb = std::max(1, std::max(m, n)); + + Tensor new_x; + new_x.mutable_data(context.GetPlace(), + size_t(batch_count * m * n * sizeof(T))); + solution->mutable_data( + context.GetPlace(), + size_t(batch_count * std::max(m, n) * nrhs * sizeof(T))); + framework::TensorCopy(x, context.GetPlace(), &new_x); + framework::TensorCopy(y, context.GetPlace(), solution); + + if (m < n) solution->Resize(UDDim(ori_solution_dim)); + + Tensor input_x_trans = dito.Transpose(new_x); + Tensor input_y_trans = dito.Transpose(*solution); + framework::TensorCopy(input_x_trans, new_x.place(), &new_x); + framework::TensorCopy(input_y_trans, solution->place(), solution); + + auto* x_vector = new_x.data(); + auto* y_vector = solution->data(); + + // "gels" divers does not need to compute rank + int rank_32 = 0; + int* rank_data = nullptr; + int* rank_working_ptr = nullptr; + if (driver != LapackDriverType::Gels) { + rank_data = rank->mutable_data(context.GetPlace()); + rank_working_ptr = rank_data; + } + + // "gelsd" and "gelss" divers need to compute singular values + ValueType* s_data = nullptr; + ValueType* s_working_ptr = nullptr; + int s_stride = 0; + if (driver == LapackDriverType::Gelsd || + driver == LapackDriverType::Gelss) { + s_data = singular_values->mutable_data(context.GetPlace()); + s_working_ptr = s_data; + auto s_dims = singular_values->dims(); + s_stride = s_dims[s_dims.size() - 1]; + } + + // "jpvt" is only used for "gelsy" driver + Tensor jpvt; + int* jpvt_data = nullptr; + if (driver == LapackDriverType::Gelsy) { + jpvt.Resize(framework::make_ddim({std::max(1, n)})); + jpvt_data = jpvt.mutable_data(context.GetPlace()); + } + + // run once the driver, first to get the optimal workspace size + int lwork = -1; + T wkopt; + ValueType rwkopt; + int iwkopt = 0; + + if (driver == LapackDriverType::Gels) { + math::lapackGels('N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt, + lwork, &info); + } else if (driver == LapackDriverType::Gelsd) { + math::lapackGelsd(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr, + static_cast(rcond), &rank_32, &wkopt, lwork, + &rwkopt, &iwkopt, &info); + } else if (driver == LapackDriverType::Gelsy) { + math::lapackGelsy(m, n, nrhs, x_vector, lda, y_vector, ldb, jpvt_data, + static_cast(rcond), &rank_32, &wkopt, lwork, + &rwkopt, &info); + } else if (driver == LapackDriverType::Gelss) { + math::lapackGelss(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr, + static_cast(rcond), &rank_32, &wkopt, lwork, + &rwkopt, &info); + } + + lwork = std::max(1, static_cast(math::Real(wkopt))); + Tensor work; + work.Resize(framework::make_ddim({lwork})); + T* work_data = work.mutable_data(context.GetPlace()); + + // "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers + Tensor rwork; + ValueType* rwork_data = nullptr; + if (framework::IsComplexType(x.type()) && + driver != LapackDriverType::Gels) { + int rwork_len = 0; + if (driver == LapackDriverType::Gelsy) { + rwork_len = std::max(1, 2 * n); + } else if (driver == LapackDriverType::Gelss) { + rwork_len = std::max(1, 5 * std::min(m, n)); + } else if (driver == LapackDriverType::Gelsd) { + rwork_len = std::max(1, rwkopt); + } + rwork.Resize(framework::make_ddim({rwork_len})); + rwork_data = rwork.mutable_data(context.GetPlace()); + } + + // "iwork" workspace array is relavant only for "gelsd" driver + Tensor iwork; + int* iwork_data = nullptr; + if (driver == LapackDriverType::Gelsd) { + iwork.Resize(framework::make_ddim({std::max(1, iwkopt)})); + iwork_data = iwork.mutable_data(context.GetPlace()); + } + + int solu_stride = std::max(y_stride, ori_solu_stride); + for (auto i = 0; i < batch_count; ++i) { + auto* x_input = &x_vector[i * x_stride]; + auto* y_input = &y_vector[i * solu_stride]; + rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr; + s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr; + + if (driver == LapackDriverType::Gels) { + math::lapackGels('N', m, n, nrhs, x_input, lda, y_input, ldb, work_data, + lwork, &info); + } else if (driver == LapackDriverType::Gelsd) { + math::lapackGelsd(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr, + static_cast(rcond), &rank_32, work_data, + lwork, rwork_data, iwork_data, &info); + } else if (driver == LapackDriverType::Gelsy) { + math::lapackGelsy(m, n, nrhs, x_input, lda, y_input, ldb, jpvt_data, + static_cast(rcond), &rank_32, work_data, + lwork, rwork_data, &info); + } else if (driver == LapackDriverType::Gelss) { + math::lapackGelss(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr, + static_cast(rcond), &rank_32, work_data, + lwork, rwork_data, &info); + } + + PADDLE_ENFORCE_EQ( + info, 0, + platform::errors::PreconditionNotMet( + "For batch [%d]: Lapack info is not zero but [%d]", i, info)); + + if (rank_working_ptr) *rank_working_ptr = static_cast(rank_32); + } + + Tensor tmp_s = dito.Transpose(*solution); + framework::TensorCopy(tmp_s, solution->place(), solution); + + if (m >= n) solution->Resize(UDDim(ori_solution_dim)); + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/math/lapack_function.cc b/paddle/fluid/operators/math/lapack_function.cc index 450400e35cdd8a..33fa2efb12c1bb 100644 --- a/paddle/fluid/operators/math/lapack_function.cc +++ b/paddle/fluid/operators/math/lapack_function.cc @@ -125,6 +125,70 @@ void lapackEig, float>( reinterpret_cast *>(work), &lwork, rwork, info); } +template <> +void lapackGels(char trans, int m, int n, int nrhs, double *a, int lda, + double *b, int ldb, double *work, int lwork, + int *info) { + platform::dynload::dgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, + &lwork, info); +} + +template <> +void lapackGels(char trans, int m, int n, int nrhs, float *a, int lda, + float *b, int ldb, float *work, int lwork, int *info) { + platform::dynload::sgels_(&trans, &m, &n, &nrhs, a, &lda, b, &ldb, work, + &lwork, info); +} + +template <> +void lapackGelsd(int m, int n, int nrhs, double *a, int lda, double *b, + int ldb, double *s, double rcond, int *rank, + double *work, int lwork, double *rwork, int *iwork, + int *info) { + platform::dynload::dgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, + work, &lwork, iwork, info); +} + +template <> +void lapackGelsd(int m, int n, int nrhs, float *a, int lda, float *b, + int ldb, float *s, float rcond, int *rank, float *work, + int lwork, float *rwork, int *iwork, int *info) { + platform::dynload::sgelsd_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, + work, &lwork, iwork, info); +} + +template <> +void lapackGelsy(int m, int n, int nrhs, double *a, int lda, double *b, + int ldb, int *jpvt, double rcond, int *rank, + double *work, int lwork, double *rwork, int *info) { + platform::dynload::dgelsy_(&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, + rank, work, &lwork, info); +} + +template <> +void lapackGelsy(int m, int n, int nrhs, float *a, int lda, float *b, + int ldb, int *jpvt, float rcond, int *rank, float *work, + int lwork, float *rwork, int *info) { + platform::dynload::sgelsy_(&m, &n, &nrhs, a, &lda, b, &ldb, jpvt, &rcond, + rank, work, &lwork, info); +} + +template <> +void lapackGelss(int m, int n, int nrhs, double *a, int lda, double *b, + int ldb, double *s, double rcond, int *rank, + double *work, int lwork, double *rwork, int *info) { + platform::dynload::dgelss_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, + work, &lwork, info); +} + +template <> +void lapackGelss(int m, int n, int nrhs, float *a, int lda, float *b, + int ldb, float *s, float rcond, int *rank, float *work, + int lwork, float *rwork, int *info) { + platform::dynload::sgelss_(&m, &n, &nrhs, a, &lda, b, &ldb, s, &rcond, rank, + work, &lwork, info); +} + template <> void lapackCholeskySolve>( char uplo, int n, int nrhs, platform::complex *a, int lda, diff --git a/paddle/fluid/operators/math/lapack_function.h b/paddle/fluid/operators/math/lapack_function.h index b3275d2ced614b..488b225ef570e5 100644 --- a/paddle/fluid/operators/math/lapack_function.h +++ b/paddle/fluid/operators/math/lapack_function.h @@ -20,21 +20,46 @@ namespace math { // LU (for example) template -void lapackLu(int m, int n, T* a, int lda, int* ipiv, int* info); +void lapackLu(int m, int n, T *a, int lda, int *ipiv, int *info); +// Eigh template -void lapackEigh(char jobz, char uplo, int n, T* a, int lda, ValueType* w, - T* work, int lwork, ValueType* rwork, int lrwork, int* iwork, - int liwork, int* info); +void lapackEigh(char jobz, char uplo, int n, T *a, int lda, ValueType *w, + T *work, int lwork, ValueType *rwork, int lrwork, int *iwork, + int liwork, int *info); +// Eig template -void lapackEig(char jobvl, char jobvr, int n, T1* a, int lda, T1* w, T1* vl, - int ldvl, T1* vr, int ldvr, T1* work, int lwork, T2* rwork, - int* info); +void lapackEig(char jobvl, char jobvr, int n, T1 *a, int lda, T1 *w, T1 *vl, + int ldvl, T1 *vr, int ldvr, T1 *work, int lwork, T2 *rwork, + int *info); +// Gels template -void lapackCholeskySolve(char uplo, int n, int nrhs, T* a, int lda, T* b, - int ldb, int* info); +void lapackGels(char trans, int m, int n, int nrhs, T *a, int lda, T *b, + int ldb, T *work, int lwork, int *info); + +// Gelsd +template +void lapackGelsd(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, T2 *s, + T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork, + int *iwork, int *info); + +// Gelsy +template +void lapackGelsy(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, + int *jpvt, T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork, + int *info); + +// Gelss +template +void lapackGelss(int m, int n, int nrhs, T1 *a, int lda, T1 *b, int ldb, T2 *s, + T2 rcond, int *rank, T1 *work, int lwork, T2 *rwork, + int *info); + +template +void lapackCholeskySolve(char uplo, int n, int nrhs, T *a, int lda, T *b, + int ldb, int *info); } // namespace math } // namespace operators diff --git a/paddle/fluid/platform/dynload/lapack.h b/paddle/fluid/platform/dynload/lapack.h index 32d7461f42d262..ce24b98defbe99 100644 --- a/paddle/fluid/platform/dynload/lapack.h +++ b/paddle/fluid/platform/dynload/lapack.h @@ -66,6 +66,39 @@ extern "C" void cgeev_(char *jobvl, char *jobvr, int *n, std::complex *a, std::complex *work, int *lwork, float *rwork, int *info); +// gels +extern "C" void dgels_(char *trans, int *m, int *n, int *nrhs, double *a, + int *lda, double *b, int *ldb, double *work, int *lwork, + int *info); +extern "C" void sgels_(char *trans, int *m, int *n, int *nrhs, float *a, + int *lda, float *b, int *ldb, float *work, int *lwork, + int *info); + +// gelsd +extern "C" void dgelsd_(int *m, int *n, int *nrhs, double *a, int *lda, + double *b, int *ldb, double *s, double *rcond, + int *rank, double *work, int *lwork, int *iwork, + int *info); +extern "C" void sgelsd_(int *m, int *n, int *nrhs, float *a, int *lda, float *b, + int *ldb, float *s, float *rcond, int *rank, + float *work, int *lwork, int *iwork, int *info); + +// gelsy +extern "C" void dgelsy_(int *m, int *n, int *nrhs, double *a, int *lda, + double *b, int *ldb, int *jpvt, double *rcond, + int *rank, double *work, int *lwork, int *info); +extern "C" void sgelsy_(int *m, int *n, int *nrhs, float *a, int *lda, float *b, + int *ldb, int *jpvt, float *rcond, int *rank, + float *work, int *lwork, int *info); + +// gelss +extern "C" void dgelss_(int *m, int *n, int *nrhs, double *a, int *lda, + double *b, int *ldb, double *s, double *rcond, + int *rank, double *work, int *lwork, int *info); +extern "C" void sgelss_(int *m, int *n, int *nrhs, float *a, int *lda, float *b, + int *ldb, float *s, float *rcond, int *rank, + float *work, int *lwork, int *info); + extern "C" void zpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, int *lda, std::complex *b, int *ldb, int *info); extern "C" void cpotrs_(char *uplo, int *n, int *nrhs, std::complex *a, @@ -115,6 +148,14 @@ extern void *lapack_dso_handle; __macro(sgeev_); \ __macro(zgeev_); \ __macro(cgeev_); \ + __macro(dgels_); \ + __macro(sgels_); \ + __macro(dgelsd_); \ + __macro(sgelsd_); \ + __macro(dgelsy_); \ + __macro(sgelsy_); \ + __macro(dgelss_); \ + __macro(sgelss_); \ __macro(zpotrs_); \ __macro(cpotrs_); \ __macro(dpotrs_); \ diff --git a/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py new file mode 100644 index 00000000000000..4c0325a35f32e9 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_linalg_lstsq_op.py @@ -0,0 +1,212 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid + + +class LinalgLstsqTestCase(unittest.TestCase): + def setUp(self): + self.init_config() + self.generate_input() + self.generate_output() + + def init_config(self): + self.dtype = 'float64' + self.rcond = 1e-15 + self.driver = "gelsd" + self._input_shape_1 = (5, 4) + self._input_shape_2 = (5, 3) + + def generate_input(self): + self._input_data_1 = np.random.random(self._input_shape_1).astype( + self.dtype) + self._input_data_2 = np.random.random(self._input_shape_2).astype( + self.dtype) + + def generate_output(self): + if len(self._input_shape_1) == 2: + out = np.linalg.lstsq( + self._input_data_1, self._input_data_2, rcond=self.rcond) + elif len(self._input_shape_1) == 3: + out = np.linalg.lstsq( + self._input_data_1[0], self._input_data_2[0], rcond=self.rcond) + + self._output_solution = out[0] + self._output_residuals = out[1] + self._output_rank = out[2] + self._output_sg_values = out[3] + + def test_dygraph(self): + paddle.disable_static() + paddle.device.set_device("cpu") + place = paddle.CPUPlace() + x = paddle.to_tensor(self._input_data_1, place=place, dtype=self.dtype) + y = paddle.to_tensor(self._input_data_2, place=place, dtype=self.dtype) + results = paddle.linalg.lstsq( + x, y, rcond=self.rcond, driver=self.driver) + + res_solution = results[0].numpy() + res_residuals = results[1].numpy() + res_rank = results[2].numpy() + res_singular_values = results[3].numpy() + + if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]: + if (np.abs(res_residuals - self._output_residuals) < 1e-6).any(): + pass + else: + raise RuntimeError("Check LSTSQ residuals dygraph Failed") + + if self.driver in ("gelsy", "gelsd", "gelss"): + if (np.abs(res_rank - self._output_rank) < 1e-6).any(): + pass + else: + raise RuntimeError("Check LSTSQ rank dygraph Failed") + + if self.driver in ("gelsd", "gelss"): + if (np.abs(res_singular_values - self._output_sg_values) < 1e-6 + ).any(): + pass + else: + raise RuntimeError("Check LSTSQ singular values dygraph Failed") + + def test_static(self): + paddle.enable_static() + paddle.device.set_device("cpu") + place = fluid.CPUPlace() + with fluid.program_guard(fluid.Program(), fluid.Program()): + x = paddle.fluid.data( + name="x", + shape=self._input_shape_1, + dtype=self._input_data_1.dtype) + y = paddle.fluid.data( + name="y", + shape=self._input_shape_2, + dtype=self._input_data_2.dtype) + results = paddle.linalg.lstsq( + x, y, rcond=self.rcond, driver=self.driver) + exe = fluid.Executor(place) + fetches = exe.run( + fluid.default_main_program(), + feed={"x": self._input_data_1, + "y": self._input_data_2}, + fetch_list=[results]) + + if x.shape[-2] > x.shape[-1] and self._output_rank == x.shape[-1]: + if (np.abs(fetches[1] - self._output_residuals) < 1e-6).any(): + pass + else: + raise RuntimeError("Check LSTSQ residuals static Failed") + + if self.driver in ("gelsy", "gelsd", "gelss"): + if (np.abs(fetches[2] - self._output_rank) < 1e-6).any(): + pass + else: + raise RuntimeError("Check LSTSQ rank static Failed") + + if self.driver in ("gelsd", "gelss"): + if (np.abs(fetches[3] - self._output_sg_values) < 1e-6).any(): + pass + else: + raise RuntimeError( + "Check LSTSQ singular values static Failed") + + +class LinalgLstsqTestCase(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float64' + self.rcond = 1e-15 + self.driver = "gels" + self._input_shape_1 = (5, 10) + self._input_shape_2 = (5, 5) + + +class LinalgLstsqTestCaseRcond(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float64' + self.rcond = 0.1 + self.driver = "gels" + self._input_shape_1 = (3, 2) + self._input_shape_2 = (3, 3) + + +class LinalgLstsqTestCaseGelsFloat32(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float32' + self.rcond = 1e-15 + self.driver = "gels" + self._input_shape_1 = (10, 5) + self._input_shape_2 = (10, 2) + + +class LinalgLstsqTestCaseGelssFloat64(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float64' + self.rcond = 1e-15 + self.driver = "gelss" + self._input_shape_1 = (5, 5) + self._input_shape_2 = (5, 1) + + +class LinalgLstsqTestCaseGelsyFloat32(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float32' + self.rcond = 1e-15 + self.driver = "gelsy" + self._input_shape_1 = (8, 2) + self._input_shape_2 = (8, 10) + + +class LinalgLstsqTestCaseBatch1(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float32' + self.rcond = 1e-15 + self.driver = None + self._input_shape_1 = (2, 3, 10) + self._input_shape_2 = (2, 3, 4) + + +class LinalgLstsqTestCaseBatch2(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float64' + self.rcond = 1e-15 + self.driver = "gelss" + self._input_shape_1 = (2, 8, 6) + self._input_shape_2 = (2, 8, 2) + + +class LinalgLstsqTestCaseLarge1(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float64' + self.rcond = 1e-15 + self.driver = "gelsd" + self._input_shape_1 = (200, 100) + self._input_shape_2 = (200, 50) + + +class LinalgLstsqTestCaseLarge2(LinalgLstsqTestCase): + def init_config(self): + self.dtype = 'float32' + self.rcond = 1e-15 + self.driver = "gelss" + self._input_shape_1 = (50, 600) + self._input_shape_2 = (50, 300) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/linalg.py b/python/paddle/linalg.py index 6b83448d0bf49a..c540d46d024978 100644 --- a/python/paddle/linalg.py +++ b/python/paddle/linalg.py @@ -31,6 +31,7 @@ from .tensor.linalg import slogdet # noqa: F401 from .tensor.linalg import pinv # noqa: F401 from .tensor.linalg import triangular_solve # noqa: F401 +from .tensor.linalg import lstsq __all__ = [ 'cholesky', #noqa @@ -52,4 +53,5 @@ 'solve', 'cholesky_solve', 'triangular_solve', + 'lstsq' ] diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index ce6c3e5350f6ab..b8575319058229 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -42,6 +42,7 @@ from .linalg import norm # noqa: F401 from .linalg import cond # noqa: F401 from .linalg import transpose # noqa: F401 +from .linalg import lstsq # noqa: F401 from .linalg import dist # noqa: F401 from .linalg import t # noqa: F401 from .linalg import cross # noqa: F401 @@ -263,6 +264,7 @@ 'norm', 'cond', 'transpose', + 'lstsq', 'dist', 't', 'cross', diff --git a/python/paddle/tensor/linalg.py b/python/paddle/tensor/linalg.py index a8c565f336a35f..6c2af7e767c587 100644 --- a/python/paddle/tensor/linalg.py +++ b/python/paddle/tensor/linalg.py @@ -23,7 +23,6 @@ from paddle.common_ops_import import core from paddle.common_ops_import import VarDesc from paddle import _C_ops -import paddle __all__ = [] @@ -2503,3 +2502,107 @@ def __check_input(x, UPLO): attrs={'UPLO': UPLO, 'is_test': is_test}) return out_value + + +def lstsq(x, y, rcond=1e-15, driver=None, name=None): + device = paddle.device.get_device() + if device == "cpu": + if driver not in (None, "gels", "gelss", "gelsd", "gelsy"): + raise ValueError( + "Only support valid driver is 'gels', 'gelss', 'gelsd', 'gelsy' or None for CPU inputs. But got {}". + format(driver)) + driver = "gelsy" if driver is None else driver + elif "gpu" in device: + if driver not in (None, "gels"): + raise ValueError( + "Only support valid driver is 'gels' or None for CUDA inputs. But got {}". + format(driver)) + driver = "gels" if driver is None else driver + else: + raise RuntimeError("Only support lstsq api for CPU or CUDA device.") + + if in_dygraph_mode(): + solution, rank, singular_values = _C_ops.lstsq(x, y, "rcond", rcond, + "driver", driver) + if x.shape[-2] > x.shape[-1]: + matmul_out = _varbase_creator(dtype=x.dtype) + _C_ops.matmul(x, solution, matmul_out, 'trans_x', False, 'trans_y', + False) + minus_out = _C_ops.elementwise_sub(matmul_out, y) + pow_out = _C_ops.pow(minus_out, 'factor', 2) + residuals = _C_ops.reduce_sum(pow_out, 'dim', [-2], 'keepdim', + False, 'reduce_all', False) + else: + residuals = paddle.empty(shape=[0], dtype=x.dtype) + + if driver == "gels": + rank = paddle.empty(shape=[0], dtype=paddle.int32) + singular_values = paddle.empty(shape=[0], dtype=x.dtype) + elif driver == "gelsy": + singular_values = paddle.empty(shape=[0], dtype=x.dtype) + + return solution, residuals, rank, singular_values + + helper = LayerHelper('lstsq', **locals()) + check_variable_and_dtype( + x, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'lstsq') + check_variable_and_dtype( + y, 'dtype', ['float32', 'float64', 'complex64', 'complex128'], 'lstsq') + + solution = helper.create_variable_for_type_inference(dtype=x.dtype) + residuals = helper.create_variable_for_type_inference(dtype=x.dtype) + rank = helper.create_variable_for_type_inference(dtype=paddle.int32) + singular_values = helper.create_variable_for_type_inference(dtype=x.dtype) + + helper.append_op( + type='lstsq', + inputs={'X': x, + 'Y': y}, + outputs={ + 'Solution': solution, + 'Rank': rank, + 'SingularValues': singular_values + }, + attrs={'rcond': rcond, + 'driver': driver}) + + matmul_out = helper.create_variable_for_type_inference(dtype=x.dtype) + minus_out = helper.create_variable_for_type_inference(dtype=x.dtype) + pow_out = helper.create_variable_for_type_inference(dtype=x.dtype) + helper.append_op( + type='matmul_v2', + inputs={'X': x, + 'Y': solution}, + outputs={'Out': matmul_out}, + attrs={ + 'trans_x': False, + 'trans_y': False, + }) + + helper.append_op( + type='elementwise_sub', + inputs={'X': matmul_out, + 'Y': y}, + outputs={'Out': minus_out}) + + helper.append_op( + type='pow', + inputs={'X': minus_out}, + outputs={'Out': pow_out}, + attrs={'factor': 2}) + + helper.append_op( + type='reduce_sum', + inputs={'X': pow_out}, + outputs={'Out': residuals}, + attrs={'dim': [-2], + 'keep_dim': False, + 'reduce_all': False}) + + if driver == "gels": + rank = paddle.static.data(name='rank', shape=[0]) + singular_values = paddle.static.data(name='singular_values', shape=[0]) + elif driver == "gelsy": + singular_values = paddle.static.data(name='singular_values', shape=[0]) + + return solution, residuals, rank, singular_values diff --git a/tools/parallel_UT_rule.py b/tools/parallel_UT_rule.py index 55d1dcf005b507..4df27bfe4e9238 100755 --- a/tools/parallel_UT_rule.py +++ b/tools/parallel_UT_rule.py @@ -1088,6 +1088,7 @@ 'test_fill_any_op', 'test_frame_op', 'test_linalg_pinv_op', + 'test_linalg_lstsq_op', 'test_gumbel_softmax_op', 'test_matrix_power_op', 'test_multi_dot_op',