Skip to content

Commit

Permalink
[auto parallel] shard optimizer enhance (#59575)
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu authored Dec 1, 2023
1 parent 39fda14 commit 91fa5ff
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 2 deletions.
1 change: 1 addition & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2304,6 +2304,7 @@
output : Tensor(param_out), Tensor(master_param_out)
infer_meta :
func : SgdInferMeta
spmd_rule : SgdInferSpmd
kernel :
func : sgd {dense, dense, dense, dense -> dense, dense},
sgd_dense_param_sparse_grad {dense, dense, selected_rows, dense -> dense, dense},
Expand Down
44 changes: 44 additions & 0 deletions paddle/phi/infermeta/spmd_rules/optimizer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -217,5 +217,49 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param,
use_global_beta_pow);
}

SpmdInfo SgdInferSpmd(const DistMetaTensor& param,
const DistMetaTensor& learning_rate,
const DistMetaTensor& grad,
const DistMetaTensor& master_param,
bool multi_precision) {
SpmdInfo param_grad_spmd = ElementwiseBinaryInferSpmd(param, grad);
TensorDistAttr param_dist_attr_spmd =
PADDLE_GET(TensorDistAttr, param_grad_spmd.first[0]);
TensorDistAttr grad_dist_attr_spmd =
PADDLE_GET(TensorDistAttr, param_grad_spmd.first[1]);

VLOG(3) << "The source dims mapping for param is: "
<< auto_parallel::str_join(param.dist_attr().dims_mapping());
VLOG(3) << "The source dims mapping for grad is: "
<< auto_parallel::str_join(grad.dist_attr().dims_mapping());
VLOG(3) << "The inter dims mapping for param (master param if available) "
<< "after elementwise spmd is: "
<< auto_parallel::str_join(param.dist_attr().dims_mapping());
VLOG(3) << "The inter dims mapping for grad after elementwise spmd is: "
<< auto_parallel::str_join(grad.dist_attr().dims_mapping());

TensorDistAttr param_dist_attr =
CopyTensorDistAttrForOutput(param_dist_attr_spmd);
TensorDistAttr grad_dist_attr =
CopyTensorDistAttrForOutput(grad_dist_attr_spmd);
TensorDistAttr lr_dist_attr =
CopyTensorDistAttrForOutput(learning_rate.dist_attr());
TensorDistAttr master_param_dist_attr =
master_param.initialized()
? CopyTensorDistAttrForOutput(master_param.dist_attr())
: TensorDistAttr();
param_dist_attr.set_dims_mapping(param_dist_attr_spmd.dims_mapping());
grad_dist_attr.set_dims_mapping(grad_dist_attr_spmd.dims_mapping());
if (master_param.initialized()) {
master_param_dist_attr.set_dims_mapping(
param_dist_attr_spmd.dims_mapping());
}
lr_dist_attr.set_dims_mapping(learning_rate.dist_attr().dims_mapping());

return {
{param_dist_attr, lr_dist_attr, grad_dist_attr, master_param_dist_attr},
{param_dist_attr, master_param_dist_attr}};
}

} // namespace distributed
} // namespace phi
6 changes: 6 additions & 0 deletions paddle/phi/infermeta/spmd_rules/optimizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,5 +60,11 @@ SpmdInfo AdamwInferSpmdDynamic(const DistMetaTensor& param,
bool multi_precision,
bool use_global_beta_pow);

SpmdInfo SgdInferSpmd(const DistMetaTensor& param,
const DistMetaTensor& learning_rate,
const DistMetaTensor& grad,
const DistMetaTensor& master_param,
bool multi_precision);

} // namespace distributed
} // namespace phi
4 changes: 2 additions & 2 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,8 +414,8 @@ def __init__(self, optimizer, shard_fn=None):
optimizer is not None
), "The argument `optimizer` cannot be empty."
assert isinstance(
optimizer, paddle.optimizer.AdamW
), "`paddle.distributed.ShardOptimizer` only supports AdamW optimizer for now."
optimizer, (paddle.optimizer.AdamW, paddle.optimizer.SGD)
), "`paddle.distributed.ShardOptimizer` only supports AdamW and SGD optimizer for now."

self.target_block = (
paddle.base.framework.default_main_program().global_block()
Expand Down
1 change: 1 addition & 0 deletions test/auto_parallel/semi_auto_parallel_simple_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,7 @@ def run_dynamic(self, layer, shard_input=False, is_pp=False):
opt = paddle.optimizer.SGD(
learning_rate=0.1, parameters=layer.parameters()
)
opt = dist.shard_optimizer(opt)
for _ in range(5):
image, label = self.init_input_data()
if shard_input:
Expand Down

0 comments on commit 91fa5ff

Please sign in to comment.