Skip to content

Commit b739ff0

Browse files
[Prim][PIR] support unbind op forward in prim pir (#64430)
* update unbind * fix size_t * update dynamic test * update unbind * add assert * Update test_unbind_op.py * prim test change * fix code
1 parent 7725a4d commit b739ff0

File tree

5 files changed

+75
-4
lines changed

5 files changed

+75
-4
lines changed

paddle/fluid/pir/dialect/op_generator/decomp_interface_gen_op_list.py

+2
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
"squeeze",
6161
"stack",
6262
"unsqueeze",
63+
"unbind",
6364
"huber_loss",
6465
]
6566

@@ -103,6 +104,7 @@
103104
"squeeze",
104105
"stack",
105106
"unsqueeze",
107+
"unbind",
106108
"huber_loss",
107109
]
108110

paddle/fluid/primitive/composite/composite.h

+17
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,23 @@ std::vector<Tensor> meshgrid_decomp(const std::vector<Tensor>& x) {
628628
return res;
629629
}
630630

631+
template <typename T>
632+
std::vector<Tensor> unbind_decomp(const Tensor x, int axis) {
633+
std::vector<Tensor> res;
634+
if (axis < 0) {
635+
axis = x.shape().size() + axis;
636+
}
637+
if (x.shape()[axis] == -1) {
638+
PADDLE_THROW(phi::errors::Unimplemented("unbind axis must not be dynamic"));
639+
}
640+
size_t num = x.shape()[axis];
641+
std::vector<Tensor> tmp = backend::split_with_num<T>(x, num, axis);
642+
for (size_t i = 0; i < tmp.size(); i++) {
643+
res.push_back(squeeze<T>(tmp[i], {axis}));
644+
}
645+
return res;
646+
}
647+
631648
template <typename T>
632649
std::tuple<Tensor, Tensor, Tensor> layer_norm_decomp(
633650
const Tensor& x,

test/deprecated/legacy_test/test_unbind_op.py

+36-4
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,7 @@ def setAxis(self):
167167

168168
def setUp(self):
169169
self._set_op_type()
170+
self.prim_op_type = "comp"
170171
self.dtype = self.get_dtype()
171172
self.axis = 0
172173
self.num = 3
@@ -186,6 +187,7 @@ def setUp(self):
186187
'Out': [('out%d' % i, self.out[i]) for i in range(len(self.out))]
187188
}
188189
self.python_api = paddle.unbind
190+
self.public_python_api = paddle.unbind
189191
self.python_out_sig = ['out%d' % i for i in range(len(self.out))]
190192

191193
def get_dtype(self):
@@ -195,10 +197,12 @@ def _set_op_type(self):
195197
self.op_type = "unbind"
196198

197199
def test_check_output(self):
198-
self.check_output(check_pir=True)
200+
self.check_output(check_pir=True, check_prim_pir=True)
199201

200202
def test_check_grad(self):
201-
self.check_grad(['X'], ['out0', 'out1', 'out2'], check_pir=True)
203+
self.check_grad(
204+
['X'], ['out0', 'out1', 'out2'], check_pir=True, check_prim_pir=True
205+
)
202206

203207

204208
class TestUnbindOp1(TestUnbindOp):
@@ -263,47 +267,73 @@ class TestUnbindOp1_Complex64(TestUnbindOp1):
263267
def get_dtype(self):
264268
return np.complex64
265269

270+
def test_check_output(self):
271+
self.check_output(check_pir=True)
272+
266273

267274
class TestUnbindOp2_Complex64(TestUnbindOp2):
268275
def get_dtype(self):
269276
return np.complex64
270277

278+
def test_check_output(self):
279+
self.check_output(check_pir=True)
280+
271281

272282
class TestUnbindOp3_Complex64(TestUnbindOp3):
273283
def get_dtype(self):
274284
return np.complex64
275285

286+
def test_check_output(self):
287+
self.check_output(check_pir=True)
288+
276289

277290
class TestUnbindOp4_Complex64(TestUnbindOp4):
278291
def get_dtype(self):
279292
return np.complex64
280293

