Skip to content

Commit 2517c8a

Browse files
committed
Support complex tensors for Pack operation
1 parent c208a88 commit 2517c8a

File tree

2 files changed

+66
-2
lines changed

2 files changed

+66
-2
lines changed

src/frontends/tensorflow_common/src/op/pack.cpp

+14-2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "openvino/op/concat.hpp"
77
#include "openvino/op/constant.hpp"
88
#include "openvino/op/unsqueeze.hpp"
9+
#include "helper_ops/complex_type_mark.hpp"
910

1011
using namespace std;
1112
using namespace ov::op;
@@ -16,20 +17,31 @@ namespace tensorflow {
1617
namespace op {
1718

1819
OutputVector translate_pack_op(const NodeContext& node) {
19-
default_op_checks(node, 1, {"Pack", "PACK"});
20-
auto num_size = static_cast<int>(node.get_input_size());
20+
default_op_checks(node, 1, {"Pack", "PACK"}, true);
2121

2222
auto axis = node.get_attribute<int64_t>("axis", 0);
23+
24+
auto num_size = static_cast<int>(node.get_input_size());
25+
auto complex_type_mark = as_type_ptr<ComplexTypeMark>(node.get_input(0).get_node_shared_ptr());
2326
auto axis_const = make_shared<v0::Constant>(element::i64, Shape{}, axis);
2427

28+
if (complex_type_mark) {
29+
axis_const = make_shared<v0::Constant>(element::i64, Shape{}, axis + 1);
30+
}
31+
2532
OutputVector concat_inputs;
2633
for (int ind = 0; ind < num_size; ++ind) {
2734
auto in = node.get_input(ind);
2835
concat_inputs.push_back(make_shared<v0::Unsqueeze>(in, axis_const));
2936
}
3037

3138
auto pack = make_shared<v0::Concat>(concat_inputs, axis);
39+
3240
set_node_name(node.get_name(), pack);
41+
if (complex_type_mark) {
42+
auto complex_result = make_shared<ComplexTypeMark>(pack, complex_type_mark->get_complex_part_type());
43+
return {complex_result};
44+
}
3345
return {pack};
3446
}
3547
} // namespace op

tests/layer_tests/tensorflow_tests/test_tf_Pack.py

+52
Original file line numberDiff line numberDiff line change
@@ -51,3 +51,55 @@ def test_pack_basic(self, params, ie_device, precision, ir_version, temp_dir,
5151
self._test(*self.create_pack_net(**params),
5252
ie_device, precision, ir_version, temp_dir=temp_dir,
5353
use_legacy_frontend=use_legacy_frontend)
54+
55+
56+
class TestComplexPack(CommonTFLayerTest):
57+
def _prepare_input(self, inputs_info):
58+
inputs_data = {}
59+
for input_name, input_shape in inputs_info.items():
60+
inputs_data[input_name] = np.random.randint(-5, 5, input_shape).astype(self.input_type)
61+
return inputs_data
62+
63+
def create_pack_net(self, input_shape, input_num, axis, input_type):
64+
self.input_type = input_type
65+
tf.compat.v1.reset_default_graph()
66+
# Create the graph and model
67+
with tf.compat.v1.Session() as sess:
68+
inputs = []
69+
type_map = {
70+
np.float32: tf.float32,
71+
np.int32: tf.int32,
72+
}
73+
assert input_type in type_map, "Test error: need to update type_map"
74+
tf_type = type_map[input_type]
75+
complex_inputs = []
76+
for ind in range(input_num):
77+
input_real = tf.compat.v1.placeholder(tf_type, input_shape, 'input' + str(ind) + '_real')
78+
input_imag = tf.compat.v1.placeholder(tf_type, input_shape, 'input' + str(ind) + '_imag')
79+
inputs.append(input_real)
80+
inputs.append(input_imag)
81+
82+
complex_inputs.append(tf.raw_ops.Complex(real=input_real, imag=input_imag))
83+
if axis is not None:
84+
tf.raw_ops.Pack(values=complex_inputs, axis=axis)
85+
else:
86+
tf.raw_ops.Pack(values=complex_inputs)
87+
tf.compat.v1.global_variables_initializer()
88+
89+
tf_net = sess.graph_def
90+
91+
return tf_net, None
92+
93+
test_data_basic = [
94+
dict(input_shape=[2, 4], input_num=2, axis=None, input_type=np.float32),
95+
dict(input_shape=[3, 1, 2], input_num=3, axis=1, input_type=np.int32),
96+
]
97+
98+
@pytest.mark.parametrize("params", test_data_basic)
99+
@pytest.mark.precommit_tf_fe
100+
@pytest.mark.nightly
101+
def test_pack_basic(self, params, ie_device, precision, ir_version, temp_dir,
102+
use_legacy_frontend):
103+
self._test(*self.create_pack_net(**params),
104+
ie_device, precision, ir_version, temp_dir=temp_dir,
105+
use_legacy_frontend=use_legacy_frontend)

0 commit comments

Comments
 (0)