From 044ec8b5ee45a60e912966c8de0f085d30a86f08 Mon Sep 17 00:00:00 2001 From: WangXi Date: Sat, 18 Sep 2021 20:03:22 +0800 Subject: [PATCH] fix param persistable --- .../sharding/offload_helper.py | 4 +- .../meta_optimizers/sharding_optimizer.py | 41 +++++++++++++++++-- 2 files changed, 39 insertions(+), 6 deletions(-) diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py index 1011a495fdc879..524c2e1d7d9a97 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding/offload_helper.py @@ -346,9 +346,9 @@ def opt_sharding_cast_fp32param(self, recompute_to_fp16 = dict() def remove_param(input_name): - global_params.pop(input_name) + global_params.remove(input_name) if input_name in local_params: - local_params.pop(input_name) + local_params.remove(input_name) if input_name in param_to_fp16: fp16_param = param_to_fp16.pop(input_name) if fp16_param in fp16_param_to_recompute: diff --git a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py index ded804b28452c9..1bf3a51aa055be 100755 --- a/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py @@ -537,7 +537,7 @@ def minimize_impl(self, self._insert_loss_grad_scale_op() # apply optimize offload or optimize cast - self._apply_optimize_offload_pass() + self._apply_optimize_offload_pass(params_grads) # step6: (optional) sharding gradient merge self._sharding_gradient_merge() @@ -1394,11 +1394,44 @@ def _initialization_broadcast(self): startup_block = self._startup_program.global_block() params = startup_block.all_parameters() + # NOTE(wangxi): if param is not persistable, program.clone will + # failed, so we remove no persistable param, re add param as a var + for param in params: + if not param.persistable: + name = param.name + shape = param.shape + dtype = param.dtype + type = param.type + lod_level = param.lod_level + stop_gradient = param.stop_gradient + trainable = param.trainable + optimize_attr = param.optimize_attr + regularizer = param.regularizer + + have_dist_attr = False + is_distributed = False + if hasattr(param, 'is_distributed'): + have_dist_attr = True + is_distributed = param.is_distributed + + startup_block._remove_var(name, sync=False) + var = startup_block.create_var( + name=name, + shape=shape, + dtype=dtype, + type=type, + lod_level=lod_level, + stop_gradient=stop_gradient, + trainable=trainable, + persistable=False) + if have_dist_attr: + var.is_distributed = is_distributed + # offload and optimize_cast will insert broadcast op broadcast_params = set() for op in startup_block.ops: if op.type == 'c_broadcast': - broadcast_params.add(op.desc.output_arg_names[0]) + broadcast_params.add(op.desc.output_arg_names()[0]) for param in params: if param.name in broadcast_params: continue @@ -1415,8 +1448,8 @@ def _initialization_broadcast(self): startup_block.append_op( type='c_sync_comm_stream', - inputs={'X': broadcast_params}, - outputs={'Out': broadcast_params}, + inputs={'X': params}, + outputs={'Out': params}, attrs={'ring_id': self.dp_ring_id, OP_ROLE_KEY: OpRole.Forward})