From 7bf1680b6ecc5178cf13d868d7a9b0a509f4d33c Mon Sep 17 00:00:00 2001 From: Chen Zhiyang <1792266893@qq.com> Date: Tue, 17 Oct 2023 17:42:09 +0800 Subject: [PATCH] change vjp interface gen list to black list (#58145) --- .../fluid/pir/dialect/op_generator/op_gen.py | 9 +- .../dialect/op_generator/op_interface_gen.py | 79 ++++-- .../op_generator/vjp_interface_black_list.py | 36 +++ .../op_generator/vjp_interface_gen_op_list.py | 229 ------------------ paddle/fluid/pir/dialect/operator/ir/ops.yaml | 3 - paddle/phi/api/yaml/legacy_backward.yaml | 4 +- 6 files changed, 95 insertions(+), 265 deletions(-) create mode 100644 paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py delete mode 100644 paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 64caafc5448924..15d2e4139e753d 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -27,10 +27,7 @@ ) from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str -from vjp_interface_gen_op_list import ( - vjp_interface_declare_gen_op_list, - vjp_interface_implementation_gen_op_list, -) +from vjp_interface_black_list import vjp_interface_black_list # import from paddle/fluid/primitive/code_gen/gen.py sys.path.append( @@ -1036,7 +1033,7 @@ def OpGenerator( if ( op_info.backward_name - and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list + and op_info.op_phi_name[0] not in vjp_interface_black_list ): op_interfaces += ["paddle::dialect::VjpInterface"] exclusive_interface_str = gen_exclusive_interface_str( @@ -1444,7 +1441,7 @@ def OpGenerator( if ( op_info.backward_name and op_info.op_phi_name[0] - in vjp_interface_implementation_gen_op_list + not in vjp_interface_black_list ): op_vjp_str = gen_op_vjp_str( op_class_name, diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 9c8ff889f2b219..3eb0179ca48e42 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -13,7 +13,7 @@ # limitations under the License. # generator interfaces -from vjp_interface_gen_op_list import vjp_interface_declare_gen_op_list +from vjp_interface_black_list import vjp_interface_black_list OP_INFER_SHAPE_TEMPLATE = """ void {op_name}::InferMeta( phi::InferMetaContext *infer_meta ) {{ @@ -26,12 +26,12 @@ {input_type} {input_name}(std::make_shared(op_obj.{input_name}()));""" OP_VJP_FORWARD_MULTI_INPUT_TEMPLATE = """ - pir::CombineOp combine_op_obj = + pir::CombineOp combine_op_obj_{input_name} = op_obj.{input_name}().dyn_cast().owner()->dyn_cast(); std::vector {input_name}; - for (size_t idx = 0; idx < combine_op_obj.inputs().size(); idx++) {{ + for (size_t idx = 0; idx < combine_op_obj_{input_name}.inputs().size(); idx++) {{ {input_name}.emplace_back( - std::make_shared(combine_op_obj.inputs()[idx])); + std::make_shared(combine_op_obj_{input_name}.inputs()[idx])); }}""" OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE = """ @@ -63,6 +63,23 @@ std::make_shared(out_grads[{index}][idx])); }}""" +OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE = """ + paddle::optional {output_grad_name}; + if (!IsEmptyValue(out_grads[{idx1}][{idx2}])){{ + {output_grad_name} = paddle::make_optional(Tensor(std::make_shared(out_grads[{idx1}][{idx2}]))); + }}""" + +OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE = """ + paddle::optional> {output_grad_name}; + std::vector optional_{output_grad_name}; + if (!IsEmptyValue(out_grads[{index}])){{ + for (size_t idx = 0; idx < out_grads[{index}].size(); idx++) {{ + optional_{output_grad_name}.emplace_back( + std::make_shared(out_grads[{index}][idx])); + }} + {output_grad_name} = paddle::make_optional>(optional_{output_grad_name}); + }}""" + OP_VJP_ATTRIBUTE_TEMPLATE = """ {attr_type} {attr_name} = op->attribute("{attr_name}").dyn_cast<{attr_parse_type}>().{func}();""" @@ -131,26 +148,25 @@ def gen_op_vjp_str( grad_idx = -1 for idx in range(len(bw_input_list)): build_args_str += bw_input_list[idx] + ", " - if op_grad_info.input_optional_list[idx] == 'true': - input_type = input_types_map[op_grad_info.input_type_list[idx]] - if input_type == 'Tensor': - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], + input_type = input_types_map[op_grad_info.input_type_list[idx]] + if ( + bw_input_list[idx] in op_info.input_name_list + or bw_input_list[idx] in op_info.output_name_list + ): + if op_grad_info.input_optional_list[idx] == 'true': + if input_type == 'Tensor': + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], + ) ) - ) - else: - forward_input_output_code += ( - OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( - input_name=bw_input_list[idx], + else: + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_VECTOR_INPUT_TEMPLATE.format( + input_name=bw_input_list[idx], + ) ) - ) - else: - if ( - bw_input_list[idx] in op_info.input_name_list - or bw_input_list[idx] in op_info.output_name_list - ): - input_type = input_types_map[op_grad_info.input_type_list[idx]] + else: if input_type == 'Tensor': forward_input_output_code += ( OP_VJP_FORWARD_INPUT_OR_OUTPUT_TEMPLATE.format( @@ -164,9 +180,22 @@ def gen_op_vjp_str( input_name=bw_input_list[idx], ) ) + else: + grad_idx += 1 + if op_grad_info.input_optional_list[idx] == 'true': + if input_type == 'Tensor': + forward_input_output_code += ( + OP_VJP_FORWARD_OPTIONAL_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], + idx1=grad_idx, + idx2=0, + ) + ) + else: + forward_input_output_code += OP_VJP_FORWARD_OPTIONAL_VECTOR_OUTPUT_GRAD_TEMPLATE.format( + output_grad_name=bw_input_list[idx], index=grad_idx + ) else: - grad_idx += 1 - input_type = input_types_map[op_grad_info.input_type_list[idx]] if input_type == 'Tensor': forward_output_grad_code += ( OP_VJP_FORWARD_OUTPUT_GRAD_TEMPLATE.format( @@ -285,6 +314,6 @@ def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str += ( " static void InferMeta( phi::InferMetaContext *infer_meta );" ) - if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: + if op_info.op_phi_name[0] not in vjp_interface_black_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py new file mode 100644 index 00000000000000..c63e0c4e418338 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_black_list.py @@ -0,0 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# ===================================== +# VjpInterface gen op list +# ===================================== +# we don't support vjp function code +# gen now, so we use a whitelist to +# control the generation of Vjp methods. +# TODO(wanghao107) +# remove this file and support Vjp methods +# code gen. + + +vjp_interface_black_list = [ + 'frobenius_norm', + 'write_to_array', + 'fused_attention', + 'fused_feedforward', + 'set_value', + 'set_value_with_tensor', + 'silu_grad', + 'fused_dropout_add', + 'fused_rotary_position_embedding', +] diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py deleted file mode 100644 index 58abcbf1143b9f..00000000000000 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ /dev/null @@ -1,229 +0,0 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# ===================================== -# VjpInterface gen op list -# ===================================== -# we don't support vjp function code -# gen now, so we use a whitelist to -# control the generation of Vjp methods. -# TODO(wanghao107) -# remove this file and support Vjp methods -# code gen. - - -vjp_interface_declare_gen_op_list = [ - 'where', - "tanh", - "mean", - "divide", - "sum", - "add", - "concat", - "split", - "split_with_num", - "gelu", - "matmul", - "erf", - "multiply", - "pow", - "rsqrt", - "subtract", - "square", - "dropout", - 'exp', - 'expm1', - 'expand', - 'layer_norm', - 'reshape', - 'cast', - "scale", - 'softmax', - 'silu', - 'elementwise_pow', - 'embedding', - 'fused_softmax_mask_upper_triangle', - 'slice', - 'transpose', - 'slice_grad', - 'gather_nd', - 'stack', - 'poisson', - 'gumbel_softmax', - 'pad', - 'pad3d', - 'squeeze', - 'unsqueeze', - 'tril', - 'triu', - 'squeeze', - 'unsqueeze', - 'conv2d', - 'depthwise_conv2d', - 'sqrt', - 'flatten', - 'relu', - 'abs', - 'log', - 'clip', - 'ceil', - 'p_norm', - 'maximum', - 'argsort', - 'min', - 'max', - 'batch_norm', - 'max_pool2d_with_index', - 'pool2d', - 'minimum', - 'prod', - 'round', - 'sin', - 'cos', - 'dot', - 'floor', - 'topk', - 'square', - 'gather', - 'label_smooth', - 'cross_entropy_with_softmax', - 'mean_all', - 'cumsum', - 'linear_interp', - 'bilinear_interp', - 'trilinear_interp', - 'nearest_interp', - 'bicubic_interp', - 'assign', - 'assign_out_', - 'real', - 'flip', - 'softmax', - 'expand', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'sigmoid', - 'pad', - 'pad3d', - 'einsum', - 'leaky_relu', - 'log10', - 'conv3d', - 'solve', - 'diag', - 'trace', - 'tile', -] -vjp_interface_implementation_gen_op_list = [ - 'where', - "tanh", - "mean", - "divide", - "sum", - "add", - "concat", - "split", - "split_with_num", - "gelu", - "matmul", - "erf", - "multiply", - "subtract", - "pow", - "rsqrt", - "square", - "dropout", - 'exp', - 'expm1', - 'expand', - 'layer_norm', - 'reshape', - 'cast', - "scale", - 'softmax', - 'silu', - 'elementwise_pow', - 'embedding', - 'fused_softmax_mask_upper_triangle', - 'slice', - 'transpose', - 'slice_grad', - 'gather_nd', - 'stack', - 'poisson', - 'gumbel_softmax', - 'pad', - 'pad3d', - 'squeeze', - 'unsqueeze', - 'tril', - 'triu', - 'squeeze', - 'unsqueeze', - 'conv2d', - 'depthwise_conv2d', - 'sqrt', - 'flatten', - 'relu', - 'abs', - 'log', - 'clip', - 'ceil', - 'p_norm', - 'maximum', - 'argsort', - 'min', - 'max', - 'batch_norm', - 'max_pool2d_with_index', - 'pool2d', - 'minimum', - 'prod', - 'round', - 'sin', - 'cos', - 'dot', - 'floor', - 'topk', - 'square', - 'gather', - 'label_smooth', - 'cross_entropy_with_softmax', - 'mean_all', - 'cumsum', - 'linear_interp', - 'bilinear_interp', - 'trilinear_interp', - 'nearest_interp', - 'bicubic_interp', - 'assign', - 'assign_out_', - 'real', - 'flip', - 'softmax', - 'expand', - 'conv2d_transpose', - 'depthwise_conv2d_transpose', - 'sigmoid', - 'pad', - 'pad3d', - 'einsum', - 'leaky_relu', - 'log10', - 'conv3d', - 'solve', - 'diag', - 'trace', - 'tile', -] diff --git a/paddle/fluid/pir/dialect/operator/ir/ops.yaml b/paddle/fluid/pir/dialect/operator/ir/ops.yaml index 5a2da284142ad3..899863d58aba12 100644 --- a/paddle/fluid/pir/dialect/operator/ir/ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/ops.yaml @@ -7,7 +7,6 @@ kernel: func: add_n param: [inputs] - backward: add_n_grad - op : add_n_with_kernel args : (Tensor[] inputs) @@ -18,7 +17,6 @@ kernel: func: add_n param: [inputs] - backward: add_n_grad - op : assert args : (Tensor cond, Tensor[] data, int64_t summarize = -1) @@ -175,7 +173,6 @@ - op : write_to_array args : (Tensor i, Tensor x) output : Tensor[](out) - backward: write_to_array_grad - op: dpsgd args: (Tensor param, Tensor grad, Tensor learning_rate, float clip = 10.0f, float batch_size = 16.0f, float sigma = 1.0f, int seed = 0) diff --git a/paddle/phi/api/yaml/legacy_backward.yaml b/paddle/phi/api/yaml/legacy_backward.yaml index 7553f5666ee07e..af93e8ac635e9c 100755 --- a/paddle/phi/api/yaml/legacy_backward.yaml +++ b/paddle/phi/api/yaml/legacy_backward.yaml @@ -408,8 +408,8 @@ composite : minimum_grad(x, y, out_grad, axis, x_grad, y_grad) - backward_op : mish_grad - forward : mish (Tensor x, float threshold) -> Tensor(out) - args : (Tensor x, Tensor out_grad, float threshold) + forward : mish (Tensor x, float lambda) -> Tensor(out) + args : (Tensor x, Tensor out_grad, float lambda) output : Tensor(x_grad) infer_meta : func : UnchangedInferMeta