diff --git a/cmake/generic.cmake b/cmake/generic.cmake index 13aaf0d760f160..9ed3d53ccdc2b9 100644 --- a/cmake/generic.cmake +++ b/cmake/generic.cmake @@ -88,7 +88,7 @@ # To build a unit test binary, which is an executable binary with libpaddle.so # automatically linked: # -# paddle_test(example SHARED) +# paddle_test(example SRCS example_test.cc) # # including binary directory for generated headers. diff --git a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc index 5d958d72665058..8c89800dd2d95a 100644 --- a/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc +++ b/paddle/fluid/framework/new_executor/instruction/cond_instruction.cc @@ -34,6 +34,7 @@ #include "paddle/pir/core/value.h" #include "paddle/fluid/framework/new_executor/instruction/instruction_util.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" namespace paddle { namespace framework { diff --git a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc index 2789c7b62bff53..f8400b1c289a5a 100644 --- a/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc +++ b/paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc @@ -43,6 +43,7 @@ #include "glog/logging.h" #include "paddle/fluid/framework/op_info.h" #include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" namespace paddle { diff --git a/paddle/fluid/ir_adaptor/translator/program_translator.cc b/paddle/fluid/ir_adaptor/translator/program_translator.cc index 2ebece4fbfef7d..313a78da1aab95 100644 --- a/paddle/fluid/ir_adaptor/translator/program_translator.cc +++ b/paddle/fluid/ir_adaptor/translator/program_translator.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/ir_adaptor/translator/op_translator.h" #include "paddle/fluid/ir_adaptor/translator/type_translator.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/attribute.h" diff --git a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt index befbb84a7117df..3026da6200254c 100644 --- a/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt +++ b/paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt @@ -190,7 +190,7 @@ cc_library( DEPS phi pd_interface pd_trait type_info) cc_library( pd_op_dialect_op - SRCS ${op_source_file} manual_op.cc + SRCS ${op_source_file} manual_op.cc control_flow_op.cc DEPS pd_op_dialect_core) cc_library( api_builder diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc new file mode 100644 index 00000000000000..94ba9a2e2e37f8 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -0,0 +1,113 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#ifdef GET_OP_LIST +#undef GET_OP_LIST +paddle::dialect::IfOp, paddle::dialect::WhileOp +#else +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" + +#include "paddle/pir/core/builder.h" +#include "paddle/pir/core/ir_printer.h" +#include "paddle/pir/core/operation_utils.h" + +namespace paddle { +namespace dialect { + +void IfOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value cond, + std::vector &&output_types) { + VLOG(4) << "Start build IfOp"; + argument.AddRegions(2u); + argument.AddInput(cond); + argument.output_types.swap(output_types); +} +pir::Block *IfOp::true_block() { + pir::Region &true_region = (*this)->region(0); + if (true_region.empty()) true_region.emplace_back(); + return true_region.front(); +} +pir::Block *IfOp::false_block() { + pir::Region &false_region = (*this)->region(1); + if (false_region.empty()) false_region.emplace_back(); + return false_region.front(); +} +void IfOp::Print(pir::IrPrinter &printer) { + auto &os = printer.os; + auto op = operation(); + printer.PrintOpResult(op); + os << " = pd_op.if"; + printer.PrintOpOperands(op); + os << " -> "; + printer.PrintOpReturnType(op); + os << "{"; + for (auto item : *true_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n } else {"; + for (auto item : *false_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n }"; +} +void IfOp::Verify() {} + +void WhileOp::Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const std::vector &output_types) { + argument.AddInputs(inputs); + argument.AddOutputs(output_types); + argument.AddRegions(2u); +} +pir::Block *WhileOp::cond_block() { + pir::Region &cond_region = (*this)->region(0); + if (cond_region.empty()) cond_region.emplace_back(); + return cond_region.front(); +} +pir::Block *WhileOp::body_block() { + pir::Region &body_region = (*this)->region(1); + if (body_region.empty()) body_region.emplace_back(); + return body_region.front(); +} + +void WhileOp::Print(pir::IrPrinter &printer) { + auto &os = printer.os; + auto op = operation(); + printer.PrintOpResult(op); + os << " \"" << name() << "\""; + printer.PrintOpOperands(op); + os << " -> "; + printer.PrintOpReturnType(op); + os << "{"; + for (auto item : *cond_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n } do {"; + for (auto item : *body_block()) { + os << "\n "; + printer.PrintOperation(item); + } + os << "\n }"; +} +} // namespace dialect +} // namespace paddle + +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) + +#endif diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h new file mode 100644 index 00000000000000..3f93c51a534e90 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -0,0 +1,68 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include + +#include "paddle/pir/core/op_base.h" + +namespace paddle { +namespace dialect { + +class IfOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.if"; } + static constexpr const char **attributes_name = nullptr; + static constexpr uint32_t attributes_num = 0; + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + pir::Value cond, + std::vector &&output_types); + + // static void Build(pir::Builder &builder, // NOLINT + // pir::OperationArgument &argument, // NOLINT + // pir::Value cond, + // std::unique_ptr&& true_block, + // std::unique_ptr&& false_block); + + pir::Value cond() { return operand_source(0); } + pir::Block *true_block(); + pir::Block *false_block(); + void Print(pir::IrPrinter &printer); // NOLINT + void Verify(); +}; + +class WhileOp : public pir::Op { + public: + using Op::Op; + static const char *name() { return "pd_op.while"; } + static constexpr uint32_t attributes_num = 0; + static constexpr const char **attributes_name = nullptr; + + static void Build(pir::Builder &builder, // NOLINT + pir::OperationArgument &argument, // NOLINT + const std::vector &inputs, + const std::vector &output_types); + pir::Block *cond_block(); + pir::Block *body_block(); + void Print(pir::IrPrinter &printer); // NOLINT + void Verify() {} +}; + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc index 8a4e4cda9f50bc..eb5f1f5a536703 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.cc @@ -1040,65 +1040,6 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) { fn(infer_meta); } -void IfOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - pir::Value cond, - std::vector &&output_types) { - VLOG(4) << "Start build IfOp"; - argument.AddRegions(2u); - argument.AddInput(cond); - argument.output_types.swap(output_types); -} -pir::Block *IfOp::true_block() { - pir::Region &true_region = (*this)->region(0); - if (true_region.empty()) true_region.emplace_back(); - return true_region.front(); -} -pir::Block *IfOp::false_block() { - pir::Region &false_region = (*this)->region(1); - if (false_region.empty()) false_region.emplace_back(); - return false_region.front(); -} -void IfOp::Print(pir::IrPrinter &printer) { - auto &os = printer.os; - auto op = operation(); - printer.PrintOpResult(op); - os << " = pd_op.if"; - printer.PrintOpOperands(op); - os << " -> "; - printer.PrintOpReturnType(op); - os << "{"; - for (auto item : *true_block()) { - os << "\n "; - printer.PrintOperation(item); - } - os << "\n } else {"; - for (auto item : *false_block()) { - os << "\n "; - printer.PrintOperation(item); - } - os << "\n }"; -} -void IfOp::Verify() {} - -void WhileOp::Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, - const std::vector &output_types) { - argument.AddInputs(inputs); - argument.AddOutputs(output_types); - argument.AddRegions(2u); -} -pir::Block *WhileOp::cond_block() { - pir::Region &cond_region = (*this)->region(0); - if (cond_region.empty()) cond_region.emplace_back(); - return cond_region.front(); -} -pir::Block *WhileOp::body_block() { - pir::Region &body_region = (*this)->region(1); - if (body_region.empty()) body_region.emplace_back(); - return body_region.front(); -} } // namespace dialect } // namespace paddle @@ -1108,5 +1049,3 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) -IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/manual_op.h b/paddle/fluid/pir/dialect/operator/ir/manual_op.h index 93f24e80cb5248..c6fc7cb32b3165 100644 --- a/paddle/fluid/pir/dialect/operator/ir/manual_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/manual_op.h @@ -176,39 +176,6 @@ class SplitGradOp : public pir::Op { static void InferMeta(phi::InferMetaContext *infer_meta); }; -class IfOp : public pir::Op { - public: - using Op::Op; - static const char *name() { return "pd_op.if"; } - static constexpr const char **attributes_name = nullptr; - static constexpr uint32_t attributes_num = 0; - static void Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - pir::Value cond, - std::vector &&output_types); - pir::Value cond() { return operand_source(0); } - pir::Block *true_block(); - pir::Block *false_block(); - void Print(pir::IrPrinter &printer); // NOLINT - void Verify(); -}; - -class WhileOp : public pir::Op { - public: - using Op::Op; - static const char *name() { return "pd.while"; } - static constexpr uint32_t attributes_num = 0; - static constexpr const char **attributes_name = nullptr; - - static void Build(pir::Builder &builder, // NOLINT - pir::OperationArgument &argument, // NOLINT - const std::vector &inputs, - const std::vector &output_types); - void Verify() {} - pir::Block *cond_block(); - pir::Block *body_block(); -}; - } // namespace dialect } // namespace paddle @@ -218,5 +185,3 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp) IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp) -IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp) diff --git a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc index ac62747026ed06..9a7c6b9de2ea26 100644 --- a/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc +++ b/paddle/fluid/pir/dialect/operator/ir/op_dialect.cc @@ -16,6 +16,7 @@ #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" // NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in // paddle/fluid/pir/dialect/CMakeLists.txt. +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/dialect/operator/ir/type_storage.h" @@ -50,14 +51,16 @@ void OperatorDialect::initialize() { #define GET_OP_LIST #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" // NOLINT >(); + RegisterOps< +#define GET_OP_LIST +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT + >(); RegisterOps(); + paddle::dialect::SplitGradOp>(); RegisterInterfaces(); } @@ -163,6 +166,8 @@ void OperatorDialect::PrintOperation(pir::Operation *op, pir::IrPrinter &printer) const { if (auto if_op = op->dyn_cast()) { if_op.Print(printer); + } else if (auto while_op = op->dyn_cast()) { + while_op.Print(printer); } else { printer.PrintGeneralOperation(op); } diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 00597318091084..c322f71893ff77 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -19,6 +19,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" diff --git a/paddle/pir/core/op_base.h b/paddle/pir/core/op_base.h index f9de8dfc6cf8d0..8e67a392c51cf3 100644 --- a/paddle/pir/core/op_base.h +++ b/paddle/pir/core/op_base.h @@ -22,6 +22,8 @@ #include "paddle/pir/core/utils.h" namespace pir { +class Builder; +class IrPrinter; class IR_API OpBase { public: diff --git a/test/cpp/new_executor/standalone_executor_new_ir_test.cc b/test/cpp/new_executor/standalone_executor_new_ir_test.cc index 02ca49d180baaf..9bdc5c3d3c718d 100644 --- a/test/cpp/new_executor/standalone_executor_new_ir_test.cc +++ b/test/cpp/new_executor/standalone_executor_new_ir_test.cc @@ -23,6 +23,7 @@ #include "paddle/phi/core/kernel_registry.h" #include "paddle/fluid/framework/new_executor/new_ir_interpreter.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/fluid/pir/transforms/pd_op_to_kernel_pass.h" diff --git a/test/cpp/pir/control_flow_dialect/if_op_test.cc b/test/cpp/pir/control_flow_dialect/if_op_test.cc index f2e49b150b7bc7..218a67e1acc5be 100644 --- a/test/cpp/pir/control_flow_dialect/if_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/if_op_test.cc @@ -14,7 +14,7 @@ #include #include -#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/core/builder.h" diff --git a/test/cpp/pir/control_flow_dialect/while_op_test.cc b/test/cpp/pir/control_flow_dialect/while_op_test.cc index 6c558cc9829267..609f1f8eb8d2e9 100644 --- a/test/cpp/pir/control_flow_dialect/while_op_test.cc +++ b/test/cpp/pir/control_flow_dialect/while_op_test.cc @@ -14,7 +14,7 @@ #include #include -#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" #include "paddle/pir/core/builder.h" diff --git a/test/cpp/pir/core/program_translator_test.cc b/test/cpp/pir/core/program_translator_test.cc index c95d5952577baf..483299c206129e 100644 --- a/test/cpp/pir/core/program_translator_test.cc +++ b/test/cpp/pir/core/program_translator_test.cc @@ -27,6 +27,7 @@ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/ir_adaptor/translator/translate.h" #include "paddle/fluid/ir_adaptor/translator/utils.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" diff --git a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc index bb99e86dfc21cd..6812e7a9ed1946 100644 --- a/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc +++ b/test/cpp/pir/kernel_dialect/ir_kernel_dialect_pass_test.cc @@ -23,6 +23,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_dialect.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" #include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h" #include "paddle/fluid/pir/dialect/operator/ir/op_type.h"