Skip to content

Commit bee5166

Browse files
authored
fix a bug when scaling images on cpu (#1405)
1 parent d9ac194 commit bee5166

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

paddleseg/core/infer.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -99,13 +99,15 @@ def get_reverse_list(ori_shape, transforms):
9999
def reverse_transform(pred, ori_shape, transforms, mode='nearest'):
100100
"""recover pred to origin shape"""
101101
reverse_list = get_reverse_list(ori_shape, transforms)
102+
intTypeList = [paddle.int8, paddle.int16, paddle.int32, paddle.int64]
103+
dtype = pred.dtype
102104
for item in reverse_list[::-1]:
103105
if item[0] == 'resize':
104106
h, w = item[1][0], item[1][1]
105-
if paddle.get_device() == 'cpu':
106-
pred = paddle.cast(pred, 'uint8')
107+
if paddle.get_device() == 'cpu' and dtype in intTypeList:
108+
pred = paddle.cast(pred, 'float32')
107109
pred = F.interpolate(pred, (h, w), mode=mode)
108-
pred = paddle.cast(pred, 'int32')
110+
pred = paddle.cast(pred, dtype)
109111
else:
110112
pred = F.interpolate(pred, (h, w), mode=mode)
111113
elif item[0] == 'padding':

0 commit comments

Comments
 (0)