Skip to content

Commit

Permalink
[hybird] fix pipeline section program Parameter (#35847)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding authored Sep 18, 2021
1 parent 5ba9fe6 commit 67c6363
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def remove_param(input_name):

if out_name in param_name_to_offload_name:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1380,10 +1380,18 @@ def _initialization_broadcast(self):
return

startup_block = self._startup_program.global_block()

params = []
for param in startup_block.iter_parameters():
params.append(param)
params = startup_block.all_parameters()

broadcast_params = []
for param in params:
broadcast_params.append(param)
# optimize_cast need broadcast fp16 param
fp16_param_name = param.name + '.cast_fp16'
if startup_block.has_var(fp16_param_name):
fp16_param = startup_block.var(fp16_param_name)
broadcast_params.append(fp16_param)

for param in broadcast_params:
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
Expand All @@ -1395,8 +1403,8 @@ def _initialization_broadcast(self):
})
startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
inputs={'X': broadcast_params},
outputs={'Out': broadcast_params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})

Expand Down
12 changes: 12 additions & 0 deletions python/paddle/fluid/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4381,6 +4381,18 @@ def _create_vars(self, block, ori_block):
name=var,
type=core.VarDesc.VarType.READER,
persistable=source_var.persistable)
elif isinstance(source_var, Parameter):
dest_var = block.create_parameter(
name=source_var.name,
shape=source_var.shape,
dtype=source_var.dtype,
type=source_var.type,
lod_level=source_var.lod_level,
stop_gradient=source_var.stop_gradient,
trainable=source_var.trainable,
optimize_attr=source_var.optimize_attr,
regularizer=source_var.regularizer,
error_clip=source_var.error_clip)
else:
dest_var = block._clone_variable(source_var, False)
self._clone_var_attr(dest_var, source_var)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def test_opt_sharding_with_pp(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
])

Expand Down Expand Up @@ -152,6 +154,8 @@ def test_opt_sharding_with_pp_with_allreduce_fuse(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_sync_comm_stream'
])

Expand Down Expand Up @@ -212,7 +216,9 @@ def test_opt_sharding_with_pp_amp_gclip(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -284,7 +290,9 @@ def test_opt_sharding_with_pp_amp_gclip_fuse_gm(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -376,7 +384,7 @@ def test_opt_sharding_with_pp_amp_gclip_boundary(self):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id',
'c_comm_init', 'c_sync_comm_stream'
'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -427,7 +435,7 @@ def test_opt_sharding_with_pp_amp_gclip_boundary_card1(self):
'uniform_random', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -762,7 +762,9 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -928,7 +930,10 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -1023,7 +1028,11 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_offload(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -1121,7 +1130,10 @@ def test_hybrid_with_pp_dp_amp_fp16allreduce_optimize_cast_with_gradient_fuse(
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_sync_comm_stream'
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down Expand Up @@ -1211,7 +1223,9 @@ def test_hybrid_with_pp_dp_amp_with_gradient_fuse(self):
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_gen_nccl_id', 'c_comm_init',
'c_gen_nccl_id', 'c_comm_init', 'c_sync_comm_stream'
'c_gen_nccl_id', 'c_comm_init', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream'
])

self.assertEqual(main_prog_op_types, [
Expand Down

0 comments on commit 67c6363

Please sign in to comment.