Skip to content

Commit

Permalink
[PIR] add print function for pd_op.while (#57917)
Browse files Browse the repository at this point in the history
  • Loading branch information
winter-wang authored Oct 8, 2023
1 parent a498e0b commit e5bdde1
Show file tree
Hide file tree
Showing 17 changed files with 202 additions and 103 deletions.
2 changes: 1 addition & 1 deletion cmake/generic.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
# To build a unit test binary, which is an executable binary with libpaddle.so
# automatically linked:
#
# paddle_test(example SHARED)
# paddle_test(example SRCS example_test.cc)
#

# including binary directory for generated headers.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "paddle/pir/core/value.h"

#include "paddle/fluid/framework/new_executor/instruction/instruction_util.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
namespace paddle {
namespace framework {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include "glog/logging.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"

namespace paddle {
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/ir_adaptor/translator/program_translator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "paddle/fluid/ir_adaptor/translator/op_translator.h"
#include "paddle/fluid/ir_adaptor/translator/type_translator.h"
#include "paddle/fluid/ir_adaptor/translator/utils.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/pir/core/attribute.h"
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ cc_library(
DEPS phi pd_interface pd_trait type_info)
cc_library(
pd_op_dialect_op
SRCS ${op_source_file} manual_op.cc
SRCS ${op_source_file} manual_op.cc control_flow_op.cc
DEPS pd_op_dialect_core)
cc_library(
api_builder
Expand Down
113 changes: 113 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef GET_OP_LIST
#undef GET_OP_LIST
paddle::dialect::IfOp, paddle::dialect::WhileOp
#else
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"

#include "paddle/pir/core/builder.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/operation_utils.h"

namespace paddle {
namespace dialect {

void IfOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
std::vector<pir::Type> &&output_types) {
VLOG(4) << "Start build IfOp";
argument.AddRegions(2u);
argument.AddInput(cond);
argument.output_types.swap(output_types);
}
pir::Block *IfOp::true_block() {
pir::Region &true_region = (*this)->region(0);
if (true_region.empty()) true_region.emplace_back();
return true_region.front();
}
pir::Block *IfOp::false_block() {
pir::Region &false_region = (*this)->region(1);
if (false_region.empty()) false_region.emplace_back();
return false_region.front();
}
void IfOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = pd_op.if";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto item : *true_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n } else {";
for (auto item : *false_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n }";
}
void IfOp::Verify() {}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types) {
argument.AddInputs(inputs);
argument.AddOutputs(output_types);
argument.AddRegions(2u);
}
pir::Block *WhileOp::cond_block() {
pir::Region &cond_region = (*this)->region(0);
if (cond_region.empty()) cond_region.emplace_back();
return cond_region.front();
}
pir::Block *WhileOp::body_block() {
pir::Region &body_region = (*this)->region(1);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
}

void WhileOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " \"" << name() << "\"";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto item : *cond_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n } do {";
for (auto item : *body_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n }";
}
} // namespace dialect
} // namespace paddle

IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)

#endif
68 changes: 68 additions & 0 deletions paddle/fluid/pir/dialect/operator/ir/control_flow_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <vector>

#include "paddle/pir/core/op_base.h"

namespace paddle {
namespace dialect {

class IfOp : public pir::Op<IfOp> {
public:
using Op::Op;
static const char *name() { return "pd_op.if"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
std::vector<pir::Type> &&output_types);

// static void Build(pir::Builder &builder, // NOLINT
// pir::OperationArgument &argument, // NOLINT
// pir::Value cond,
// std::unique_ptr<pir::Block>&& true_block,
// std::unique_ptr<pir::Block>&& false_block);

pir::Value cond() { return operand_source(0); }
pir::Block *true_block();
pir::Block *false_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify();
};

class WhileOp : public pir::Op<WhileOp> {
public:
using Op::Op;
static const char *name() { return "pd_op.while"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types);
pir::Block *cond_block();
pir::Block *body_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify() {}
};

} // namespace dialect
} // namespace paddle

IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
61 changes: 0 additions & 61 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,65 +1040,6 @@ void SplitGradOp::InferMeta(phi::InferMetaContext *infer_meta) {
fn(infer_meta);
}

void IfOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
std::vector<pir::Type> &&output_types) {
VLOG(4) << "Start build IfOp";
argument.AddRegions(2u);
argument.AddInput(cond);
argument.output_types.swap(output_types);
}
pir::Block *IfOp::true_block() {
pir::Region &true_region = (*this)->region(0);
if (true_region.empty()) true_region.emplace_back();
return true_region.front();
}
pir::Block *IfOp::false_block() {
pir::Region &false_region = (*this)->region(1);
if (false_region.empty()) false_region.emplace_back();
return false_region.front();
}
void IfOp::Print(pir::IrPrinter &printer) {
auto &os = printer.os;
auto op = operation();
printer.PrintOpResult(op);
os << " = pd_op.if";
printer.PrintOpOperands(op);
os << " -> ";
printer.PrintOpReturnType(op);
os << "{";
for (auto item : *true_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n } else {";
for (auto item : *false_block()) {
os << "\n ";
printer.PrintOperation(item);
}
os << "\n }";
}
void IfOp::Verify() {}

void WhileOp::Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types) {
argument.AddInputs(inputs);
argument.AddOutputs(output_types);
argument.AddRegions(2u);
}
pir::Block *WhileOp::cond_block() {
pir::Region &cond_region = (*this)->region(0);
if (cond_region.empty()) cond_region.emplace_back();
return cond_region.front();
}
pir::Block *WhileOp::body_block() {
pir::Region &body_region = (*this)->region(1);
if (body_region.empty()) body_region.emplace_back();
return body_region.front();
}
} // namespace dialect
} // namespace paddle

Expand All @@ -1108,5 +1049,3 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
35 changes: 0 additions & 35 deletions paddle/fluid/pir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,39 +176,6 @@ class SplitGradOp : public pir::Op<SplitGradOp, OpYamlInfoInterface> {
static void InferMeta(phi::InferMetaContext *infer_meta);
};

class IfOp : public pir::Op<IfOp> {
public:
using Op::Op;
static const char *name() { return "pd_op.if"; }
static constexpr const char **attributes_name = nullptr;
static constexpr uint32_t attributes_num = 0;
static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
pir::Value cond,
std::vector<pir::Type> &&output_types);
pir::Value cond() { return operand_source(0); }
pir::Block *true_block();
pir::Block *false_block();
void Print(pir::IrPrinter &printer); // NOLINT
void Verify();
};

class WhileOp : public pir::Op<WhileOp> {
public:
using Op::Op;
static const char *name() { return "pd.while"; }
static constexpr uint32_t attributes_num = 0;
static constexpr const char **attributes_name = nullptr;

static void Build(pir::Builder &builder, // NOLINT
pir::OperationArgument &argument, // NOLINT
const std::vector<pir::Value> &inputs,
const std::vector<pir::Type> &output_types);
void Verify() {}
pir::Block *cond_block();
pir::Block *body_block();
};

} // namespace dialect
} // namespace paddle

Expand All @@ -218,5 +185,3 @@ IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddN_Op)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::AddNWithKernelOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::FusedGemmEpilogueGradOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::IfOp)
IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::WhileOp)
11 changes: 8 additions & 3 deletions paddle/fluid/pir/dialect/operator/ir/op_dialect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
// NOTE(zhangbo9674): File pd_op.h is generated by op_gen.py, see details in
// paddle/fluid/pir/dialect/CMakeLists.txt.
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_type.h"
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/type_storage.h"
Expand Down Expand Up @@ -50,14 +51,16 @@ void OperatorDialect::initialize() {
#define GET_OP_LIST
#include "paddle/fluid/pir/dialect/operator/ir/pd_op.h" // NOLINT
>();
RegisterOps<
#define GET_OP_LIST
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc" // NOLINT
>();
RegisterOps<paddle::dialect::AddNOp,
paddle::dialect::AddN_Op,
paddle::dialect::AddNWithKernelOp,
paddle::dialect::FusedGemmEpilogueOp,
paddle::dialect::FusedGemmEpilogueGradOp,
paddle::dialect::SplitGradOp,
paddle::dialect::IfOp,
paddle::dialect::WhileOp>();
paddle::dialect::SplitGradOp>();

RegisterInterfaces<ParameterConvertInterface>();
}
Expand Down Expand Up @@ -163,6 +166,8 @@ void OperatorDialect::PrintOperation(pir::Operation *op,
pir::IrPrinter &printer) const {
if (auto if_op = op->dyn_cast<IfOp>()) {
if_op.Print(printer);
} else if (auto while_op = op->dyn_cast<WhileOp>()) {
while_op.Print(printer);
} else {
printer.PrintGeneralOperation(op);
}
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h"
#include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h"
#include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h"
#include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/manual_op.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h"
#include "paddle/fluid/pir/dialect/operator/ir/op_dialect.h"
Expand Down
2 changes: 2 additions & 0 deletions paddle/pir/core/op_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
#include "paddle/pir/core/utils.h"

namespace pir {
class Builder;
class IrPrinter;

class IR_API OpBase {
public:
Expand Down
Loading

0 comments on commit e5bdde1

Please sign in to comment.