From 1fcb9f058b5ed97257ef0ef68b34d9a7f94571d8 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 2 May 2024 12:18:08 +0000 Subject: [PATCH 1/2] floor-div-dev --- .../dialect/op_generator/decomp_interface_gen_op_list.py | 2 ++ paddle/fluid/primitive/composite/composite.h | 8 ++++++++ .../legacy_test/test_elementwise_floordiv_op.py | 2 ++ 3 files changed, 12 insertions(+) diff --git a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py index 5ae7c3152e0fb..2348eec77b712 100644 --- a/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py @@ -29,6 +29,7 @@ "elu", "embedding", "flatten", + "floor_divide", "full_like", "gelu", "hardswish", @@ -67,6 +68,7 @@ "elu", "embedding", "flatten", + "floor_divide", "full_like", "gelu", "hardswish", diff --git a/paddle/fluid/primitive/composite/composite.h b/paddle/fluid/primitive/composite/composite.h index 101354241e03b..7151127804712 100644 --- a/paddle/fluid/primitive/composite/composite.h +++ b/paddle/fluid/primitive/composite/composite.h @@ -729,6 +729,14 @@ Tensor full_like_decomp(const Tensor& x, } } +template +Tensor floor_divide_decomp(const Tensor& x, const Tensor& y) { + auto x_cast = cast(x, DataType::INT64); + auto y_cast = cast(y, DataType::INT64); + auto res = x_cast / y_cast; + return cast(res, x.dtype()); +} + template std::tuple dropout_decomp( const Tensor& x, diff --git a/test/deprecated/legacy_test/test_elementwise_floordiv_op.py b/test/deprecated/legacy_test/test_elementwise_floordiv_op.py index ccfab0b9adf56..b7e95504a8853 100644 --- a/test/deprecated/legacy_test/test_elementwise_floordiv_op.py +++ b/test/deprecated/legacy_test/test_elementwise_floordiv_op.py @@ -29,7 +29,9 @@ def init_kernel_type(self): def setUp(self): self.op_type = "elementwise_floordiv" + self.prim_op_type = "comp" self.python_api = paddle.floor_divide + self.public_python_api = paddle.floor_divide self.dtype = np.int32 self.axis = -1 self.init_dtype() From 9844195f7edf10d3fac3d1aa6072fd6621778442 Mon Sep 17 00:00:00 2001 From: Eddie-Wang1120 Date: Thu, 2 May 2024 12:20:37 +0000 Subject: [PATCH 2/2] update test --- test/deprecated/legacy_test/test_elementwise_floordiv_op.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/deprecated/legacy_test/test_elementwise_floordiv_op.py b/test/deprecated/legacy_test/test_elementwise_floordiv_op.py index b7e95504a8853..e49f5687b1c9e 100644 --- a/test/deprecated/legacy_test/test_elementwise_floordiv_op.py +++ b/test/deprecated/legacy_test/test_elementwise_floordiv_op.py @@ -47,7 +47,7 @@ def setUp(self): self.outputs = {'Out': self.out} def test_check_output(self): - self.check_output(check_pir=True) + self.check_output(check_pir=True, check_prim_pir=True) def init_input_output(self): self.x = np.random.uniform(0, 10000, [10, 10]).astype(self.dtype)