Skip to content

Commit

Permalink
Adapt more amp uts in PIR (#62880)
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f authored Mar 21, 2024
1 parent 714ddbe commit 984b284
Show file tree
Hide file tree
Showing 4 changed files with 307 additions and 7 deletions.
8 changes: 4 additions & 4 deletions test/amp/amp_base_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import paddle
from paddle import nn
from paddle.base import core
from paddle.framework import in_dynamic_mode
from paddle.framework import in_dynamic_or_pir_mode


def copy_bits_from_float_to_uint16(f):
Expand Down Expand Up @@ -68,7 +68,7 @@ def _build_optimizer(
grad_clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
else:
grad_clip = None
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
assert model is not None
parameters = model.parameters()
else:
Expand All @@ -82,7 +82,7 @@ def _build_optimizer(
epsilon=1e-4,
weight_decay=0.01,
)
if not in_dynamic_mode() and use_amp:
if not in_dynamic_or_pir_mode() and use_amp:
optimizer = paddle.static.amp.decorate(
optimizer,
amp_lists,
Expand Down Expand Up @@ -178,7 +178,7 @@ def forward(self, x):
def build_conv_model(
use_amp, amp_dtype="float16", amp_level="O1", use_promote=False
):
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
model = SimpleConvNet()
optimizer = _build_optimizer(use_amp=False, model=model)
if use_amp and amp_dtype == "float16":
Expand Down
141 changes: 141 additions & 0 deletions test/amp/test_amp_promote.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,100 @@ def test_o2_promote_off(self):
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
class TestPirAmpPromoteStats(AmpTestBase):
def check_promote_results(
self, dtype, level, use_promote, expected_op_calls, debug_info
):
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
model, optimizer, scaler = build_conv_model(
use_amp=True,
amp_dtype=dtype,
amp_level=level,
use_promote=use_promote,
)
model.train()

with paddle.amp.auto_cast(
enable=True,
dtype=dtype,
level=level,
use_promote=use_promote,
):
x = paddle.static.data(
'x', shape=[1, 1, 6, 6], dtype='float32'
)
out = model(x)
loss = paddle.mean(out)
scaled = scaler.scale(loss)
scaler.minimize(optimizer, scaled)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
paddle.amp.debugging.enable_operator_stats_collection()
exe.run(
main,
feed={
'x': np.random.random([1, 1, 6, 6]).astype('float32'),
},
fetch_list=[loss],
)
paddle.amp.debugging.disable_operator_stats_collection()
op_stats = paddle.base.core.get_low_precision_op_list()

self._check_op_calls(
op_stats,
expected_fp16_calls=expected_op_calls,
debug_info=debug_info,
)

def test_o2_promote_on(self):
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
expected_fp16_calls = {
"pd_op.conv2d": 1,
"pd_op.add": 2,
"pd_op.relu": 0,
"pd_op.matmul": 1,
"pd_op.softmax": 1,
"pd_op.mean": 1,
"pd_op.adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=True,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_on",
)

def test_o2_promote_off(self):
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
expected_fp16_calls = {
"pd_op.conv2d": 1,
"pd_op.add": 2,
"pd_op.relu": 1,
"pd_op.matmul": 1,
"pd_op.softmax": 1,
"pd_op.mean": 1,
"pd_op.adamw_": 4,
}
self.check_promote_results(
'float16',
'O2',
use_promote=False,
expected_op_calls=expected_fp16_calls,
debug_info="TestEagerAmpPromoteStats/test_o2_promote_off",
)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
Expand Down Expand Up @@ -220,5 +314,52 @@ def test_o2_use_promote_off(self):
self.assertEqual(linear_out.dtype, paddle.float16)


@unittest.skipIf(
not core.is_compiled_with_cuda()
or paddle.device.cuda.get_device_capability()[0] < 7.0,
"run test when gpu's compute capability is at least 7.0.",
)
class TestPirAmpPromoteSimple(AmpTestBase):
def init_net(self):
self._conv = paddle.nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=3, bias_attr=False
)
self._linear = paddle.nn.Linear(in_features=4, out_features=4)

def test_o2_use_promote_on(self):
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
self.init_net()
with paddle.amp.auto_cast(level='O2'):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)

self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float32)

def test_o2_use_promote_off(self):
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
self.init_net()
with paddle.amp.auto_cast(level='O2', use_promote=False):
x = paddle.rand(shape=[1, 1, 6, 6], dtype='float32')
conv_out = self._conv(x)
y = paddle.rand(shape=conv_out.shape, dtype='float16')
add_out = conv_out + y
linear_out = self._linear(add_out)

self.assertEqual(conv_out.dtype, paddle.float16)
self.assertEqual(add_out.dtype, paddle.float16)
self.assertEqual(linear_out.dtype, paddle.float16)


if __name__ == '__main__':
unittest.main()
85 changes: 84 additions & 1 deletion test/amp/test_collect_operator_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import unittest