294+
def test_check_output(self):
295+
self.check_output(check_pir=True)
296+
281297

282298
class TestUnbindOp1_Complex128(TestUnbindOp1):
283299
def get_dtype(self):
284300
return np.complex128
285301

302+
def test_check_output(self):
303+
self.check_output(check_pir=True)
304+
286305

287306
class TestUnbindOp2_Complex128(TestUnbindOp2):
288307
def get_dtype(self):
289308
return np.complex128
290309

310+
def test_check_output(self):
311+
self.check_output(check_pir=True)
312+
291313

292314
class TestUnbindOp3_Complex128(TestUnbindOp3):
293315
def get_dtype(self):
294316
return np.complex128
295317

318+
def test_check_output(self):
319+
self.check_output(check_pir=True)
320+
296321

297322
class TestUnbindOp4_Complex128(TestUnbindOp4):
298323
def get_dtype(self):
299324
return np.complex128
300325

326+
def test_check_output(self):
327+
self.check_output(check_pir=True)
328+
301329

302330
class TestUnbindFP16Op(OpTest):
303331
def setUp(self):
304332
paddle.disable_static()
305333
self.op_type = "unbind"
334+
self.prim_op_type = "comp"
306335
self.python_api = paddle.unbind
336+
self.public_python_api = paddle.unbind
307337
self.dtype = self.get_dtype()
308338
self.axis = 0
309339
self.num = 3
@@ -326,14 +356,16 @@ def get_dtype(self):
326356
return np.float16
327357

328358
def test_check_output(self):
329-
self.check_output(check_pir=True)
359+
self.check_output(check_pir=True, check_prim_pir=True)
330360

331361

332362
class TestUnbindBF16Op(OpTest):
333363
def setUp(self):
334364
paddle.disable_static()
335365
self._set_op_type()
366+
self.prim_op_type = "comp"
336367
self.python_api = paddle.unbind
368+
self.public_python_api = paddle.unbind
337369
self.dtype = self.get_dtype()
338370
self.axis = 0
339371
self.num = 3
@@ -362,7 +394,7 @@ def _set_op_type(self):
362394
self.op_type = "unbind"
363395

364396
def test_check_output(self):
365-
self.check_output(check_pir=True)
397+
self.check_output(check_pir=True, check_prim_pir=True)
366398

367399
def test_check_grad(self):
368400
pass

test/ir/pir/cinn/symbolic/test_infer_sym_shape_unary_op.py

+3
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import paddle
2626
import paddle.nn.functional as F
27+
from paddle.framework import core
2728
from paddle.static import InputSpec
2829

2930
sys.path.append(dirname(dirname(__file__)))
@@ -810,6 +811,8 @@ def prepare_data(self):
810811
]
811812

812813
def test_eval_symbolic(self):
814+
core._set_prim_forward_blacklist("pd_op.unbind")
815+
813816
net = UnbindNet()
814817

815818
for i in range(len(self.cases)):

test/prim/pir_prim/test_prim_sub_graph_dynamic_shape.py

+17
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,10 @@ def meshgrid_net(x, y):
172172
return paddle.meshgrid(x, y)
173173

174174

175+
def unbind_net(x):
176+
return paddle.unbind(x)
177+
178+
175179
class TestPrimBase(unittest.TestCase):
176180
def setUp(self):
177181
np.random.seed(2023)
@@ -247,6 +251,19 @@ def setUp(self):
247251
self.tol = 1e-6
248252

249253

254+
class TestUnbind(TestPrimBase):
255+
def setUp(self):
256+
np.random.seed(2023)
257+
self.dtype = "float32"
258+
self.x_shape = [4, 5, 6]
259+
self.init_x_shape = [4, 5, None]
260+
self.x = np.random.random(self.x_shape).astype(self.dtype)
261+
self.net = unbind_net
262+
self.necessary_ops = "pd_op.unbind"
263+
self.enable_cinn = False
264+
self.tol = 1e-6
265+
266+
250267
class TestPrimFullLike(TestPrimBase):
251268
def setUp(self):
252269
np.random.seed(2023)

0 commit comments

Comments
 (0)