Skip to content

Commit 7f0b174

Browse files
authored
Fix CrossEntropyLoss bug (#1100)
1 parent 1e4bba5 commit 7f0b174

File tree

1 file changed

+22
-16
lines changed

1 file changed

+22
-16
lines changed

paddleseg/models/losses/cross_entropy_loss.py

+22-16
Original file line numberDiff line numberDiff line change
@@ -68,34 +68,40 @@ def forward(self, logit, label, semantic_weights=None):
6868
raise ValueError(
6969
'The number of weights = {} must be the same as the number of classes = {}.'
7070
.format(len(self.weight), logit.shape[1]))
71-
if channel_axis == 1:
72-
logit = paddle.transpose(logit, [0, 2, 3, 1])
73-
if self.weight is None:
74-
loss = F.cross_entropy(
75-
logit, label, ignore_index=self.ignore_index, reduction='none')
76-
else:
77-
label_one_hot = F.one_hot(label, logit.shape[-1])
78-
loss = F.cross_entropy(
79-
logit,
80-
label_one_hot * self.weight,
81-
soft_label=True,
82-
reduction='none')
83-
loss = loss.squeeze(-1)
71+
72+
logit = paddle.transpose(logit, [0, 2, 3, 1])
73+
loss = F.cross_entropy(
74+
logit,
75+
label,
76+
ignore_index=self.ignore_index,
77+
reduction='none',
78+
weight=self.weight)
8479

8580
mask = label != self.ignore_index
8681
mask = paddle.cast(mask, 'float32')
82+
8783
loss = loss * mask
8884
if semantic_weights is not None:
8985
loss = loss * semantic_weights
9086

87+
if self.weight is not None:
88+
_one_hot = F.one_hot(label, logit.shape[-1])
89+
coef = paddle.sum(_one_hot * self.weight, axis=-1)
90+
else:
91+
coef = paddle.ones_like(label)
92+
9193
label.stop_gradient = True
9294
mask.stop_gradient = True
95+
9396
if self.top_k_percent_pixels == 1.0:
94-
avg_loss = paddle.mean(loss) / (paddle.mean(mask) + self.EPS)
97+
avg_loss = paddle.mean(loss) / (paddle.mean(mask * coef) + self.EPS)
9598
return avg_loss
9699

97100
loss = loss.reshape((-1, ))
98101
top_k_pixels = int(self.top_k_percent_pixels * loss.numel())
99-
loss, _ = paddle.topk(loss, top_k_pixels)
102+
loss, indices = paddle.topk(loss, top_k_pixels)
103+
coef = coef.reshape((-1, ))
104+
coef = paddle.gather(coef, indices)
105+
coef.stop_gradient = True
100106

101-
return loss.mean()
107+
return loss.mean() / (paddle.mean(coef) + self.EPS)

0 commit comments

Comments
 (0)