Skip to content

Commit

Permalink
change vjp interface gen list to black list (PaddlePaddle#58145)
Browse files Browse the repository at this point in the history
  • Loading branch information
changeyoung98 authored and wentaoyu committed Oct 24, 2023
1 parent 1d4cc71 commit 7bf1680
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 265 deletions.
9 changes: 3 additions & 6 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,7 @@
)
from op_member_func_gen import gen_op_get_inputs_outputs_str
from op_verify_gen import gen_verify_func_str
from vjp_interface_gen_op_list import (
vjp_interface_declare_gen_op_list,
vjp_interface_implementation_gen_op_list,
)
from vjp_interface_black_list import vjp_interface_black_list

# import from paddle/fluid/primitive/code_gen/gen.py
sys.path.append(
Expand Down Expand Up @@ -1036,7 +1033,7 @@ def OpGenerator(

if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list
and op_info.op_phi_name[0] not in vjp_interface_black_list
):
op_interfaces += ["paddle::dialect::VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(
Expand Down Expand Up @@ -1444,7 +1441,7 @@ def OpGenerator(
if (
op_info.backward_name
and op_info.op_phi_name[0]
in vjp_interface_implementation_gen_op_list
not in vjp_interface_black_list
):
op_vjp_str = gen_op_vjp_str(
op_class_name,
Expand Down
79 changes: 54 additions & 25 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

# generator interfaces
from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list
from vjp_interface_black_list import vjp_interface_black_list

OP_INFER_SHAPE_TEMPLATE = """
void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{
Expand All @@ -26,12 +26,12 @@
{input_type} {input_name}(std::make_shared<primitive::LazyTensor>(op_obj.{input_name}()));"""

OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """
pir::CombineOp combine_op_obj =
pir::CombineOp combine_op_obj_{input_name} =
op_obj.{input_name}().dyn_cast<pir::OpResult>().owner()->dyn_cast<pir::CombineOp>();
std::vector<Tensor> {input_name};
for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{
for (size_t idx = 0; idx < combine_op_obj_{input_name}.inputs().size(); idx++) {{
{input_name}.emplace_back(
std::make_shared<primitive::LazyTensor>(combine_op_obj.inputs()[idx]));
std::make_shared<primitive::LazyTensor>(combine_op_obj_{input_name}.inputs()[idx]));
}}"""

OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE = """
Expand Down Expand Up @@ -63,6 +63,23 @@
std::make_shared<primitive::LazyTensor>(out_grads[{index}][idx]));
}}"""

OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE = """
paddle::optional<Tensor> {output_grad_name};
if (!IsEmptyValue(out_grads[{idx1}][{idx2}])){{
{output_grad_name} = paddle::make_optional<Tensor>(Tensor(std::make_shared<primitive::LazyTensor>(out_grads[{idx1}][{idx2}])));
}}"""

OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE = """
paddle::optional<std::vector<Tensor>> {output_grad_name};
std::vector<Tensor> optional_{output_grad_name};
if (!IsEmptyValue(out_grads[{index}])){{
for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{
optional_{output_grad_name}.emplace_back(
std::make_shared<primitive::LazyTensor>(out_grads[{index}][idx]));
}}
{output_grad_name} = paddle::make_optional<std::vector<Tensor>>(optional_{output_grad_name});
}}"""

OP_VJP_ATTRIBUTE_TEMPLATE = """
{attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();"""

Expand Down Expand Up @@ -131,26 +148,25 @@ def gen_op_vjp_str(
grad_idx = -1
for idx in range(len(bw_input_list)):
build_args_str += bw_input_list[idx] + ", "
if op_grad_info.input_optional_list[idx] == 'true':
input_type = input_types_map[op_grad_info.input_type_list[idx]]
if input_type == 'Tensor':
forward_input_output_code += (
OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format(
input_name=bw_input_list[idx],
input_type = input_types_map[op_grad_info.input_type_list[idx]]
if (
bw_input_list[idx] in op_info.input_name_list
or bw_input_list[idx] in op_info.output_name_list
):
if op_grad_info.input_optional_list[idx] == 'true':
if input_type == 'Tensor':
forward_input_output_code += (
OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format(
input_name=bw_input_list[idx],
)
)
)
else:
forward_input_output_code += (
OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format(
input_name=bw_input_list[idx],
else:
forward_input_output_code += (
OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format(
input_name=bw_input_list[idx],
)
)
)
else:
if (
bw_input_list[idx] in op_info.input_name_list
or bw_input_list[idx] in op_info.output_name_list
):
input_type = input_types_map[op_grad_info.input_type_list[idx]]
else:
if input_type == 'Tensor':
forward_input_output_code += (
OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format(
Expand All @@ -164,9 +180,22 @@ def gen_op_vjp_str(
input_name=bw_input_list[idx],
)
)
else:
grad_idx += 1
if op_grad_info.input_optional_list[idx] == 'true':
if input_type == 'Tensor':
forward_input_output_code += (
OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE.format(
output_grad_name=bw_input_list[idx],
idx1=grad_idx,
idx2=0,
)
)
else:
forward_input_output_code += OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE.format(
output_grad_name=bw_input_list[idx], index=grad_idx
)
else:
grad_idx += 1
input_type = input_types_map[op_grad_info.input_type_list[idx]]
if input_type == 'Tensor':
forward_output_grad_code += (
OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format(
Expand Down Expand Up @@ -285,6 +314,6 @@ def gen_exclusive_interface_str(op_info, op_info_items):
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list:
if op_info.op_phi_name[0] not in vjp_interface_black_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str
36 changes: 36 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# =====================================
# VjpInterface gen op list
# =====================================
# we don't support vjp function code
# gen now, so we use a whitelist to
# control the generation of Vjp methods.
# TODO(wanghao107)
# remove this file and support Vjp methods
# code gen.


vjp_interface_black_list = [
'frobenius_norm',
'write_to_array',
'fused_attention',
'fused_feedforward',
'set_value',
'set_value_with_tensor',
'silu_grad',
'fused_dropout_add',
'fused_rotary_position_embedding',
]
Loading

0 comments on commit 7bf1680

Please sign in to comment.