@@ -68,34 +68,40 @@ def forward(self, logit, label, semantic_weights=None):
68
68
raise ValueError (
69
69
'The number of weights = {} must be the same as the number of classes = {}.'
70
70
.format (len (self .weight ), logit .shape [1 ]))
71
- if channel_axis == 1 :
72
- logit = paddle .transpose (logit , [0 , 2 , 3 , 1 ])
73
- if self .weight is None :
74
- loss = F .cross_entropy (
75
- logit , label , ignore_index = self .ignore_index , reduction = 'none' )
76
- else :
77
- label_one_hot = F .one_hot (label , logit .shape [- 1 ])
78
- loss = F .cross_entropy (
79
- logit ,
80
- label_one_hot * self .weight ,
81
- soft_label = True ,
82
- reduction = 'none' )
83
- loss = loss .squeeze (- 1 )
71
+
72
+ logit = paddle .transpose (logit , [0 , 2 , 3 , 1 ])
73
+ loss = F .cross_entropy (
74
+ logit ,
75
+ label ,
76
+ ignore_index = self .ignore_index ,
77
+ reduction = 'none' ,
78
+ weight = self .weight )
84
79
85
80
mask = label != self .ignore_index
86
81
mask = paddle .cast (mask , 'float32' )
82
+
87
83
loss = loss * mask
88
84
if semantic_weights is not None :
89
85
loss = loss * semantic_weights
90
86
87
+ if self .weight is not None :
88
+ _one_hot = F .one_hot (label , logit .shape [- 1 ])
89
+ coef = paddle .sum (_one_hot * self .weight , axis = - 1 )
90
+ else :
91
+ coef = paddle .ones_like (label )
92
+
91
93
label .stop_gradient = True
92
94
mask .stop_gradient = True
95
+
93
96
if self .top_k_percent_pixels == 1.0 :
94
- avg_loss = paddle .mean (loss ) / (paddle .mean (mask ) + self .EPS )
97
+ avg_loss = paddle .mean (loss ) / (paddle .mean (mask * coef ) + self .EPS )
95
98
return avg_loss
96
99
97
100
loss = loss .reshape ((- 1 , ))
98
101
top_k_pixels = int (self .top_k_percent_pixels * loss .numel ())
99
- loss , _ = paddle .topk (loss , top_k_pixels )
102
+ loss , indices = paddle .topk (loss , top_k_pixels )
103
+ coef = coef .reshape ((- 1 , ))
104
+ coef = paddle .gather (coef , indices )
105
+ coef .stop_gradient = True
100
106
101
- return loss .mean ()
107
+ return loss .mean () / ( paddle . mean ( coef ) + self . EPS )
0 commit comments