Skip to content

Commit c22b7d9

Browse files
committed
Add output shape to pooling layers
1 parent 9340a3e commit c22b7d9

File tree

1 file changed

+13
-0
lines changed

1 file changed

+13
-0
lines changed

groco/layers/pooling.py

+13
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def get_config(self):
7676
config.update(self.group_transforms.get_config())
7777
return config
7878

79+
def compute_output_shape(self, input_shape):
80+
ga = self.group_transforms.group_axis
81+
input_shape_normal = input_shape[:ga] + input_shape[ga + 1 :]
82+
output_shape = self.pooling.compute_output_shape(input_shape_normal)
83+
go = self.group.order
84+
return output_shape[:ga] + (go,) + output_shape[ga:]
85+
7986

8087
class GlobalGroupPooling(Layer):
8188
"""
@@ -127,6 +134,12 @@ def get_config(self):
127134
config = self.pooling.get_config()
128135
return config
129136

137+
def compute_output_shape(self, input_shape):
138+
ga = self.group_transforms.group_axis
139+
input_shape_normal = input_shape[:ga] + input_shape[ga + 1 :]
140+
output_shape = self.pooling.compute_output_shape(input_shape_normal)
141+
return output_shape[:ga] + output_shape[ga:]
142+
130143

131144
class GroupMaxPooling1D(GroupPooling):
132145
"""

0 commit comments

Comments
 (0)