diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc index af770d94ef1f0a..92baa51f261e3e 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.cc @@ -16,6 +16,7 @@ paddle::dialect::IfOp, paddle::dialect::WhileOp #else #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" +#include "paddle/fluid/pir/dialect/operator/ir/api_builder.h" #include "paddle/phi/core/enforce.h" #include "paddle/pir/core/builder.h" @@ -182,6 +183,45 @@ void IfOp::VerifyRegion() { } } +std::vector> IfOp::Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients) { + PADDLE_ENFORCE_EQ( + inputs_.size() == 1u && inputs_[0].size() >= 1u, + 1u, + phi::errors::InvalidArgument( + "if op's inputs size should be 1, but now is %d.", inputs_.size())); + + VLOG(6) << "Prepare inputs for if_grad"; + auto cond_val = inputs_[0][0]; + VLOG(6) << "Prepare attributes for if_grad"; + + VLOG(6) << "Prepare outputs for if_grad"; + + std::vector output_types; + for (size_t i = 0; i < inputs_[0].size(); ++i) { + if (!stop_gradients[0][i]) { + output_types.push_back(inputs_[0][i].type()); + } + } + + auto if_grad = ApiBuilder::Instance().GetBuilder()->Build( + cond_val, std::move(output_types)); + + std::vector> res{ + std::vector(inputs_[0].size())}; + + for (size_t i = 0, j = 0; i < inputs_[0].size(); ++i) { + if (!stop_gradients[0][i]) { + res[0][i] = if_grad->result(j++); + } + } + return res; +} + void WhileOp::Build(pir::Builder &builder, // NOLINT pir::OperationArgument &argument, // NOLINT pir::Value cond, diff --git a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h index 848cef6410a3aa..9e80987bab0d73 100644 --- a/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h +++ b/paddle/fluid/pir/dialect/operator/ir/control_flow_op.h @@ -15,12 +15,13 @@ #pragma once #include +#include "paddle/fluid/pir/dialect/operator/interface/vjp.h" #include "paddle/pir/core/op_base.h" namespace paddle { namespace dialect { -class IfOp : public pir::Op { +class IfOp : public pir::Op { public: using Op::Op; static const char *name() { return "pd_op.if"; } @@ -45,6 +46,13 @@ class IfOp : public pir::Op { void Print(pir::IrPrinter &printer); // NOLINT void VerifySig(); void VerifyRegion(); + + static std::vector> Vjp( + pir::Operation *op, + const std::vector> &inputs_, + const std::vector> &outputs, + const std::vector> &out_grads, + const std::vector> &stop_gradients); }; /// diff --git a/paddle/fluid/pybind/control_flow_api.cc b/paddle/fluid/pybind/control_flow_api.cc index 113a9cf2fc68d0..1a7cfab5a6bf8a 100644 --- a/paddle/fluid/pybind/control_flow_api.cc +++ b/paddle/fluid/pybind/control_flow_api.cc @@ -73,6 +73,7 @@ void BindIfOp(py::module* m) { if_op.def("true_block", &PyIfOp::true_block, return_value_policy::reference) .def("false_block", &PyIfOp::false_block, return_value_policy::reference) .def("update_output", &PyIfOp::UpdateOutput) + .def("as_operation", &PyIfOp::operation, return_value_policy::reference) .def("results", [](PyIfOp& self) -> py::list { py::list op_list; for (uint32_t i = 0; i < self->num_results(); i++) { diff --git a/test/ir/pir/test_if_api.py b/test/ir/pir/test_if_api.py index 8583e5385c2ce0..eff81645e9301a 100644 --- a/test/ir/pir/test_if_api.py +++ b/test/ir/pir/test_if_api.py @@ -15,6 +15,7 @@ import unittest import paddle +from paddle.base.core import call_vjp, has_vjp from paddle.base.libpaddle.pir import ( build_pipe_for_block, get_used_external_value, @@ -36,34 +37,44 @@ def false_func(): class TestBuildModuleWithIfOp(unittest.TestCase): - def test_if_with_single_output(self): + def construct_program_with_if(self): main_program = paddle.static.Program() startup_program = paddle.static.Program() with paddle.static.program_guard(main_program, startup_program): x = paddle.static.data(name="x", shape=[6, 1], dtype="float32") y = paddle.static.data(name="y", shape=[6, 1], dtype="float32") - out = paddle.static.nn.cond(x < y, lambda: x + y, lambda: x - y) - if_op = out[0].get_defining_op() + paddle.static.nn.cond(x < y, lambda: x + y, lambda: x - y) + return main_program + + def test_if_with_single_output(self): + main_program = self.construct_program_with_if() + if_op = main_program.global_block().ops[-1] self.assertEqual(if_op.name(), "pd_op.if") - self.assertEqual(len(out), 1) + self.assertEqual(len(if_op.results()), 1) value_list = get_used_external_value(if_op) - print(value_list) + self.assertEqual(len(value_list), 3) + self.assertEqual(value_list[0], if_op.operand_source(0)) def test_if_with_multiple_output(self): - main_program = paddle.static.Program() - startup_program = paddle.static.Program() - with paddle.static.program_guard(main_program, startup_program): - x = paddle.static.data(name="x", shape=[6, 1], dtype="float32") - y = paddle.static.data(name="y", shape=[6, 1], dtype="float32") - pred = paddle.less_than(x=x, y=y, name=None) - out = paddle.static.nn.cond(pred, true_func, false_func) - self.assertEqual(out[0].get_defining_op().name(), "pd_op.if") + main_program = self.construct_program_with_if() + cond_value = main_program.global_block().ops[-1].operand_source(0) + with paddle.pir.core.program_guard(main_program): + paddle.static.nn.cond(cond_value, true_func, false_func) + last_op = main_program.global_block().ops[-1] + out = last_op.results() + self.assertEqual(last_op.name(), "pd_op.if") self.assertEqual(len(out), 2) - if_op = out[0].get_defining_op().as_if_op() + + # check Operaion::as_if_op interface + if_op = last_op.as_if_op() true_block = if_op.true_block() self.assertEqual(len(true_block), 3) + + # check build_pipe_for_block interface build_pipe_for_block(true_block) self.assertEqual(len(true_block), 4) + + # check Operaion::blocks interface block_list = [] for block in out[0].get_defining_op().blocks(): block_list.append(block) @@ -71,6 +82,26 @@ def test_if_with_multiple_output(self): self.assertEqual(block_list[0], true_block) self.assertEqual(block_list[1], if_op.false_block()) + def test_if_op_vjp_interface(self): + main_program = self.construct_program_with_if() + if_op = main_program.global_block().ops[-1] + self.assertEqual(if_op.name(), "pd_op.if") + with paddle.pir.core.program_guard(main_program): + out_grad = paddle.full(shape=[6, 1], dtype='float32', fill_value=3) + if_input = [get_used_external_value(if_op)] + if_input_stop_graditents = [[True, False, False]] + if_output = [if_op.results()] + if_output_grad = [[out_grad]] + self.assertEqual(has_vjp(if_op), True) + grad_outs = call_vjp( + if_op, + if_input, + if_output, + if_output_grad, + if_input_stop_graditents, + ) + self.assertEqual(grad_outs[0][0], None) + if __name__ == "__main__": unittest.main()