Skip to content

Commit 059299c

Browse files
Eddie-Wang1120co63oc
authored andcommitted
[Prim][PIR] support square op backward in prim pir (PaddlePaddle#64381)
* update square * update
1 parent c238eb9 commit 059299c

File tree

3 files changed

+38
-1
lines changed

3 files changed

+38
-1
lines changed

paddle/fluid/primitive/codegen/gen.py

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
'sin_grad',
7676
'cos_grad',
7777
'tanh_grad',
78+
'square_grad',
7879
]
7980

8081
# prim op with two inputs and one output, with no attribute

paddle/fluid/primitive/rule/vjp/details.h

+8
Original file line numberDiff line numberDiff line change
@@ -905,6 +905,14 @@ void log_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
905905
}
906906
}
907907

908+
template <typename T>
909+
void square_grad(const Tensor& x, const Tensor& out_grad, Tensor* x_grad) {
910+
if (x_grad) {
911+
Tensor x_grad_tmp = 2 * x * out_grad;
912+
set_output<T>(x_grad_tmp, x_grad);
913+
}
914+
}
915+
908916
template <typename T>
909917
void exp_grad(const Tensor& out, const Tensor& out_grad, Tensor* x_grad) {
910918
if (x_grad) {

test/legacy_test/test_activation_op.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -4305,6 +4305,7 @@ def test_check_grad(self):
43054305
'Out',
43064306
max_relative_error=0.007,
43074307
check_pir=True,
4308+
check_prim_pir=True,
43084309
check_pir_onednn=self.check_pir_onednn,
43094310
)
43104311

@@ -4323,6 +4324,17 @@ def init_dtype(self):
43234324
def test_check_output(self):
43244325
self.check_output(check_pir=True)
43254326

4327+
def test_check_grad(self):
4328+
if self.dtype == np.float16:
4329+
return
4330+
self.check_grad(
4331+
['X'],
4332+
'Out',
4333+
max_relative_error=0.007,
4334+
check_pir=True,
4335+
check_pir_onednn=self.check_pir_onednn,
4336+
)
4337+
43264338

43274339
class TestSquare_Complex128(TestSquare):
43284340
def init_dtype(self):
@@ -4331,6 +4343,17 @@ def init_dtype(self):
43314343
def test_check_output(self):
43324344
self.check_output(check_pir=True)
43334345

4346+
def test_check_grad(self):
4347+
if self.dtype == np.float16:
4348+
return
4349+
self.check_grad(
4350+
['X'],
4351+
'Out',
4352+
max_relative_error=0.007,
4353+
check_pir=True,
4354+
check_pir_onednn=self.check_pir_onednn,
4355+
)
4356+
43344357

43354358
class TestSquare_ZeroDim(TestSquare):
43364359
def init_shape(self):
@@ -4373,7 +4396,12 @@ def test_check_output(self):
43734396
def test_check_grad(self):
43744397
place = core.CUDAPlace(0)
43754398
self.check_grad_with_place(
4376-
place, ['X'], 'Out', numeric_grad_delta=0.5, check_pir=True
4399+
place,
4400+
['X'],
4401+
'Out',
4402+
numeric_grad_delta=0.5,
4403+
check_pir=True,
4404+
check_prim_pir=True,
43774405
)
43784406

43794407

0 commit comments

Comments
 (0)