Skip to content

Commit

Permalink
support inplace in dygraph eager fluid state
Browse files Browse the repository at this point in the history
  • Loading branch information
pangyoki committed Mar 10, 2022
1 parent 9aa6bfc commit 527752b
Show file tree
Hide file tree
Showing 10 changed files with 331 additions and 36 deletions.
3 changes: 2 additions & 1 deletion paddle/fluid/eager/api/utils/tensor_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ namespace egr_utils_api {

bool IsLeafTensor(const paddle::experimental::Tensor& target) {
std::shared_ptr<GradNodeBase> grad_node = EagerUtils::grad_node(target);
if (std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node)) {
if (!grad_node ||
std::dynamic_pointer_cast<GradNodeAccumulation>(grad_node)) {
return true;
}

Expand Down
162 changes: 132 additions & 30 deletions paddle/fluid/eager/auto_code_generator/eager_generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -976,7 +976,9 @@ static bool CollectGradInformationFromOpInfo(
/* --------------------------------------------------- */
static std::string GenerateGradNodeCreationContent(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
const GradNodeGenerationInfo& bwd_info,
const std::string& trace_op_body_str,
const std::map<std::string, std::string>& inplace_map = {}) {
VLOG(6) << "Generating GradNode Creation codes";

const std::string& op_type = fwd_info.GetOpType();
Expand All @@ -995,7 +997,9 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
std::string get_autograd_meta_str = " // Prepare Autograd Meta \n";
std::string get_input_autograd_meta_str = " // Prepare Autograd Meta \n";
// output autograd_meta should be got after run TraceOP.
std::string get_output_autograd_meta_str = "";
// If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
Expand All @@ -1007,13 +1011,13 @@ static std::string GenerateGradNodeCreationContent(
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::autograd_meta(&%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
get_output_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
} else {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::autograd_meta(&%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
get_output_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, output_autograd_name, output_name);
}
}
Expand All @@ -1027,28 +1031,42 @@ static std::string GenerateGradNodeCreationContent(
const char* GET_MULTI_AUTOGRAD_META_TEMPLATE =
" std::vector<egr::AutogradMeta*> %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
get_input_autograd_meta_str += paddle::string::Sprintf(
GET_MULTI_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);

} else if (input.dispensable()) {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
get_input_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);

} else {
const char* GET_SINGLE_AUTOGRAD_META_TEMPLATE =
" egr::AutogradMeta* %s = "
"egr::EagerUtils::nullable_autograd_meta(%s);\n";
get_autograd_meta_str += paddle::string::Sprintf(
get_input_autograd_meta_str += paddle::string::Sprintf(
GET_SINGLE_AUTOGRAD_META_TEMPLATE, input_autograd_name, input_name);
}
}
VLOG(6) << "Generated inputs autograd_meta";

std::string check_inplace_str = "";
if (!inplace_map.empty()) {
const char* CHECKING_INPLACE_TEMPLATE =
" // Check Inplace and Bump Inplace Version\n"
" egr::EagerUtils::CheckInplace(%s, p_autograd_%s, "
"require_any_grad);\n";
for (auto& inplace_pair : inplace_map) {
std::string inplace_name = inplace_pair.second;
check_inplace_str += paddle::string::Sprintf(CHECKING_INPLACE_TEMPLATE,
inplace_name, inplace_name);
}
VLOG(6) << "Check Inplace Input and Bump Version";
}

std::string prepare_autograd_meta_str = "";
prepare_autograd_meta_str += get_autograd_meta_str;
prepare_autograd_meta_str += get_input_autograd_meta_str;
prepare_autograd_meta_str += "\n";

// [GradOpNode] GetTraceBackward
Expand Down Expand Up @@ -1200,14 +1218,18 @@ static std::string GenerateGradNodeCreationContent(
// [Generation] GradNode Creation
const char* GRAD_NODE_CREATION_TEMPLATE =
" %s"
" bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(%s);\n"
" bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(%s);\n\n"
"%s\n"
"%s"
"%s"
" if(require_any_grad) {\n"
" VLOG(6) << \" Construct Grad for %s \"; \n"
" egr::EagerUtils::PassStopGradient(%s);\n"
"%s\n }";
std::string grad_node_creation_body_str = paddle::string::Sprintf(
GRAD_NODE_CREATION_TEMPLATE, prepare_autograd_meta_str,
compute_require_grad_args, op_type, pass_stop_gradient_args,
compute_require_grad_args, check_inplace_str, trace_op_body_str,
get_output_autograd_meta_str, op_type, pass_stop_gradient_args,
grad_node_creation_str);

return grad_node_creation_body_str;
Expand All @@ -1218,7 +1240,8 @@ static std::string GenerateGradNodeCreationContent(
/* -------------------------------- */
static std::pair<std::string, std::string> GenerateForwardFunctionContents(
const ForwardGenerationInfo& fwd_info,
const GradNodeGenerationInfo& bwd_info) {
const GradNodeGenerationInfo& bwd_info,
std::map<std::string, std::string> inplace_map = {}) {
/* --- Process Forward Info ---*/
const std::string& op_type = fwd_info.GetOpType();
const std::unordered_map<std::string, size_t>& fwd_inputs_name_pos_map =
Expand Down Expand Up @@ -1298,8 +1321,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(

core_ops_args_type_info[op_type][input_position] = "list";
} else {
const char* FWD_INS_ARG_TEMPLATE =
"const paddle::experimental::Tensor& %s";
const char* FWD_INS_ARG_TEMPLATE;
bool flag_find_input_name = false;
if (!inplace_map.empty()) {
for (auto& inplace_pair : inplace_map) {
if (inplace_pair.second == input_name) {
flag_find_input_name = true;
FWD_INS_ARG_TEMPLATE = "paddle::experimental::Tensor& %s";
break;
}
}
}
if (!flag_find_input_name) {
FWD_INS_ARG_TEMPLATE = "const paddle::experimental::Tensor& %s";
}
input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);

Expand Down Expand Up @@ -1359,6 +1394,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(

// [Generation] Get Outs Map
std::string outs_contents_str = "";
std::string inplace_mapping_str = "";
for (const proto::OpProto::Var& output : out_vars) {
const std::string& output_name = output.name();
std::string outnum = "1";
Expand Down Expand Up @@ -1401,6 +1437,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
core_ops_args_info[op_type].push_back(output_var_name);

} else if (!inplace_map.empty() && inplace_map.count(output_name)) {
PADDLE_ENFORCE_NE(
inplace_map[output_name], "",
paddle::platform::errors::InvalidArgument(
"Inplace op %s has no input corresponding to output %s.", op_type,
output_name));
const char* FWD_OUTS_CONTENT_TEMPLATE = "{ \"%s\", ins[\"%s\"] },";
auto inplace_input_name = inplace_map[output_name];
outs_contents_str += paddle::string::Sprintf(
FWD_OUTS_CONTENT_TEMPLATE, output_name, inplace_input_name);

const char* INPLACE_MAPPING_TEMPLATE = R"({"%s", "%s"},)";
inplace_mapping_str += paddle::string::Sprintf(
INPLACE_MAPPING_TEMPLATE, inplace_input_name, output_name);
} else {
if (output.duplicable()) {
outnum = output_name + "Num";
Expand All @@ -1427,6 +1477,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
}
if (outs_contents_str.size() > 0)
outs_contents_str.pop_back(); // Remove trailing ","
if (inplace_mapping_str.size() > 0)
inplace_mapping_str.pop_back(); // Remove trailing ","

const char* FWD_OUTS_MAP_TEMPLATE =
" std::map<std::string, "
Expand Down Expand Up @@ -1460,18 +1512,20 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
dygraph_function_args_str +=
", const paddle::framework::AttributeMap& attr_map";

std::string trace_op_body_str = "";
// [Generation] Get TraceOp
const char* FWD_TRACE_OP_TEMPLATE =
" paddle::framework::AttributeMap attrs = attr_map;\n"
" paddle::framework::AttributeMap default_attrs;\n"
" egr::Controller::Instance().GetCurrentTracer()->TraceOp(\"%s\", ins, "
"outs, attrs, \n"
" egr::Controller::Instance().GetExpectedPlace(),\n"
" &default_attrs, true, {});\n";
std::string trace_op_str =
paddle::string::Sprintf(FWD_TRACE_OP_TEMPLATE, op_type);
generated_function_body += trace_op_str;
generated_function_body += "\n";
" &default_attrs, true, {%s});\n";
std::string trace_op_str = paddle::string::Sprintf(
FWD_TRACE_OP_TEMPLATE, op_type, inplace_mapping_str);

trace_op_body_str += trace_op_str;
trace_op_body_str += "\n";

VLOG(6) << "Generated AttrMap & TraceOp";

Expand Down Expand Up @@ -1536,34 +1590,56 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
output_varname, output_var_args_name);
}
} else {
const char* FWD_OUT_TENSOR_TEMPLATE =
" paddle::experimental::Tensor %s;\n"
" egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n";
out_tensor_str =
paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, output_varname,
output_name, output_varname);
if (!inplace_map.empty() && inplace_map.count(output_name)) {
auto inplace_input_name = inplace_map[output_name];
const char* FWD_OUT_TENSOR_TEMPLATE =
" paddle::experimental::Tensor %s;\n"
" egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n"
" //egr::EagerUtils::ModifyInplaceInput(outs[\"%s\"][0], &%s);\n"
" %s.set_inplace_version(%s.current_inplace_version());\n"
" %s = std::move(%s);\n"
" VLOG(6) << \"Modify Inplace input tensor (\" << %s.name() << "
"\").\";\n"
" %s.bump_inplace_version();\n"
" VLOG(3) << \"Tensor(\" << %s.name() << \") uses Inplace "
"Strategy.\";\n";
out_tensor_str = paddle::string::Sprintf(
FWD_OUT_TENSOR_TEMPLATE, output_varname, output_name,
output_varname, output_name, inplace_input_name, output_varname,
inplace_input_name, inplace_input_name, output_varname,
inplace_input_name, inplace_input_name, inplace_input_name);
} else {
const char* FWD_OUT_TENSOR_TEMPLATE =
" paddle::experimental::Tensor %s;\n"
" egr::EagerUtils::GetOutput(outs[\"%s\"][0], &%s);\n";
out_tensor_str =
paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, output_varname,
output_name, output_varname);
}
}
return_types[return_position] = "paddle::experimental::Tensor";
}

return_contents[return_position] = output_varname;
generated_function_body += out_tensor_str;
trace_op_body_str += out_tensor_str;
}
generated_function_body += "\n";
trace_op_body_str += "\n";
VLOG(6) << "Converted Output VarBase to EagerVariable(s)";

// [Generation] Handle core_ops_returns_info
core_ops_returns_info[op_type] = return_contents;

// [Generation] ComputeRequireGrad -> GradNodeCreation
if (!bwd_info.GenerateForwardOnly()) {
std::string grad_node_creation_body_str =
GenerateGradNodeCreationContent(fwd_info, bwd_info);
std::string grad_node_creation_body_str = GenerateGradNodeCreationContent(
fwd_info, bwd_info, trace_op_body_str, inplace_map);
generated_function_body += grad_node_creation_body_str;
generated_function_body += "\n";

// [Generation] Call RetainGradForTensor
VLOG(6) << "Generated GradNode Creation codes";
} else {
generated_function_body += trace_op_body_str;
}

// [Generation] Handle return: Tuple/Vector/Tensor
Expand Down Expand Up @@ -1610,7 +1686,12 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
VLOG(6) << "Generated return codes";

// [Generation] Get Full Function
std::string function_name = op_type + "_dygraph_function";
std::string function_name;
if (inplace_map.empty()) {
function_name = op_type + "_dygraph_function";
} else {
function_name = op_type + "__dygraph_function";
}

if (dygraph_function_args_str.size() > 0) {
auto iter = dygraph_function_args_str.begin();
Expand Down Expand Up @@ -2379,14 +2460,35 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
/* --------------------------- */
VLOG(6) << "-------- GenerateForwardFunctionContents -------";
std::pair<std::string, std::string> body_and_declaration =
GenerateForwardFunctionContents(fwd_info, bwd_info);
GenerateForwardFunctionContents(fwd_info, bwd_info, {});

fwd_function_str += body_and_declaration.first + "\n";

VLOG(6) << "-------- GenerateDygraphForwardAPIContents -------";
std::string fwd_function_declare_str = body_and_declaration.second;
dygraph_forward_api_str += fwd_function_declare_str;

auto& infer_inplace =
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_;
std::map<std::string, std::string> inplace_map;
if (op_type != "sum" && infer_inplace) {
auto in_to_outs = infer_inplace(true);
for (auto& inplace_pair : in_to_outs) {
inplace_map[inplace_pair.second] = inplace_pair.first;
}

VLOG(6) << "-------- GenerateInplaceForwardFunctionContents -------";
std::pair<std::string, std::string> inplace_body_and_declaration =
GenerateForwardFunctionContents(fwd_info, bwd_info, inplace_map);

fwd_function_str += inplace_body_and_declaration.first + "\n";

VLOG(6) << "-------- GenerateInplaceDygraphForwardAPIContents -------";
std::string inplace_fwd_function_declare_str =
inplace_body_and_declaration.second;
dygraph_forward_api_str += inplace_fwd_function_declare_str;
}

if (bwd_info.GenerateForwardOnly()) continue;

VLOG(6) << "-------- GenerateGradNodeHeaderContents -------";
Expand Down
Loading

0 comments on commit 527752b

Please sign in to comment.