Skip to content

Commit

Permalink
optimizer sharding add optimize_cast
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding committed Sep 18, 2021
1 parent 67c6363 commit b51cd14
Show file tree
Hide file tree
Showing 2 changed files with 232 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name
from .shard import Shard

__all__ = []

Expand All @@ -23,11 +25,8 @@ class OffloadHelper(object):
cuda_place_type = 1
cuda_pinned_place_type = 2

def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def __init__(self, ring_id=None):
self.ring_id = ring_id

def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
Expand All @@ -50,6 +49,20 @@ def _insert_cast_op(self, block, idx, src_name, dst_name):
OP_ROLE_KEY: OpRole.Optimize
})

def _insert_broadcast_op(self, block, idx, param):
if self.ring_id is None: return
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})

def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
Expand Down Expand Up @@ -206,20 +219,25 @@ def remove_param(input_name):

# step5: startup_block add offload
visited_vars = set()
# FIXME(wangxi): should insert in idx, need move comm init to the head.
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue

if out_name in param_name_to_offload_name:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
self._insert_offload_op(startup_block, insert_idx,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)

visited_vars.add(out_name)

Expand Down Expand Up @@ -303,3 +321,181 @@ def offload(self, block, startup_block):

block._sync_with_cpp()
startup_block._sync_with_cpp()

def opt_sharding_cast_fp32param(self,
block,
startup_block,
params,
offload=False):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(pout,) = adam(p)
(p_fp16) = cast(p)
broadcast(p_fp16)
"""
global_params = set()
local_params = set()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()

def remove_param(input_name):
global_params.pop(input_name)
if input_name in local_params:
local_params.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)

# step1: record param
global_params = set(params)
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
local_params.add(param)

# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
# TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast
if op.type == 'coalesce_tensor':
continue
for input_name in op.desc.input_arg_names():
if input_name not in global_params:
continue

# param which will be used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue

# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue

if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param

param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
if param not in global_params: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
if offload:
self._create_offload_var(param, offload_var_name,
[block, startup_block])

# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param,
offload_var_name)

assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])

if offload:
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue

# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in global_params:
block._remove_op(idx, sync=False)
continue

# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])

# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)

# step5: remove fp32 param which not need
for idx, op in enumerate(block.ops):
if op.type not in ['coalesce_tensor', 'c_broadcast']:
continue
for input_name in op.desc.input_arg_names():
if input_name in param_to_fp16:
op._rename_input(input_name, param_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in param_to_fp16:
op._rename_output(output_name, param_to_fp16[output_name])

for param in global_params:
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True

if param not in local_params:
block._remove_var(param, sync=False)

# step6: startup_block add offload
visited_vars = set()
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue

if out_name in param_to_fp16:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
var_name, offload_var_name)

self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])

self._insert_broadcast_op(startup_block, insert_idx,
var_name)

if var_name not in local_params:
param = startup_block.var(out_name)
param.persistable = False

visited_vars.add(out_name)

block._sync_with_cpp()
startup_block._sync_with_cpp()
Original file line number Diff line number Diff line change
Expand Up @@ -329,6 +329,7 @@ def _insert_allreduce_for_pp(self, params_grads):
if self.pp_degree == 1: return

strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs

main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()
Expand Down Expand Up @@ -399,18 +400,21 @@ def _insert_allreduce_for_pp(self, params_grads):
first_optimize_op_index += (len(main_block.ops) - len_of_ops)
len_of_ops = len(main_block.ops)

optimize_cast = sharding_configs['optimize_cast']
optimizer_param = utils.insert_broadcast_param_ops(
main_block,
len_of_ops,
self.dp_ring_id, [x[0].name for x in params_grads],
self.dp_ring_id,
[x[0].name for x in params_grads],
self._shard,
OpRole.Optimize,
use_calc_stream=True,
rank=self.dp_rank,
strategy=strategy)
# should close fuse when optimize_cast
strategy=None if optimize_cast else strategy)
logger.info("Optimizer param in this rank {}".format(
optimizer_param))
if not strategy.fuse_grad_merge:
if not strategy.fuse_grad_merge and not optimize_cast:
assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
insert_allreduce_ops(
Expand Down Expand Up @@ -458,27 +462,35 @@ def _insert_loss_grad_scale_op(self):

main_block._sync_with_cpp()

def _apply_optimize_offload_pass(self):
def _apply_optimize_offload_pass(self, params_grads):
strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs
main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block()

dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None

# optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be
# overlap with calc, otherwise it will slower down training severely.
if sharding_configs["optimize_offload"]:
logger.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper = OffloadHelper(ring_id=dp_ring_id)
offload_helper.offload(main_block, startup_block)
# The optimize_cast is already included in offload_fp32param
offload_helper.offload_fp32param(main_block, startup_block)
elif sharding_configs['optimize_cast']:
logger.info("Sharding with optimize cast !")
# NOTE(wangxi): optimize_cast will persist fp16 param, it
# will take more memory, but will be faster. Trade space for time.
offload_helper = OffloadHelper()
offload_helper.cast_fp32param_in_optimize(main_block, startup_block)
offload_helper = OffloadHelper(ring_id=dp_ring_id)
if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block,
[x[0].name for x in params_grads])
else:
offload_helper.cast_fp32param_in_optimize(main_block,
startup_block)

def _dump_program_for_debug(self):
main_block = self._main_program.global_block()
Expand Down Expand Up @@ -1382,25 +1394,25 @@ def _initialization_broadcast(self):
startup_block = self._startup_program.global_block()
params = startup_block.all_parameters()

broadcast_params = []
# offload and optimize_cast will insert broadcast op
broadcast_params = set()
for op in startup_block.ops:
if op.type == 'c_broadcast':
broadcast_params.add(op.desc.output_arg_names[0])

for param in params:
broadcast_params.append(param)
# optimize_cast need broadcast fp16 param
fp16_param_name = param.name + '.cast_fp16'
if startup_block.has_var(fp16_param_name):
fp16_param = startup_block.var(fp16_param_name)
broadcast_params.append(fp16_param)

for param in broadcast_params:
if param.name in broadcast_params: continue
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.dp_ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})

startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': broadcast_params},
Expand Down

0 comments on commit b51cd14

Please sign in to comment.