Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cpu kernel of new api : lstsq #38585

Merged
merged 5 commits into from
Dec 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
142 changes: 142 additions & 0 deletions paddle/fluid/operators/lstsq_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/operators/lstsq_op.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"

namespace paddle {
namespace operators {

class LstsqOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "LstsqOp");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "LstsqOp");

OP_INOUT_CHECK(ctx->HasOutput("Solution"), "Output", "Solution", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("Rank"), "Output", "Rank", "LstsqOp");
OP_INOUT_CHECK(ctx->HasOutput("SingularValues"), "Output", "SingularValues",
"LstsqOp");

auto x_dims = ctx->GetInputDim("X");
auto y_dims = ctx->GetInputDim("Y");
int x_rank = x_dims.size();
int y_rank = y_dims.size();

PADDLE_ENFORCE_GE(x_rank, 2,
platform::errors::InvalidArgument(
"Expects input tensor x to be not less than "
"2 dimentions, but got dimention %d",
x_rank));
PADDLE_ENFORCE_GE(y_rank, 2,
platform::errors::InvalidArgument(
"Expects input tensor y to be not less than "
"2 dimentions, but got dimention %d",
y_rank));

PADDLE_ENFORCE_EQ(
x_rank, y_rank,
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same dimension "
"but got x's dimention [%d] and y's dimention [%d]",
x_rank, y_rank));

std::vector<int> batch_dims_vec{};
for (int i = 0; i < x_rank - 2; ++i) {
PADDLE_ENFORCE_EQ(
x_dims[i], y_dims[i],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same batch "
"dimension, but got x's batch dimention [%d] and "
"y's batch dimention [%d] in %d-th dim",
x_dims[i], y_dims[i], i));
batch_dims_vec.emplace_back(x_dims[i]);
}

PADDLE_ENFORCE_EQ(
x_dims[x_rank - 2], y_dims[y_rank - 2],
platform::errors::InvalidArgument(
"Expects input tensor x and y to have the same row dimension "
"of the inner-most 2-dims matrix, "
"but got x's row dimention [%d] and y's row dimention [%d]",
x_dims[x_rank - 2], y_dims[y_rank - 2]));

ctx->SetOutputDim("Rank", framework::make_ddim(batch_dims_vec));

batch_dims_vec.emplace_back(
std::min(x_dims[x_rank - 2], x_dims[x_rank - 1]));
ctx->SetOutputDim("SingularValues", framework::make_ddim(batch_dims_vec));

batch_dims_vec[x_rank - 2] = x_dims[x_rank - 1];
batch_dims_vec.emplace_back(y_dims[x_rank - 1]);
ctx->SetOutputDim("Solution", framework::make_ddim(batch_dims_vec));
}

protected:
// The output of lstsq is always complex-valued even for real-valued inputs
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
auto dtype = OperatorWithKernel::IndicateVarDataType(ctx, "X");
if (dtype != framework::proto::VarType::FP32 &&
dtype != framework::proto::VarType::FP64) {
PADDLE_THROW(platform::errors::InvalidArgument(
"unsupported data type: %s!", dtype));
}
return framework::OpKernelType(dtype, ctx.GetPlace());
}
};

class LstsqOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor), A real-valued tensor with shape (*, m, n). "
"The accepted datatype is one of float32, float64");
AddInput("Y",
"(Tensor), A real-valued tensor with shape (*, m, k). "
"The accepted datatype is one of float32, float64");
AddAttr<float>(
"rcond",
"(float, default 0.0), A float value used to determine the effective "
"rank of A.")
.SetDefault(0.0f);
AddAttr<std::string>("driver",
"(string, default \"gels\"). "
"name of the LAPACK method to be used.")
.SetDefault("gels");
AddOutput("Solution",
"(Tensor), The output Solution tensor with shape (*, n, k).");
AddOutput("Rank", "(Tensor), The output Rank tensor with shape (*).");
AddOutput(
"SingularValues",
"(Tensor), The output SingularValues tensor with shape (*, min(m,n)).");
AddComment(R"DOC(
Lstsq Operator.
This API processes Lstsq functor for general matrices.
)DOC");
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OPERATOR(lstsq, ops::LstsqOp, ops::LstsqOpMaker)

