Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dygraph]Integration sharding stage2 function #38151

Merged
merged 2 commits into from
Dec 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,21 +16,19 @@
#Commit: 8acbec718f3c70a6b9785470bb9e05cd84fc3f8e

import copy
import time
import logging
import numpy as np
from math import inf
from itertools import chain
from functools import reduce
from collections import OrderedDict

import paddle
import paddle.fluid as fluid
from paddle import framework
from paddle.fluid import core
import paddle.distributed as dist
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from paddle.distributed.collective import _get_global_group

from ...utils.internal_storage import ParamStorage
from ...meta_parallel.sharding.sharding_utils import Type, device_guard, ShardingClipGrad
Expand Down Expand Up @@ -59,14 +57,14 @@ class ShardingOptimizerStage2(Optimizer):
# Feature Notes:
# 1. Unified memory for parameters and parameters.grad to InternalStorage.
# 2. Support the segmentation of optimizer parameters and partial updating of parameters.
# 3. Dynamically adjust training parameters and models
# 3. Dynamically adjust training parameters and models.
# 4. Support offload function.
# 5. Support the establishment of independent communication groups.
# 6. Broadcast_fp16 is not supported now.
def __init__(self,
params,
optim,
group,
group=None,
broadcast_fp16=False,
offload=False,
device="gpu",
Expand All @@ -78,13 +76,16 @@ def __init__(self,
self._dtype_rank_params = OrderedDict(
) # {dtype:[param1,param2]} device, rank, params
self._param2rank = {}
self._segment_params = []
self.__segment_params = []
self._rank_buffer_size = {} # {dtype: {rank: numel+alignment}}
self._param2align = {} # {param.name: align}

# Default information
self._optim_defaults = kw
self._optim = optim
self._ori_parameter_list = self._optim._parameter_list
self._ori_param_groups = self._optim._param_groups

assert hasattr(self._optim, "_master_weights"
), "Must use optimizer with _master_weights attribute"
self._local_params = params
Expand All @@ -94,8 +95,8 @@ def __init__(self,
filter(lambda x: x.trainable and x.dtype == Type.fp16.value,
self._local_params))) > 0

assert group is not None, "Distributed communication group is must be gived"
self.group = group
group = _get_global_group() if group is None else group
self.world_size = group.nranks
self.rank = group.rank

Expand All @@ -119,7 +120,7 @@ def __init__(self,
self._master_params = {}

# Update optimizer parameters and adjust parameter storage and use according to rank.
self.update_opt_status()
self._update_opt_status()

def _generate_master_params(self, trainable_params):
if self.offload:
Expand All @@ -137,7 +138,7 @@ def _generate_master_params(self, trainable_params):
self._optim._master_weights[param.name] = paddle.cast(
param, Type.fp32.value)

def update_opt_status(self):
def _update_opt_status(self):
"""Update optimizer status and parameter storage information, and special functions to be developed.
"""
# func 1
Expand All @@ -147,12 +148,12 @@ def update_opt_status(self):

# Segement helpers

def segment_params(self):
def _segment_params(self):
"""
Divide all optimizer parameters equally into rank.
"""
if len(self._segment_params) == 0:
self._segment_params, param_lists = [
if len(self.__segment_params) == 0:
self.__segment_params, param_lists = [
[] for _ in range(self.world_size)
], [[] for _ in range(self.world_size)]
sizes = [0] * self.world_size
Expand All @@ -165,9 +166,8 @@ def segment_params(self):
sizes[rank] += np.prod(param.shape) if param.trainable else 0

for rank, params in enumerate(param_lists):
# param_group_rank = copy.copy(params)
self._segment_params[rank].extend(params)
return self._segment_params
self.__segment_params[rank].extend(params)
return self.__segment_params

@property
def local_params(self):
Expand All @@ -177,7 +177,7 @@ def local_params(self):
def param2rank(self):
"""Map the params to the rank which owns them"""
if len(self._param2rank) == 0:
for rank, params in enumerate(self.segment_params()):
for rank, params in enumerate(self._segment_params()):
for param in params:
self._param2rank[param.name] = rank
return self._param2rank
Expand Down Expand Up @@ -271,32 +271,31 @@ def step(self):
"""

if self.offload:
self._optim._parameter_list = [
param for name, param in self._master_params.items()
]
params_list = list(self._master_params.values())
else:
# Synchronize optimizer parameters for the current rank
if len(self.dtype_rank_params.keys(
)) == 1 and Type.fp32.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp32.value][self.rank]
elif len(self.dtype_rank_params.keys(
)) == 1 and Type.fp16.value in self.dtype_rank_params.keys():
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank]
else:
self._optim._parameter_list = self.dtype_rank_params[
Type.fp16.value][self.rank] + self.dtype_rank_params[
Type.fp32.value][self.rank]
params_list = []
for dtype in self.dtype_rank_params.keys():
params_list.extend(self.dtype_rank_params[dtype][self.rank])

params_name_list = list(map(lambda p: p.name, params_list))
if not isinstance(self._optim._param_groups[0], dict):
self._optim._parameter_list = params_list
self._optim._param_groups = params_list
else:
for param_group in self._optim._param_groups:
p_group = []
for param in param_group['params']:
if param.name in params_name_list:
p_group.append(params_list[params_name_list.index(
param.name)])
param_group['params'] = p_group

# Run the optimizer of the current rank step
if self.offload:
with device_guard(self.rank, self.offload_device):
with device_guard(device=self.offload_device):
self._optim.step()

for param in self._optim._parameter_list:
self._master_params[param.name].set_value(param)

dev_id = 0 if paddle.get_device() == "cpu" else int(
paddle.get_device().split(":")[1])

Expand All @@ -312,10 +311,11 @@ def step(self):
self._broadcast_params()

# Return full parameters to optimizer parameters
self._optim._parameter_list = self._local_params
self._optim._parameter_list = self._ori_parameter_list
self._optim._param_groups = self._ori_param_groups

def clear_cache(self):
self._segment_params.clear()
def _clear_cache(self):
self.__segment_params.clear()
self._dtype_rank_params.clear()
self._param2rank.clear()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@
from itertools import chain
from functools import reduce
from collections import deque
from types import MethodType

import paddle
from paddle import nn
import paddle.distributed as dist
from paddle.distributed.collective import _get_global_group

from ...utils.internal_storage import GradStorage
from ...meta_optimizers.dygraph_optimizer.sharding_optimizer_stage2 import ShardingOptimizerStage2
Expand Down Expand Up @@ -57,7 +59,7 @@ def __init__(
self,
layer,
sharding_optimizer,
group,
group=None,
sync_buffers=False,
pertrain_sync_models=True,
buffer_max_size=2**23, #8MB
Expand All @@ -83,13 +85,12 @@ def __init__(
self._accumulate_grads = accumulate_grads

# Communication related attributes
assert group is not None, "Distributed communication group is must be gived"
self._group = group
self._world_size_scaling = 1.0 / self._group.nranks
assert self._group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = self._group.rank
group = _get_global_group() if group is None else group
self._world_size_scaling = 1.0 / group.nranks
assert group.nranks > 1, "Training must be distributed, ranks must be greater than 1"
self._rank = group.rank
self._global_root_rank = 0 # picking rank 0 as the reference
self._global_ranks = self._group.ranks
self._default_device = device

# Global statistical parameters
Expand All @@ -112,8 +113,8 @@ def __init__(
self._has_grad_storage = []
self._grad_storage_list = []

# offload
# TODO(haohongxiang): Now it's not supported for multi-optimizers using Offload strategy
# Offload
# TODO(haohongxiang): Now it's not be supported for multi-optimizers using Offload strategy
self._offload_optims = list(
filter(lambda optim: optim.offload, self._sharding_optimizers))
if len(self._offload_optims) > 0:
Expand All @@ -134,6 +135,11 @@ def __init__(
# Set tasks flow
self._tasks_flow = deque()

# Define optimizer step and clear_grad
if self._accumulate_grads:
self._redefine_opt_step()
self._redefine_opt_clear()

def forward(self, *inputs, **kwargs):
"""
A wrapper for Sharding Stage2 layer.
Expand Down Expand Up @@ -161,7 +167,7 @@ def forward(self, *inputs, **kwargs):

return fw

def clear_gradients(self):
def _clear_gradients(self):
"""
Set zero to the gradient of the optimizer's current rank trainable parameters.
"""
Expand All @@ -176,7 +182,7 @@ def clear_gradients(self):
if param.name in self._param_grads and param.grad is not None:
param.clear_gradient()

def grad_scale(self):
def _grad_scale(self):
"""
Before the gradient accumulation, scale the gradient.
"""
Expand Down Expand Up @@ -287,9 +293,6 @@ def _clear_counters(self):
for grad_storage in self._grad_storage_list:
grad_storage.reset_checked_in()

if not self._accumulate_grads:
self._grads_flipped = False

def _get_reduce_fn(self, index, param, dst_rank):
"""
There are two ways to reduce gradient.
Expand Down Expand Up @@ -412,7 +415,6 @@ def _setup_backward_hooks(self):
self._bw_hooks.pop().remove()

# Go through the parameters, attach the hook
self._grad_accs = []
if not self.training:
return

Expand Down Expand Up @@ -500,9 +502,6 @@ def _detect_train_change(self):
# Whether parameters trainability changed
trainability_changed = trainable_mask != self._trainable_mask

# The whole model is not trainable but we still have grad hooks
trainability_changed |= not self.training and len(self._bw_hooks) > 0

if trainability_changed:
logging.warning(
"Trainable params changed, because of eval/train mode or parameter freezing/unfreeze."
Expand Down Expand Up @@ -548,3 +547,25 @@ def _rank_buffer_size(self, buffer_max_size, model_size):
format(rank_buffer_size[Type.fp32.value] / 2**18, model_size / 2
**18))
return rank_buffer_size

def _redefine_opt_step(self):
if not self._accumulate_grads:
return
grad_func = self._grad_scale
for opt in self._sharding_optimizers:
opt_step = opt.step

def _opt_step(self):
grad_func()
opt_step()

opt.step = MethodType(_opt_step, opt)

def _redefine_opt_clear(self):
clear_func = self._clear_gradients

def _opt_clear(self):
clear_func()

for opt in self._sharding_optimizers:
opt.clear_grad = MethodType(_opt_clear, opt)
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def __call__(self, params_grads):


@contextlib.contextmanager
def device_guard(dev_id, device="cpu"):
def device_guard(dev_id=0, device="cpu"):
origin_device = paddle.device.get_device()
if device == "cpu":
paddle.set_device(device)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def train_mlp():
oss_optimizer.step()

# oss_optimizer clear cache
oss_optimizer.clear_cache()
oss_optimizer._clear_cache()


if __name__ == '__main__':
Expand Down
Loading