Skip to content

Commit 8e45c42

Browse files
update square
1 parent a246d2c commit 8e45c42

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

paddle/fluid/primitive/codegen/gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
'sin_grad',
6868
'cos_grad',
6969
'tanh_grad',
70+
'square_grad',
7071
]
7172

7273
# prim op with two inputs and one output, with no attribute

paddle/fluid/primitive/rule/vjp/details.h

+8
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,14 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
795795
}
796796
}
797797

798+
template <typename T>
799+
void square_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
800+
if (x_grad) {
801+
auto x_grad_tmp = 2 * x * out_grad;
802+
set_output<T>(x_grad_tmp, x_grad);
803+
}
804+
}
805+
798806
template <typename T>
799807
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
800808
if (x_grad) {

test/legacy_test/test_activation_op.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -4297,6 +4297,7 @@ def test_check_grad(self):
42974297
'Out',
42984298
max_relative_error=0.007,
42994299
check_pir=True,
4300+
check_prim_pir=True,
43004301
check_pir_onednn=self.check_pir_onednn,
43014302
)
43024303

@@ -4315,6 +4316,17 @@ def init_dtype(self):
43154316
def test_check_output(self):
43164317
self.check_output(check_pir=True)
43174318

4319+
def test_check_grad(self):
4320+
if self.dtype == np.float16:
4321+
return
4322+
self.check_grad(
4323+
['X'],
4324+
'Out',
4325+
max_relative_error=0.007,
4326+
check_pir=True,
4327+
check_pir_onednn=self.check_pir_onednn,
4328+
)
4329+
43184330

43194331
class TestSquare_Complex128(TestSquare):
43204332
def init_dtype(self):
@@ -4323,6 +4335,17 @@ def init_dtype(self):
43234335
def test_check_output(self):
43244336
self.check_output(check_pir=True)
43254337

4338+
def test_check_grad(self):
4339+
if self.dtype == np.float16:
4340+
return
4341+
self.check_grad(
4342+
['X'],
4343+
'Out',
4344+
max_relative_error=0.007,
4345+
check_pir=True,
4346+
check_pir_onednn=self.check_pir_onednn,
4347+
)
4348+
43264349

43274350
class TestSquare_ZeroDim(TestSquare):
43284351
def init_shape(self):
@@ -4365,7 +4388,12 @@ def test_check_output(self):
43654388
def test_check_grad(self):
43664389
place = core.CUDAPlace(0)
43674390
self.check_grad_with_place(
4368-
place, ['X'], 'Out', numeric_grad_delta=0.5, check_pir=True
4391+
place,
4392+
['X'],
4393+
'Out',
4394+
numeric_grad_delta=0.5,
4395+
check_pir=True,
4396+
check_prim_pir=True,
43694397
)
43704398

43714399

0 commit comments

Comments
 (0)