@@ -167,6 +167,7 @@ def setAxis(self):
167
167
168
168
def setUp (self ):
169
169
self ._set_op_type ()
170
+ self .prim_op_type = "comp"
170
171
self .dtype = self .get_dtype ()
171
172
self .axis = 0
172
173
self .num = 3
@@ -186,6 +187,7 @@ def setUp(self):
186
187
'Out' : [('out%d' % i , self .out [i ]) for i in range (len (self .out ))]
187
188
}
188
189
self .python_api = paddle .unbind
190
+ self .public_python_api = paddle .unbind
189
191
self .python_out_sig = ['out%d' % i for i in range (len (self .out ))]
190
192
191
193
def get_dtype (self ):
@@ -195,10 +197,12 @@ def _set_op_type(self):
195
197
self .op_type = "unbind"
196
198
197
199
def test_check_output (self ):
198
- self .check_output (check_pir = True )
200
+ self .check_output (check_pir = True , check_prim_pir = True )
199
201
200
202
def test_check_grad (self ):
201
- self .check_grad (['X' ], ['out0' , 'out1' , 'out2' ], check_pir = True )
203
+ self .check_grad (
204
+ ['X' ], ['out0' , 'out1' , 'out2' ], check_pir = True , check_prim_pir = True
205
+ )
202
206
203
207
204
208
class TestUnbindOp1 (TestUnbindOp ):
@@ -263,47 +267,73 @@ class TestUnbindOp1_Complex64(TestUnbindOp1):
263
267
def get_dtype (self ):
264
268
return np .complex64
265
269
270
+ def test_check_output (self ):
271
+ self .check_output (check_pir = True )
272
+
266
273
267
274
class TestUnbindOp2_Complex64 (TestUnbindOp2 ):
268
275
def get_dtype (self ):
269
276
return np .complex64
270
277
278
+ def test_check_output (self ):
279
+ self .check_output (check_pir = True )
280
+
271
281
272
282
class TestUnbindOp3_Complex64 (TestUnbindOp3 ):
273
283
def get_dtype (self ):
274
284
return np .complex64
275
285
286
+ def test_check_output (self ):
287
+ self .check_output (check_pir = True )
288
+
276
289
277
290
class TestUnbindOp4_Complex64 (TestUnbindOp4 ):
278
291
def get_dtype (self ):
279
292
return np .complex64
280
293
294
+ def test_check_output (self ):
295
+ self .check_output (check_pir = True )
296
+
281
297
282
298
class TestUnbindOp1_Complex128 (TestUnbindOp1 ):
283
299
def get_dtype (self ):
284
300
return np .complex128
285
301
302
+ def test_check_output (self ):
303
+ self .check_output (check_pir = True )
304
+
286
305
287
306
class TestUnbindOp2_Complex128 (TestUnbindOp2 ):
288
307
def get_dtype (self ):
289
308
return np .complex128
290
309
310
+ def test_check_output (self ):
311
+ self .check_output (check_pir = True )
312
+
291
313
292
314
class TestUnbindOp3_Complex128 (TestUnbindOp3 ):
293
315
def get_dtype (self ):
294
316
return np .complex128
295
317
318
+ def test_check_output (self ):
319
+ self .check_output (check_pir = True )
320
+
296
321
297
322
class TestUnbindOp4_Complex128 (TestUnbindOp4 ):
298
323
def get_dtype (self ):
299
324
return np .complex128
300
325
326
+ def test_check_output (self ):
327
+ self .check_output (check_pir = True )
328
+
301
329
302
330
class TestUnbindFP16Op (OpTest ):
303
331
def setUp (self ):
304
332
paddle .disable_static ()
305
333
self .op_type = "unbind"
334
+ self .prim_op_type = "comp"
306
335
self .python_api = paddle .unbind
336
+ self .public_python_api = paddle .unbind
307
337
self .dtype = self .get_dtype ()
308
338
self .axis = 0
309
339
self .num = 3
@@ -326,14 +356,16 @@ def get_dtype(self):
326
356
return np .float16
327
357
328
358
def test_check_output (self ):
329
- self .check_output (check_pir = True )
359
+ self .check_output (check_pir = True , check_prim_pir = True )
330
360
331
361
332
362
class TestUnbindBF16Op (OpTest ):
333
363
def setUp (self ):
334
364
paddle .disable_static ()
335
365
self ._set_op_type ()
366
+ self .prim_op_type = "comp"
336
367
self .python_api = paddle .unbind
368
+ self .public_python_api = paddle .unbind
337
369
self .dtype = self .get_dtype ()
338
370
self .axis = 0
339
371
self .num = 3
@@ -362,7 +394,7 @@ def _set_op_type(self):
362
394
self .op_type = "unbind"
363
395
364
396
def test_check_output (self ):
365
- self .check_output (check_pir = True )
397
+ self .check_output (check_pir = True , check_prim_pir = True )
366
398
367
399
def test_check_grad (self ):
368
400
pass
0 commit comments