Skip to content

Commit

Permalink
Pir cinn support multi group (#59037)
Browse files Browse the repository at this point in the history
* pir cinn support multi group

* update

* update
  • Loading branch information
phlrain authored Nov 16, 2023
1 parent 557499b commit 7161b06
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,10 +140,13 @@ std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
auto group_list =
cinn::dialect::ir::GeneralFusionMergePassInternal(op_fusion);

PADDLE_ENFORCE_EQ(group_list.size(),
1u,
phi::errors::Unimplemented(
"Only support one group after group fusion"));
// using yield op to sort
std::unordered_map<::pir::Value, size_t> value2id;
auto yeild_op = group_op.ops().back();
for (size_t i = 0; i < yeild_op->num_operands(); ++i) {
value2id[yeild_op->operand_source(i)] = i;
}

for (auto group : group_list) {
auto ir_compiler = std::make_shared<cinn::hlir::framework::PirCompiler>(
*program, target, scope);
Expand All @@ -162,26 +165,23 @@ std::unique_ptr<pir::Program> CINNGroupLoweringPass(::pir::Program* program) {
vec_new_ins.push_back(value_map.at(vec_ins[i]));
}

// using yield op to sort
std::unordered_map<::pir::Value, size_t> value2id;
auto yeild_op = group_op.ops().back();
for (size_t i = 0; i < yeild_op->num_operands(); ++i) {
value2id[yeild_op->operand_source(i)] = i;
}

std::unordered_map<size_t, size_t> codegen2orig;

std::vector<pir::Type> vec_types;
for (size_t i = 0; i < group->output_values.size(); ++i) {
vec_types.push_back(group->output_values[i].type());
codegen2orig[value2id.at(group->output_values[i])] = i;
}

::pir::Operation* cinn_op =
::pir::Operation::Create(vec_new_ins, op_attrs, vec_types, op_info);

for (size_t i = 0; i < group_op.num_results(); ++i) {
value_map[group_op.result(i)] = cinn_op->result(codegen2orig.at(i));
for (size_t i = 0; i < cinn_op->num_results(); ++i) {
auto find_it = value2id.find(group->output_values[i]);
if (find_it == value2id.end()) {
value_map[group->output_values[i]] = cinn_op->result(i);
} else {
value_map[group_op.result(find_it->second)] = cinn_op->result(i);
}
}

ir_program->block()->push_back(cinn_op);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <unordered_map>

#include "paddle/cinn/hlir/dialect/operator/transforms/op_group.h"
#include "paddle/pir/core/ir_printer.h"
#include "paddle/pir/core/value.h"

#include "paddle/cinn/hlir/dialect/operator/transforms/group_with_group_merge_pass_utils.h"
Expand Down
74 changes: 73 additions & 1 deletion test/cpp/pir/cinn/pir_all_path_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ TEST(GroupOp, TestBuildLayerNorm) {
// executor.Run({}, true);

// auto out_tensor =
// executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();
// executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();
}

std::shared_ptr<::pir::Program> BuildDropOutProgram() {
Expand Down Expand Up @@ -495,3 +495,75 @@ TEST(GroupOp, TestBuildPower) {
bool res0 = simple_cmp(out_tensor.data<float>()[0], 4.0);
EXPECT_EQ(res0, true);
}

std::shared_ptr<::pir::Program> BuildSum2GroupProgram() {
::pir::IrContext* ctx = ::pir::IrContext::Instance();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
auto program = std::make_shared<::pir::Program>(ctx);
::pir::Builder builder = ::pir::Builder(ctx, program->block());

auto x = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({16, 16}),
0.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto cos = builder.Build<paddle::dialect::CosOp>(x).result(0);

auto y = builder
.Build<paddle::dialect::FullOp>(std::vector<int64_t>({8, 8}),
0.0,
phi::DataType::FLOAT32,
phi::GPUPlace())
.result(0);

auto sin = builder.Build<paddle::dialect::SinOp>(y).result(0);

builder.Build<paddle::dialect::FetchOp>(cos, "out", 0);
builder.Build<paddle::dialect::FetchOp>(sin, "out2", 0);
return program;
}

TEST(GroupOp, TestBuildSum2Group) {
// Step 1: Construct pir::Program
::pir::IrContext* ctx = ::pir::IrContext::Instance();
std::shared_ptr<::pir::Program> program = BuildSum2GroupProgram();
ctx->GetOrRegisterDialect<paddle::dialect::OperatorDialect>();
ctx->GetOrRegisterDialect<cinn::dialect::OperatorDialect>();

cinn::dialect::ir::PdOp2CinnOpConverter(program.get());

pir::PassManager pm(ctx);
pm.AddPass(
std::make_unique<cinn::dialect::ir::AddBroadcastToElementwisePass>());
pm.AddPass(pir::CreateBuildCinnPass());
CHECK_EQ(pm.Run(program.get()), true);

auto res = cinn::dialect::ir::CINNGroupLoweringPass(program.get());

paddle::platform::Place place = paddle::platform::CUDAPlace(0);

res->Print(std::cout);
auto kernel_program =
paddle::dialect::PdOpLowerToKernelPass(res.get(), place);

paddle::framework::Scope exe_scope;

paddle::framework::InterpreterCore executor(
place, {"out@fetch"}, kernel_program->block(), &exe_scope);

executor.Run({}, true);

auto out_tensor =
executor.local_scope()->FindVar("out@fetch")->Get<phi::DenseTensor>();

auto out_tensor2 =
executor.local_scope()->FindVar("out2@fetch")->Get<phi::DenseTensor>();

bool res0 = simple_cmp(out_tensor.data<float>()[0], 1.0);
EXPECT_EQ(res0, true);

bool res1 = (out_tensor2.data<float>()[0] == 0.0);
EXPECT_EQ(res1, true);
}

0 comments on commit 7161b06

Please sign in to comment.