Skip to content

Commit f730b55

Browse files
committed
Fix AddBias Tests and NHCW logic
1 parent d7ed7d6 commit f730b55

File tree

2 files changed

+25
-20
lines changed

2 files changed

+25
-20
lines changed

src/frontends/tensorflow_common/src/op/bias_add.cpp

+5-2
Original file line numberDiff line numberDiff line change
@@ -47,11 +47,14 @@ OutputVector translate_bias_add_op(const NodeContext& node) {
4747
TENSORFLOW_OP_VALIDATION(node,
4848
value_shape.rank().is_static(),
4949
"Value of dynamic rank for BiasAdd in NCHW layout is not supported.");
50-
auto value_rank = complex_type_inputs ? value_shape.rank().get_length() - 1 : value_shape.rank().get_length();
50+
auto value_rank = complex_type_inputs ? value_shape.rank().get_length() : value_shape.rank().get_length();
5151

5252
std::vector<int64_t> axes_unsqueeze;
5353
for (int64_t dim_ind = 0; dim_ind < value_rank; ++dim_ind) {
54-
if (dim_ind != 1) {
54+
if (!complex_type_inputs && dim_ind != 1) {
55+
axes_unsqueeze.push_back(dim_ind);
56+
}
57+
if (complex_type_inputs && dim_ind != 2){
5558
axes_unsqueeze.push_back(dim_ind);
5659
}
5760
}

tests/layer_tests/tensorflow_tests/test_tf_BiasAdd.py

+20-18
Original file line numberDiff line numberDiff line change
@@ -161,27 +161,29 @@ def _prepare_input(self, inputs_info):
161161
rng = np.random.default_rng()
162162
assert 'x_real:0' in inputs_info
163163
assert 'x_imag:0' in inputs_info
164-
x_real_shape = inputs_info['x_real:0']
165-
x_imag_shape = inputs_info['x_imag:0']
164+
assert 'y_real:0' in inputs_info
165+
assert 'y_imag:0' in inputs_info
166+
x_shape = inputs_info['x_real:0']
167+
y_shape = inputs_info['y_real:0']
166168
inputs_data = {}
167-
inputs_data['x_real:0'] = 4 * rng.random(x_real_shape).astype(np.float64) - 2
168-
inputs_data['x_imag:0'] = 4 * rng.random(x_imag_shape).astype(np.float64) - 2
169+
170+
inputs_data['x_real:0'] = 4 * rng.random(x_shape).astype(np.float64) - 2
171+
inputs_data['x_imag:0'] = 4 * rng.random(x_shape).astype(np.float64) - 2
172+
173+
inputs_data['y_real:0'] = 4 * rng.random(y_shape).astype(np.float64) - 2
174+
inputs_data['y_imag:0'] = 4 * rng.random(y_shape).astype(np.float64) - 2
175+
169176
return inputs_data
170177

171-
def create_complex_bias_add_net(self, shape, data_format, ir_version, use_legacy_frontend, output_type=tf.float64):
178+
def create_complex_bias_add_net(self, input_shape, bias_shape, data_format, ir_version, use_legacy_frontend, output_type=tf.float64):
172179
tf.compat.v1.reset_default_graph()
173180

174181
with tf.compat.v1.Session() as sess:
175-
x_real_shape = shape.copy()
176-
x_imag_shape = shape.copy()
177-
178-
x_real = tf.compat.v1.placeholder(output_type, x_real_shape, 'x_real')
179-
x_imag = tf.compat.v1.placeholder(output_type, x_imag_shape, 'x_imag')
182+
x_real = tf.compat.v1.placeholder(output_type, input_shape, 'x_real')
183+
x_imag = tf.compat.v1.placeholder(output_type, input_shape, 'x_imag')
180184

181-
constant_value_real = np.random.randint(-256, 256, x_real_shape[-1]).astype(output_type.as_numpy_dtype())
182-
constant_value_imag = np.random.randint(-256, 256, x_imag_shape[-1]).astype(output_type.as_numpy_dtype())
183-
y_real = tf.constant(constant_value_real)
184-
y_imag = tf.constant(constant_value_imag)
185+
y_real = tf.compat.v1.placeholder(output_type, bias_shape, 'y_real')
186+
y_imag = tf.compat.v1.placeholder(output_type, bias_shape, 'y_imag')
185187

186188
complex_input = tf.complex(x_real, x_imag)
187189
complex_bias = tf.complex(y_real, y_imag)
@@ -195,10 +197,10 @@ def create_complex_bias_add_net(self, shape, data_format, ir_version, use_legacy
195197
return tf_net, None
196198

197199
test_data_2D = [
198-
dict(shape=[1, 1], data_format="NHWC"),
199-
dict(shape=[1, 224], data_format="NHWC"),
200-
dict(shape=[1, 1], data_format="NCHW"),
201-
dict(shape=[1, 224], data_format="NCHW"),
200+
dict(shape=[1, 1], bias_shape=[1], data_format="NHWC"),
201+
dict(shape=[3, 2, 7], bias_shape=[7], data_format="NHWC"),
202+
dict(shape=[3, 2, 7, 10], bias_shape=[2], data_format="NCHW"),
203+
dict(shape=[7, 6, 4, 5], bias_shape=[6], data_format="NCHW"),
202204
]
203205

204206
@pytest.mark.parametrize("params", test_data_2D)

0 commit comments

Comments
 (0)