@@ -54,3 +54,67 @@ def test_cumsum_basic(self, params, exclusive, reverse, ie_device, precision, ir
54
54
self ._test (* self .create_cumsum_net (** params , exclusive = exclusive , reverse = reverse ),
55
55
ie_device , precision , ir_version , temp_dir = temp_dir ,
56
56
use_legacy_frontend = use_legacy_frontend )
57
+
58
+
59
+ class TestComplexCumsum (CommonTFLayerTest ):
60
+ # input_shape - should be an array
61
+ # axis - array which points on axis for the operation
62
+ # exclusive - enables exclusive Cumsum
63
+ # reverse - enables reverse order of Cumsum
64
+ def _prepare_input (self , inputs_info ):
65
+ rng = np .random .default_rng ()
66
+ assert 'x_real:0' in inputs_info
67
+ assert 'x_imag:0' in inputs_info
68
+ x_shape = inputs_info ['x_real:0' ]
69
+ inputs_data = {}
70
+
71
+ inputs_data ['x_real:0' ] = 4 * rng .random (x_shape ).astype (np .float64 ) - 2
72
+ inputs_data ['x_imag:0' ] = 4 * rng .random (x_shape ).astype (np .float64 ) - 2
73
+
74
+ return inputs_data
75
+
76
+ def create_cumsum_net (self , input_shape , axis , exclusive , reverse ):
77
+ import tensorflow as tf
78
+
79
+ tf .compat .v1 .reset_default_graph ()
80
+
81
+ # Create the graph and model
82
+ with tf .compat .v1 .Session () as sess :
83
+ x_real = tf .compat .v1 .placeholder (tf .float32 , input_shape , 'x_real' )
84
+ x_imag = tf .compat .v1 .placeholder (tf .float32 , input_shape , 'x_imag' )
85
+
86
+ complex_input = tf .complex (x_real , x_imag )
87
+
88
+ tf_axis = tf .constant (axis , dtype = tf .int32 )
89
+ result = tf .raw_ops .Cumsum (x = complex_input , axis = tf_axis , exclusive = exclusive , reverse = reverse )
90
+
91
+ tf .compat .v1 .global_variables_initializer ()
92
+ real = tf .raw_ops .Real (input = result )
93
+ img = tf .raw_ops .Imag (input = result )
94
+
95
+ tf_net = sess .graph_def
96
+
97
+ ref_net = None
98
+
99
+ return tf_net , ref_net
100
+
101
+ test_data = [
102
+ dict (input_shape = [2 ], axis = - 1 ),
103
+ dict (input_shape = [2 , 3 ], axis = 0 ),
104
+ dict (input_shape = [2 , 3 ], axis = 1 ),
105
+ dict (input_shape = [2 , 3 ], axis = - 2 ),
106
+ dict (input_shape = [2 , 3 , 3 , 4 ], axis = 2 ),
107
+ dict (input_shape = [2 , 3 , 3 , 4 ], axis = - 3 ),
108
+ ]
109
+
110
+ @pytest .mark .parametrize ("params" , test_data )
111
+ @pytest .mark .parametrize ("exclusive" , [False , True , None ])
112
+ @pytest .mark .parametrize ("reverse" , [False , True , None ])
113
+ @pytest .mark .precommit
114
+ @pytest .mark .precommit_tf_fe
115
+ @pytest .mark .nightly
116
+ def test_cumsum_basic (self , params , exclusive , reverse , ie_device , precision , ir_version , temp_dir ,
117
+ use_legacy_frontend ):
118
+ self ._test (* self .create_cumsum_net (** params , exclusive = exclusive , reverse = reverse ),
119
+ ie_device , precision , ir_version , temp_dir = temp_dir ,
120
+ use_legacy_frontend = use_legacy_frontend )
0 commit comments