Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CINN]Add InferSymbolicShape of fpn ops #63947

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions paddle/cinn/ir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,7 @@ const std::string &Store::name() const {
}

Type Store::type() const { return value.type(); }

std::vector<Expr *> Store::expr_fields() {
std::vector<Expr *> exprs({&tensor, &value});
for (auto &idx : indices) exprs.push_back(&idx);
Expand Down Expand Up @@ -776,6 +777,9 @@ Expr Reduce::Make(Reduce::ReduceType reduce_type,
n->set_type(body.type());
return Expr(n);
}

Type Reduce::type() const { return body.type().ElementOf(); }

std::vector<Expr *> Reduce::expr_fields() {
std::vector<Expr *> res;
if (init.defined()) {
Expand All @@ -802,6 +806,12 @@ void Reduce::Verify() const {
CHECK_EQ(init.type(), body.type());
}

Type Select::type() const {
PADDLE_ENFORCE_EQ(
true_value.type(), false_value.type(), "Type of Select must be same");
return true_value.type();
}

void Select::Verify() const {
CHECK(condition.defined());
CHECK(true_value.defined());
Expand Down Expand Up @@ -862,12 +872,16 @@ void MultiOperandVerify(llvm::ArrayRef<Expr> operands) {
}
}

Type Product::type() const { return operands().front().type(); }

void Product::Verify() const {
CHECK_GT(operands().size(), 1UL)
<< "Product node should have more than 1 operands";
MultiOperandVerify(operands());
}

Type Sum::type() const { return operands().front().type(); }

void Sum::Verify() const {
CHECK_GT(operands().size(), 1UL)
<< "Sum node should have more than 1 operands";
Expand Down
12 changes: 5 additions & 7 deletions paddle/cinn/ir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -476,7 +476,7 @@ struct Reduce : public ExprNode<Reduce> {
Expr body,
const std::vector<Var>& reduce_axis);

Type type() const override { return body.type().ElementOf(); }
Type type() const override;

std::vector<Expr*> expr_fields() override;
std::vector<const Expr*> expr_fields() const override;
Expand Down Expand Up @@ -509,10 +509,7 @@ struct Select : public ExprNode<Select> {
return Expr(node);
}

Type type() const override {
CHECK_EQ(true_value.type(), false_value.type());
return true_value.type();
}
Type type() const override;

void Verify() const override;

Expand Down Expand Up @@ -573,6 +570,7 @@ struct Store : public ExprNode<Store>, public LoadStoreAddrMnger {
const std::string& name() const;

Type type() const override;

Expr index() const;

static const IrNodeTy _node_type_ = IrNodeTy::Store;
Expand Down Expand Up @@ -932,7 +930,7 @@ struct Product : public ExprNode<Product> {

using ExprNode<Product>::operand;

Type type() const override { return operands().front().type(); }
Type type() const override;

void Verify() const override;

Expand All @@ -944,7 +942,7 @@ struct Sum : public ExprNode<Sum> {

using ExprNode<Sum>::operand;

Type type() const override { return operands().front().type(); }
Type type() const override;

void Verify() const override;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -504,6 +504,25 @@ bool MemoryEfficientAttentionOpInferSymbolicShape(
return true;
}

bool RoiAlignOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &x = op->operand_source(0);
const auto &boxes = op->operand_source(1);

const auto &num_boxes =
infer_context->GetShapeOrDataForValue(boxes).shape()[0];
symbol::DimExpr channel_num =
infer_context->GetShapeOrDataForValue(x).shape()[1];

int32_t out_h = op->attribute<pir::Int32Attribute>("pooled_height").data();
int32_t out_w = op->attribute<pir::Int32Attribute>("pooled_width").data();

std::vector<symbol::DimExpr> out_dim = {num_boxes, channel_num, out_h, out_w};
infer_context->SetShapeOrDataForValue(
op->result(0), symbol::TensorShapeOrDataDimExprs(out_dim));
return true;
}

bool MeshgridOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const symbol::TensorListShapeOrDataDimExprs &shape_data_list =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logspace)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(MemoryEfficientAttention)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Meshgrid)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(NearestInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(RoiAlign)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Stack)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(TrilinearInterp)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Where)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,84 @@ bool DiagonalOpInferSymbolicShape(
return true;
}

bool DistributeFpnProposalsOpInferSymbolicShape(
pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) {
const auto &attributes = op->attributes();
int32_t min_level =
attributes.at("min_level").dyn_cast<pir::Int32Attribute>().data();
int32_t max_level =
attributes.at("max_level").dyn_cast<pir::Int32Attribute>().data();
int32_t num_levels = max_level - min_level + 1;
int64_t batch_size = 1;

symbol::DimExpr num_rois = [&]() {
pir::Value rois_num = op->operand_source(1);
const auto &rois_num_shape_or_data =
infer_context->GetShapeOrDataForValue(rois_num);

batch_size = rois_num_shape_or_data.shape()[0].Get<int64_t>();
PADDLE_ENFORCE_EQ(rois_num_shape_or_data.data().has_value(),
true,
::common::errors::InvalidArgument(
"InferSymbolicShape of DistributeFpnProposalsdOp "
"only support input with rois_num."));

symbol::DimExpr rois_total_num = 0;
for (int i = 0; i < batch_size; i++) {
const auto &rois_num_value = rois_num_shape_or_data.data().value()[i];

CHECK(rois_num_value.isa<int64_t>() || rois_num_value.isa<std::string>())
<< "rois_num must be int64 or SymName.";
if (rois_num_value.isa<int64_t>()) {
return symbol::DimExpr(rois_num_value.Get<int64_t>());
} else {
return symbol::DimExpr(rois_num_value.Get<std::string>());
}
rois_total_num = rois_total_num + rois_num_value;
}

return rois_total_num;
}();

const auto &multi_rois_out_shape = [&]() {
symbol::TensorListShapeOrDataDimExprs multi_rois_out_shape;
if (num_levels == 1) {
multi_rois_out_shape.emplace_back(
symbol::TensorShapeOrDataDimExprs({num_rois, 4}));
} else {
symbol::DimExpr last_dim = num_rois;
for (int i = 0; i < num_levels - 1; i++) {
const auto &next_sym_name = infer_context->GetNextSymName();
std::vector<symbol::DimExpr> level_dim = {next_sym_name, 4};
multi_rois_out_shape.emplace_back(
symbol::TensorShapeOrDataDimExprs(level_dim));
last_dim = last_dim - level_dim[0];
}
multi_rois_out_shape.emplace_back(
symbol::TensorShapeOrDataDimExprs({last_dim, 4}));
}

return multi_rois_out_shape;
}();

const auto &rois_num_per_level_out_shape = [&]() {
symbol::TensorListShapeOrDataDimExprs rois_num_per_level_out_shape;
rois_num_per_level_out_shape.resize(
num_levels, symbol::TensorShapeOrDataDimExprs({batch_size}));
return rois_num_per_level_out_shape;
}();

const auto &restore_ind = [&]() {
return symbol::TensorShapeOrDataDimExprs({num_rois, 1});
}();

infer_context->SetShapeOrDataForValue(op->result(0), multi_rois_out_shape);
infer_context->SetShapeOrDataForValue(op->result(1),
rois_num_per_level_out_shape);
infer_context->SetShapeOrDataForValue(op->result(2), restore_ind);
return true;
}

bool EinsumOpInferSymbolicShape(pir::Operation *op,
pir::InferSymbolicShapeContext *infer_context) {
PADDLE_THROW(phi::errors::Unimplemented(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Cumsum_)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DiagEmbed)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Diagonal)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(DistributeFpnProposals)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Einsum)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Kthvalue)
OP_DECLARE_INFER_SYMBOLIC_SHAPE(Logcumsumexp)
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/dialect/operator/ir/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,7 @@
func : distribute_fpn_proposals
data_type : fpn_rois
optional : rois_num, multi_level_rois_num
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : distributed_fused_lamb
args : (Tensor[] param, Tensor[] grad, Tensor fp32_fused_param, Tensor fp32_fused_grad, Tensor fp16_fused_param, Tensor fp16_fused_grad, Tensor moment1, Tensor moment2, Tensor beta1pow, Tensor beta2pow, Tensor fused_param_offsets, Tensor fp32_shard_fused_param_offsets, Tensor fp16_shard_fused_param_offsets, Tensor param_info, Tensor param_order, Tensor learning_rate, Tensor global_scale, float beta1, float beta2, float epsilon, float max_global_grad_norm, float weight_decay, bool clip_after_allreduce, int[] ring_ids= {}, int acc_steps = 1, bool use_master_param_norm = true, bool use_master_acc_grad = true, bool is_grad_scaled_by_nranks = true, int64_t nranks = 1, bool use_hierarchical_allreduce = false)
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2500,6 +2500,7 @@
data_type : x
optional : boxes_num
backward : roi_align_grad
interfaces : paddle::dialect::InferSymbolicShapeInterface

- op : roi_pool
args : (Tensor x, Tensor boxes, Tensor boxes_num, int pooled_height=1, int pooled_width=1, float spatial_scale=1.0)
Expand Down