Skip to content

Commit 8cef847

Browse files
authored
fix the bug in sigmoid_cross_entropy_with_logits_grad_kernel (#64253)
* fix the bug in sigmoid_cross_entropy_with_logits_grad_kernel * add the test for sigmoid_cross_entropy_with_logits_grad * modify the test * modify the grad kernel * modify the kernel
1 parent 78be56a commit 8cef847

4 files changed

+111
-6
lines changed

paddle/phi/kernels/cpu/sigmoid_cross_entropy_with_logits_grad_kernel.cc

+9-2
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,15 @@ void SigmoidCrossEntropyWithLogitsGradKernel(
4747
if (static_cast<int>(label) == ignore_index) {
4848
dx_data[idx] = static_cast<T>(0.);
4949
} else {
50-
T simoid_x = static_cast<T>(1) / (static_cast<T>(1) + std::exp(-x));
51-
T diff = simoid_x * pos_weight_idx - label;
50+
T term1 = (x > 0) ? static_cast<T>(1) : static_cast<T>(0);
51+
52+
T e_x = std::exp(-std::abs(x));
53+
T down = 1 + e_x;
54+
T abs_grad = (x >= 0) ? static_cast<T>(1) : static_cast<T>(-1);
55+
T up = -e_x * abs_grad * pos_weight_idx;
56+
T term3 = up / down;
57+
58+
T diff = term1 - label + term3;
5259
dx_data[idx] = dout * diff;
5360
}
5461
}

paddle/phi/kernels/gpu/sigmoid_cross_entropy_with_logits_grad_kernel.cu

+8-3
Original file line numberDiff line numberDiff line change
@@ -73,9 +73,14 @@ struct SigmoidBwdPosWeightFunctor {
7373
dx_data = static_cast<T>(0.);
7474
counts = 0;
7575
} else {
76-
T simoid_x =
77-
static_cast<T>(1) / (static_cast<T>(1) + phi::funcs::real_exp(-x));
78-
T diff = simoid_x * pos_weight - label;
76+
T term1 = (x > 0) ? static_cast<T>(1) : static_cast<T>(0);
77+
T e_x = phi::funcs::real_exp(-abs(x));
78+
T down = 1 + e_x;
79+
T abs_grad = (x >= 0) ? static_cast<T>(1) : static_cast<T>(-1);
80+
T up = -e_x * abs_grad * pos_weight;
81+
T term3 = up / down;
82+
83+
T diff = term1 - label + term3;
7984
dx_data = dout * diff;
8085
counts = 1;
8186
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import numpy as np
18+
from scipy.special import logit
19+
20+
import paddle
21+
from paddle import base
22+
23+
24+
class TestSigmoidCrossEntropyWithLogitsOpGradWithAutoGrad(unittest.TestCase):
25+
def setUp(self):
26+
np.random.seed(2023)
27+
paddle.seed(2023)
28+
self.places = [base.CPUPlace()]
29+
if base.core.is_compiled_with_cuda():
30+
self.places.append(base.CUDAPlace(0))
31+
self.batch_size = 64
32+
self.num_classes = 20
33+
34+
self.x = logit(
35+
np.random.uniform(0, 1, (self.batch_size, self.num_classes)).astype(
36+
"float32"
37+
)
38+
)
39+
40+
self.lable = np.random.uniform(
41+
0, 1, (self.batch_size, self.num_classes)
42+
).astype("float32")
43+
44+
self.pos_weight = np.random.uniform(
45+
0, 1, (self.batch_size, self.num_classes)
46+
).astype("float32")
47+
48+
def test_check_grad_with_auto_grad(self):
49+
def fn_ref(x, label, weight):
50+
out = paddle._C_ops.sigmoid_cross_entropy_with_logits(
51+
x, label, weight, False, -100
52+
)
53+
loss = out.sum()
54+
loss.backward()
55+
return out, x.grad
56+
57+
def fn_comp(x, label, weight):
58+
zeros = paddle.full((self.batch_size, self.num_classes), 0.0)
59+
t1 = paddle.where(x > 0, x, zeros)
60+
t2 = x * label
61+
t3 = paddle.log(1 + paddle.exp(-paddle.abs(x)))
62+
t4 = t1 - t2 + t3 * weight
63+
t5 = paddle.full((self.batch_size, self.num_classes), -100.0)
64+
out = paddle.where(label == t5, zeros, t4)
65+
loss = out.sum()
66+
loss.backward()
67+
return out, x.grad
68+
69+
def cal(fn, place):
70+
x1 = paddle.to_tensor(self.x, stop_gradient=False, place=place)
71+
label1 = paddle.to_tensor(self.lable)
72+
pos_weight1 = paddle.to_tensor(self.pos_weight, place=place)
73+
res = fn(x1, label1, pos_weight1)
74+
return res
75+
76+
for idx, p in enumerate(self.places):
77+
if idx == 0:
78+
paddle.set_device('cpu')
79+
else:
80+
paddle.set_device('gpu')
81+
82+
ref = cal(fn_ref, p)
83+
actual = cal(fn_comp, p)
84+
85+
for idx in range(len(ref)):
86+
np.testing.assert_allclose(
87+
ref[idx].numpy(), actual[idx].numpy(), atol=1e-6, rtol=1e-6
88+
)
89+
90+
91+
if __name__ == '__main__':
92+
unittest.main()

test/deprecated/legacy_test/test_sigmoid_cross_entropy_with_logits_op.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
1516
import unittest
1617

1718
import numpy as np
@@ -173,7 +174,7 @@ def test_check_output(self):
173174
self.check_output(check_pir=True)
174175

175176
def test_check_grad(self):
176-
self.check_grad(['X'], 'Out', check_pir=True)
177+
self.check_grad(['X'], 'Out', max_relative_error=0.0005, check_pir=True)
177178

178179

179180
class TestSigmoidCrossEntropyWithNorm(OpTest):

0 commit comments

Comments
 (0)