From 4c01e34aea031cb79fc3602a42e0eb9e292725bb Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Wed, 3 Aug 2022 16:18:26 +0800 Subject: [PATCH 1/6] add paddle vsplit api --- python/paddle/__init__.py | 1 + .../fluid/tests/unittests/test_splits_api.py | 120 ++++++++++++++++++ python/paddle/tensor/__init__.py | 1 + python/paddle/tensor/manipulation.py | 42 ++++++ 4 files changed, 164 insertions(+) create mode 100644 python/paddle/fluid/tests/unittests/test_splits_api.py diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index b39f4161eee978..56267ecd594be1 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -164,6 +164,7 @@ from .tensor.manipulation import slice # noqa: F401 from .tensor.manipulation import crop # noqa: F401 from .tensor.manipulation import split # noqa: F401 +from .tensor.manipulation import vsplit # noqa: F401 from .tensor.manipulation import squeeze # noqa: F401 from .tensor.manipulation import squeeze_ # noqa: F401 from .tensor.manipulation import stack # noqa: F401 diff --git a/python/paddle/fluid/tests/unittests/test_splits_api.py b/python/paddle/fluid/tests/unittests/test_splits_api.py new file mode 100644 index 00000000000000..b87fd16e7b0214 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_splits_api.py @@ -0,0 +1,120 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core + + +def func_ref(func, x, num_or_sections): + # Do not support -1 + if isinstance(num_or_sections, int): + indices_or_sections = num_or_sections + else: + indices_or_sections = np.cumsum(num_or_sections)[:-1] + return func(x, indices_or_sections) + + +test_list = [ + (paddle.vsplit, np.vsplit), +] + + +class TestSplitsAPI(unittest.TestCase): + + def setUp(self): + self.shape = [4, 5, 2] + self.num_or_sections = 2 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_api(self): + paddle.enable_static() + for func, func_type in test_list: + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', self.x_np.shape, self.x_np.dtype) + out = func(x, self.num_or_sections) + exe = paddle.static.Executor(self.place) + res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) + out_ref = func_ref(func_type, self.x_np, self.num_or_sections) + for n, p in zip(out_ref, res): + self.assertTrue(np.allclose(n, p)) + + def test_dygraph_api(self): + paddle.disable_static(self.place) + x = paddle.to_tensor(self.x_np) + for func, func_type in test_list: + out = func(x, self.num_or_sections) + out_ref = func_ref(func_type, self.x_np, self.num_or_sections) + for n, p in zip(out_ref, out): + self.assertTrue(np.allclose(n, p.numpy())) + paddle.enable_static() + + +class TestSplitsSections(TestSplitsAPI): + + def setUp(self): + self.shape = [6, 2, 4] + self.num_or_sections = [2, 1, 3] + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsFloat32(TestSplitsAPI): + + def setUp(self): + self.shape = [2, 3, 4] + self.num_or_sections = 2 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float32') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsCPU(TestSplitsAPI): + + def setUp(self): + self.shape = [8, 2, 3, 5] + self.num_or_sections = (2, 3, 3) + self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') + self.place = paddle.CPUPlace() + + +class TestSplitsError(unittest.TestCase): + + def setUp(self): + self.num_or_sections = 1 + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + def test_static_error(self): + paddle.enable_static() + for func, _ in test_list: + with paddle.static.program_guard(paddle.static.Program()): + x = paddle.fluid.data('X', [5], 'float32') + self.assertRaises(ValueError, func, x, self.num_or_sections) + + def test_dygraph_error(self): + paddle.disable_static(self.place) + for func, _ in test_list: + x_np = np.random.randn(2) + x = paddle.to_tensor(x_np, dtype='float64') + self.assertRaises(ValueError, func, x, self.num_or_sections) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 58dfa26cfe377e..77ec8ef9451b70 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -106,6 +106,7 @@ from .manipulation import shard_index # noqa: F401 from .manipulation import slice # noqa: F401 from .manipulation import split # noqa: F401 +from .manipulation import vsplit # noqa: F401 from .manipulation import squeeze # noqa: F401 from .manipulation import squeeze_ # noqa: F401 from .manipulation import stack # noqa: F401 diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 42e3bc9039f08a..f49099e8c2e96e 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1926,6 +1926,48 @@ def _get_SectionsTensorList(one_list): return outs +def vsplit(x, num_or_sections, name=None): + """ + Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``. + + Args: + x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, float16, float32, float64, int32 or int64. + num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` + indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. + If ``num_or_sections`` is a list or tuple, the length of it indicates the number of + sub-Tensors and the elements in it indicate the sizes of sub-Tensors' dimension orderly. + The length of the list must not be larger than the ``x`` 's size of axis 0. + name (str, optional): The default value is None. Normally there is no need for user to set this property. + For more information, please refer to :ref:`api_guide_Name` . + Returns: + list(Tensor): The list of segmented Tensors. + + Example: + .. code-block:: python + + import paddle + + # x is a Tensor of shape [8, 6, 7] + x = paddle.rand([8, 6, 7]) + out0, out1, out2 = paddle.vsplit(x, num_or_sections=2) + print(out0.shape) # [4, 6, 7] + print(out1.shape) # [4, 6, 7] + out0, out1, out2 = paddle.vsplit(x, num_or_sections=[1, 3, 4]) + print(out0.shape) # [1, 6, 7] + print(out1.shape) # [3, 6, 7] + print(out2.shape) # [4, 6, 7] + out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, -1]) + print(out0.shape) # [2, 6, 7] + print(out1.shape) # [3, 6, 7] + print(out2.shape) # [3, 6, 7] + """ + if x.ndim < 2: + raise ValueError( + "The input tensor's dimension must be greater than 1, but got {}". + format(x.ndim)) + return split(x, num_or_sections, axis=0, name=name) + + def squeeze(x, axis=None, name=None): """ Squeeze the dimension(s) of size 1 of input tensor x's shape. From f5e762b6ae9b5bfc3036b144a9e5e539de5edd4d Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 15 Aug 2022 15:01:14 +0800 Subject: [PATCH 2/6] update unittest and fix a typo --- .../fluid/tests/unittests/test_splits_api.py | 40 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 2 +- 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_splits_api.py b/python/paddle/fluid/tests/unittests/test_splits_api.py index b87fd16e7b0214..4d98ac14bf0ad0 100644 --- a/python/paddle/fluid/tests/unittests/test_splits_api.py +++ b/python/paddle/fluid/tests/unittests/test_splits_api.py @@ -20,6 +20,7 @@ def func_ref(func, x, num_or_sections): + # Convert the num_or_sections in paddle to indices_or_sections in numpy # Do not support -1 if isinstance(num_or_sections, int): indices_or_sections = num_or_sections @@ -28,6 +29,7 @@ def func_ref(func, x, num_or_sections): return func(x, indices_or_sections) +# TODO: add other split API, such as dsplit、hsplit test_list = [ (paddle.vsplit, np.vsplit), ] @@ -66,6 +68,9 @@ def test_dygraph_api(self): class TestSplitsSections(TestSplitsAPI): + """ + Test num_or_sections which is a list and date type is float64. + """ def setUp(self): self.shape = [6, 2, 4] @@ -76,6 +81,9 @@ def setUp(self): class TestSplitsFloat32(TestSplitsAPI): + """ + Test num_or_sections which is an integer and data type is float32. + """ def setUp(self): self.shape = [2, 3, 4] @@ -85,7 +93,36 @@ def setUp(self): else paddle.CPUPlace() +class TestSplitsInt32(TestSplitsAPI): + """ + Test data type int32. + """ + + def setUp(self): + self.shape = [5, 1, 2] + self.num_or_sections = 1 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('int32') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + +class TestSplitsInt64(TestSplitsAPI): + """ + Test data type int64. + """ + + def setUp(self): + self.shape = [4, 3, 2] + self.num_or_sections = 2 + self.x_np = np.random.uniform(-1, 1, self.shape).astype('int64') + self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ + else paddle.CPUPlace() + + class TestSplitsCPU(TestSplitsAPI): + """ + Test cpu place and num_or_sections which is a tuple. + """ def setUp(self): self.shape = [8, 2, 3, 5] @@ -95,6 +132,9 @@ def setUp(self): class TestSplitsError(unittest.TestCase): + """ + Test the situation that input shape less than 2. + """ def setUp(self): self.num_or_sections = 1 diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index f49099e8c2e96e..fa0172cf3d2e18 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1956,7 +1956,7 @@ def vsplit(x, num_or_sections, name=None): print(out0.shape) # [1, 6, 7] print(out1.shape) # [3, 6, 7] print(out2.shape) # [4, 6, 7] - out0, out1, out2 = paddle.split(x, num_or_sections=[2, 3, -1]) + out0, out1, out2 = paddle.vplit(x, num_or_sections=[2, 3, -1]) print(out0.shape) # [2, 6, 7] print(out1.shape) # [3, 6, 7] print(out2.shape) # [3, 6, 7] From cab2c2888983f83210f127d7211202f489c847f7 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 15 Aug 2022 15:56:34 +0800 Subject: [PATCH 3/6] update --- python/paddle/fluid/tests/unittests/test_splits_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/fluid/tests/unittests/test_splits_api.py b/python/paddle/fluid/tests/unittests/test_splits_api.py index 4d98ac14bf0ad0..6b45770f223196 100644 --- a/python/paddle/fluid/tests/unittests/test_splits_api.py +++ b/python/paddle/fluid/tests/unittests/test_splits_api.py @@ -100,7 +100,7 @@ class TestSplitsInt32(TestSplitsAPI): def setUp(self): self.shape = [5, 1, 2] - self.num_or_sections = 1 + self.num_or_sections = 5 self.x_np = np.random.uniform(-1, 1, self.shape).astype('int32') self.place = paddle.CUDAPlace(0) if core.is_compiled_with_cuda() \ else paddle.CPUPlace() From 9727511870f330b832a88e4a0dee427dc22c4cbd Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Fri, 19 Aug 2022 13:09:33 +0800 Subject: [PATCH 4/6] add vsplit to __all__ --- python/paddle/__init__.py | 1 + python/paddle/tensor/__init__.py | 1 + 2 files changed, 2 insertions(+) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 56267ecd594be1..708ef4514fb83c 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -456,6 +456,7 @@ 'searchsorted', 'bucketize', 'split', + 'vsplit' 'logical_and', 'full_like', 'less_than', diff --git a/python/paddle/tensor/__init__.py b/python/paddle/tensor/__init__.py index 77ec8ef9451b70..ba7dd5d0cec529 100755 --- a/python/paddle/tensor/__init__.py +++ b/python/paddle/tensor/__init__.py @@ -429,6 +429,7 @@ 'shard_index', 'slice', 'split', + 'vsplit', 'chunk', 'tensordot', 'squeeze', From 091a913b68ab4b98f7b2db9856cbdf7af13774f9 Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Mon, 22 Aug 2022 11:36:58 +0800 Subject: [PATCH 5/6] update unit test and description of x --- python/paddle/__init__.py | 2 +- .../fluid/tests/unittests/test_splits_api.py | 22 +++++++++++++------ python/paddle/tensor/manipulation.py | 4 ++-- 3 files changed, 18 insertions(+), 10 deletions(-) diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 708ef4514fb83c..8f75eb97057491 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -456,7 +456,7 @@ 'searchsorted', 'bucketize', 'split', - 'vsplit' + 'vsplit', 'logical_and', 'full_like', 'less_than', diff --git a/python/paddle/fluid/tests/unittests/test_splits_api.py b/python/paddle/fluid/tests/unittests/test_splits_api.py index 6b45770f223196..4b6254e266bc1f 100644 --- a/python/paddle/fluid/tests/unittests/test_splits_api.py +++ b/python/paddle/fluid/tests/unittests/test_splits_api.py @@ -38,6 +38,11 @@ def func_ref(func, x, num_or_sections): class TestSplitsAPI(unittest.TestCase): def setUp(self): + self.rtol = 1e-5 + self.atol = 1e-8 + self.set_input() + + def set_input(self): self.shape = [4, 5, 2] self.num_or_sections = 2 self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') @@ -54,7 +59,7 @@ def test_static_api(self): res = exe.run(feed={'X': self.x_np}, fetch_list=[out]) out_ref = func_ref(func_type, self.x_np, self.num_or_sections) for n, p in zip(out_ref, res): - self.assertTrue(np.allclose(n, p)) + np.testing.assert_allclose(n, p, rtol=self.rtol, atol=self.atol) def test_dygraph_api(self): paddle.disable_static(self.place) @@ -63,7 +68,10 @@ def test_dygraph_api(self): out = func(x, self.num_or_sections) out_ref = func_ref(func_type, self.x_np, self.num_or_sections) for n, p in zip(out_ref, out): - self.assertTrue(np.allclose(n, p.numpy())) + np.testing.assert_allclose(n, + p.numpy(), + rtol=self.rtol, + atol=self.atol) paddle.enable_static() @@ -72,7 +80,7 @@ class TestSplitsSections(TestSplitsAPI): Test num_or_sections which is a list and date type is float64. """ - def setUp(self): + def set_input(self): self.shape = [6, 2, 4] self.num_or_sections = [2, 1, 3] self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') @@ -85,7 +93,7 @@ class TestSplitsFloat32(TestSplitsAPI): Test num_or_sections which is an integer and data type is float32. """ - def setUp(self): + def set_input(self): self.shape = [2, 3, 4] self.num_or_sections = 2 self.x_np = np.random.uniform(-1, 1, self.shape).astype('float32') @@ -98,7 +106,7 @@ class TestSplitsInt32(TestSplitsAPI): Test data type int32. """ - def setUp(self): + def set_input(self): self.shape = [5, 1, 2] self.num_or_sections = 5 self.x_np = np.random.uniform(-1, 1, self.shape).astype('int32') @@ -111,7 +119,7 @@ class TestSplitsInt64(TestSplitsAPI): Test data type int64. """ - def setUp(self): + def set_input(self): self.shape = [4, 3, 2] self.num_or_sections = 2 self.x_np = np.random.uniform(-1, 1, self.shape).astype('int64') @@ -124,7 +132,7 @@ class TestSplitsCPU(TestSplitsAPI): Test cpu place and num_or_sections which is a tuple. """ - def setUp(self): + def set_input(self): self.shape = [8, 2, 3, 5] self.num_or_sections = (2, 3, 3) self.x_np = np.random.uniform(-1, 1, self.shape).astype('float64') diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index fa0172cf3d2e18..57a83408fc6ccb 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1931,7 +1931,7 @@ def vsplit(x, num_or_sections, name=None): Split the input tensor into multiple sub-Tensors along the vertical axis, which is equivalent to ``paddle.split`` with ``axis=0``. Args: - x (Tensor): A Tensor whose dimension must be greater than 2. The data type is bool, float16, float32, float64, int32 or int64. + x (Tensor): A Tensor whose dimension must be greater than 1. The data type is bool, float16, float32, float64, uint8, int8, int32 or int64. num_or_sections (int|list|tuple): If ``num_or_sections`` is an int, then ``num_or_sections`` indicates the number of equal sized sub-Tensors that the ``x`` will be divided into. If ``num_or_sections`` is a list or tuple, the length of it indicates the number of @@ -1940,7 +1940,7 @@ def vsplit(x, num_or_sections, name=None): name (str, optional): The default value is None. Normally there is no need for user to set this property. For more information, please refer to :ref:`api_guide_Name` . Returns: - list(Tensor): The list of segmented Tensors. + list[Tensor], The list of segmented Tensors. Example: .. code-block:: python From 02ced4bf84c7fce473aabd989d46bb162ac69b0e Mon Sep 17 00:00:00 2001 From: Asthestarsfalll <1186454801@qq.com> Date: Tue, 6 Sep 2022 11:38:01 +0800 Subject: [PATCH 6/6] fix typo --- python/paddle/tensor/manipulation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 57a83408fc6ccb..9f4e9803e4cd1a 100755 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -1956,7 +1956,7 @@ def vsplit(x, num_or_sections, name=None): print(out0.shape) # [1, 6, 7] print(out1.shape) # [3, 6, 7] print(out2.shape) # [4, 6, 7] - out0, out1, out2 = paddle.vplit(x, num_or_sections=[2, 3, -1]) + out0, out1, out2 = paddle.vsplit(x, num_or_sections=[2, 3, -1]) print(out0.shape) # [2, 6, 7] print(out1.shape) # [3, 6, 7] print(out2.shape) # [3, 6, 7]