From 39fda14e76184fad1f34994cd4f68e8f7b856dd0 Mon Sep 17 00:00:00 2001 From: Yiqun Liu Date: Fri, 1 Dec 2023 17:02:14 +0800 Subject: [PATCH] Polish bfloat16 main_grad unittest for data parallel and sharding stage1. (#58842) * Polish bfloat16 main_grad unittest for data parallel. * Optimize unittest of sharding stage1. * Polish codes and add check of weights. * Polish unittest for sharding stage1. * Revert some minor changes. * Polish the compare of parameters. * Compute loss in float32. --- test/collective/fleet/dist_amp_base.py | 156 ++++++++++++ .../fleet/dygraph_dataparallel_bf16.py | 207 +++++++-------- .../dygraph_group_sharded_stage1_bf16.py | 240 ++++++++---------- 3 files changed, 350 insertions(+), 253 deletions(-) create mode 100644 test/collective/fleet/dist_amp_base.py diff --git a/test/collective/fleet/dist_amp_base.py b/test/collective/fleet/dist_amp_base.py new file mode 100644 index 00000000000000..e4b972472b6d4d --- /dev/null +++ b/test/collective/fleet/dist_amp_base.py @@ -0,0 +1,156 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import re +from collections import OrderedDict + +import numpy as np + +import paddle +from paddle.distributed.fleet.utils import mix_precision_utils +from paddle.nn import Linear, ReLU + +logging.basicConfig(level="INFO", format="%(message)s") + + +class MLP(paddle.nn.Layer): + def __init__(self, linear_size=1000): + super().__init__() + + self._linear1 = Linear(linear_size, linear_size) + self._linear2 = Linear(linear_size, linear_size) + self._linear3 = Linear(linear_size, 10) + self._relu = ReLU() + + def forward(self, inputs): + y = self._linear1(inputs) + y = self._linear2(y) + y = self._linear3(y) + y = self._relu(y) + return y + + +class RandomDataset(paddle.io.Dataset): + def __init__(self, num_samples=200, linear_size=1000): + self.num_samples = num_samples + self.linear_size = linear_size + self.samples = [] + for i in range(num_samples): + img = np.random.rand(self.linear_size).astype('float32') + self.samples.append(img) + + def __getitem__(self, idx): + return self.samples[idx] + + def __len__(self): + return self.num_samples + + +def create_optimizer(model, use_pure_bf16, use_main_grad): + if use_main_grad: + assert use_pure_bf16 + model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") + optimizer = paddle.optimizer.AdamW( + parameters=model.parameters(), + learning_rate=0.00001, + weight_decay=0.00001, + grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), + multi_precision=use_pure_bf16, + ) + if use_main_grad: + optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) + + return optimizer + + +def save_model_parameters(model): + param_dict = OrderedDict() + for param in model.parameters(): + param_dict[param.name] = param + return param_dict + + +def _extract_linear_order(param_names): + # for param_names from model.state_dict, they are as like: ["_linear1.weight", "_linear1.bias"] + # for master weight names from optimizer.state_dict, they are as like: ["linear_6.w_0", "linear_6.b_0"] + param_order = [] + for name in param_names: + param_id = re.findall(r"\d+", name) + assert len(param_id) >= 1 + param_order.append(int(param_id[0])) + return list(set(param_order)) + + +def _extract_param_order_dict(model_param_dict_o1, model_param_dict_o2): + param_names_o1 = list(model_param_dict_o1.keys()) + param_order_o1 = _extract_linear_order(param_names_o1) + param_order_o1.sort() + + param_names_o2 = list(model_param_dict_o2.keys()) + param_order_o2 = _extract_linear_order(param_names_o2) + param_order_o2.sort() + + assert len(param_order_o1) == len(param_order_o2) + + param_order_dict = {} + for i in range(len(param_order_o1)): + param_order_dict[param_order_o2[i]] = param_order_o1[i] + + logging.info(f"-- param_names_o1: {param_names_o1}") + logging.info(f"-- param_names_o2: {param_names_o2}") + logging.info(f"param_order_dict: {param_order_dict}") + return param_order_dict + + +def compare_state_dict( + model_param_dict_o1, model_param_dict_o2, optimizer_state_dict_o2 +): + master_weights = None + if optimizer_state_dict_o2.get("master_weights", None) is not None: + master_weights = optimizer_state_dict_o2["master_weights"] + assert master_weights is not None + master_weights_names = list(master_weights.keys()) + + param_names = list(model_param_dict_o1.keys()) + param_order_dict = _extract_param_order_dict( + model_param_dict_o1, model_param_dict_o2 + ) + param_master_pair = [] + + # We assume the order of params in param_names and master_weights_names is the same. + param_id = 0 + for master_weight_name in master_weights_names: + master_weight_id = re.findall(r"\d+", master_weight_name)[0] + param_id = param_order_dict[int(master_weight_id)] + for param_name in param_names: + if ( + master_weight_name.endswith("w_0") + and param_name.endswith("weight") + ) or ( + master_weight_name.endswith("b_0") + and param_name.endswith("bias") + ): + name_prefix = "linear" + param_id + if name_prefix in param_name: + param_master_pair.append([param_name, master_weight_name]) + + logging.info(f"-- master_weights_names: {master_weights_names}") + for pair in param_master_pair: + param_name = pair[0] + master_weight_name = pair[1] + logging.info(f"-- compare {param_name} with {master_weight_name}") + param_o1 = model_param_dict_o1[param_name] + master_param_o2 = master_weights[master_weight_name] + np.testing.assert_array_equal(param_o1.numpy(), master_param_o2.numpy()) diff --git a/test/collective/fleet/dygraph_dataparallel_bf16.py b/test/collective/fleet/dygraph_dataparallel_bf16.py index efc7b6f993d987..6ca31b82b514e7 100644 --- a/test/collective/fleet/dygraph_dataparallel_bf16.py +++ b/test/collective/fleet/dygraph_dataparallel_bf16.py @@ -1,5 +1,3 @@ -# -*- coding: UTF-8 -*- - # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,74 +12,33 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import numpy as np +from dist_amp_base import ( + MLP, + RandomDataset, + compare_state_dict, + create_optimizer, + save_model_parameters, +) import paddle -from paddle.distributed.fleet.utils import mix_precision_utils from paddle.distributed.fleet.utils.hybrid_parallel_util import ( fused_allreduce_gradients, ) -from paddle.nn import Linear, ReLU - -seed = 2022 -epoch = 2 -linear_size = 1000 - -np.random.seed(seed) -paddle.seed(seed) - - -class MLP(paddle.nn.Layer): - def __init__(self, linear_size=1000): - super().__init__() - - self._linear1 = Linear(linear_size, linear_size) - self._linear2 = Linear(linear_size, linear_size) - self._linear3 = Linear(linear_size, 10) - self._relu = ReLU() - - def forward(self, inputs): - y = self._linear1(inputs) - y = self._linear2(y) - y = self._linear3(y) - y = self._relu(y) - return y - - -class RandomDataset(paddle.io.Dataset): - def __init__(self, num_samples=200, linear_size=1000): - self.num_samples = num_samples - self.linear_size = linear_size - def __getitem__(self, idx): - img = np.random.rand(self.linear_size).astype('float32') - return img - - def __len__(self): - return self.num_samples - - -def optimizer_setting(model, use_pure_bf16, use_main_grad): - if use_main_grad: - assert use_pure_bf16 - model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.00001, - weight_decay=0.00001, - grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), - multi_precision=use_pure_bf16, - ) - if use_main_grad: - optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) - - return optimizer +logging.basicConfig(level="INFO", format="%(message)s") def train_mlp( - model, use_pure_bf16=False, use_main_grad=False, accumulate_grad=False + model, train_loader, use_pure_bf16=False, use_main_grad=False, acc_steps=1 ): - optimizer = optimizer_setting( + logging.info( + f"-- Train Info: use_pure_bf16={use_pure_bf16}, use_main_grad={use_main_grad}, acc_steps={acc_steps}" + ) + + optimizer = create_optimizer( model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad ) if use_pure_bf16: @@ -98,19 +55,9 @@ def train_mlp( "matmul_v2", "elementwise_add", "relu", - "reduce_mean", ] model = paddle.DataParallel(model) - paddle.seed(2023) - np.random.seed(2023) - train_loader = paddle.io.DataLoader( - RandomDataset(), - batch_size=100, - shuffle=False, - drop_last=True, - num_workers=0, - ) if not use_pure_bf16: for param in model.parameters(): t = paddle.cast( @@ -118,13 +65,20 @@ def train_mlp( ) param.set_value(t) + local_rank = paddle.distributed.get_rank() + losses = [] + epoch = 2 for eop in range(epoch): model.train() for batch_id, data in enumerate(train_loader()): data.stop_gradient = True + enable_stats = False # eop == 0 + if enable_stats: + logging.info("<<<<<<<<<<<< forward-backward >>>>>>>>>>>") + paddle.amp.debugging.enable_operator_stats_collection() with model.no_sync(): with paddle.amp.auto_cast( True, @@ -133,65 +87,92 @@ def train_mlp( custom_white_list=custom_white_list, ): out = model(data) - loss = paddle.mean(out) - losses.append(loss) + # compute loss in float32 + loss = paddle.mean(out.astype("float32")) - loss.backward() + # normal implementation for gradient accumulation. + if acc_steps != 1: + loss = loss / acc_steps - if not accumulate_grad: + losses.append(loss.item()) + loss.backward() + logging.info( + f"-- [rank={local_rank}] epoch {eop}, batch {batch_id}, loss: {loss.astype(paddle.float32).numpy()}" + ) + if enable_stats: + paddle.amp.debugging.disable_operator_stats_collection() + + if (batch_id + 1) % acc_steps == 0: + if enable_stats: + logging.info( + "<<<<<<<<<<<< fused_allreduce_gradients >>>>>>>>>>>" + ) + paddle.amp.debugging.enable_operator_stats_collection() fused_allreduce_gradients(list(model.parameters()), None) + if enable_stats: + paddle.amp.debugging.disable_operator_stats_collection() + if enable_stats: + logging.info("<<<<<<<<<<<< optimizer >>>>>>>>>>>") + paddle.amp.debugging.enable_operator_stats_collection() optimizer.step() optimizer.clear_grad() + if enable_stats: + paddle.amp.debugging.disable_operator_stats_collection() - if accumulate_grad: - fused_allreduce_gradients(list(model.parameters()), None) - - optimizer.step() - optimizer.clear_grad() - - return losses + model_param_dict = save_model_parameters(model) + optimizer_state_dict = optimizer.state_dict() + return losses, model_param_dict, optimizer_state_dict def test_dp_bf16(): if not paddle.amp.is_bfloat16_supported(): + logging.info("BFloat16 is not supported!") return + paddle.distributed.init_parallel_env() - mlp = MLP() - state_dict = mlp.state_dict() - - # dp bf16 O1 vs dp bf16 O2 main_grad - mlp1 = MLP() - mlp2 = MLP() - mlp1.set_state_dict(state_dict) - mlp2.set_state_dict(state_dict) - losses_o1 = train_mlp(mlp1, use_pure_bf16=False) - losses_o2 = train_mlp(mlp2, use_pure_bf16=True, use_main_grad=True) - for i in range(len(losses_o2)): - loss_o2 = paddle.cast(losses_o2[i], dtype='float32').detach() - loss_o1 = paddle.cast(losses_o1[i], dtype='float32').detach() - np.testing.assert_array_equal(loss_o2, loss_o1) - - # grad accumulation test - mlp3 = MLP() - mlp4 = MLP() - mlp3.set_state_dict(state_dict) - mlp4.set_state_dict(state_dict) - losses_acc_grad_o1 = train_mlp( - mlp3, use_pure_bf16=False, accumulate_grad=True - ) - losses_acc_grad_o2 = train_mlp( - mlp4, use_pure_bf16=True, use_main_grad=True, accumulate_grad=True + local_rank = paddle.distributed.get_rank() + paddle.seed(2023 + local_rank) + np.random.seed(2023 + local_rank) + + # For DataParallel, DataLoader should feed different data for different GPUs. + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, ) - for i in range(len(losses_acc_grad_o2)): - loss_acc_grad_o2 = paddle.cast( - losses_acc_grad_o2[i], dtype='float32' - ).detach() - loss_acc_grad_o1 = paddle.cast( - losses_acc_grad_o1[i], dtype='float32' - ).detach() - np.testing.assert_array_equal(loss_acc_grad_o2, loss_acc_grad_o1) + + single_mlp = MLP() + state_dict = single_mlp.state_dict() + + def _compare_bf16_o1_vs_o2(acc_steps=1): + # dp bf16 O1 vs dp bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + losses_o1, model_param_dict_o1, optimizer_state_dict_o1 = train_mlp( + mlp1, train_loader, use_pure_bf16=False, acc_steps=acc_steps + ) + losses_o2, model_param_dict_o2, optimizer_state_dict_o2 = train_mlp( + mlp2, + train_loader, + use_pure_bf16=True, + use_main_grad=True, + acc_steps=acc_steps, + ) + np.testing.assert_array_equal(losses_o2, losses_o1) + compare_state_dict( + model_param_dict_o1, model_param_dict_o2, optimizer_state_dict_o2 + ) + + # no gradient accumulation + _compare_bf16_o1_vs_o2(acc_steps=1) + # gradient accumulation + _compare_bf16_o1_vs_o2(acc_steps=2) if __name__ == '__main__': diff --git a/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py b/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py index 9a69976b830cc6..65135dddbda5a7 100644 --- a/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py +++ b/test/collective/fleet/dygraph_group_sharded_stage1_bf16.py @@ -1,5 +1,3 @@ -# -*- coding: UTF-8 -*- - # Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -14,90 +12,50 @@ # See the License for the specific language governing permissions and # limitations under the License. +import logging + import numpy as np +from dist_amp_base import ( + MLP, + RandomDataset, + compare_state_dict, + create_optimizer, + save_model_parameters, +) import paddle from paddle.distributed import fleet -from paddle.distributed.fleet.utils import mix_precision_utils -from paddle.nn import Linear, ReLU - -seed = 2022 -epoch = 2 -linear_size = 1000 - -np.random.seed(seed) -paddle.seed(seed) - - -class MLP(paddle.nn.Layer): - def __init__(self, linear_size=1000): - super().__init__() - - self._linear1 = Linear(linear_size, linear_size) - self._linear2 = Linear(linear_size, linear_size) - self._linear3 = Linear(linear_size, 10) - self._relu = ReLU() - - def forward(self, inputs): - y = self._linear1(inputs) - y = self._linear2(y) - y = self._linear3(y) - y = self._relu(y) - return y - - -class RandomDataset(paddle.io.Dataset): - def __init__(self, num_samples=200, linear_size=1000): - self.num_samples = num_samples - self.linear_size = linear_size - - def __getitem__(self, idx): - img = np.random.rand(self.linear_size).astype('float32') - return img - - def __len__(self): - return self.num_samples - - -def optimizer_setting(model, use_pure_bf16, use_main_grad): - if use_main_grad: - assert use_pure_bf16 - model = mix_precision_utils.MixPrecisionLayer(model, dtype="bfloat16") - optimizer = paddle.optimizer.AdamW( - parameters=model.parameters(), - learning_rate=0.00001, - weight_decay=0.00001, - grad_clip=paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0), - multi_precision=use_pure_bf16, - ) - if use_main_grad: - optimizer = mix_precision_utils.MixPrecisionOptimizer(optimizer) - return optimizer +logging.basicConfig(level="INFO", format="%(message)s") def train_mlp( model, sharding_stage, + train_loader, use_pure_bf16=False, - accumulate_grad=False, + acc_steps=1, use_main_grad=False, test_scaler=False, ): - # bf16 not support dynamic loss scaling - # disable dynamic_loss_scaling to coverage distributed_scaler - dynamic_loss_scaling = False + logging.info( + f"-- Train Info: use_pure_bf16={use_pure_bf16}, use_main_grad={use_main_grad}, acc_steps={acc_steps}" + ) + scaler = None - scale_loss = 1024 if test_scaler: assert sharding_stage == 1 - assert not accumulate_grad + assert acc_steps == 1 + # bf16 not support dynamic loss scaling + # disable dynamic_loss_scaling to coverage distributed_scaler + dynamic_loss_scaling = False + scale_loss = 1024 scaler = paddle.amp.GradScaler( init_loss_scaling=scale_loss, use_dynamic_loss_scaling=dynamic_loss_scaling, ) scaler = fleet.distributed_scaler(scaler) - optimizer = optimizer_setting( + optimizer = create_optimizer( model=model, use_pure_bf16=use_pure_bf16, use_main_grad=use_main_grad ) @@ -105,21 +63,13 @@ def train_mlp( if use_pure_bf16: level = 'O2' custom_white_list = None - - amp_configs = { - "init_loss_scaling": scale_loss, - "use_pure_bf16": True, - "use_dynamic_loss_scaling": dynamic_loss_scaling, - } - strategy.amp = True - strategy.amp_configs = amp_configs + model = paddle.amp.decorate(models=model, dtype="bfloat16", level=level) else: level = 'O1' custom_white_list = [ "matmul_v2", "elementwise_add", "relu", - "reduce_mean", ] if sharding_stage == 1: @@ -137,16 +87,6 @@ def train_mlp( if sharding_stage == 1: optimizer = fleet.distributed_optimizer(optimizer) - paddle.seed(2023) - np.random.seed(2023) - train_loader = paddle.io.DataLoader( - RandomDataset(), - batch_size=100, - shuffle=False, - drop_last=True, - num_workers=0, - ) - if sharding_stage == 1: model.to(device="gpu") @@ -157,13 +97,20 @@ def train_mlp( ) param.set_value(t) + local_rank = paddle.distributed.get_rank() + losses = [] + epoch = 2 for eop in range(epoch): model.train() for batch_id, data in enumerate(train_loader()): data.stop_gradient = True + enable_stats = False # eop == 0 + if enable_stats: + logging.info("<<<<<<<<<<<< forward & backward >>>>>>>>>>>") + paddle.amp.debugging.enable_operator_stats_collection() with paddle.amp.auto_cast( True, level=level, @@ -171,57 +118,97 @@ def train_mlp( custom_white_list=custom_white_list, ): out = model(data) - loss = paddle.mean(out) - losses.append(loss) + # compute loss in float32 + loss = paddle.mean(out.astype("float32")) + + # normal implementation for gradient accumulation. + if acc_steps != 1: + loss = loss / acc_steps + + losses.append(loss.item()) + logging.info( + f"-- [rank={local_rank}] epoch {eop}, batch {batch_id}, loss: {loss.astype(paddle.float32).numpy()}" + ) if test_scaler: assert scaler is not None scaler.scale(loss).backward() + if enable_stats: + paddle.amp.debugging.disable_operator_stats_collection() scaler.step(optimizer) scaler.update() optimizer.clear_grad() else: loss.backward() - if not accumulate_grad: + if enable_stats: + paddle.amp.debugging.disable_operator_stats_collection() + if (batch_id + 1) % acc_steps == 0: + if enable_stats: + logging.info("<<<<<<<<<<<< optimizer >>>>>>>>>>>") + paddle.amp.debugging.enable_operator_stats_collection() optimizer.step() optimizer.clear_grad() + if enable_stats: + paddle.amp.debugging.disable_operator_stats_collection() - if accumulate_grad: - optimizer.step() - optimizer.clear_grad() - - return losses + model_param_dict = save_model_parameters(model) + optimizer_state_dict = optimizer.state_dict() + return losses, model_param_dict, optimizer_state_dict def test_stage1_bf16(): if not paddle.amp.is_bfloat16_supported(): + logging.info("BFloat16 is not supported!") return + paddle.distributed.init_parallel_env() + local_rank = paddle.distributed.get_rank() + paddle.seed(2023 + local_rank) + np.random.seed(2023 + local_rank) + + # For Sharding, DataLoader should feed different data for different GPUs. + train_loader = paddle.io.DataLoader( + RandomDataset(), + batch_size=100, + shuffle=False, + drop_last=True, + num_workers=0, + ) mlp = MLP() state_dict = mlp.state_dict() - # stage1 bf16 O1 vs stage1 bf16 O2 main_grad - mlp1 = MLP() - mlp2 = MLP() - mlp1.set_state_dict(state_dict) - mlp2.set_state_dict(state_dict) - o1_losses = train_mlp( - mlp1, - sharding_stage=1, - use_pure_bf16=False, - ) - o2_losses = train_mlp( - mlp2, - sharding_stage=1, - use_pure_bf16=True, - use_main_grad=True, - ) - for i in range(len(o1_losses)): - o1_32_loss = paddle.cast(o1_losses[i], dtype='float32').detach() - o2_32_loss = paddle.cast(o2_losses[i], dtype='float32').detach() - np.testing.assert_array_equal(o1_32_loss, o2_32_loss) + def _compare_bf16_o1_vs_o2(acc_steps=1): + # stage1 bf16 O1 vs stage1 bf16 O2 main_grad + mlp1 = MLP() + mlp2 = MLP() + mlp1.set_state_dict(state_dict) + mlp2.set_state_dict(state_dict) + o1_losses, model_param_dict_o1, optimizer_state_dict_o1 = train_mlp( + mlp1, + sharding_stage=1, + train_loader=train_loader, + use_pure_bf16=False, + acc_steps=acc_steps, + ) + o2_losses, model_param_dict_o2, optimizer_state_dict_o2 = train_mlp( + mlp2, + sharding_stage=1, + train_loader=train_loader, + use_pure_bf16=True, + use_main_grad=True, + acc_steps=acc_steps, + ) + np.testing.assert_array_equal(o2_losses, o1_losses) + compare_state_dict( + model_param_dict_o1, model_param_dict_o2, optimizer_state_dict_o2 + ) + + # no gradient accumulation + _compare_bf16_o1_vs_o2(acc_steps=1) + # gradient accumulation + _compare_bf16_o1_vs_o2(acc_steps=2) # stage1 scaler test with main_grad mlp3 = MLP() @@ -229,6 +216,7 @@ def test_stage1_bf16(): train_mlp( mlp3, sharding_stage=1, + train_loader=train_loader, use_pure_bf16=True, use_main_grad=True, test_scaler=True, @@ -240,40 +228,12 @@ def test_stage1_bf16(): train_mlp( mlp4, sharding_stage=1, + train_loader=train_loader, use_pure_bf16=True, use_main_grad=False, test_scaler=True, ) - # grad accumulation test - mlp5 = MLP() - mlp6 = MLP() - mlp5.set_state_dict(state_dict) - mlp6.set_state_dict(state_dict) - o1_losses_grad_acc = train_mlp( - mlp5, - sharding_stage=1, - use_pure_bf16=False, - accumulate_grad=True, - ) - o2_losses_grad_acc = train_mlp( - mlp6, - sharding_stage=1, - use_pure_bf16=True, - use_main_grad=True, - accumulate_grad=True, - ) - for i in range(len(o2_losses_grad_acc)): - o2_loss_grad_acc = paddle.cast( - o2_losses_grad_acc[i], dtype='float32' - ).detach() - o1_loss_grad_acc = paddle.cast( - o1_losses_grad_acc[i], dtype='float32' - ).detach() - np.testing.assert_array_equal(o2_loss_grad_acc, o1_loss_grad_acc) - - return - if __name__ == '__main__': test_stage1_bf16()