Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel] add chunk id for vpp in TensorDistAttr and OperatorDistAttr #59416

Merged
merged 6 commits into from
Nov 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions paddle/fluid/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ std::vector<int64_t> get_tensor_shape(const VarDesc* tensor) {
std::vector<std::string> OperatorDistAttr::fields_{"process_mesh",
"impl_type",
"impl_idx",
"chunk_id",
"is_recompute",
"execution_stream",
"stream_priority",
Expand All @@ -66,6 +67,7 @@ OperatorDistAttr& OperatorDistAttr::operator=(
std::swap(this->op_type_, tmp.op_type_);
std::swap(this->impl_type_, tmp.impl_type_);
std::swap(this->impl_idx_, tmp.impl_idx_);
std::swap(this->chunk_id_, tmp.chunk_id_);
std::swap(this->is_recompute_, tmp.is_recompute_);
std::swap(this->execution_stream_, tmp.execution_stream_);
std::swap(this->stream_priority_, tmp.stream_priority_);
Expand Down Expand Up @@ -100,6 +102,7 @@ void OperatorDistAttr::initialize(const OpDesc* op) {
op_type_ = op->Type();
impl_type_ = kDefault;
impl_idx_ = 0;
chunk_id_ = 0;
is_recompute_ = false;
execution_stream_ = kDefault;
stream_priority_ = 0;
Expand All @@ -113,6 +116,7 @@ void OperatorDistAttr::copy_from(const OperatorDistAttr& dist_attr) {
set_op_type(dist_attr.op_type());
set_impl_type(dist_attr.impl_type());
set_impl_idx(dist_attr.impl_idx());
set_chunk_id(dist_attr.chunk_id());
set_is_recompute(dist_attr.is_recompute());
set_execution_stream(dist_attr.execution_stream());
set_stream_priority(dist_attr.stream_priority());
Expand Down Expand Up @@ -359,6 +363,7 @@ std::string OperatorDistAttr::to_string() const {
std::string str;
str += "{impl_type: " + impl_type_ + ", ";
str += "impl_idx: " + std::to_string(impl_idx_) + ", ";
str += "chunk_id: " + std::to_string(chunk_id_) + ", ";
str += "execution_stream: " + execution_stream_ + ", ";
str += "stream_priority: " + std::to_string(stream_priority_) + ", ";
str += "scheduling_priority: " + std::to_string(scheduling_priority_) + ", ";
Expand Down Expand Up @@ -393,6 +398,7 @@ void OperatorDistAttr::from_proto(const OperatorDistAttrProto& proto) {
process_mesh_ = ProcessMesh::from_proto(proto.process_mesh());
impl_type_ = proto.impl_type();
impl_idx_ = proto.impl_idx();
chunk_id_ = proto.chunk_id();
}

OperatorDistAttrProto OperatorDistAttr::to_proto() const {
Expand All @@ -410,6 +416,7 @@ OperatorDistAttrProto OperatorDistAttr::to_proto() const {
proto.mutable_process_mesh()->CopyFrom(process_mesh_.to_proto());
proto.set_impl_type(impl_type_);
proto.set_impl_idx(impl_idx_);
proto.set_chunk_id(chunk_id_);
return proto;
}

Expand Down Expand Up @@ -443,6 +450,9 @@ bool operator==(const OperatorDistAttr& lhs, const OperatorDistAttr& rhs) {
if (lhs.impl_idx() != rhs.impl_idx()) {
return false;
}
if (lhs.chunk_id() != rhs.chunk_id()) {
return false;
}
if (lhs.execution_stream() != rhs.execution_stream()) {
return false;
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/fluid/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,10 @@ class OperatorDistAttr {

void set_impl_idx(const int64_t& impl_idx) { impl_idx_ = impl_idx; }

int64_t chunk_id() const { return chunk_id_; }

void set_chunk_id(const int64_t& chunk_id) { chunk_id_ = chunk_id; }

bool is_recompute() const { return is_recompute_; }

void set_is_recompute(bool is_recompute) { is_recompute_ = is_recompute; }
Expand Down Expand Up @@ -240,6 +244,7 @@ class OperatorDistAttr {
std::string op_type_;
std::string impl_type_ = kDefault;
int64_t impl_idx_ = 0;
int64_t chunk_id_ = 0;
bool is_recompute_ = false;
std::string execution_stream_ = kDefault;
bool force_record_event_ = false;
Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/pybind/auto_parallel_py.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ static inline void reset_operator_dist_attr(OperatorDistAttr *dist_attr) {
}
dist_attr->set_impl_type(kDefault);
dist_attr->set_impl_idx(0);
dist_attr->set_chunk_id(0);
dist_attr->clear_annotated();
}

Expand Down Expand Up @@ -435,6 +436,8 @@ void BindAutoParallel(py::module *m) {
.def_property("batch_dim",
&TensorDistAttr::batch_dim,
&TensorDistAttr::set_batch_dim)
.def_property(
"chunk_id", &TensorDistAttr::chunk_id, &TensorDistAttr::set_chunk_id)
.def_property("dynamic_dims",
&TensorDistAttr::dynamic_dims,
&TensorDistAttr::set_dynamic_dims)
Expand Down Expand Up @@ -531,6 +534,9 @@ void BindAutoParallel(py::module *m) {
.def_property("impl_idx",
&OperatorDistAttr::impl_idx,
&OperatorDistAttr::set_impl_idx)
.def_property("chunk_id",
&OperatorDistAttr::chunk_id,
&OperatorDistAttr::set_chunk_id)
.def_property("is_recompute",
&OperatorDistAttr::is_recompute,
&OperatorDistAttr::set_is_recompute)
Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/auto_parallel.proto
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ message TensorDistAttrProto {
// If the dynamic_dims[i] is True, the i-th dimension of the corresponding tensor
// is dynamic changed. Otherwise, the i-th dimension of the tensor is static determined.
repeated bool dynamic_dims = 4;

// This field is used to distinguish vars which are in same process_mesh and in different vpp chunk
optional int64 chunk_id = 5;
}

// This distributed attribute describes how to distribute the corresponding operator,
Expand Down Expand Up @@ -81,6 +84,9 @@ message OperatorDistAttrProto {
// This field tells which distributed implementations of this corresponding operator
// will be selected for the actual computation.
optional int64 impl_idx = 5;

// This field is used to distinguish ops which are in same process_mesh and in different vpp chunk
optional int64 chunk_id = 6;
}

// This proto describes the capability of one device such as the computation and memory.
Expand Down
14 changes: 13 additions & 1 deletion paddle/phi/core/distributed/auto_parallel/dist_attr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ using phi::distributed::auto_parallel::TensorDistAttrProto;

// partial is not allow annotated by user by now.
std::vector<std::string> TensorDistAttr::fields_{
"process_mesh", "dims_mapping", "batch_dim", "dynamic_dims"};
"process_mesh", "dims_mapping", "batch_dim", "chunk_id", "dynamic_dims"};

TensorDistAttr::TensorDistAttr(const std::vector<int64_t>& tensor_shape) {
set_default_dims_mapping(tensor_shape);
Expand All @@ -44,6 +44,7 @@ TensorDistAttr& TensorDistAttr::operator=(const TensorDistAttr& dist_attr) {
std::swap(this->process_mesh_, tmp.process_mesh_);
std::swap(this->dims_mapping_, tmp.dims_mapping_);
std::swap(this->batch_dim_, tmp.batch_dim_);
std::swap(this->chunk_id_, tmp.chunk_id_);
std::swap(this->dynamic_dims_, tmp.dynamic_dims_);
std::swap(this->annotated_, tmp.annotated_);
std::swap(this->partial_status_, tmp.partial_status_);
Expand All @@ -54,6 +55,7 @@ void TensorDistAttr::copy_from(const TensorDistAttr& dist_attr) {
set_process_mesh(dist_attr.process_mesh());
set_dims_mapping(dist_attr.dims_mapping());
set_batch_dim(dist_attr.batch_dim());
set_chunk_id(dist_attr.chunk_id());
set_dynamic_dims(dist_attr.dynamic_dims());
set_annotated(dist_attr.annotated());
set_partial_status(dist_attr.partial_status());
Expand All @@ -72,6 +74,10 @@ void TensorDistAttr::set_batch_dim(int64_t batch_dim) {
batch_dim_ = batch_dim;
}

void TensorDistAttr::set_chunk_id(const int64_t& chunk_id) {
chunk_id_ = chunk_id;
}

void TensorDistAttr::set_dynamic_dims(const std::vector<bool>& dynamic_dims) {
dynamic_dims_ = dynamic_dims;
}
Expand Down Expand Up @@ -265,6 +271,7 @@ std::string TensorDistAttr::to_string() const {
dist_str += "{process_mesh: " + process_mesh_.to_string() + ", ";
dist_str += "dims_mappings: [" + str_join(dims_mapping_) + "], ";
dist_str += "batch_dim: " + std::to_string(batch_dim_) + ", ";
dist_str += "chunk_id: " + std::to_string(chunk_id_) + ", ";
dist_str += "dynamic_dims: [" + str_join(dynamic_dims_) + "], ";
dist_str += "annotated: [" + str_join(annotated_) + "], ";
dist_str += "partial: " + partial_status_string() + ".}";
Expand All @@ -278,6 +285,7 @@ void TensorDistAttr::from_proto(const TensorDistAttrProto& proto) {
dims_mapping_[i] = proto.dims_mapping(i);
}
batch_dim_ = proto.batch_dim();
chunk_id_ = proto.chunk_id();
dynamic_dims_.resize(proto.dynamic_dims_size());
for (int i = 0; i < proto.dynamic_dims_size(); ++i) {
dynamic_dims_[i] = proto.dynamic_dims(i);
Expand All @@ -291,6 +299,7 @@ TensorDistAttrProto TensorDistAttr::to_proto() const {
proto.add_dims_mapping(i);
}
proto.set_batch_dim(batch_dim_);
proto.set_chunk_id(chunk_id_);
for (const auto& i : dynamic_dims_) {
proto.add_dynamic_dims(i);
}
Expand Down Expand Up @@ -328,6 +337,9 @@ bool operator==(const TensorDistAttr& lhs, const TensorDistAttr& rhs) {
if (lhs.batch_dim() != rhs.batch_dim()) {
return false;
}
if (lhs.chunk_id() != rhs.chunk_id()) {
return false;
}
if (lhs.dynamic_dims() != rhs.dynamic_dims()) {
return false;
}
Expand Down
5 changes: 5 additions & 0 deletions paddle/phi/core/distributed/auto_parallel/dist_attr.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ class TEST_API TensorDistAttr {

void set_batch_dim(int64_t batch_dim);

const int64_t& chunk_id() const { return chunk_id_; }

void set_chunk_id(const int64_t& chunk_id);

const std::vector<bool>& dynamic_dims() const { return dynamic_dims_; }

void set_dynamic_dims(const std::vector<bool>& dynamic_dims);
Expand Down Expand Up @@ -195,6 +199,7 @@ class TEST_API TensorDistAttr {
int64_t batch_dim_{0};
std::vector<bool> dynamic_dims_;
std::map<std::string, bool> annotated_;
int64_t chunk_id_{0};
// partial map would be small (less than mesh.size)
// iterate operation (copy and comparision) would more frequency than random
// element access. <key: dim on mesh, value: reduce type>
Expand Down
6 changes: 4 additions & 2 deletions python/paddle/distributed/auto_parallel/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,9 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
return x


def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
def shard_op(
op, process_mesh=None, in_shard_specs=None, out_shard_specs=None, **kwargs
):
"""
Shard an operation on a process mesh according to its input and output shard specification.

Expand Down Expand Up @@ -199,7 +201,7 @@ def shard_op(op, process_mesh=None, in_shard_specs=None, out_shard_specs=None):
else:
out_dims_mappings.append(None)
op = DistributedOperatorHelper(
op, process_mesh, in_dims_mappings, out_dims_mappings
op, process_mesh, in_dims_mappings, out_dims_mappings, kwargs
)
return op

Expand Down
17 changes: 14 additions & 3 deletions python/paddle/distributed/auto_parallel/static/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,8 +196,12 @@ def __str__(self):
partial_dims,
)

str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr.impl_idx, self.dist_attr.impl_type
str += (
", dist_impl idx: {} , dist_impl type: {}, chunk_id: {} }}".format(
self.dist_attr.impl_idx,
self.dist_attr.impl_type,
self.dist_attr.chunk_id,
)
)

return str
Expand All @@ -220,12 +224,18 @@ def __deepcopy__(self, memo):

class DistributedOperatorHelper:
def __init__(
self, serial_op, process_mesh, in_dims_mappings, out_dims_mappings
self,
serial_op,
process_mesh,
in_dims_mappings,
out_dims_mappings,
kwargs,
):
self._serial_op = serial_op
self._process_mesh = process_mesh
self._in_dims_mappings = in_dims_mappings
self._out_dims_mappings = out_dims_mappings
self._chunk_id = kwargs["chunk_id"] if "chunk_id" in kwargs else 0

def __call__(self, *args, **kwargs):
tensor_to_dims_mapping = {}
Expand Down Expand Up @@ -327,6 +337,7 @@ def __call__(self, *args, **kwargs):
tensor_dist_attr.dims_mapping = dims_mapping
tensor_dist_attr.mark_annotated("dims_mapping")
dist_op.dist_attr.process_mesh = self._process_mesh
dist_op.dist_attr.chunk_id = self._chunk_id
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,7 @@ def __str__(self):
)

str += f", is_parameter: {self.serial_tensor.is_parameter}"
str += f", chunk_id: {self.dist_attr.chunk_id}"

if self.dist_attr.is_annotated("dims_mapping"):
annotated_str = "annotated"
Expand Down
3 changes: 3 additions & 0 deletions test/auto_parallel/test_dist_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def run_dtensor_from_fn(self):
self.assertEqual(result.placements, placements)
else:
dist_attr.dynamic_dims = [0]
dist_attr.chunk_id = 0
self.assertIsInstance(result, paddle.static.Variable)
self.assertEqual(result.shape, (16,))
self.assertEqual(result.dist_attr, dist_attr)
Expand All @@ -85,6 +86,7 @@ def run_dtensor_from_fn(self):
self.assertEqual(result.placements, placements)
else:
dist_attr.dynamic_dims = [0]
dist_attr.chunk_id = 0
self.assertIsInstance(result, paddle.static.Variable)
self.assertEqual(result.shape, (16,))
self.assertEqual(result.dist_attr, dist_attr)
Expand All @@ -99,6 +101,7 @@ def run_dtensor_from_fn(self):
self.assertEqual(result.placements, placements)
else:
dist_attr.dynamic_dims = [0]
dist_attr.chunk_id = 0
self.assertIsInstance(result, paddle.static.Variable)
self.assertEqual(result.shape, (16,))
self.assertEqual(result.dist_attr, dist_attr)
Expand Down
15 changes: 13 additions & 2 deletions test/auto_parallel/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,15 +72,21 @@ def forward(self, input):
process_mesh1,
[["y", None, None]],
[[None, "x", None]],
chunk_id=0,
)
linear0_out = linear0(input)

gelu = auto.shard_op(F.gelu, process_mesh1, [["y", "x", None], None])
gelu = auto.shard_op(
F.gelu, process_mesh1, [["y", "x", None], None], chunk_id=0
)
gelu_out = gelu(linear0_out, approximate=True)

auto.shard_tensor(self.linear1.weight, shard_spec=["y", None])
linear1 = auto.shard_op(
self.linear1, process_mesh1[1], out_shard_specs=[["y", None, None]]
self.linear1,
process_mesh1[1],
out_shard_specs=[["y", None, None]],
chunk_id=1,
)
linear1_out = linear1(gelu_out)

Expand Down Expand Up @@ -181,6 +187,7 @@ def test_api(self):
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertEqual(dist_op.dist_attr.chunk_id, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(input.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1)
Expand All @@ -192,6 +199,7 @@ def test_api(self):
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertEqual(dist_op.dist_attr.chunk_id, 0)
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
linear0_out.name
)
Expand All @@ -205,6 +213,7 @@ def test_api(self):
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1)
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertEqual(dist_op.dist_attr.chunk_id, 0)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(
linear0_out.name
Expand All @@ -223,6 +232,7 @@ def test_api(self):
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertEqual(dist_op.dist_attr.chunk_id, 1)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_input_dist_attr(gelu_out.name)
self.assertEqual(tensor_dist_attr.process_mesh, process_mesh1[1])
Expand All @@ -234,6 +244,7 @@ def test_api(self):
self.assertEqual(dist_op.dist_attr.process_mesh, process_mesh1[1])
self.assertEqual(dist_op.dist_attr.impl_type, "default")
self.assertEqual(dist_op.dist_attr.impl_idx, 0)
self.assertEqual(dist_op.dist_attr.chunk_id, 1)
self.assertTrue(dist_op.dist_attr.is_annotated("process_mesh"))
tensor_dist_attr = dist_op.dist_attr.get_output_dist_attr(
linear1_out.name
Expand Down
Loading