@@ -46,15 +46,7 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
46
46
raise TypeError (f"Sampling not implemented for { type (self )} " )
47
47
48
48
49
- class TorchProductLayer (TorchInnerLayer , ABC ):
50
- ...
51
-
52
-
53
- class TorchSumLayer (TorchInnerLayer , ABC ):
54
- ...
55
-
56
-
57
- class TorchHadamardLayer (TorchProductLayer ):
49
+ class TorchHadamardLayer (TorchInnerLayer ):
58
50
"""The Hadamard product layer."""
59
51
60
52
def __init__ (
@@ -110,7 +102,7 @@ def sample(self, x: Tensor) -> tuple[Tensor, None]:
110
102
return x , None
111
103
112
104
113
- class TorchKroneckerLayer (TorchProductLayer ):
105
+ class TorchKroneckerLayer (TorchInnerLayer ):
114
106
"""The Kronecker product layer."""
115
107
116
108
def __init__ (
@@ -171,13 +163,14 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor | None]:
171
163
return torch .flatten (x , start_dim = 2 , end_dim = 3 ), None
172
164
173
165
174
- class TorchDenseLayer ( TorchSumLayer ):
175
- """The sum layer for dense sum within a layer ."""
166
+ class TorchSumLayer ( TorchInnerLayer ):
167
+ """The sum layer."""
176
168
177
169
def __init__ (
178
170
self ,
179
171
num_input_units : int ,
180
172
num_output_units : int ,
173
+ arity : int = 1 ,
181
174
* ,
182
175
weight : TorchParameter ,
183
176
semiring : Semiring | None = None ,
@@ -192,91 +185,7 @@ def __init__(
192
185
num_folds (int): The number of channels. Defaults to 1.
193
186
"""
194
187
assert weight .num_folds == num_folds
195
- assert weight .shape == (num_output_units , num_input_units )
196
- super ().__init__ (
197
- num_input_units , num_output_units , arity = 1 , semiring = semiring , num_folds = num_folds
198
- )
199
- self .weight = weight
200
-
201
- @property
202
- def config (self ) -> Mapping [str , Any ]:
203
- return {"num_input_units" : self .num_input_units , "num_output_units" : self .num_output_units }
204
-
205
- @property
206
- def params (self ) -> Mapping [str , TorchParameter ]:
207
- return {"weight" : self .weight }
208
-
209
- def forward (self , x : Tensor ) -> Tensor :
210
- """Run forward pass.
211
-
212
- Args:
213
- x (Tensor): The input to this layer, shape (F, H, B, Ki).
214
-
215
- Returns:
216
- Tensor: The output of this layer, shape (F, B, Ko).
217
- """
218
- x = x .squeeze (dim = 1 ) # shape (F, H=1, B, Ki) -> (F, B, Ki).
219
- weight = self .weight ()
220
- return self .semiring .einsum (
221
- "fbi,foi->fbo" , inputs = (x ,), operands = (weight ,), dim = - 1 , keepdim = True
222
- ) # shape (F, B, Ko).
223
-
224
- def sample (self , x : Tensor ) -> tuple [Tensor , Tensor ]:
225
- weight = self .weight ()
226
- negative = torch .any (weight < 0.0 )
227
- if negative :
228
- raise ValueError ("Sampling only works with positive weights" )
229
- normalized = torch .allclose (torch .sum (weight , dim = - 1 ), torch .ones (1 , device = weight .device ))
230
- if not normalized :
231
- raise ValueError ("Sampling only works with a normalized parametrization" )
232
-
233
- # x: (F, H, C, K, num_samples, D)
234
- c = x .shape [2 ]
235
- d = x .shape [- 1 ]
236
- num_samples = x .shape [- 2 ]
237
-
238
- # mixing_distribution: (F, O, K)
239
- mixing_distribution = torch .distributions .Categorical (probs = weight )
240
-
241
- mixing_samples = mixing_distribution .sample ((num_samples ,))
242
- mixing_samples = E .rearrange (mixing_samples , "n f o -> f o n" )
243
- mixing_indices = E .repeat (mixing_samples , "f o n -> f a c o n d" , a = self .arity , c = c , d = d )
244
-
245
- x = torch .gather (x , dim = - 3 , index = mixing_indices )
246
- x = x [:, 0 ]
247
- return x , mixing_samples
248
-
249
-
250
- class TorchMixingLayer (TorchSumLayer ):
251
- """The sum layer for mixture among layers.
252
-
253
- It can also be used as a sparse sum within a layer when arity=1.
254
- """
255
-
256
- def __init__ (
257
- self ,
258
- num_input_units : int ,
259
- num_output_units : int ,
260
- arity : int = 2 ,
261
- * ,
262
- weight : TorchParameter ,
263
- semiring : Semiring | None = None ,
264
- num_folds : int = 1 ,
265
- ) -> None :
266
- """Init class.
267
-
268
- Args:
269
- num_input_units (int): The number of input units.
270
- num_output_units (int): The number of output units, must be the same as input.
271
- arity (int, optional): The arity of the layer. Defaults to 2.
272
- weight (TorchParameter): The reparameterization for layer parameters.
273
- num_folds (int): The number of channels. Defaults to 1.
274
- """
275
- assert (
276
- num_output_units == num_input_units
277
- ), "The number of input and output units must be the same for MixingLayer."
278
- assert weight .num_folds == num_folds
279
- assert weight .shape == (num_output_units , arity )
188
+ assert weight .shape == (num_output_units , arity * num_input_units )
280
189
super ().__init__ (
281
190
num_input_units , num_output_units , arity = arity , semiring = semiring , num_folds = num_folds
282
191
)
@@ -303,11 +212,13 @@ def forward(self, x: Tensor) -> Tensor:
303
212
Returns:
304
213
Tensor: The output of this layer, shape (F, B, Ko).
305
214
"""
306
- # shape (F, H, B, K) -> (F, B, K).
215
+ # x: (F, H, B, Ki) -> (F, B, H * Ki)
216
+ # weight: (F, Ko, H * Ki)
217
+ x = x .permute (0 , 2 , 1 , 3 ).flatten (start_dim = 2 )
307
218
weight = self .weight ()
308
219
return self .semiring .einsum (
309
- "fhbk,fkh->fbk " , inputs = (x ,), operands = (weight ,), dim = 1 , keepdim = False
310
- )
220
+ "fbi,foi->fbo " , inputs = (x ,), operands = (weight ,), dim = - 1 , keepdim = True
221
+ ) # shape (F, B, Ko).
311
222
312
223
def sample (self , x : Tensor ) -> tuple [Tensor , Tensor ]:
313
224
weight = self .weight ()
@@ -318,18 +229,22 @@ def sample(self, x: Tensor) -> tuple[Tensor, Tensor]:
318
229
if not normalized :
319
230
raise ValueError ("Sampling only works with a normalized parametrization" )
320
231
321
- # x: (F, H, C, K , num_samples, D)
322
- c = x .shape [ 2 ]
323
- k = x .shape [- 3 ]
324
- d = x .shape [- 1 ]
325
- num_samples = x .shape [- 2 ]
232
+ # x: (F, H, C, Ki, num_samples, D) -> (F, C, H * Ki , num_samples, D)
233
+ x = x .permute ( 0 , 2 , 1 , 3 , 4 , 5 ). flatten ( 2 , 3 )
234
+ c = x .shape [1 ]
235
+ num_samples = x .shape [3 ]
236
+ d = x .shape [4 ]
326
237
327
- # mixing_distribution: (F, O, K )
238
+ # mixing_distribution: (F, Ko, H * Ki )
328
239
mixing_distribution = torch .distributions .Categorical (probs = weight )
329
240
241
+ # mixing_samples: (num_samples, F, Ko) -> (F, Ko, num_samples)
330
242
mixing_samples = mixing_distribution .sample ((num_samples ,))
331
243
mixing_samples = E .rearrange (mixing_samples , "n f k -> f k n" )
332
- mixing_indices = E .repeat (mixing_samples , "f k n -> f 1 c k n d" , c = c , k = k , d = d )
333
244
334
- x = torch .gather (x , 1 , mixing_indices )[:, 0 ]
245
+ # mixing_indices: (F, C, Ko, num_samples, D)
246
+ mixing_indices = E .repeat (mixing_samples , "f k n -> f c k n d" , c = c , d = d )
247
+
248
+ # x: (F, C, Ko, num_samples, D)
249
+ x = torch .gather (x , dim = 2 , index = mixing_indices )
335
250
return x , mixing_samples
0 commit comments