From 43f3b2d962e9ba0976cff98277b513dd893fec18 Mon Sep 17 00:00:00 2001 From: yulangz <1301481108@qq.com> Date: Sun, 28 Apr 2024 09:06:43 +0000 Subject: [PATCH 1/4] Add DistributeFpnProposalsOpInferSymbolicShape; Add RoiAlignOpInferSymbolicShape; fix GatherOpInferSymbolicShape when axis input is tensor; fix Expr set_type implementation of derived class; --- paddle/cinn/ir/ir.cc | 57 ++++++++++++++ paddle/cinn/ir/ir.h | 26 +++++-- paddle/cinn/ir/ir_base.cc | 7 +- paddle/cinn/ir/ir_base.h | 2 +- .../infer_symbolic_shape/binary_infer_sym.cc | 13 +++- .../multiary_infer_sym.cc | 19 +++++ .../infer_symbolic_shape/multiary_infer_sym.h | 1 + .../infer_symbolic_shape/unary_infer_sym.cc | 75 +++++++++++++++++++ .../infer_symbolic_shape/unary_infer_sym.h | 1 + paddle/fluid/pir/dialect/operator/ir/ops.yaml | 1 + paddle/phi/api/yaml/ops.yaml | 1 + 11 files changed, 188 insertions(+), 15 deletions(-) diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index a121806e6f3bf8..8ff16c37faa45c 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -209,6 +209,11 @@ void Let::Verify() const { Type Let::type() const { return symbol.type(); } +void Let::set_type(Type t) { + IrNode::set_type(t); + symbol->set_type(t); +} + Expr _Var_::Make(const std::string &name, const Type &type) { auto node = new _Var_(name, type); return Expr(node); @@ -405,6 +410,11 @@ const std::string &Store::name() const { } Type Store::type() const { return value.type(); } + +void Store::set_type(Type t) { + IrNode::set_type(t); + value->set_type(t); +} std::vector Store::expr_fields() { std::vector exprs({&tensor, &value}); for (auto &idx : indices) exprs.push_back(&idx); @@ -610,6 +620,12 @@ Type Load::type() const { return type; } +void Load::set_type(Type t) { + CHECK(tensor.defined()); + IrNode::set_type(t); + tensor->set_type(t); +} + std::vector Load::expr_fields() { std::vector exprs({&tensor}); for (auto &idx : indices) exprs.push_back(&idx); @@ -696,6 +712,11 @@ Type Broadcast::type() const { return value.type().ElementOf().with_lanes(lanes); } +void Broadcast::set_type(Type t) { + IrNode::set_type(t); + value->set_type(t); +} + Expr Sum::Make(const std::vector &vs) { CHECK(!vs.empty()); if (vs.size() == 1) return vs.front(); @@ -772,6 +793,14 @@ Expr Reduce::Make(Reduce::ReduceType reduce_type, n->set_type(body.type()); return Expr(n); } + +Type Reduce::type() const { return body.type().ElementOf(); } + +void Reduce::set_type(Type t) { + IrNode::set_type(t); + body->set_type(t); +} + std::vector Reduce::expr_fields() { std::vector res; if (init.defined()) { @@ -798,6 +827,16 @@ void Reduce::Verify() const { CHECK_EQ(init.type(), body.type()); } +Type Select::type() const { + CHECK_EQ(true_value.type(), false_value.type()); + return true_value.type(); +} + +void Select::set_type(Type t) { + IrNode::set_type(t); + true_value->set_type(t); +} + void Select::Verify() const { CHECK(condition.defined()); CHECK(true_value.defined()); @@ -858,12 +897,30 @@ void MultiOperandVerify(llvm::ArrayRef operands) { } } +Type Product::type() const { return operands().front().type(); } + +void Product::set_type(Type t) { + IrNode::set_type(t); + for (auto &operand : operands()) { + operand->set_type(t); + } +} + void Product::Verify() const { CHECK_GT(operands().size(), 1UL) << "Product node should have more than 1 operands"; MultiOperandVerify(operands()); } +Type Sum::type() const { return operands().front().type(); } + +void Sum::set_type(Type t) { + IrNode::set_type(t); + for (auto &operand : operands()) { + operand->set_type(t); + } +} + void Sum::Verify() const { CHECK_GT(operands().size(), 1UL) << "Sum node should have more than 1 operands"; diff --git a/paddle/cinn/ir/ir.h b/paddle/cinn/ir/ir.h index d711e93ce61abb..a3e42a5a4d9a5d 100644 --- a/paddle/cinn/ir/ir.h +++ b/paddle/cinn/ir/ir.h @@ -305,6 +305,8 @@ struct Let : public ExprNode { Type type() const override; + void set_type(Type t) override; + void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Let; @@ -476,7 +478,9 @@ struct Reduce : public ExprNode { Expr body, const std::vector& reduce_axis); - Type type() const override { return body.type().ElementOf(); } + Type type() const override; + + void set_type(Type t) override; std::vector expr_fields() override; std::vector expr_fields() const override; @@ -509,10 +513,9 @@ struct Select : public ExprNode { Type type() const override; - void set_type(Type t) override; - void Verify() const override; std::vector expr_fields() override { @@ -556,8 +550,6 @@ struct Load : public ExprNode, public LoadStoreAddrMnger { Type type() const override; - void set_type(Type t) override; - static const IrNodeTy _node_type_ = IrNodeTy::Load; }; @@ -578,7 +570,7 @@ struct Store : public ExprNode, public LoadStoreAddrMnger { const std::string& name() const; Type type() const override; - void set_type(Type type) override; + Expr index() const; static const IrNodeTy _node_type_ = IrNodeTy::Store; @@ -905,8 +897,6 @@ struct Broadcast : public ExprNode { Type type() const override; - void set_type(Type type) override; - void Verify() const override; std::vector expr_fields() override { return {&value}; } @@ -942,8 +932,6 @@ struct Product : public ExprNode { Type type() const override; - void set_type(Type t) override; - void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Product; @@ -956,8 +944,6 @@ struct Sum : public ExprNode { Type type() const override; - void set_type(Type t) override; - void Verify() const override; static const IrNodeTy _node_type_ = IrNodeTy::Sum; diff --git a/paddle/cinn/ir/ir_base.cc b/paddle/cinn/ir/ir_base.cc index da67b4e6e14c65..c1b0580d16562e 100644 --- a/paddle/cinn/ir/ir_base.cc +++ b/paddle/cinn/ir/ir_base.cc @@ -239,11 +239,10 @@ const Expr &IrNode::operand(int i) { void IrNode::set_type(Type type) { type_ = type; } void IrNode::convert_int32_to_int64() { - common::Type node_type = type(); - CHECK(node_type == Int(64) || node_type == Int(32) || node_type.is_unk()) + CHECK(type_ == Int(64) || type_ == Int(32) || type_.is_unk()) << "Current only support convert int32_t to int64_t, but get type is " - << node_type; - set_type(Int(64)); + << type_; + type_ = Int(64); for (Expr &operand : operands) { operand->convert_int32_to_int64(); } diff --git a/paddle/cinn/ir/ir_base.h b/paddle/cinn/ir/ir_base.h index e691038d3c83dd..236e8afb67fe86 100644 --- a/paddle/cinn/ir/ir_base.h +++ b/paddle/cinn/ir/ir_base.h @@ -162,7 +162,7 @@ class IrNode : public cinn::common::Object { virtual IrNodeTy node_type() const { return IrNodeTy::kUnk; } virtual Type type() const { return type_; } - virtual void set_type(Type t); + void set_type(Type type); //! Elevate int32 to int64 if needed void convert_int32_to_int64(); diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 4604f262669340..d30cbdb735c396 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -312,7 +312,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape( }(); const auto &restore_ind = [&]() { - return symbol::TensorShapeOrDataDimExprs({num_levels, 1}); + return symbol::TensorShapeOrDataDimExprs({num_rois, 1}); }(); shape_analysis->SetShapeOrDataForValue(op->result(0), multi_rois_out_shape); From 1f5dfc9d7501c0225e021182c2bea98b48f18b02 Mon Sep 17 00:00:00 2001 From: yulangz <1301481108@qq.com> Date: Wed, 8 May 2024 06:55:16 +0000 Subject: [PATCH 3/4] change shape_analysis to infer_context --- .../infer_symbolic_shape/multiary_infer_sym.cc | 8 ++++---- .../infer_symbolic_shape/unary_infer_sym.cc | 18 +++++++++--------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc index 170ad3d15f00ac..4a8b0c1000cc0b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/multiary_infer_sym.cc @@ -505,20 +505,20 @@ bool MemoryEfficientAttentionOpInferSymbolicShape( } bool RoiAlignOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &x = op->operand_source(0); const auto &boxes = op->operand_source(1); const auto &num_boxes = - shape_analysis->GetShapeOrDataForValue(boxes).shape()[0]; + infer_context->GetShapeOrDataForValue(boxes).shape()[0]; symbol::DimExpr channel_num = - shape_analysis->GetShapeOrDataForValue(x).shape()[1]; + infer_context->GetShapeOrDataForValue(x).shape()[1]; int32_t out_h = op->attribute("pooled_height").data(); int32_t out_w = op->attribute("pooled_width").data(); std::vector out_dim = {num_boxes, channel_num, out_h, out_w}; - shape_analysis->SetShapeOrDataForValue( + infer_context->SetShapeOrDataForValue( op->result(0), symbol::TensorShapeOrDataDimExprs(out_dim)); return true; } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index e5da9be8add0a8..90f46d0c6940ac 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -247,7 +247,7 @@ bool DiagonalOpInferSymbolicShape( } bool DistributeFpnProposalsOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { + pir::Operation *op, pir::InferSymbolicShapeContext *infer_context) { const auto &attributes = op->attributes(); int32_t min_level = attributes.at("min_level").dyn_cast().data(); @@ -258,7 +258,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape( symbol::DimExpr num_rois = [&]() { pir::Value rois_num = op->operand_source(1); const auto &rois_num_shape_or_data = - shape_analysis->GetShapeOrDataForValue(rois_num); + infer_context->GetShapeOrDataForValue(rois_num); PADDLE_ENFORCE_EQ( rois_num_shape_or_data.shape()[0], @@ -290,7 +290,7 @@ bool DistributeFpnProposalsOpInferSymbolicShape( } else { symbol::DimExpr last_dim = num_rois; for (int i = 0; i < num_levels - 1; i++) { - const auto &next_sym_name = shape_analysis->GetNextSymName(); + const auto &next_sym_name = infer_context->GetNextSymName(); std::vector level_dim = {next_sym_name, 4}; multi_rois_out_shape.emplace_back( symbol::TensorShapeOrDataDimExprs(level_dim)); @@ -314,15 +314,15 @@ bool DistributeFpnProposalsOpInferSymbolicShape( return symbol::TensorShapeOrDataDimExprs({num_rois, 1}); }(); - shape_analysis->SetShapeOrDataForValue(op->result(0), multi_rois_out_shape); - shape_analysis->SetShapeOrDataForValue(op->result(1), - rois_num_per_level_out_shape); - shape_analysis->SetShapeOrDataForValue(op->result(2), restore_ind); + infer_context->SetShapeOrDataForValue(op->result(0), multi_rois_out_shape); + infer_context->SetShapeOrDataForValue(op->result(1), + rois_num_per_level_out_shape); + infer_context->SetShapeOrDataForValue(op->result(2), restore_ind); return true; } -bool EinsumOpInferSymbolicShape( - pir::Operation *op, pir::ShapeConstraintIRAnalysis *shape_analysis) { +bool EinsumOpInferSymbolicShape(pir::Operation *op, + pir::InferSymbolicShapeContext *infer_context) { PADDLE_THROW(phi::errors::Unimplemented( op->name() + " 's InferSymbolicShape interface is NOT implemented now.")); return true; From f1d2dee98d764c36f1a3b6846b71e2866bbae782 Mon Sep 17 00:00:00 2001 From: yulangz <1301481108@qq.com> Date: Wed, 8 May 2024 13:15:55 +0000 Subject: [PATCH 4/4] fix --- paddle/cinn/ir/ir.cc | 3 +- .../infer_symbolic_shape/unary_infer_sym.cc | 35 ++++++++++--------- 2 files changed, 21 insertions(+), 17 deletions(-) diff --git a/paddle/cinn/ir/ir.cc b/paddle/cinn/ir/ir.cc index 268538cbaf6803..7157fcf5c15ce3 100644 --- a/paddle/cinn/ir/ir.cc +++ b/paddle/cinn/ir/ir.cc @@ -807,7 +807,8 @@ void Reduce::Verify() const { } Type Select::type() const { - CHECK_EQ(true_value.type(), false_value.type()); + PADDLE_ENFORCE_EQ( + true_value.type(), false_value.type(), "Type of Select must be same"); return true_value.type(); } diff --git a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc index 90f46d0c6940ac..f84ca545188bfa 100644 --- a/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc +++ b/paddle/fluid/pir/dialect/operator/interface/infer_symbolic_shape/unary_infer_sym.cc @@ -254,32 +254,35 @@ bool DistributeFpnProposalsOpInferSymbolicShape( int32_t max_level = attributes.at("max_level").dyn_cast().data(); int32_t num_levels = max_level - min_level + 1; + int64_t batch_size = 1; symbol::DimExpr num_rois = [&]() { pir::Value rois_num = op->operand_source(1); const auto &rois_num_shape_or_data = infer_context->GetShapeOrDataForValue(rois_num); - PADDLE_ENFORCE_EQ( - rois_num_shape_or_data.shape()[0], - 1, - phi::errors::InvalidArgument("DistributeFpnProposalsOp in pir model " - "only support batch_size=1 now.")); - + batch_size = rois_num_shape_or_data.shape()[0].Get(); PADDLE_ENFORCE_EQ(rois_num_shape_or_data.data().has_value(), true, - phi::errors::InvalidArgument( + ::common::errors::InvalidArgument( "InferSymbolicShape of DistributeFpnProposalsdOp " "only support input with rois_num.")); - const auto &rois_num_value = rois_num_shape_or_data.data().value()[0]; - CHECK(rois_num_value.isa() || rois_num_value.isa()) - << "rois_num must be int64 or SymName."; - if (rois_num_value.isa()) { - return symbol::DimExpr(rois_num_value.Get()); - } else { - return symbol::DimExpr(rois_num_value.Get()); + symbol::DimExpr rois_total_num = 0; + for (int i = 0; i < batch_size; i++) { + const auto &rois_num_value = rois_num_shape_or_data.data().value()[i]; + + CHECK(rois_num_value.isa() || rois_num_value.isa()) + << "rois_num must be int64 or SymName."; + if (rois_num_value.isa()) { + return symbol::DimExpr(rois_num_value.Get()); + } else { + return symbol::DimExpr(rois_num_value.Get()); + } + rois_total_num = rois_total_num + rois_num_value; } + + return rois_total_num; }(); const auto &multi_rois_out_shape = [&]() { @@ -305,8 +308,8 @@ bool DistributeFpnProposalsOpInferSymbolicShape( const auto &rois_num_per_level_out_shape = [&]() { symbol::TensorListShapeOrDataDimExprs rois_num_per_level_out_shape; - rois_num_per_level_out_shape.resize(num_levels, - symbol::TensorShapeOrDataDimExprs({1})); + rois_num_per_level_out_shape.resize( + num_levels, symbol::TensorShapeOrDataDimExprs({batch_size})); return rois_num_per_level_out_shape; }();