Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PIR] Fix bug of BuildScope for IfOp #58109

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -257,8 +257,7 @@ void InstructionBase::InitInputsOutputsIds(

std::string InstructionBase::DebugStringEx(
const paddle::framework::Scope* scope,
const std::unordered_map<::pir::Value, std::string>& value_2_var_name)
const {
ValueExecutionInfo* value_exe_info) const {
std::stringstream ss;
ss << "Op(" << Name() << "), inputs:{";

Expand All @@ -268,7 +267,7 @@ std::string InstructionBase::DebugStringEx(
auto& input = *it;
bool is_no_need_buffer_var = (!no_need_buffer_vars.empty() &&
no_need_buffer_vars.count(input.first) > 0);
auto var_name = value_2_var_name.at(input.first);
auto var_name = value_exe_info->GetVarName(input.first);
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
Expand Down Expand Up @@ -296,7 +295,7 @@ std::string InstructionBase::DebugStringEx(
ss << "}, outputs:{";
for (auto it = Outputs().begin(); it != Outputs().end();) {
auto& output = *it;
auto var_name = value_2_var_name.at(output.first);
auto var_name = value_exe_info->GetVarName(output.first);
ss << var_name;
if (scope) {
if (!VarInited(*scope, var_name)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,10 +144,8 @@ class InstructionBase {
const ValueExecutionInfo& value_exec_info);

// if scope is not null, also show dimensions of arguments
virtual std::string DebugStringEx(
const paddle::framework::Scope* scope,
const std::unordered_map<::pir::Value, std::string>& value_2_var_name)
const;
virtual std::string DebugStringEx(const paddle::framework::Scope* scope,
ValueExecutionInfo* value_exe_info) const;

protected:
size_t id_;
Expand Down
24 changes: 15 additions & 9 deletions paddle/fluid/framework/new_executor/new_ir_interpreter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ Scope* NewIRInterpreter::InnerScope() const {
}

std::string NewIRInterpreter::GetNameByValue(::pir::Value value) const {
return value_exe_info_->GetValue2VarName().at(value);
return value_exe_info_->GetVarName(value);
}

void NewIRInterpreter::UpdateSyncOpNum() {
Expand Down Expand Up @@ -627,7 +627,7 @@ std::string NewIRInterpreter::DebugValueInfo() {
PADDLE_ENFORCE((bool)kv.first,
platform::errors::PreconditionNotMet(
"vlaue(%s) should not be nullptr", kv.second));
PADDLE_ENFORCE(value_exe_info_->GetVarName2Id().count(kv.second) > 0,
PADDLE_ENFORCE(value_exe_info_->HasVar(kv.second),
platform::errors::PreconditionNotMet(
"var(%s) should exist in var_name_2_id_", kv.second));
auto* var = InnerScope()->FindVar(kv.second);
Expand All @@ -636,8 +636,7 @@ std::string NewIRInterpreter::DebugValueInfo() {
platform::errors::PreconditionNotMet(
"var(%s) should exist in scope (%p)", kv.second, InnerScope()));
os << kv.first.impl() << " -> " << kv.second << " -> "
<< value_exe_info_->GetVarName2Id().at(kv.second) << " -> " << var
<< "\n";
<< value_exe_info_->GetVarId(kv.first) << " -> " << var << "\n";
}
return os.str();
}
Expand Down Expand Up @@ -857,6 +856,7 @@ void NewIRInterpreter::CheckGC(InstructionBase* instr) {
}

void NewIRInterpreter::CalculateLastLiveOps() {
VLOG(4) << "NewIRInterpreter(): " << this << " start CalculateLastLiveOps";
// calculate last_live_ops_
for (size_t op_idx = 0; op_idx < vec_instruction_base_.size(); ++op_idx) {
InstructionBase* instr = vec_instruction_base_[op_idx].get();
Expand All @@ -882,11 +882,16 @@ void NewIRInterpreter::CalculateLastLiveOps() {
gc_check_vars.insert(var_id);
}
}
VLOG(4) << "get gc check vars for: " << instr->Name();

for (auto var_id : gc_check_vars) {
Scope* inner_scope = InnerScope();
paddle::framework::Variable* var = inner_scope->FindVar(
value_exe_info_->GetNameById(static_cast<int>(var_id)));
PADDLE_ENFORCE_NOT_NULL(
var,
platform::errors::NotFound("Var(id=%d) should not be nullptr.",
static_cast<int>(var_id)));
if (var->IsType<phi::DenseTensor>() || var->IsType<phi::SelectedRows>() ||
var->IsType<LoDTensorArray>() ||
var->IsType<phi::SparseCooTensor>() ||
Expand All @@ -899,6 +904,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
<< framework::ToTypeName(var->Type());
}
}
VLOG(4) << "update last_live_ops for: " << instr->Name();
}
// clear the last_live_ops list for all vars in skip_gc_vars
for (const std::string& skip_gc_var : execution_config_.skip_gc_vars) {
Expand All @@ -908,7 +914,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
VLOG(8) << "Skip gc for var: " << skip_gc_var;
}
}
VLOG(4) << "calculate last_live_ops_";
VLOG(4) << "clear the last_live_ops list for all vars in skip_gc_vars";

// shrink, find the downstream op that has no other op in the
// downstream list happens before it
Expand Down Expand Up @@ -949,6 +955,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
last_live_ops_[i] = minumum_last_live_ops;
var_ref_count_[i] = static_cast<int>(last_live_ops_[i].size());
}
VLOG(4) << "shrink the last_live_ops list for all vars in skip_gc_vars";

for (auto& dep : *dependecy_count_) {
deps_.emplace_back(std::make_shared<interpreter::OpDepInfo>(dep));
Expand All @@ -957,6 +964,7 @@ void NewIRInterpreter::CalculateLastLiveOps() {
refs_.emplace_back(std::make_shared<interpreter::VarRefInfo>(
var_ref_count_[i], value_exe_info_->GetVarList()[i]));
}
VLOG(4) << "done CalculateLastLiveOps";
}

void NewIRInterpreter::ConstructEventForJitInput() {
Expand Down Expand Up @@ -1410,8 +1418,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) {
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
VLOG(4) << place_ << " "
<< instr_node->DebugStringEx(scope_,
value_exe_info_->GetValue2VarName());
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
if (!instr_node->IsArtificial()) {
instr_node->Run();

Expand All @@ -1433,8 +1440,7 @@ void NewIRInterpreter::RunInstructionBase(InstructionBase* instr_node) {
: "kGpuAsync"))
<< " runs on " << platform::GetCurrentThreadName();
VLOG(4) << place_ << " "
<< instr_node->DebugStringEx(scope_,
value_exe_info_->GetValue2VarName());
<< instr_node->DebugStringEx(scope_, value_exe_info_.get());
CheckGC(instr_node);
VLOG(4) << "done CheckGC";
interpreter::LogDeviceMemoryStats(place_);
Expand Down
104 changes: 18 additions & 86 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,11 @@ std::shared_ptr<ValueExecutionInfo> ValueExecutionInfo::NewChild(Scope* scope) {
std::shared_ptr<ValueExecutionInfo> info =
std::make_shared<ValueExecutionInfo>(scope);
info->parent_ = this;
info->value_2_var_name_ = this->value_2_var_name_;
info->var_2_var_name_ = this->var_2_var_name_;
info->var_name_2_id_ = this->var_name_2_id_;
info->id_2_var_name_ = this->id_2_var_name_;
info->var_list_ = this->var_list_;
return info;
}

Expand Down Expand Up @@ -157,120 +162,49 @@ void ValueExecutionInfo::ResetVarList(int id, Variable* var) {
var_list_[id] = var;
}

bool ValueExecutionInfo::HasValue(::pir::Value value) const {
return HasValueInternal(value);
}

bool ValueExecutionInfo::HasLocalValue(::pir::Value value) const {
return HasValueLocally(value);
}

std::string ValueExecutionInfo::GetVarName(::pir::Value value) const {
return GetVarNameInternal(value);
}

std::string ValueExecutionInfo::GetVarName(const Variable* var) const {
return GetVarNameInternal(var);
}

std::string ValueExecutionInfo::GetLocalVarName(::pir::Value value) const {
return GetVarNameLocally(value);
}

std::string ValueExecutionInfo::GetLocalVarName(const Variable* var) const {
return GetVarNameLocally(var);
}

int ValueExecutionInfo::GetVarId(::pir::Value value) const {
return GetVarIdInternal(value);
}

int ValueExecutionInfo::GetVarId(const Variable* var) const {
return GetVarIdInternal(var);
}

int ValueExecutionInfo::GetLocalVarId(::pir::Value value) const {
return GetVarIdLocally(value);
}

int ValueExecutionInfo::GetLocalVarId(const Variable* var) const {
return GetVarIdLocally(var);
}

bool ValueExecutionInfo::HasValueInternal(::pir::Value value) const {
if (HasValueLocally(value)) {
bool ValueExecutionInfo::HasVar(const std::string& var_name) const {
auto it = var_name_2_id_.find(var_name);
if (it != var_name_2_id_.end()) {
return true;
}
return (parent_ == nullptr) ? false : parent_->HasValueInternal(value);
return false;
}

bool ValueExecutionInfo::HasValueLocally(::pir::Value value) const {
bool ValueExecutionInfo::HasValue(::pir::Value value) const {
auto it = value_2_var_name_.find(value);
if (it != value_2_var_name_.end()) {
return true;
}
return false;
}

std::string ValueExecutionInfo::GetVarNameInternal(::pir::Value value) const {
auto name = GetVarNameLocally(value);
if (name != "") {
return name;
}
return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(value);
}

std::string ValueExecutionInfo::GetVarNameLocally(::pir::Value value) const {
std::string ValueExecutionInfo::GetVarName(::pir::Value value) const {
auto it = value_2_var_name_.find(value);
if (it != value_2_var_name_.end()) {
return it->second;
}
return "";
}

std::string ValueExecutionInfo::GetVarNameInternal(const Variable* var) const {
auto name = GetVarNameLocally(var);
if (name != "") {
return name;
}
return (parent_ == nullptr) ? "" : parent_->GetVarNameInternal(var);
}

std::string ValueExecutionInfo::GetVarNameLocally(const Variable* var) const {
std::string ValueExecutionInfo::GetVarName(const Variable* var) const {
auto it = var_2_var_name_.find(var);
if (it != var_2_var_name_.end()) {
return it->second;
}
return "";
}

int ValueExecutionInfo::GetVarIdInternal(::pir::Value value) const {
auto id = GetVarIdLocally(value);
if (id != -1) {
return id;
}
return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(value);
}

int ValueExecutionInfo::GetVarIdLocally(::pir::Value value) const {
auto var_name = GetVarNameLocally(value);
int ValueExecutionInfo::GetVarId(::pir::Value value) const {
auto var_name = GetVarName(value);
auto it = var_name_2_id_.find(var_name);
if (it != var_name_2_id_.end()) {
return it->second;
}
return -1;
}

int ValueExecutionInfo::GetVarIdInternal(const Variable* var) const {
auto id = GetVarIdLocally(var);
if (id != -1) {
return id;
}
return (parent_ == nullptr) ? -1 : parent_->GetVarIdInternal(var);
}

int ValueExecutionInfo::GetVarIdLocally(const Variable* var) const {
auto var_name = GetVarNameLocally(var);
int ValueExecutionInfo::GetVarId(const Variable* var) const {
auto var_name = GetVarName(var);
auto it = var_name_2_id_.find(var_name);
if (it != var_name_2_id_.end()) {
return it->second;
Expand Down Expand Up @@ -608,8 +542,7 @@ void HandleForInplaceOp(pir::Operation* op,
const std::string& inplace_name = yaml_parser.InplaceName(value_name);
pir::Value inplace_value =
op->operand_source(yaml_parser.InputName2Id().at(inplace_name));
std::string var_name =
value_exe_info->GetValue2VarName().at(inplace_value);
std::string var_name = value_exe_info->GetVarName(inplace_value);
VLOG(4) << "inplace: " << value_name << " -> " << inplace_name
<< " (var: " << var_name << ")";
value_exe_info->AddValue2VarName(value, var_name);
Expand All @@ -618,8 +551,7 @@ void HandleForInplaceOp(pir::Operation* op,
pir::Value view_value =
op->operand_source(yaml_parser.InputName2Id().at(view_name));
// const std::string& var_name = value_2_var_name->at(view_value);
const std::string& var_name =
value_exe_info->GetValue2VarName().at(view_value);
std::string var_name = value_exe_info->GetVarName(view_value);
VLOG(4) << "view: " << value_name << " -> " << view_name
<< " (var: " << var_name << ")";
value_exe_info->AddValue2VarName(value, var_name);
Expand Down
34 changes: 2 additions & 32 deletions paddle/fluid/framework/new_executor/pir_adaptor/pir_adaptor_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,49 +79,19 @@ class ValueExecutionInfo {

void ResetVarList(int id, Variable* var);

/// Check a value exist in the ValueExecutionInfo or any of its ancestors.
bool HasValue(::pir::Value value) const;
bool HasVar(const std::string& var_name) const;

/// Check a value exist in the ValueExecutionInfo.
bool HasLocalValue(::pir::Value value) const;
bool HasValue(::pir::Value value) const;

std::string GetVarName(::pir::Value value) const;

std::string GetVarName(const Variable* var) const;

std::string GetLocalVarName(::pir::Value value) const;

std::string GetLocalVarName(const Variable* var) const;

int GetVarId(::pir::Value value) const;

int GetVarId(const Variable* var) const;

int GetLocalVarId(::pir::Value value) const;

int GetLocalVarId(const Variable* var) const;

private:
bool HasValueInternal(::pir::Value value) const;

bool HasValueLocally(::pir::Value value) const;

std::string GetVarNameInternal(::pir::Value value) const;

std::string GetVarNameLocally(::pir::Value value) const;

std::string GetVarNameInternal(const Variable* var) const;

std::string GetVarNameLocally(const Variable* var) const;

int GetVarIdInternal(::pir::Value value) const;

int GetVarIdLocally(::pir::Value value) const;

int GetVarIdInternal(const Variable* var) const;

int GetVarIdLocally(const Variable* var) const;

std::shared_ptr<ValueExecutionInfo> NewChild(Scope* scope);

ValueExecutionInfo* parent_{nullptr}; // not owned
Expand Down