@@ -45,7 +45,7 @@ def test_lift_shape(self):
45
45
for group in self .group_dict .values ():
46
46
conv_layer = self .generate_layer (group )
47
47
signal_on_group = conv_layer (signal_on_grid )
48
- self .assertEqual (signal_on_group .shape , self .shape [: - 1 ] + (group . order , self . filters ))
48
+ self .assertEqual (signal_on_group .shape , self .generate_shape (group ))
49
49
50
50
def test_lift_shape_subgroup (self ):
51
51
signal_on_grid = self .generate_signal (None )
@@ -55,15 +55,15 @@ def test_lift_shape_subgroup(self):
55
55
conv_layer = self .generate_layer (group , subgroup = subgroup_name )
56
56
57
57
signal_on_group = conv_layer (signal_on_grid )
58
- out_size = group . order if self .transpose else subgroup . order
59
- self .assertEqual (signal_on_group .shape , self .shape [: - 1 ] + ( out_size , self . filters ))
58
+ out_group = group if self .transpose else subgroup
59
+ self .assertEqual (signal_on_group .shape , self .generate_shape ( out_group ))
60
60
61
61
def test_gc_shape (self ):
62
62
for group in self .group_dict .values ():
63
63
signal_on_group = self .generate_signal (group )
64
64
conv_layer = self .generate_layer (group )
65
65
new_signal = conv_layer (signal_on_group )
66
- self .assertEqual (new_signal .shape , signal_on_group . shape [: - 1 ] + ( self .filters , ))
66
+ self .assertEqual (new_signal .shape , self .generate_shape ( group ))
67
67
68
68
def test_gc_shape_subgroup (self ):
69
69
for group in self .group_dict .values ():
@@ -73,8 +73,8 @@ def test_gc_shape_subgroup(self):
73
73
conv_layer = self .generate_layer (group , subgroup = subgroup_name )
74
74
75
75
new_signal = conv_layer (signal )
76
- out_size = group . order if self .transpose else self .group_dict [subgroup_name ]. order
77
- self .assertEqual (new_signal .shape , self .shape [: - 1 ] + ( out_size , self . filters ))
76
+ out_group = group if self .transpose else self .group_dict [subgroup_name ]
77
+ self .assertEqual (new_signal .shape , self .generate_shape ( out_group ))
78
78
79
79
def test_lift_equiv (self ):
80
80
signal_on_grid = self .generate_signal (None )
@@ -127,14 +127,20 @@ def test_padding_equiv(self):
127
127
conv_layer = self .generate_layer (group , padding = padding , strides = strides )
128
128
self .check_equivariance (conv_layer , signal_on_group )
129
129
130
- def generate_signal (self , group ):
130
+ def generate_shape (self , group , output = True ):
131
131
if type (group ) == str :
132
132
group = self .group_dict [group ]
133
133
if group == None :
134
134
shape = self .shape
135
135
else :
136
136
shape = self .shape [:- 1 ] + (group .order , self .shape [- 1 ])
137
- return keras .random .normal (shape = shape , seed = 42 )
137
+
138
+ if output :
139
+ shape = shape [:- 1 ] + (self .filters ,)
140
+ return shape
141
+
142
+ def generate_signal (self , group ):
143
+ return keras .random .normal (shape = self .generate_shape (group , output = False ), seed = 42 )
138
144
139
145
def generate_layer (self , group , padding = "same_equiv" , strides = 1 , subgroup = "" ):
140
146
return self .conv (
0 commit comments