6
6
7
7
from common .tf2_layer_test_class import CommonTF2LayerTest
8
8
9
-
10
- class MapFNLayer (tf .keras .layers .Layer ):
11
- def __init__ (self , fn , input_type , fn_output_signature , back_prop ):
12
- super (MapFNLayer , self ).__init__ ()
13
- self .fn = fn
14
- self .input_type = input_type
15
- self .fn_output_signature = fn_output_signature
16
- self .back_prop = back_prop
17
-
18
- def call (self , x ):
19
- return tf .map_fn (self .fn , x , dtype = self .input_type ,
20
- fn_output_signature = self .fn_output_signature ,
21
- back_prop = self .back_prop )
22
-
23
9
def fn_1 (x ):
24
10
return (x [0 ] * x [1 ] + x [2 ])
25
11
26
12
def fn_2 (x ):
27
13
return (x [0 ] + x [1 ] + x [2 ], x [0 ] - x [2 ] + x [1 ], 2 + x [2 ])
28
14
29
15
def fn_3 (x ):
30
- return (x [0 ] * x [1 ], x [ 0 ] + x [ 1 ] )
16
+ return (x [0 ] * x [1 ])
31
17
32
18
def fn_4 (x ):
33
- return (x [0 ] * x [1 ] + x [ 2 ], x [ 0 ] + x [ 2 ] * x [ 1 ], 2 * x [2 ])
19
+ return (x [0 ] * x [1 ] + 2 * x [2 ])
34
20
35
21
def fn_5 (x ):
36
- return (x [0 ] * x [1 ] + x [2 ])
22
+ return (x [0 ] * x [1 ], x [ 0 ] + x [1 ])
37
23
38
24
def fn_6 (x ):
25
+ return (x [0 ] * x [1 ] + x [2 ], x [0 ] + x [2 ] * x [1 ], 2 * x [2 ])
26
+
27
+ def fn_7 (x ):
28
+ return (x [0 ] * x [1 ] + x [2 ])
29
+
30
+ def fn_8 (x ):
39
31
return (x [0 ] + x [1 ] + x [2 ], x [0 ] - x [2 ] + x [1 ], 2 + x [2 ])
40
32
33
+ list_fns = [fn_1 , fn_2 , fn_3 , fn_4 , fn_5 , fn_6 , fn_7 , fn_8 ]
34
+
35
+ class MapFNLayer (tf .keras .layers .Layer ):
36
+ def __init__ (self , fn , input_type , fn_output_signature , back_prop ):
37
+ super (MapFNLayer , self ).__init__ ()
38
+ self .fn = list_fns [fn - 1 ]
39
+ self .input_type = input_type
40
+ self .fn_output_signature = fn_output_signature
41
+ self .back_prop = back_prop
42
+
43
+ def call (self , x ):
44
+ return tf .map_fn (self .fn , x , dtype = self .input_type ,
45
+ fn_output_signature = self .fn_output_signature ,
46
+ back_prop = self .back_prop )
41
47
42
48
class TestMapFN (CommonTF2LayerTest ):
43
49
def create_map_fn_net (self , fn , input_type , fn_output_signature , back_prop ,
@@ -57,10 +63,10 @@ def create_map_fn_net(self, fn, input_type, fn_output_signature, back_prop,
57
63
return tf2_net , ref_net
58
64
59
65
test_basic = [
60
- dict (fn = fn_1 , input_type = tf .float32 ,
66
+ dict (fn = 1 , input_type = tf .float32 ,
61
67
fn_output_signature = tf .float32 , back_prop = False ,
62
68
input_names = ["x1" , "x2" , "x3" ], input_shapes = [[2 , 3 , 4 ], [2 , 3 , 4 ], [2 , 3 , 4 ]]),
63
- pytest .param (dict (fn = fn_2 ,
69
+ pytest .param (dict (fn = 2 ,
64
70
input_type = tf .float32 ,
65
71
fn_output_signature = (tf .float32 , tf .float32 , tf .float32 ), back_prop = True ,
66
72
input_names = ["x1" , "x2" , "x3" ],
@@ -77,10 +83,10 @@ def test_basic(self, params, ie_device, precision, ir_version, temp_dir, use_leg
77
83
** params )
78
84
79
85
test_multiple_inputs = [
80
- dict (fn = lambda x : x [ 0 ] * x [ 1 ] , input_type = tf .float32 ,
86
+ dict (fn = 3 , input_type = tf .float32 ,
81
87
fn_output_signature = tf .float32 , back_prop = True ,
82
88
input_names = ["x1" , "x2" ], input_shapes = [[2 , 4 ], [2 , 4 ]]),
83
- dict (fn = lambda x : x [ 0 ] * x [ 1 ] + 2 * x [ 2 ] , input_type = tf .float32 ,
89
+ dict (fn = 4 , input_type = tf .float32 ,
84
90
fn_output_signature = tf .float32 , back_prop = False ,
85
91
input_names = ["x1" , "x2" , "x3" ], input_shapes = [[2 , 1 , 3 , 4 ],
86
92
[2 , 1 , 3 , 4 ],
@@ -95,11 +101,11 @@ def test_multiple_inputs(self, params, ie_device, precision, ir_version, temp_di
95
101
** params )
96
102
97
103
test_multiple_outputs = [
98
- pytest .param (dict (fn = fn_3 , input_type = tf .float32 ,
104
+ pytest .param (dict (fn = 5 , input_type = tf .float32 ,
99
105
fn_output_signature = (tf .float32 , tf .float32 ), back_prop = True ,
100
106
input_names = ["x1" , "x2" ], input_shapes = [[2 , 4 ], [2 , 4 ]]),
101
107
marks = pytest .mark .xfail (reason = "61587" )),
102
- pytest .param (dict (fn = fn_4 ,
108
+ pytest .param (dict (fn = 6 ,
103
109
input_type = tf .float32 ,
104
110
fn_output_signature = (tf .float32 , tf .float32 , tf .float32 ), back_prop = True ,
105
111
input_names = ["x1" , "x2" , "x3" ],
@@ -115,12 +121,12 @@ def test_multiple_outputs(self, params, ie_device, precision, ir_version, temp_d
115
121
** params )
116
122
117
123
test_multiple_inputs_outputs_int32 = [
118
- dict (fn = fn_5 ,
124
+ dict (fn = 7 ,
119
125
input_type = tf .int32 ,
120
126
fn_output_signature = tf .int32 , back_prop = True ,
121
127
input_names = ["x1" , "x2" , "x3" ],
122
128
input_shapes = [[2 , 1 , 3 ], [2 , 1 , 3 ], [2 , 1 , 3 ]]),
123
- pytest .param (dict (fn = fn_6 ,
129
+ pytest .param (dict (fn = 8 ,
124
130
input_type = tf .int32 ,
125
131
fn_output_signature = (tf .int32 , tf .int32 , tf .int32 ), back_prop = True ,
126
132
input_names = ["x1" , "x2" , "x3" ],
0 commit comments