Skip to content

Commit

Permalink
fix param persistable
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding committed Sep 18, 2021
1 parent b51cd14 commit 044ec8b
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ def opt_sharding_cast_fp32param(self,
recompute_to_fp16 = dict()

def remove_param(input_name):
global_params.pop(input_name)
global_params.remove(input_name)
if input_name in local_params:
local_params.pop(input_name)
local_params.remove(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,7 @@ def minimize_impl(self,
self._insert_loss_grad_scale_op()

# apply optimize offload or optimize cast
self._apply_optimize_offload_pass()
self._apply_optimize_offload_pass(params_grads)

# step6: (optional) sharding gradient merge
self._sharding_gradient_merge()
Expand Down Expand Up @@ -1394,11 +1394,44 @@ def _initialization_broadcast(self):
startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()

# NOTE(wangxi): if param is not persistable, program.clone will
# failed, so we remove no persistable param, re add param as a var
for param in params:
if not param.persistable:
name = param.name
shape = param.shape
dtype = param.dtype
type = param.type
lod_level = param.lod_level
stop_gradient = param.stop_gradient
trainable = param.trainable
optimize_attr = param.optimize_attr
regularizer = param.regularizer

have_dist_attr = False
is_distributed = False
if hasattr(param, 'is_distributed'):
have_dist_attr = True
is_distributed = param.is_distributed

startup_block._remove_var(name, sync=False)
var = startup_block.create_var(
name=name,
shape=shape,
dtype=dtype,
type=type,
lod_level=lod_level,
stop_gradient=stop_gradient,
trainable=trainable,
persistable=False)
if have_dist_attr:
var.is_distributed = is_distributed

# offload and optimize_cast will insert broadcast op
broadcast_params = set()
for op in startup_block.ops:
if op.type == 'c_broadcast':
broadcast_params.add(op.desc.output_arg_names[0])
broadcast_params.add(op.desc.output_arg_names()[0])

for param in params:
if param.name in broadcast_params: continue
Expand All @@ -1415,8 +1448,8 @@ def _initialization_broadcast(self):

startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': broadcast_params},
outputs={'Out': broadcast_params},
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})

Expand Down

0 comments on commit 044ec8b

Please sign in to comment.