@@ -51,3 +51,55 @@ def test_pack_basic(self, params, ie_device, precision, ir_version, temp_dir,
51
51
self ._test (* self .create_pack_net (** params ),
52
52
ie_device , precision , ir_version , temp_dir = temp_dir ,
53
53
use_legacy_frontend = use_legacy_frontend )
54
+
55
+
56
+ class TestComplexPack (CommonTFLayerTest ):
57
+ def _prepare_input (self , inputs_info ):
58
+ inputs_data = {}
59
+ for input_name , input_shape in inputs_info .items ():
60
+ inputs_data [input_name ] = np .random .randint (- 5 , 5 , input_shape ).astype (self .input_type )
61
+ return inputs_data
62
+
63
+ def create_pack_net (self , input_shape , input_num , axis , input_type ):
64
+ self .input_type = input_type
65
+ tf .compat .v1 .reset_default_graph ()
66
+ # Create the graph and model
67
+ with tf .compat .v1 .Session () as sess :
68
+ inputs = []
69
+ type_map = {
70
+ np .float32 : tf .float32 ,
71
+ np .int32 : tf .int32 ,
72
+ }
73
+ assert input_type in type_map , "Test error: need to update type_map"
74
+ tf_type = type_map [input_type ]
75
+ complex_inputs = []
76
+ for ind in range (input_num ):
77
+ input_real = tf .compat .v1 .placeholder (tf_type , input_shape , 'input' + str (ind ) + '_real' )
78
+ input_imag = tf .compat .v1 .placeholder (tf_type , input_shape , 'input' + str (ind ) + '_imag' )
79
+ inputs .append (input_real )
80
+ inputs .append (input_imag )
81
+
82
+ complex_inputs .append (tf .raw_ops .Complex (real = input_real , imag = input_imag ))
83
+ if axis is not None :
84
+ tf .raw_ops .Pack (values = complex_inputs , axis = axis )
85
+ else :
86
+ tf .raw_ops .Pack (values = complex_inputs )
87
+ tf .compat .v1 .global_variables_initializer ()
88
+
89
+ tf_net = sess .graph_def
90
+
91
+ return tf_net , None
92
+
93
+ test_data_basic = [
94
+ dict (input_shape = [2 , 4 ], input_num = 2 , axis = None , input_type = np .float32 ),
95
+ dict (input_shape = [3 , 1 , 2 ], input_num = 3 , axis = 1 , input_type = np .int32 ),
96
+ ]
97
+
98
+ @pytest .mark .parametrize ("params" , test_data_basic )
99
+ @pytest .mark .precommit_tf_fe
100
+ @pytest .mark .nightly
101
+ def test_pack_basic (self , params , ie_device , precision , ir_version , temp_dir ,
102
+ use_legacy_frontend ):
103
+ self ._test (* self .create_pack_net (** params ),
104
+ ie_device , precision , ir_version , temp_dir = temp_dir ,
105
+ use_legacy_frontend = use_legacy_frontend )
0 commit comments