-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
support inplace in dygraph eager_fluid state #40400
Changes from all commits
9fc70fe
ba8d79e
137db9d
1a18aa2
d09ec3b
f84f2be
733672e
68b1991
7665d63
86393f5
c653ec0
9156cea
6fd613d
58731e9
a88f9b1
778719b
65cf9e3
af7b919
4d3b57d
415ff65
bb283ce
e548c22
519c9a6
1fbc61b
c0a2b8b
536a28b
2417858
34fa7c0
1b89072
af7f058
f397b8f
e3f9826
27830a9
7ede919
f9adf49
2fe3b9f
90e97d6
8c27961
daff8bd
db573fe
f0d8f65
d18697a
1b5eac2
f4e42e2
58a03b5
ea41a1c
b04e9a9
c7bd6fc
7bb3cbd
ac85d81
945d282
8312d2d
441bc81
326eee5
3562b64
a0ec433
df99eea
cecc6e1
9491a06
894791c
41dec57
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -718,6 +718,15 @@ static PyObject* set_grad_type(TensorObject* self, PyObject* args, | |
EAGER_CATCH_AND_THROW_RETURN_NULL | ||
} | ||
|
||
static PyObject* tensor__inplace_version(TensorObject* self, PyObject* args, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. single underscore "_" in function name? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. its ok if this method indicate |
||
PyObject* kwargs) { | ||
EAGER_TRY | ||
uint32_t inplace_version = self->tensor.current_inplace_version(); | ||
|
||
return ToPyObject(inplace_version); | ||
EAGER_CATCH_AND_THROW_RETURN_NULL | ||
} | ||
|
||
PyMethodDef variable_methods[] = { | ||
{"numpy", (PyCFunction)(void (*)(void))tensor_method_numpy, | ||
METH_VARARGS | METH_KEYWORDS, NULL}, | ||
|
@@ -766,6 +775,8 @@ PyMethodDef variable_methods[] = { | |
METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"_set_grad_type", (PyCFunction)(void (*)(void))set_grad_type, | ||
METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{"_inplace_version", (PyCFunction)(void (*)(void))tensor__inplace_version, | ||
METH_VARARGS | METH_KEYWORDS, NULL}, | ||
{NULL, NULL, 0, NULL}}; | ||
|
||
} // namespace pybind | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -162,17 +162,22 @@ static inline std::string TempName(const std::string& name) { | |
|
||
std::string GenerateOpFunctionsBody( | ||
const paddle::framework::proto::OpProto* op_proto, std::string func_name, | ||
bool use_inplace_strategy = false, | ||
std::map<std::string, std::string> inplace_map = {}) { | ||
auto& op_type = op_proto->type(); | ||
std::string input_args = ""; | ||
std::string call_api_str = "auto out = " + op_type + "_dygraph_function("; | ||
std::string call_api_str = ""; | ||
std::string ins_initializer_with_null = ""; | ||
std::string py_arg = ""; | ||
int arg_idx = 0; | ||
int input_args_num = 0; | ||
std::string ins_cast_str = ""; | ||
std::string view_strategy_str = ""; | ||
if (!inplace_map.empty()) { | ||
// change call_api_str for inplace op | ||
call_api_str = "auto out = " + op_type + "__dygraph_function("; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better add "" at the very end of the function name, like "scale_dygraph_function" for inplaced scale |
||
} else { | ||
call_api_str = "auto out = " + op_type + "_dygraph_function("; | ||
} | ||
for (auto& input : op_proto->inputs()) { | ||
auto& in_name = input.name(); | ||
// skip those dispensable inputs, like ResidualData in conv2d | ||
|
@@ -288,8 +293,31 @@ std::string GenerateOpFunctionsBody( | |
HANDLE_VIEW_BETWEEN_INPUT_AND_OUTPUT, viwe_input_name, viwe_output_name, | ||
viwe_input_name, viwe_output_name); | ||
} | ||
|
||
return_str = "return ToPyObject(out);"; | ||
if (!inplace_map.empty()) { | ||
// For inplace op, Use the input PyObject directly. | ||
for (auto& inplace_pair : inplace_map) { | ||
// Find index of inplace tensor, and directly use input PyObject. | ||
std::string inplace_arg_name = inplace_pair.second; | ||
std::string inplace_return_name = inplace_pair.first; | ||
const char* RETURN_INPLACE_TENSOR_TEMPLATE = | ||
"ssize_t arg_id = GetIdxFromCoreOpsInfoMap(core_ops_args_info, " | ||
"\"%s\", \"%s\");\n" | ||
" ssize_t return_id = " | ||
"GetIdxFromCoreOpsInfoMap(core_ops_returns_info, \"%s\", \"%s\");\n" | ||
" return ToPyObject(out, return_id, args, arg_id);"; | ||
return_str = paddle::string::Sprintf(RETURN_INPLACE_TENSOR_TEMPLATE, | ||
op_type, inplace_arg_name, op_type, | ||
inplace_return_name); | ||
// only support one inplace_var in temporary. | ||
PADDLE_ENFORCE_EQ( | ||
inplace_map.size(), 1, | ||
paddle::platform::errors::InvalidArgument( | ||
"size of inplace_map must be 1, but got %d", inplace_map.size())); | ||
break; | ||
} | ||
} else { | ||
return_str = "return ToPyObject(out);"; | ||
} | ||
|
||
std::string function_args = ""; | ||
if (input_args == "") { | ||
|
@@ -383,14 +411,49 @@ GenerateOpFunctions() { | |
continue; | ||
} | ||
std::string func_name = "eager_api_" + op_type; | ||
std::string op_function_str = GenerateOpFunctionsBody(op_proto, func_name); | ||
std::string op_function_str = | ||
GenerateOpFunctionsBody(op_proto, func_name, {}); | ||
|
||
// generate pybind item | ||
auto bind_function_str = paddle::string::Sprintf( | ||
PYBIND_ITEM_TEMPLATE, op_type, func_name, op_type); | ||
|
||
op_function_list.emplace_back(std::move(op_function_str)); | ||
bind_function_list.emplace_back(std::move(bind_function_str)); | ||
|
||
// NOTE(pangyoki): Inplace Strategy. | ||
// In this case, output will reuse input varbase. | ||
// Dygraph mode needs to be aligned with the in-place strategy in static | ||
// mode, and the mapping relationships between output and input that have | ||
// been defined in static mode should be used in dygraph mode. | ||
// Find which ops need to use Inplace strategy in static mode, and get the | ||
// mapping relationship between Inplace output and input. | ||
auto& infer_inplace = | ||
paddle::framework::OpInfoMap::Instance().Get(op_type).infer_inplace_; | ||
std::map<std::string, std::string> inplace_map; | ||
// `sum` op has duplicate input. Don't consider adding inplace strategy | ||
// for `sum` in temporary. | ||
if (op_type != "sum" && infer_inplace) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Better store hard-coded op name in a static set There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done in PR #41118 |
||
// Inplace OP: op_type_. | ||
// The inplace OP needs a new implementation method. | ||
auto in_to_outs = infer_inplace(true); | ||
for (auto& inplace_pair : in_to_outs) { | ||
inplace_map[inplace_pair.second] = inplace_pair.first; | ||
} | ||
|
||
std::string inplace_op_type = op_type + "_"; | ||
std::string inplace_func_name = "eager_api_" + inplace_op_type; | ||
std::string inplace_op_function_str = | ||
GenerateOpFunctionsBody(op_proto, inplace_func_name, inplace_map); | ||
|
||
// generate pybind item | ||
auto inplace_bind_function_str = | ||
paddle::string::Sprintf(PYBIND_ITEM_TEMPLATE, inplace_op_type, | ||
inplace_func_name, inplace_op_type); | ||
|
||
op_function_list.emplace_back(std::move(inplace_op_function_str)); | ||
bind_function_list.emplace_back(std::move(inplace_bind_function_str)); | ||
} | ||
} | ||
if (append_custom_head_file) { | ||
op_function_list.emplace_back(CUSTOM_HANDWRITE_OP_FUNC_FILE); | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like we're gonna check inplace version anyway, let's move this function "check_inplace_version" out.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done in PR #41118