@@ -258,7 +258,7 @@ class CLSToken(nn.Module):
258
258
x = torch.randn(batch_size, n_tokens, d_token)
259
259
x = cls_token(x)
260
260
assert x.shape == (batch_size, n_tokens + 1, d_token)
261
- assert (x[:, -1: , :] == cls_token.repeat_as(x )).all()
261
+ assert (x[:, -1, :] == cls_token.expand(len(x) )).all()
262
262
"""
263
263
264
264
def __init__ (self , d_token : int , initialization : str ) -> None :
@@ -268,22 +268,32 @@ def __init__(self, d_token: int, initialization: str) -> None:
268
268
self .weight = nn .Parameter (Tensor (d_token ))
269
269
initialization_ .apply (self .weight , d_token )
270
270
271
- def repeat_as (self , x : Tensor ) -> Tensor :
272
- """Repeat self to match the given batch of token-based objects.
271
+ def expand (self , * leading_dimensions : int ) -> Tensor :
272
+ """Expand (repeat) the underlying [CLS]-token to a tensor with the given leading dimensions.
273
+
274
+ A possible use case is building a batch of [CLS]-tokens. See `CLSToken` for
275
+ examples of usage.
276
+
277
+ Note:
278
+
279
+ Under the hood, the `torch.Tensor.expand` method is applied to the
280
+ underlying :code:`weight` parameter, so gradients will be propagated as
281
+ expected.
273
282
274
283
Args:
275
- x: tensor of a shape `(batch_size, n_tokens, d_token)`.
284
+ leading_dimensions: the additional new dimensions
276
285
277
286
Returns:
278
- tensor of a shape `( len(x), 1, d_token )`
287
+ tensor of the shape :code:`(*leading_dimensions, len(self.weight) )`
279
288
"""
280
- assert x .ndim == 3
281
- assert len (self .weight ) == x .shape [- 1 ]
282
- return self .weight .view (1 , 1 , - 1 ).repeat (len (x ), 1 , 1 )
289
+ if not leading_dimensions :
290
+ return self .weight
291
+ new_dims = (1 ,) * (len (leading_dimensions ) - 1 )
292
+ return self .weight .view (* new_dims , - 1 ).expand (* leading_dimensions , - 1 )
283
293
284
294
def forward (self , x : Tensor ) -> Tensor :
285
295
"""Append self **to the end** of each item in the batch (see `CLSToken`)."""
286
- return torch .cat ([x , self .repeat_as ( x )], dim = 1 )
296
+ return torch .cat ([x , self .expand ( len ( x ), 1 )], dim = 1 )
287
297
288
298
289
299
def _make_nn_module (module_type : ModuleType , * args ) -> nn .Module :
0 commit comments