import numpy as np
from amp_base_models import build_while_model

import paddle
Expand All @@ -38,7 +39,7 @@ def _check_result(self, dtype):
self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)

if dtype == "float16":
if dtype == paddle.float16:
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)

Expand Down Expand Up @@ -67,6 +68,88 @@ def test_context(self):
self._check_result(dtype=out.dtype)


class TestOpStatsPir(unittest.TestCase):
def _check_result(self, dtype):
# Returned the dict.
op_list = paddle.base.core.get_low_precision_op_list()

self.assertTrue('pd_op.add' in op_list)
self.assertTrue('pd_op.conv2d' in op_list)

conv2d_called = op_list['pd_op.conv2d'].split(',')
add_called = op_list['pd_op.add'].split(',')
add_num = 0
conv_num = 0
for i in range(4):
add_num += int(add_called[i])
conv_num += int(add_called[i])

self.assertTrue(conv_num == 1)
self.assertTrue(add_num == 1)

if dtype == paddle.float16:
self.assertTrue(int(conv2d_called[0]) == 1)
self.assertTrue(int(add_called[0]) == 1)

def test_enable_disable(self):
if not paddle.is_compiled_with_cuda():
return
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.static.data('x', [10, 3, 32, 32], 'float32')

with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
paddle.amp.debugging.enable_operator_stats_collection()
exe.run(
main,
feed={
'x': np.random.random([10, 3, 32, 32]).astype(
'float32'
),
},
fetch_list=[out],
)
paddle.amp.debugging.disable_operator_stats_collection()
self._check_result(dtype=out.dtype)

def test_context(self):
if not paddle.is_compiled_with_cuda():
return
paddle.set_flags({"FLAGS_pir_apply_inplace_pass": 0})
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
conv = paddle.nn.Conv2D(3, 2, 3)
x = paddle.static.data('x', [10, 3, 32, 32], 'float32')
with paddle.amp.auto_cast(enable=True, level='O2'):
out = conv(x)

place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
with paddle.amp.debugging.collect_operator_stats():
exe.run(
main,
feed={
'x': np.random.random([10, 3, 32, 32]).astype(
'float32'
),
},
fetch_list=[out],
)
self._check_result(dtype=out.dtype)


class TestOpStatsStatic(unittest.TestCase):
def test_while_op(self):
paddle.enable_static()
Expand Down
80 changes: 78 additions & 2 deletions test/amp/test_compare_accuracy_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,17 @@

import unittest

import numpy as np

import paddle
from paddle.base import core


@unittest.skipIf(
not core.is_compiled_with_cuda(), "not support cpu TestCompareAccuracyApi"
not core.is_compiled_with_cuda(),
"not support cpu TestEagerCompareAccuracyApi",
)
class TestCompareAccuracyApi(unittest.TestCase):
class TestEagerCompareAccuracyApi(unittest.TestCase):
def calc(self, path, dtype):
paddle.base.core.set_nan_inf_debug_path(path)
x = paddle.to_tensor(
Expand Down Expand Up @@ -67,5 +70,78 @@ def test2(self):
)


@unittest.skipIf(
not core.is_compiled_with_cuda(),
"not support cpu TestPirCompareAccuracyApi",
)
class TestPirCompareAccuracyApi(unittest.TestCase):
def calc(self, path, dtype):
paddle.base.core.set_nan_inf_debug_path(path)
with paddle.pir_utils.IrGuard():
startup = paddle.static.Program()
main = paddle.static.Program()
with paddle.static.program_guard(main, startup):
x = paddle.static.data(
'x',
[
4,
],
dtype,
)
y = paddle.static.data(
'y',
[
4,
],
dtype,
)
# normal
z1 = x + y
# inf
z2 = x * y
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
exe.run(startup)
exe.run(
main,
feed={
'x': np.array([2000, 3000, 4, 0]).astype(dtype),
'y': np.array([100, 500, 2, 10000]).astype(dtype),
},
fetch_list=[z2],
)

def test(self):
paddle.set_flags(
{"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 3}
)
fp32_path = "workerlog_fp32_log_dir"
fp16_path = "workerlog_fp16_log_dir"
self.calc(fp32_path, "float32")
self.calc(fp16_path, "float16")

out_excel = "compare_accuracy_out_excel.csv"
paddle.amp.debugging.compare_accuracy(
fp32_path,
fp16_path,
out_excel,
loss_scale=1,
dump_all_tensors=False,
)

def test2(self):
fp32_path = "workerlog_fp32_log_dir"
fp16_path = "workerlog_fp16_null_log_dir"
self.calc(fp32_path, "float32")
out_excel = "compare_accuracy_out_excel_2.csv"
paddle.amp.debugging.compare_accuracy(
fp32_path,
fp16_path,
out_excel,
loss_scale=1,
dump_all_tensors=False,
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 984b284

Please sign in to comment.