Commit c22b7d9 1 parent 9340a3e commit c22b7d9 Copy full SHA for c22b7d9
File tree 1 file changed +13
-0
lines changed
1 file changed +13
-0
lines changed Original file line number Diff line number Diff line change @@ -76,6 +76,13 @@ def get_config(self):
76
76
config .update (self .group_transforms .get_config ())
77
77
return config
78
78
79
+ def compute_output_shape (self , input_shape ):
80
+ ga = self .group_transforms .group_axis
81
+ input_shape_normal = input_shape [:ga ] + input_shape [ga + 1 :]
82
+ output_shape = self .pooling .compute_output_shape (input_shape_normal )
83
+ go = self .group .order
84
+ return output_shape [:ga ] + (go ,) + output_shape [ga :]
85
+
79
86
80
87
class GlobalGroupPooling (Layer ):
81
88
"""
@@ -127,6 +134,12 @@ def get_config(self):
127
134
config = self .pooling .get_config ()
128
135
return config
129
136
137
+ def compute_output_shape (self , input_shape ):
138
+ ga = self .group_transforms .group_axis
139
+ input_shape_normal = input_shape [:ga ] + input_shape [ga + 1 :]
140
+ output_shape = self .pooling .compute_output_shape (input_shape_normal )
141
+ return output_shape [:ga ] + output_shape [ga :]
142
+
130
143
131
144
class GroupMaxPooling1D (GroupPooling ):
132
145
"""
You can’t perform that action at this time.
0 commit comments