@@ -29,16 +29,18 @@ def huber_loss_forward(val, delta):
29
29
return delta * (abs_val - 0.5 * delta )
30
30
31
31
32
- def huber_loss_wraper (x , y , delta ):
32
+ def huber_loss_wrapper (x , y , delta ):
33
33
a = paddle ._C_ops .huber_loss (x , y , delta )
34
34
return a
35
35
36
36
37
37
class TestHuberLossOp (OpTest ):
38
38
def setUp (self ):
39
39
self .op_type = 'huber_loss'
40
+ self .prim_op_type = "comp"
40
41
self .python_out_sig = ["Out" ]
41
- self .python_api = huber_loss_wraper
42
+ self .python_api = huber_loss_wrapper
43
+ self .public_python_api = huber_loss_wrapper
42
44
43
45
self .delta = 1.0
44
46
self .init_dtype ()
@@ -65,7 +67,7 @@ def set_shape(self):
65
67
return (100 , 1 )
66
68
67
69
def test_check_output (self ):
68
- self .check_output ()
70
+ self .check_output (check_prim_pir = True )
69
71
70
72
def test_check_grad_normal (self ):
71
73
self .check_grad (['X' , 'Y' ], 'Out' )
@@ -105,8 +107,10 @@ def init_dtype(self):
105
107
class TestHuberLossBF16Op (OpTest ):
106
108
def setUp (self ):
107
109
self .op_type = 'huber_loss'
110
+ self .prim_op_type = "comp"
108
111
self .python_out_sig = ["Out" ]
109
- self .python_api = huber_loss_wraper
112
+ self .python_api = huber_loss_wrapper
113
+ self .public_python_api = huber_loss_wrapper
110
114
111
115
self .delta = 1.0
112
116
self .init_dtype ()
@@ -142,7 +146,7 @@ def set_shape(self):
142
146
return (100 , 1 )
143
147
144
148
def test_check_output (self ):
145
- self .check_output_with_place (self .place )
149
+ self .check_output_with_place (self .place , check_prim_pir = True )
146
150
147
151
def test_check_grad_normal (self ):
148
152
self .check_grad_with_place (self .place , ['X' , 'Y' ], 'Out' )
0 commit comments