Skip to content

Commit 1f85195

Browse files
committed
Factor out shape generation
1 parent b917c83 commit 1f85195

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

tests/conv_test_base.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def test_lift_shape(self):
4545
for group in self.group_dict.values():
4646
conv_layer = self.generate_layer(group)
4747
signal_on_group = conv_layer(signal_on_grid)
48-
self.assertEqual(signal_on_group.shape, self.shape[:-1] + (group.order, self.filters))
48+
self.assertEqual(signal_on_group.shape, self.generate_shape(group))
4949

5050
def test_lift_shape_subgroup(self):
5151
signal_on_grid = self.generate_signal(None)
@@ -55,15 +55,15 @@ def test_lift_shape_subgroup(self):
5555
conv_layer = self.generate_layer(group, subgroup=subgroup_name)
5656

5757
signal_on_group = conv_layer(signal_on_grid)
58-
out_size = group.order if self.transpose else subgroup.order
59-
self.assertEqual(signal_on_group.shape, self.shape[:-1] + (out_size, self.filters))
58+
out_group = group if self.transpose else subgroup
59+
self.assertEqual(signal_on_group.shape, self.generate_shape(out_group))
6060

6161
def test_gc_shape(self):
6262
for group in self.group_dict.values():
6363
signal_on_group = self.generate_signal(group)
6464
conv_layer = self.generate_layer(group)
6565
new_signal = conv_layer(signal_on_group)
66-
self.assertEqual(new_signal.shape, signal_on_group.shape[:-1] + (self.filters,))
66+
self.assertEqual(new_signal.shape, self.generate_shape(group))
6767

6868
def test_gc_shape_subgroup(self):
6969
for group in self.group_dict.values():
@@ -73,8 +73,8 @@ def test_gc_shape_subgroup(self):
7373
conv_layer = self.generate_layer(group, subgroup=subgroup_name)
7474

7575
new_signal = conv_layer(signal)
76-
out_size = group.order if self.transpose else self.group_dict[subgroup_name].order
77-
self.assertEqual(new_signal.shape, self.shape[:-1] + (out_size, self.filters))
76+
out_group = group if self.transpose else self.group_dict[subgroup_name]
77+
self.assertEqual(new_signal.shape, self.generate_shape(out_group))
7878

7979
def test_lift_equiv(self):
8080
signal_on_grid = self.generate_signal(None)
@@ -127,14 +127,20 @@ def test_padding_equiv(self):
127127
conv_layer = self.generate_layer(group, padding=padding, strides=strides)
128128
self.check_equivariance(conv_layer, signal_on_group)
129129

130-
def generate_signal(self, group):
130+
def generate_shape(self, group, output=True):
131131
if type(group) == str:
132132
group = self.group_dict[group]
133133
if group == None:
134134
shape = self.shape
135135
else:
136136
shape = self.shape[:-1] + (group.order, self.shape[-1])
137-
return keras.random.normal(shape=shape, seed=42)
137+
138+
if output:
139+
shape = shape[:-1] + (self.filters,)
140+
return shape
141+
142+
def generate_signal(self, group):
143+
return keras.random.normal(shape=self.generate_shape(group, output=False), seed=42)
138144

139145
def generate_layer(self, group, padding="same_equiv", strides=1, subgroup=""):
140146
return self.conv(

0 commit comments

Comments
 (0)