REGISTER_OP_CPU_KERNEL(
lstsq, ops::LstsqCPUKernel<paddle::platform::CPUDeviceContext, float>,
ops::LstsqCPUKernel<paddle::platform::CPUDeviceContext, double>);
229 changes: 229 additions & 0 deletions paddle/fluid/operators/lstsq_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,229 @@
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <math.h>
#include <algorithm>
#include <complex>
#include "paddle/fluid/operators/eig_op.h"
#include "paddle/fluid/operators/math/complex_functors.h"
#include "paddle/fluid/operators/math/eigen_values_vectors.h"
#include "paddle/fluid/operators/math/lapack_function.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/math/matrix_solve.h"
#include "paddle/fluid/operators/svd_helper.h"
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/operators/triangular_solve_op.h"
#include "paddle/fluid/platform/for_range.h"

#define EPSILON 1e-6

namespace paddle {
namespace operators {

using paddle::framework::Tensor;
enum class LapackDriverType : int { Gels, Gelsd, Gelsy, Gelss };

using DDim = framework::DDim;
static DDim UDDim(const DDim& x_dim) {
auto x_vec = vectorize(x_dim);
return framework::make_ddim(x_vec);
}

template <typename DeviceContext, typename T>
class LstsqCPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
using ValueType = math::Real<T>;

const Tensor& x = *context.Input<Tensor>("X");
const Tensor& y = *context.Input<Tensor>("Y");
auto rcond = context.Attr<float>("rcond");
auto driver_string = context.Attr<std::string>("driver");

static auto driver_type = std::unordered_map<std::string, LapackDriverType>(
{{"gels", LapackDriverType::Gels},
{"gelsy", LapackDriverType::Gelsy},
{"gelsd", LapackDriverType::Gelsd},
{"gelss", LapackDriverType::Gelss}});
auto driver = driver_type[driver_string];

auto solution = context.Output<Tensor>("Solution");
auto* rank = context.Output<Tensor>("Rank");
auto* singular_values = context.Output<Tensor>("SingularValues");

auto dito =
math::DeviceIndependenceTensorOperations<DeviceContext, T>(context);

auto x_dims = x.dims();
auto y_dims = y.dims();
int dim_size = x_dims.size();
int x_stride = MatrixStride(x);
int y_stride = MatrixStride(y);
int batch_count = BatchCount(x);
auto ori_solution_dim = solution->dims();
int ori_solu_stride = MatrixStride(*solution);

// lapack is a column-major storge, transpose make the input to
// have a continuous memory layout
int info = 0;
int m = x_dims[dim_size - 2];
int n = x_dims[dim_size - 1];
int nrhs = y_dims[dim_size - 1];
int lda = std::max<int>(m, 1);
int ldb = std::max<int>(1, std::max(m, n));

Tensor new_x;
new_x.mutable_data<T>(context.GetPlace(),
size_t(batch_count * m * n * sizeof(T)));
solution->mutable_data<T>(
context.GetPlace(),
size_t(batch_count * std::max(m, n) * nrhs * sizeof(T)));
framework::TensorCopy(x, context.GetPlace(), &new_x);
framework::TensorCopy(y, context.GetPlace(), solution);

if (m < n) solution->Resize(UDDim(ori_solution_dim));

Tensor input_x_trans = dito.Transpose(new_x);
Tensor input_y_trans = dito.Transpose(*solution);
framework::TensorCopy(input_x_trans, new_x.place(), &new_x);
framework::TensorCopy(input_y_trans, solution->place(), solution);

auto* x_vector = new_x.data<T>();
auto* y_vector = solution->data<T>();

// "gels" divers does not need to compute rank
int rank_32 = 0;
int* rank_data = nullptr;
int* rank_working_ptr = nullptr;
if (driver != LapackDriverType::Gels) {
rank_data = rank->mutable_data<int>(context.GetPlace());
rank_working_ptr = rank_data;
}

// "gelsd" and "gelss" divers need to compute singular values
ValueType* s_data = nullptr;
ValueType* s_working_ptr = nullptr;
int s_stride = 0;
if (driver == LapackDriverType::Gelsd ||
driver == LapackDriverType::Gelss) {
s_data = singular_values->mutable_data<ValueType>(context.GetPlace());
s_working_ptr = s_data;
auto s_dims = singular_values->dims();
s_stride = s_dims[s_dims.size() - 1];
}

// "jpvt" is only used for "gelsy" driver
Tensor jpvt;
int* jpvt_data = nullptr;
if (driver == LapackDriverType::Gelsy) {
jpvt.Resize(framework::make_ddim({std::max<int>(1, n)}));
jpvt_data = jpvt.mutable_data<int>(context.GetPlace());
}

// run once the driver, first to get the optimal workspace size
int lwork = -1;
T wkopt;
ValueType rwkopt;
int iwkopt = 0;

if (driver == LapackDriverType::Gels) {
math::lapackGels('N', m, n, nrhs, x_vector, lda, y_vector, ldb, &wkopt,
lwork, &info);
} else if (driver == LapackDriverType::Gelsd) {
math::lapackGelsd(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, &wkopt, lwork,
&rwkopt, &iwkopt, &info);
} else if (driver == LapackDriverType::Gelsy) {
math::lapackGelsy(m, n, nrhs, x_vector, lda, y_vector, ldb, jpvt_data,
static_cast<ValueType>(rcond), &rank_32, &wkopt, lwork,
&rwkopt, &info);
} else if (driver == LapackDriverType::Gelss) {
math::lapackGelss(m, n, nrhs, x_vector, lda, y_vector, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, &wkopt, lwork,
&rwkopt, &info);
}

lwork = std::max<int>(1, static_cast<int>(math::Real<T>(wkopt)));
Tensor work;
work.Resize(framework::make_ddim({lwork}));
T* work_data = work.mutable_data<T>(context.GetPlace());

// "rwork" only used for complex inputs and "gelsy/gelsd/gelss" drivers
Tensor rwork;
ValueType* rwork_data = nullptr;
if (framework::IsComplexType(x.type()) &&
driver != LapackDriverType::Gels) {
int rwork_len = 0;
if (driver == LapackDriverType::Gelsy) {
rwork_len = std::max<int>(1, 2 * n);
} else if (driver == LapackDriverType::Gelss) {
rwork_len = std::max<int>(1, 5 * std::min(m, n));
} else if (driver == LapackDriverType::Gelsd) {
rwork_len = std::max<int>(1, rwkopt);
}
rwork.Resize(framework::make_ddim({rwork_len}));
rwork_data = rwork.mutable_data<ValueType>(context.GetPlace());
}

// "iwork" workspace array is relavant only for "gelsd" driver
Tensor iwork;
int* iwork_data = nullptr;
if (driver == LapackDriverType::Gelsd) {
iwork.Resize(framework::make_ddim({std::max<int>(1, iwkopt)}));
iwork_data = iwork.mutable_data<int>(context.GetPlace());
}

int solu_stride = std::max(y_stride, ori_solu_stride);
for (auto i = 0; i < batch_count; ++i) {
auto* x_input = &x_vector[i * x_stride];
auto* y_input = &y_vector[i * solu_stride];
rank_working_ptr = rank_working_ptr ? &rank_data[i] : nullptr;
s_working_ptr = s_working_ptr ? &s_data[i * s_stride] : nullptr;

if (driver == LapackDriverType::Gels) {
math::lapackGels('N', m, n, nrhs, x_input, lda, y_input, ldb, work_data,
lwork, &info);
} else if (driver == LapackDriverType::Gelsd) {
math::lapackGelsd(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, work_data,
lwork, rwork_data, iwork_data, &info);
} else if (driver == LapackDriverType::Gelsy) {
math::lapackGelsy(m, n, nrhs, x_input, lda, y_input, ldb, jpvt_data,
static_cast<ValueType>(rcond), &rank_32, work_data,
lwork, rwork_data, &info);
} else if (driver == LapackDriverType::Gelss) {
math::lapackGelss(m, n, nrhs, x_input, lda, y_input, ldb, s_working_ptr,
static_cast<ValueType>(rcond), &rank_32, work_data,
lwork, rwork_data, &info);
}

PADDLE_ENFORCE_EQ(
info, 0,
platform::errors::PreconditionNotMet(
"For batch [%d]: Lapack info is not zero but [%d]", i, info));

if (rank_working_ptr) *rank_working_ptr = static_cast<int>(rank_32);
}

Tensor tmp_s = dito.Transpose(*solution);
framework::TensorCopy(tmp_s, solution->place(), solution);

if (m >= n) solution->Resize(UDDim(ori_solution_dim));
}
};

} // namespace operators
} // namespace paddle
Loading