Skip to content

Commit fef142d

Browse files
committed
Fix failing test
1 parent 0ae6e24 commit fef142d

File tree

1 file changed

+31
-25
lines changed

1 file changed

+31
-25
lines changed

tests/layer_tests/tensorflow2_keras_tests/test_tf2_map_fn.py

+31-25
Original file line numberDiff line numberDiff line change
@@ -6,38 +6,44 @@
66

77
from common.tf2_layer_test_class import CommonTF2LayerTest
88

9-
10-
class MapFNLayer(tf.keras.layers.Layer):
11-
def __init__(self, fn, input_type, fn_output_signature, back_prop):
12-
super(MapFNLayer, self).__init__()
13-
self.fn = fn
14-
self.input_type = input_type
15-
self.fn_output_signature = fn_output_signature
16-
self.back_prop = back_prop
17-
18-
def call(self, x):
19-
return tf.map_fn(self.fn, x, dtype=self.input_type,
20-
fn_output_signature=self.fn_output_signature,
21-
back_prop=self.back_prop)
22-
239
def fn_1(x):
2410
return (x[0] * x[1] + x[2])
2511

2612
def fn_2(x):
2713
return (x[0] + x[1] + x[2], x[0] - x[2] + x[1], 2 + x[2])
2814

2915
def fn_3(x):
30-
return (x[0] * x[1], x[0] + x[1])
16+
return (x[0] * x[1])
3117

3218
def fn_4(x):
33-
return (x[0] * x[1] + x[2], x[0] + x[2] * x[1], 2 * x[2])
19+
return (x[0] * x[1] + 2 * x[2])
3420

3521
def fn_5(x):
36-
return (x[0] * x[1] + x[2])
22+
return (x[0] * x[1], x[0] + x[1])
3723

3824
def fn_6(x):
25+
return (x[0] * x[1] + x[2], x[0] + x[2] * x[1], 2 * x[2])
26+
27+
def fn_7(x):
28+
return (x[0] * x[1] + x[2])
29+
30+
def fn_8(x):
3931
return (x[0] + x[1] + x[2], x[0] - x[2] + x[1], 2 + x[2])
4032

33+
list_fns = [fn_1, fn_2, fn_3, fn_4, fn_5, fn_6, fn_7, fn_8]
34+
35+
class MapFNLayer(tf.keras.layers.Layer):
36+
def __init__(self, fn, input_type, fn_output_signature, back_prop):
37+
super(MapFNLayer, self).__init__()
38+
self.fn = list_fns[fn-1]
39+
self.input_type = input_type
40+
self.fn_output_signature = fn_output_signature
41+
self.back_prop = back_prop
42+
43+
def call(self, x):
44+
return tf.map_fn(self.fn, x, dtype=self.input_type,
45+
fn_output_signature=self.fn_output_signature,
46+
back_prop=self.back_prop)
4147

4248
class TestMapFN(CommonTF2LayerTest):
4349
def create_map_fn_net(self, fn, input_type, fn_output_signature, back_prop,
@@ -57,10 +63,10 @@ def create_map_fn_net(self, fn, input_type, fn_output_signature, back_prop,
5763
return tf2_net, ref_net
5864

5965
test_basic = [
60-
dict(fn=fn_1, input_type=tf.float32,
66+
dict(fn=1, input_type=tf.float32,
6167
fn_output_signature=tf.float32, back_prop=False,
6268
input_names=["x1", "x2", "x3"], input_shapes=[[2, 3, 4], [2, 3, 4], [2, 3, 4]]),
63-
pytest.param(dict(fn=fn_2,
69+
pytest.param(dict(fn=2,
6470
input_type=tf.float32,
6571
fn_output_signature=(tf.float32, tf.float32, tf.float32), back_prop=True,
6672
input_names=["x1", "x2", "x3"],
@@ -77,10 +83,10 @@ def test_basic(self, params, ie_device, precision, ir_version, temp_dir, use_leg
7783
**params)
7884

7985
test_multiple_inputs = [
80-
dict(fn=lambda x: x[0] * x[1], input_type=tf.float32,
86+
dict(fn=3, input_type=tf.float32,
8187
fn_output_signature=tf.float32, back_prop=True,
8288
input_names=["x1", "x2"], input_shapes=[[2, 4], [2, 4]]),
83-
dict(fn=lambda x: x[0] * x[1] + 2 * x[2], input_type=tf.float32,
89+
dict(fn=4, input_type=tf.float32,
8490
fn_output_signature=tf.float32, back_prop=False,
8591
input_names=["x1", "x2", "x3"], input_shapes=[[2, 1, 3, 4],
8692
[2, 1, 3, 4],
@@ -95,11 +101,11 @@ def test_multiple_inputs(self, params, ie_device, precision, ir_version, temp_di
95101
**params)
96102

97103
test_multiple_outputs = [
98-
pytest.param(dict(fn=fn_3, input_type=tf.float32,
104+
pytest.param(dict(fn=5, input_type=tf.float32,
99105
fn_output_signature=(tf.float32, tf.float32), back_prop=True,
100106
input_names=["x1", "x2"], input_shapes=[[2, 4], [2, 4]]),
101107
marks=pytest.mark.xfail(reason="61587")),
102-
pytest.param(dict(fn=fn_4,
108+
pytest.param(dict(fn=6,
103109
input_type=tf.float32,
104110
fn_output_signature=(tf.float32, tf.float32, tf.float32), back_prop=True,
105111
input_names=["x1", "x2", "x3"],
@@ -115,12 +121,12 @@ def test_multiple_outputs(self, params, ie_device, precision, ir_version, temp_d
115121
**params)
116122

117123
test_multiple_inputs_outputs_int32 = [
118-
dict(fn=fn_5,
124+
dict(fn=7,
119125
input_type=tf.int32,
120126
fn_output_signature=tf.int32, back_prop=True,
121127
input_names=["x1", "x2", "x3"],
122128
input_shapes=[[2, 1, 3], [2, 1, 3], [2, 1, 3]]),
123-
pytest.param(dict(fn=fn_6,
129+
pytest.param(dict(fn=8,
124130
input_type=tf.int32,
125131
fn_output_signature=(tf.int32, tf.int32, tf.int32), back_prop=True,
126132
input_names=["x1", "x2", "x3"],

0 commit comments

Comments
 (0)