Skip to content

Commit babe266

Browse files
committed
decouple group transforms from pooling
1 parent 40aaea8 commit babe266

File tree

1 file changed

+2
-8
lines changed

1 file changed

+2
-8
lines changed

groco/layers/pooling.py

+2-8
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def __init__(
4646
**kwargs
4747
)
4848
kwargs["padding"] = self.group_transforms.built_in_padding_option
49-
self.group = self.group_transforms.group
50-
self.subgroup = self.group_transforms.subgroup
5149

5250
self.pool_type = pool_type
5351
if self.dimensions == 1:
@@ -58,8 +56,6 @@ def __init__(
5856
pool_layer = MaxPooling3D if self.pool_type == "max" else AveragePooling3D
5957
self.pooling = pool_layer(pool_size=pool_size, **kwargs)
6058

61-
self.pooling_indices = None # created during build
62-
6359
def call(self, inputs):
6460
inputs = self.group_transforms.subgroup_pooling(inputs, self.pool_type)
6561
inputs = self.group_transforms.merge_group_axis_and_pad(inputs)
@@ -69,7 +65,7 @@ def call(self, inputs):
6965
def build(self, input_shape):
7066
self.group_transforms.build(input_shape)
7167
self.pooling.build(self.group_transforms.reshaped_input)
72-
self.pooling_indices = self.group_transforms.build_pool()
68+
self.group_transforms.build_pool()
7369

7470
def get_config(self):
7571
config = self.pooling.get_config()
@@ -99,8 +95,6 @@ def __init__(self, dimensions: int, pool_type: str, **kwargs):
9995
else:
10096
self.group_axis = self.dimensions + 1
10197

102-
self.pooling_indices = None # created during build
103-
10498
def call(self, inputs):
10599
inputs = self.pool_group(inputs)
106100
outputs = self.pooling(inputs)
@@ -121,7 +115,7 @@ def restore_group_axis(self, outputs):
121115
def build(self, input_shape):
122116
reshaped_input = self.group_transforms.build(input_shape)
123117
self.pooling.build(reshaped_input)
124-
self.pooling_indices = self.group_transforms.build_pool()
118+
self.group_transforms.build_pool()
125119

126120
def get_config(self):
127121
config = self.pooling.get_config()

0 commit comments

Comments
 (0)