Skip to content

Commit 1e1157f

Browse files
committed
CLSToken | add expand method instead of repeat_as
1 parent 898f303 commit 1e1157f

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

rtdl/modules.py

+19-9
Original file line numberDiff line numberDiff line change
@@ -258,7 +258,7 @@ class CLSToken(nn.Module):
258258
x = torch.randn(batch_size, n_tokens, d_token)
259259
x = cls_token(x)
260260
assert x.shape == (batch_size, n_tokens + 1, d_token)
261-
assert (x[:, -1:, :] == cls_token.repeat_as(x)).all()
261+
assert (x[:, -1, :] == cls_token.expand(len(x))).all()
262262
"""
263263

264264
def __init__(self, d_token: int, initialization: str) -> None:
@@ -268,22 +268,32 @@ def __init__(self, d_token: int, initialization: str) -> None:
268268
self.weight = nn.Parameter(Tensor(d_token))
269269
initialization_.apply(self.weight, d_token)
270270

271-
def repeat_as(self, x: Tensor) -> Tensor:
272-
"""Repeat self to match the given batch of token-based objects.
271+
def expand(self, *leading_dimensions: int) -> Tensor:
272+
"""Expand (repeat) the underlying [CLS]-token to a tensor with the given leading dimensions.
273+
274+
A possible use case is building a batch of [CLS]-tokens. See `CLSToken` for
275+
examples of usage.
276+
277+
Note:
278+
279+
Under the hood, the `torch.Tensor.expand` method is applied to the
280+
underlying :code:`weight` parameter, so gradients will be propagated as
281+
expected.
273282
274283
Args:
275-
x: tensor of a shape `(batch_size, n_tokens, d_token)`.
284+
leading_dimensions: the additional new dimensions
276285
277286
Returns:
278-
tensor of a shape `(len(x), 1, d_token)`
287+
tensor of the shape :code:`(*leading_dimensions, len(self.weight))`
279288
"""
280-
assert x.ndim == 3
281-
assert len(self.weight) == x.shape[-1]
282-
return self.weight.view(1, 1, -1).repeat(len(x), 1, 1)
289+
if not leading_dimensions:
290+
return self.weight
291+
new_dims = (1,) * (len(leading_dimensions) - 1)
292+
return self.weight.view(*new_dims, -1).expand(*leading_dimensions, -1)
283293

284294
def forward(self, x: Tensor) -> Tensor:
285295
"""Append self **to the end** of each item in the batch (see `CLSToken`)."""
286-
return torch.cat([x, self.repeat_as(x)], dim=1)
296+
return torch.cat([x, self.expand(len(x), 1)], dim=1)
287297

288298

289299
def _make_nn_module(module_type: ModuleType, *args) -> nn.Module:

0 commit comments

Comments
 (0)