From b5789b8dbabfe92087cdae0ed34027d5ed4d07b3 Mon Sep 17 00:00:00 2001 From: wanghuancoder Date: Mon, 30 Oct 2023 14:46:04 +0800 Subject: [PATCH] [AutoParallel] Addn support AutoParallel (#58434) * phi add_n support disttensor --- paddle/phi/api/lib/api_custom_impl.cc | 76 +++++++++++++++- paddle/phi/api/lib/api_gen_utils.cc | 10 +++ paddle/phi/api/lib/api_gen_utils.h | 3 + paddle/phi/api/lib/data_transform.cc | 58 +++++++++++++ paddle/phi/api/lib/data_transform.h | 13 +++ paddle/phi/infermeta/spmd_rules/utils.h | 6 ++ .../semi_auto_parallel_for_add_n.py | 87 +++++++++++++++++++ .../test_semi_auto_parallel_basic.py | 10 +++ 8 files changed, 261 insertions(+), 2 deletions(-) create mode 100644 test/auto_parallel/semi_auto_parallel_for_add_n.py diff --git a/paddle/phi/api/lib/api_custom_impl.cc b/paddle/phi/api/lib/api_custom_impl.cc index 840f761482684a..acb45d058038e0 100644 --- a/paddle/phi/api/lib/api_custom_impl.cc +++ b/paddle/phi/api/lib/api_custom_impl.cc @@ -28,7 +28,10 @@ limitations under the License. */ #include "paddle/phi/infermeta/multiary.h" #include "paddle/phi/infermeta/nullary.h" #include "paddle/phi/infermeta/unary.h" - +#ifdef PADDLE_WITH_DISTRIBUTE +#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h" +#include "paddle/phi/infermeta/spmd_rules/rules.h" +#endif namespace paddle { namespace experimental { @@ -57,7 +60,8 @@ Tensor add_n_impl(const std::vector& x) { bool is_sr_kernel = true; for (auto& input : x) { - if (phi::DenseTensor::classof(input.impl().get())) { + if (phi::DenseTensor::classof(input.impl().get()) || + phi::distributed::DistTensor::classof(input.impl().get())) { is_sr_kernel = false; break; } @@ -98,6 +102,74 @@ Tensor add_n_impl(const std::vector& x) { (*kernel_fn)(*dev_ctx, input_x, kernel_out); } else { +#ifdef PADDLE_WITH_DISTRIBUTE + bool run_auto_parallel = AllInputsAreDistTensor(x); + bool rank_is_in_current_mesh = true; + if (run_auto_parallel) { + auto mesh = + std::static_pointer_cast(x[0].impl()) + ->dist_attr() + .process_mesh(); + rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh); + + std::vector input_x(x.size()); + for (size_t i = 0; i < input_x.size(); ++i) { + input_x[i] = x[i].impl().get(); + } + + auto meta_dist_input_x = MakeDistMetaTensor(input_x); + auto spmd_info = + phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x); + + auto dist_out = SetKernelDistOutput(&api_output); + auto dense_out = dist_out->unsafe_mutable_value(); + if (!rank_is_in_current_mesh) { + *dense_out = phi::DenseTensor( + std::make_shared( + nullptr, 0, phi::distributed::GetDefaultPlace()), + phi::DenseTensorMeta()); + } + + phi::MetaTensor meta_dist_out(dist_out); + auto x_meta_vec = MakeMetaTensor(input_x); + std::vector x_metas(x_meta_vec.size()); + for (size_t i = 0; i < x_meta_vec.size(); ++i) { + x_metas[i] = &x_meta_vec[i]; + } + phi::AddNInferMeta(x_metas, &meta_dist_out); + if (rank_is_in_current_mesh) { + auto dist_input_x = + ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first); + dist_input_x = PrepareDataForDistTensor( + dist_input_x, + GetKernelInputArgDef(kernel.InputAt(0), kernel_backend), + {}, + kernel_result.is_stride_kernel); + std::vector input_x(dist_input_x.size()); + for (size_t i = 0; i < dist_input_x.size(); ++i) { + input_x[i] = dist_input_x[i]->unsafe_mutable_value(); + } + + auto x_meta_vec = MakeMetaTensor(input_x); + std::vector x_metas(x_meta_vec.size()); + for (size_t i = 0; i < x_meta_vec.size(); ++i) { + x_metas[i] = &x_meta_vec[i]; + } + phi::MetaTensor meta_dense_out(dense_out); + phi::AddNInferMeta(x_metas, &meta_dense_out); + + using kernel_signature = + void (*)(const phi::DeviceContext&, + const std::vector&, + phi::DenseTensor*); + auto* kernel_fn = kernel.GetVariadicKernelFn(); + (*kernel_fn)(*dev_ctx, input_x, dense_out); + } + auto current_process_mesh = spmd_info.first[0].process_mesh(); + SetReplicatedDistAttrForOutput(dist_out, current_process_mesh); + return api_output; + } +#endif std::vector input_x(x.size()); std::vector> temp_dense_tensots; temp_dense_tensots.reserve(x.size()); diff --git a/paddle/phi/api/lib/api_gen_utils.cc b/paddle/phi/api/lib/api_gen_utils.cc index 3f62c52eaed1c3..c020d0332c5700 100644 --- a/paddle/phi/api/lib/api_gen_utils.cc +++ b/paddle/phi/api/lib/api_gen_utils.cc @@ -536,6 +536,16 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor( return phi::distributed::DistMetaTensor(tensor); } +std::vector MakeDistMetaTensor( + const std::vector& tensors) { + std::vector meta_tensors; + meta_tensors.reserve(tensors.size()); + for (const auto* t : tensors) { + meta_tensors.emplace_back(*t); + } + return meta_tensors; +} + phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) { if (out) { diff --git a/paddle/phi/api/lib/api_gen_utils.h b/paddle/phi/api/lib/api_gen_utils.h index a57d951ce738f5..5272c14209e0de 100644 --- a/paddle/phi/api/lib/api_gen_utils.h +++ b/paddle/phi/api/lib/api_gen_utils.h @@ -139,6 +139,9 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx, phi::distributed::DistMetaTensor MakeDistMetaTensor( const phi::TensorBase& tensor); +std::vector MakeDistMetaTensor( + const std::vector& tensors); + phi::distributed::DistTensor* SetKernelDistOutput( Tensor* out, const phi::distributed::TensorDistAttr& dist_attr = diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 1a8d92c2d90406..561d8ce379b9d7 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -674,6 +674,20 @@ ReshardApiInputToReplicatedKernelInput( return paddle::none; } +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs) { + std::vector> result; + result.reserve(tensors.size()); + for (size_t i = 0; i < tensors.size(); ++i) { + result.emplace_back(ReshardApiInputToReplicatedKernelInput( + dev_ctx, tensors[i], dist_attrs[i])); + } + return result; +} + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) { if (out_tensor->dist_attr().is_partial()) { @@ -825,6 +839,50 @@ PrepareDataForDistTensor(const std::vector& input, return out; } +std::vector> +PrepareDataForDistTensor( + const std::vector>& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel) { + std::vector> out; + for (auto& x : input) { + const auto& tensor_in = x; + if (tensor_in) { + phi::distributed::DistTensor* dist_tensor = + static_cast(tensor_in.get()); + const phi::DenseTensor& dense_tensor = dist_tensor->value(); + if (!transform_flag.NeedTransform() || !dense_tensor.initialized() || + (!NeedTransformPlace( + dense_tensor.place(), target_args_def.backend, transform_flag) && + !NeedTransformDataType( + dense_tensor.dtype(), target_args_def.dtype, transform_flag) && + !NeedTransformLayout(dense_tensor.layout(), + target_args_def.layout, + dense_tensor.place(), + transform_flag) && + !NeedTransform2Contiguous(is_stride_kernel, + dense_tensor.meta().is_contiguous()))) { + out.push_back( + std::static_pointer_cast(tensor_in)); + } else { + phi::DenseTensor trans_in_tensor = TransformData( + dense_tensor, target_args_def, transform_flag, is_stride_kernel); + // TODO(GhostScreaming): The global meta in DistTensor is not changed, + // but the local meta in DenseTensor maybe changed, such as layout + // change(NCHW->NHWC), so the new DistTensor's meta maybe not unified. + VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor"; + out.push_back(std::make_shared( + std::make_shared(trans_in_tensor), + dist_tensor->dist_attr())); + } + } else { + out.push_back(nullptr); + } + } + return out; +} + paddle::optional PrepareDataForDistTensor( const paddle::optional& input, const phi::TensorArgDef& target_args_def, diff --git a/paddle/phi/api/lib/data_transform.h b/paddle/phi/api/lib/data_transform.h index 712f568479d2e8..8df013860a5ab5 100644 --- a/paddle/phi/api/lib/data_transform.h +++ b/paddle/phi/api/lib/data_transform.h @@ -192,6 +192,12 @@ ReshardApiInputToReplicatedKernelInput( const paddle::optional& tensor, const phi::distributed::TensorDistAttr& dist_attr); +std::vector> +ReshardApiInputToReplicatedKernelInput( + phi::DeviceContext* dev_ctx, + const std::vector& tensors, + const std::vector& dist_attrs); + void ReshardOutputPartialAxisToReplicated( phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor); @@ -226,6 +232,13 @@ PrepareDataForDistTensor(const std::vector& input, const TransformFlag& transform_flag, bool is_stride_kernel); +std::vector> +PrepareDataForDistTensor( + const std::vector>& input, + const phi::TensorArgDef& target_args_def, + const TransformFlag& transform_flag, + bool is_stride_kernel); + paddle::optional PrepareDataForDistTensor( const paddle::optional& input, const phi::TensorArgDef& target_args_def, diff --git a/paddle/phi/infermeta/spmd_rules/utils.h b/paddle/phi/infermeta/spmd_rules/utils.h index cd16a95bceac74..cd140c68fc8ac8 100644 --- a/paddle/phi/infermeta/spmd_rules/utils.h +++ b/paddle/phi/infermeta/spmd_rules/utils.h @@ -112,6 +112,12 @@ struct VariadicSpmdRuleArgumentParser } } + void operator()(const std::vector& x) { + for (auto& t : x) { + inputs.emplace_back(&t); + } + } + // deal with outputs void operator()(DistMetaTensor* out) { outputs.emplace_back(out); } diff --git a/test/auto_parallel/semi_auto_parallel_for_add_n.py b/test/auto_parallel/semi_auto_parallel_for_add_n.py new file mode 100644 index 00000000000000..9d7786eeaaf080 --- /dev/null +++ b/test/auto_parallel/semi_auto_parallel_for_add_n.py @@ -0,0 +1,87 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os + +import numpy as np + +import paddle +import paddle.distributed as dist + + +class TestAddNApiForSemiAutoParallel: + def __init__(self): + self._dtype = os.getenv("dtype") + self._backend = os.getenv("backend") + self._seed = eval(os.getenv("seed")) + self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"]) + + def check_tensor_eq(self, a, b): + np1 = a.numpy() + np2 = b.numpy() + np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True) + + def test_body( + self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False + ): + paddle.seed(self._seed) + np.random.seed(self._seed) + + x_np = np.random.random(size=x_shape).astype(self._dtype) + y_np = np.random.random(size=y_shape).astype(self._dtype) + x = paddle.to_tensor(x_np) + y = paddle.to_tensor(y_np) + x.stop_gradient = False + y.stop_gradient = False + + x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs) + y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs) + + dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr) + dist_y = dist.shard_tensor(y_np, dist_attr=y_dist_attr) + dist_x.stop_gradient = False + dist_y.stop_gradient = False + + out = paddle.add_n([x, y]) + dist_out = paddle.add_n([dist_x, dist_y]) + self.check_tensor_eq(out, dist_out) + + out.backward() + dist_out.backward() + self.check_tensor_eq(x.grad, dist_x.grad) + self.check_tensor_eq(y.grad, dist_y.grad) + + return dist_out, dist_x.grad, dist_y.grad + + def test_add_n(self): + self.test_body( + x_shape=[64, 32], + y_shape=[64, 32], + x_specs=[None, None], + y_specs=[None, None], + ) + + def run_test_case(self): + if self._backend == "cpu": + paddle.set_device("cpu") + elif self._backend == "gpu": + paddle.set_device("gpu:" + str(dist.get_rank())) + else: + raise ValueError("Only support cpu or gpu backend.") + + self.test_add_n() + + +if __name__ == '__main__': + TestAddNApiForSemiAutoParallel().run_test_case() diff --git a/test/auto_parallel/test_semi_auto_parallel_basic.py b/test/auto_parallel/test_semi_auto_parallel_basic.py index 047b769f12f75f..56cabcb318f3dd 100644 --- a/test/auto_parallel/test_semi_auto_parallel_basic.py +++ b/test/auto_parallel/test_semi_auto_parallel_basic.py @@ -66,6 +66,16 @@ def test_several_replicated_spmd_api(self): user_defined_envs=envs, ) + def test_add_n_api(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + self.run_test_case( + "semi_auto_parallel_for_add_n.py", + user_defined_envs=envs, + ) + if __name__ == "__main__": unittest.main()