Skip to content

Commit b2e7395

Browse files
committed
refactor prepare_call
1 parent 5c0fff7 commit b2e7395

File tree

5 files changed

+7
-7
lines changed

5 files changed

+7
-7
lines changed

groco/layers/conv2d.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,11 @@ def __init__(
6060
self.subgroup = self.group_transforms.subgroup
6161

6262
self.group_valued_input = None
63-
self.group_order = self.group.order if transpose else self.subgroup.order
6463

6564
@backup_and_restore(("kernel", "bias", "filters"))
6665
def call(self, inputs):
6766
self.kernel, self.bias, self.filters, inputs = self.group_transforms.prepare_call(
68-
self.kernel, self.bias, self.filters, inputs, self.use_bias, self.group_order
67+
self.kernel, self.bias, self.filters, inputs, self.use_bias
6968
)
7069
outputs = super().call(inputs)
7170
return self.group_transforms.restore_group_axis(outputs)

groco/layers/conv2d_transpose.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ def __init__(
6262
@backup_and_restore(("kernel", "bias", "filters"))
6363
def call(self, inputs):
6464
self.kernel, self.bias, self.filters, inputs = self.group_transforms.prepare_call(
65-
self.kernel, self.bias, self.filters, inputs, self.use_bias, self.group_order
65+
self.kernel, self.bias, self.filters, inputs, self.use_bias
6666
)
6767
outputs = super().call(inputs)
6868
return self.group_transforms.restore_group_axis(outputs)

groco/layers/conv3d.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __init__(
6565
@backup_and_restore(("kernel", "bias", "filters"))
6666
def call(self, inputs):
6767
self.kernel, self.bias, self.filters, inputs = self.group_transforms.prepare_call(
68-
self.kernel, self.bias, self.filters, inputs, self.use_bias, self.group_order
68+
self.kernel, self.bias, self.filters, inputs, self.use_bias
6969
)
7070
outputs = super().call(inputs)
7171
return self.group_transforms.restore_group_axis(outputs)

groco/layers/conv3d_transpose.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def __init__(
6868
@backup_and_restore(("kernel", "bias", "filters"))
6969
def call(self, inputs):
7070
self.kernel, self.bias, self.filters, inputs = self.group_transforms.prepare_call(
71-
self.kernel, self.bias, self.filters, inputs, self.use_bias, self.group_order
71+
self.kernel, self.bias, self.filters, inputs, self.use_bias
7272
)
7373
outputs = super().call(inputs)
7474
return self.group_transforms.restore_group_axis(outputs)

groco/layers/group_transforms.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -243,10 +243,11 @@ def get_config(self):
243243
config.update(self.equivariant_padding.get_config())
244244
return config
245245

246-
def prepare_call(self, kernel, bias, filters, inputs, use_bias, group_order):
246+
def prepare_call(self, kernel, bias, filters, inputs, use_bias):
247247
inputs = self.merge_group_axis_and_pad(inputs)
248248
kernel = self.transform_kernel(kernel)
249249
if use_bias:
250250
bias = self.repeat_bias(bias)
251-
filters *= group_order
251+
factor = len(self.group.subgroup[self.acting_group])
252+
filters *= factor
252253
return kernel, bias, filters, inputs

0 commit comments

Comments
 (0)