Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Auto Parallel] Improve the dist op interface and the compatible computation #39014

Merged
merged 11 commits into from
Jan 20, 2022
72 changes: 24 additions & 48 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
Expand All @@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
return changed


Expand Down
92 changes: 90 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, program=None):
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}

# Distributed programs
self._dist_main_programs = {}
Expand All @@ -80,6 +82,10 @@ def serial_program(self, program):
"This distributed context has already been realted to a serial program"
self._serial_program = program

@property
def serial_ordered_nodes(self):
return self._serial_ordered_nodes

@property
def process_meshes(self):
return self._process_meshes
Expand Down Expand Up @@ -186,6 +192,18 @@ def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
else:
return None

# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor

def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
Expand Down Expand Up @@ -218,6 +236,35 @@ def get_op_dist_attr_for_graph(self, serial_op_node):
else:
return None

# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None

def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
Expand Down Expand Up @@ -248,6 +295,44 @@ def init_dist_attr_for_program(self):
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True

def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
return True
return False

ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."

def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
Expand All @@ -257,7 +342,8 @@ def init_dist_attr_for_graph(self):
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
Expand Down Expand Up @@ -397,7 +483,9 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs":
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _init_default_dist_attr(self):
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
self._dist_attr.impl_idx = 0
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False

Expand Down Expand Up @@ -217,7 +217,8 @@ def __str__(self):

str += ", pipeline stage: {}".format(None)

str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type)

return str

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
from . import dist_softmax
from . import dist_transpose
from . import dist_default
from . import dist_eltwise
from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling
Loading