Skip to content

Commit

Permalink
[AutoParallel] Addn support AutoParallel (#58434)
Browse files Browse the repository at this point in the history
* phi add_n support disttensor
  • Loading branch information
wanghuancoder authored Oct 30, 2023
1 parent c633e52 commit b5789b8
Show file tree
Hide file tree
Showing 8 changed files with 261 additions and 2 deletions.
76 changes: 74 additions & 2 deletions paddle/phi/api/lib/api_custom_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ limitations under the License. */
#include "paddle/phi/infermeta/multiary.h"
#include "paddle/phi/infermeta/nullary.h"
#include "paddle/phi/infermeta/unary.h"

#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/reshard_utils.h"
#include "paddle/phi/infermeta/spmd_rules/rules.h"
#endif
namespace paddle {
namespace experimental {

Expand Down Expand Up @@ -57,7 +60,8 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {

bool is_sr_kernel = true;
for (auto& input : x) {
if (phi::DenseTensor::classof(input.impl().get())) {
if (phi::DenseTensor::classof(input.impl().get()) ||
phi::distributed::DistTensor::classof(input.impl().get())) {
is_sr_kernel = false;
break;
}
Expand Down Expand Up @@ -98,6 +102,74 @@ Tensor add_n_impl(const std::vector<Tensor>& x) {

(*kernel_fn)(*dev_ctx, input_x, kernel_out);
} else {
#ifdef PADDLE_WITH_DISTRIBUTE
bool run_auto_parallel = AllInputsAreDistTensor(x);
bool rank_is_in_current_mesh = true;
if (run_auto_parallel) {
auto mesh =
std::static_pointer_cast<phi::distributed::DistTensor>(x[0].impl())
->dist_attr()
.process_mesh();
rank_is_in_current_mesh = phi::distributed::IsCurRankInMesh(mesh);

std::vector<const phi::TensorBase*> input_x(x.size());
for (size_t i = 0; i < input_x.size(); ++i) {
input_x[i] = x[i].impl().get();
}

auto meta_dist_input_x = MakeDistMetaTensor(input_x);
auto spmd_info =
phi::distributed::VariadicReplicatedInferSpmd(meta_dist_input_x);

auto dist_out = SetKernelDistOutput(&api_output);
auto dense_out = dist_out->unsafe_mutable_value();
if (!rank_is_in_current_mesh) {
*dense_out = phi::DenseTensor(
std::make_shared<phi::Allocation>(
nullptr, 0, phi::distributed::GetDefaultPlace()),
phi::DenseTensorMeta());
}

phi::MetaTensor meta_dist_out(dist_out);
auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
phi::AddNInferMeta(x_metas, &meta_dist_out);
if (rank_is_in_current_mesh) {
auto dist_input_x =
ReshardApiInputToReplicatedKernelInput(dev_ctx, x, spmd_info.first);
dist_input_x = PrepareDataForDistTensor(
dist_input_x,
GetKernelInputArgDef(kernel.InputAt(0), kernel_backend),
{},
kernel_result.is_stride_kernel);
std::vector<const phi::TensorBase*> input_x(dist_input_x.size());
for (size_t i = 0; i < dist_input_x.size(); ++i) {
input_x[i] = dist_input_x[i]->unsafe_mutable_value();
}

auto x_meta_vec = MakeMetaTensor(input_x);
std::vector<const phi::MetaTensor*> x_metas(x_meta_vec.size());
for (size_t i = 0; i < x_meta_vec.size(); ++i) {
x_metas[i] = &x_meta_vec[i];
}
phi::MetaTensor meta_dense_out(dense_out);
phi::AddNInferMeta(x_metas, &meta_dense_out);

using kernel_signature =
void (*)(const phi::DeviceContext&,
const std::vector<const phi::TensorBase*>&,
phi::DenseTensor*);
auto* kernel_fn = kernel.GetVariadicKernelFn<kernel_signature>();
(*kernel_fn)(*dev_ctx, input_x, dense_out);
}
auto current_process_mesh = spmd_info.first[0].process_mesh();
SetReplicatedDistAttrForOutput(dist_out, current_process_mesh);
return api_output;
}
#endif
std::vector<const phi::TensorBase*> input_x(x.size());
std::vector<std::shared_ptr<phi::DenseTensor>> temp_dense_tensots;
temp_dense_tensots.reserve(x.size());
Expand Down
10 changes: 10 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -536,6 +536,16 @@ phi::distributed::DistMetaTensor MakeDistMetaTensor(
return phi::distributed::DistMetaTensor(tensor);
}

std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
const std::vector<const phi::TensorBase*>& tensors) {
std::vector<phi::distributed::DistMetaTensor> meta_tensors;
meta_tensors.reserve(tensors.size());
for (const auto* t : tensors) {
meta_tensors.emplace_back(*t);
}
return meta_tensors;
}

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out, const phi::distributed::TensorDistAttr& dist_attr) {
if (out) {
Expand Down
3 changes: 3 additions & 0 deletions paddle/phi/api/lib/api_gen_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,9 @@ void TransStrideLegacy(phi::DeviceContext* dev_ctx,
phi::distributed::DistMetaTensor MakeDistMetaTensor(
const phi::TensorBase& tensor);

std::vector<phi::distributed::DistMetaTensor> MakeDistMetaTensor(
const std::vector<const phi::TensorBase*>& tensors);

phi::distributed::DistTensor* SetKernelDistOutput(
Tensor* out,
const phi::distributed::TensorDistAttr& dist_attr =
Expand Down
58 changes: 58 additions & 0 deletions paddle/phi/api/lib/data_transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -674,6 +674,20 @@ ReshardApiInputToReplicatedKernelInput(
return paddle::none;
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> result;
result.reserve(tensors.size());
for (size_t i = 0; i < tensors.size(); ++i) {
result.emplace_back(ReshardApiInputToReplicatedKernelInput(
dev_ctx, tensors[i], dist_attrs[i]));
}
return result;
}

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor) {
if (out_tensor->dist_attr().is_partial()) {
Expand Down Expand Up @@ -825,6 +839,50 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
return out;
}

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(
const std::vector<std::shared_ptr<phi::distributed::DistTensor>>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel) {
std::vector<std::shared_ptr<phi::distributed::DistTensor>> out;
for (auto& x : input) {
const auto& tensor_in = x;
if (tensor_in) {
phi::distributed::DistTensor* dist_tensor =
static_cast<phi::distributed::DistTensor*>(tensor_in.get());
const phi::DenseTensor& dense_tensor = dist_tensor->value();
if (!transform_flag.NeedTransform() || !dense_tensor.initialized() ||
(!NeedTransformPlace(
dense_tensor.place(), target_args_def.backend, transform_flag) &&
!NeedTransformDataType(
dense_tensor.dtype(), target_args_def.dtype, transform_flag) &&
!NeedTransformLayout(dense_tensor.layout(),
target_args_def.layout,
dense_tensor.place(),
transform_flag) &&
!NeedTransform2Contiguous(is_stride_kernel,
dense_tensor.meta().is_contiguous()))) {
out.push_back(
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_in));
} else {
phi::DenseTensor trans_in_tensor = TransformData(
dense_tensor, target_args_def, transform_flag, is_stride_kernel);
// TODO(GhostScreaming): The global meta in DistTensor is not changed,
// but the local meta in DenseTensor maybe changed, such as layout
// change(NCHW->NHWC), so the new DistTensor's meta maybe not unified.
VLOG(6) << "PrepareDataForDistTensor return transformed dist tensor";
out.push_back(std::make_shared<phi::distributed::DistTensor>(
std::make_shared<phi::DenseTensor>(trans_in_tensor),
dist_tensor->dist_attr()));
}
} else {
out.push_back(nullptr);
}
}
return out;
}

paddle::optional<phi::distributed::DistTensor> PrepareDataForDistTensor(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
Expand Down
13 changes: 13 additions & 0 deletions paddle/phi/api/lib/data_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ ReshardApiInputToReplicatedKernelInput(
const paddle::optional<Tensor>& tensor,
const phi::distributed::TensorDistAttr& dist_attr);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
ReshardApiInputToReplicatedKernelInput(
phi::DeviceContext* dev_ctx,
const std::vector<Tensor>& tensors,
const std::vector<phi::distributed::TensorDistAttr>& dist_attrs);

void ReshardOutputPartialAxisToReplicated(
phi::DeviceContext* dev_ctx, phi::distributed::DistTensor* out_tensor);

Expand Down Expand Up @@ -226,6 +232,13 @@ PrepareDataForDistTensor(const std::vector<Tensor>& input,
const TransformFlag& transform_flag,
bool is_stride_kernel);

std::vector<std::shared_ptr<phi::distributed::DistTensor>>
PrepareDataForDistTensor(
const std::vector<std::shared_ptr<phi::distributed::DistTensor>>& input,
const phi::TensorArgDef& target_args_def,
const TransformFlag& transform_flag,
bool is_stride_kernel);

paddle::optional<phi::distributed::DistTensor> PrepareDataForDistTensor(
const paddle::optional<Tensor>& input,
const phi::TensorArgDef& target_args_def,
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/spmd_rules/utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,12 @@ struct VariadicSpmdRuleArgumentParser
}
}

void operator()(const std::vector<DistMetaTensor>& x) {
for (auto& t : x) {
inputs.emplace_back(&t);
}
}

// deal with outputs
void operator()(DistMetaTensor* out) { outputs.emplace_back(out); }

Expand Down
87 changes: 87 additions & 0 deletions test/auto_parallel/semi_auto_parallel_for_add_n.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import numpy as np

import paddle
import paddle.distributed as dist


class TestAddNApiForSemiAutoParallel:
def __init__(self):
self._dtype = os.getenv("dtype")
self._backend = os.getenv("backend")
self._seed = eval(os.getenv("seed"))
self._mesh = dist.ProcessMesh([0, 1], dim_names=["x"])

def check_tensor_eq(self, a, b):
np1 = a.numpy()
np2 = b.numpy()
np.testing.assert_allclose(np1, np2, rtol=1e-05, verbose=True)

def test_body(
self, x_shape, y_shape, x_specs, y_specs, trans_x=False, trans_y=False
):
paddle.seed(self._seed)
np.random.seed(self._seed)

x_np = np.random.random(size=x_shape).astype(self._dtype)
y_np = np.random.random(size=y_shape).astype(self._dtype)
x = paddle.to_tensor(x_np)
y = paddle.to_tensor(y_np)
x.stop_gradient = False
y.stop_gradient = False

x_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=x_specs)
y_dist_attr = dist.DistAttr(mesh=self._mesh, sharding_specs=y_specs)

dist_x = dist.shard_tensor(x_np, dist_attr=x_dist_attr)
dist_y = dist.shard_tensor(y_np, dist_attr=y_dist_attr)
dist_x.stop_gradient = False
dist_y.stop_gradient = False

out = paddle.add_n([x, y])
dist_out = paddle.add_n([dist_x, dist_y])
self.check_tensor_eq(out, dist_out)

out.backward()
dist_out.backward()
self.check_tensor_eq(x.grad, dist_x.grad)
self.check_tensor_eq(y.grad, dist_y.grad)

return dist_out, dist_x.grad, dist_y.grad

def test_add_n(self):
self.test_body(
x_shape=[64, 32],
y_shape=[64, 32],
x_specs=[None, None],
y_specs=[None, None],
)

def run_test_case(self):
if self._backend == "cpu":
paddle.set_device("cpu")
elif self._backend == "gpu":
paddle.set_device("gpu:" + str(dist.get_rank()))
else:
raise ValueError("Only support cpu or gpu backend.")

self.test_add_n()


if __name__ == '__main__':
TestAddNApiForSemiAutoParallel().run_test_case()
10 changes: 10 additions & 0 deletions test/auto_parallel/test_semi_auto_parallel_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,16 @@ def test_several_replicated_spmd_api(self):
user_defined_envs=envs,
)

def test_add_n_api(self):
envs_list = test_base.gen_product_envs_list(
self._default_envs, self._changeable_envs
)
for envs in envs_list:
self.run_test_case(
"semi_auto_parallel_for_add_n.py",
user_defined_envs=envs,
)


if __name__ == "__main__":
unittest.main()

0 comments on commit b5789b8

Please sign in to comment.