File tree 1 file changed +5
-3
lines changed
1 file changed +5
-3
lines changed Original file line number Diff line number Diff line change @@ -99,13 +99,15 @@ def get_reverse_list(ori_shape, transforms):
99
99
def reverse_transform (pred , ori_shape , transforms , mode = 'nearest' ):
100
100
"""recover pred to origin shape"""
101
101
reverse_list = get_reverse_list (ori_shape , transforms )
102
+ intTypeList = [paddle .int8 , paddle .int16 , paddle .int32 , paddle .int64 ]
103
+ dtype = pred .dtype
102
104
for item in reverse_list [::- 1 ]:
103
105
if item [0 ] == 'resize' :
104
106
h , w = item [1 ][0 ], item [1 ][1 ]
105
- if paddle .get_device () == 'cpu' :
106
- pred = paddle .cast (pred , 'uint8 ' )
107
+ if paddle .get_device () == 'cpu' and dtype in intTypeList :
108
+ pred = paddle .cast (pred , 'float32 ' )
107
109
pred = F .interpolate (pred , (h , w ), mode = mode )
108
- pred = paddle .cast (pred , 'int32' )
110
+ pred = paddle .cast (pred , dtype )
109
111
else :
110
112
pred = F .interpolate (pred , (h , w ), mode = mode )
111
113
elif item [0 ] == 'padding' :
You can’t perform that action at this time.
0 commit comments