Skip to content

Commit 1e3ddcb

Browse files
ModernBERT bug fixes (#35404)
* bug fixes * organize imports * wrap cpu warning in reference_compile * Avoid needing repad_logits_with_grad, always repad with grads when training I'm not 100% that the conditional with "or labels is None" makes sense though - not sure what the intention is there. Perhaps we can remove that? * Revert "Avoid needing repad_logits_with_grad, always repad with grads when training" This reverts commit cedcb4e. * Fix grammar: keep -> keeps * Propagate grammar fix with modular_model_converter --------- Co-authored-by: Tom Aarsen <Cubiegamedev@gmail.com> Co-authored-by: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com>
1 parent e97d7a5 commit 1e3ddcb

File tree

5 files changed

+53
-19
lines changed

5 files changed

+53
-19
lines changed

docs/source/en/_toctree.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -505,7 +505,7 @@
505505
- local: model_doc/mobilebert
506506
title: MobileBERT
507507
- local: model_doc/modernbert
508-
title: ModernBert
508+
title: ModernBERT
509509
- local: model_doc/mpnet
510510
title: MPNet
511511
- local: model_doc/mpt

docs/source/en/model_doc/modernbert.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ rendered properly in your Markdown viewer.
1414
1515
-->
1616

17-
# ModernBert
17+
# ModernBERT
1818

1919
<div class="flex flex-wrap space-x-1">
2020
<a href="https://huggingface.co/models?filter=modernbert">
@@ -27,7 +27,7 @@ rendered properly in your Markdown viewer.
2727

2828
## Overview
2929

30-
The ModernBert model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
30+
The ModernBERT model was proposed in [Smarter, Better, Faster, Longer: A Modern Bidirectional Encoder for Fast, Memory Efficient, and Long Context Finetuning and Inference](https://arxiv.org/abs/2412.13663) by Benjamin Warner, Antoine Chaffin, Benjamin Clavié, Orion Weller, Oskar Hallström, Said Taghadouini, Alexis Galalgher, Raja Bisas, Faisal Ladhak, Tom Aarsen, Nathan Cooper, Grifin Adams, Jeremy Howard and Iacopo Poli.
3131

3232
It is a refresh of the traditional encoder architecture, as used in previous models such as [BERT](https://huggingface.co/docs/transformers/en/model_doc/bert) and [RoBERTa](https://huggingface.co/docs/transformers/en/model_doc/roberta).
3333

src/transformers/models/modernbert/configuration_modernbert.py

+5
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,9 @@ class ModernBertConfig(PretrainedConfig):
109109
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
110110
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
111111
be faster in some scenarios.
112+
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
113+
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
114+
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
112115
113116
Examples:
114117
@@ -164,6 +167,7 @@ def __init__(
164167
sparse_prediction=False,
165168
sparse_pred_ignore_index=-100,
166169
reference_compile=None,
170+
repad_logits_with_grad=False,
167171
**kwargs,
168172
):
169173
super().__init__(
@@ -203,6 +207,7 @@ def __init__(
203207
self.sparse_prediction = sparse_prediction
204208
self.sparse_pred_ignore_index = sparse_pred_ignore_index
205209
self.reference_compile = reference_compile
210+
self.repad_logits_with_grad = repad_logits_with_grad
206211

207212
if self.classifier_pooling not in ["cls", "mean"]:
208213
raise ValueError(

src/transformers/models/modernbert/modeling_modernbert.py

+20-8
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
# limitations under the License.
2121

2222
import math
23+
from contextlib import nullcontext
2324
from typing import Dict, Optional, Tuple, Union
2425

2526
import torch
@@ -632,12 +633,14 @@ def _autoset_attn_implementation(
632633
):
633634
# If the user didn't specify anything, try to use flash_attention_2 if available.
634635
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
636+
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
637+
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
635638
if config._attn_implementation_internal is None:
636639
config._attn_implementation_internal = "flash_attention_2"
637640
try:
638641
return cls._check_and_enable_flash_attn_2(
639642
config,
640-
torch_dtype=torch_dtype,
643+
torch_dtype=torch.float16,
641644
device_map=device_map,
642645
hard_check_only=False,
643646
check_device_map=check_device_map,
@@ -647,7 +650,7 @@ def _autoset_attn_implementation(
647650
return super()._autoset_attn_implementation(
648651
config,
649652
use_flash_attention_2=use_flash_attention_2,
650-
torch_dtype=torch_dtype,
653+
torch_dtype=torch.float16,
651654
device_map=device_map,
652655
check_device_map=check_device_map,
653656
)
@@ -672,6 +675,14 @@ def _maybe_set_compile(self):
672675
)
673676
self.config.reference_compile = False
674677

678+
if self.device.type == "cpu":
679+
if self.config.reference_compile:
680+
logger.warning_once(
681+
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
682+
"Falling back to non-compiled mode."
683+
)
684+
self.config.reference_compile = False
685+
675686
if self.config.reference_compile is None:
676687
self.config.reference_compile = is_triton_available()
677688

@@ -763,8 +774,8 @@ def _pad_modernbert_output(
763774
MODERNBERT_INPUTS_DOCSTRING = r"""
764775
Args:
765776
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
766-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
767-
it.
777+
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
778+
by default should you provide it.
768779
769780
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
770781
[`PreTrainedTokenizer.__call__`] for details.
@@ -790,7 +801,7 @@ def _pad_modernbert_output(
790801
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
791802
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
792803
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
793-
far-away tokens in the local attention layers.
804+
far-away tokens in the local attention layers when not using Flash Attention.
794805
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
795806
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
796807
config.n_positions - 1]`.
@@ -805,11 +816,11 @@ def _pad_modernbert_output(
805816
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
806817
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
807818
max_seqlen (`int`, *optional*):
808-
Maximum sequence length in the batch. Used to pad the output tensors.
819+
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
809820
batch_size (`int`, *optional*):
810821
Batch size of the input sequences. Used to pad the output tensors.
811822
seq_len (`int`, *optional*):
812-
Sequence length of the input sequences. Used to pad the output tensors.
823+
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
813824
output_attentions (`bool`, *optional*):
814825
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
815826
tensors for more detail.
@@ -1128,8 +1139,9 @@ def forward(
11281139
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
11291140

11301141
if self.config._attn_implementation == "flash_attention_2":
1131-
with torch.no_grad():
1142+
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
11321143
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1144+
11331145
if not return_dict:
11341146
output = (logits,)
11351147
return ((loss,) + output) if loss is not None else output

src/transformers/models/modernbert/modular_modernbert.py

+25-8
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515

1616
import math
17+
from contextlib import nullcontext
1718
from typing import Dict, Literal, Optional, Tuple, Union
1819

1920
import torch
@@ -141,6 +142,9 @@ class ModernBertConfig(PretrainedConfig):
141142
the model will be compiled if 1) `triton` is installed, 2) the model is not on MPS, 3) the model is not
142143
shared between devices, and 4) the model is not resized after initialization. If `True`, then the model may
143144
be faster in some scenarios.
145+
repad_logits_with_grad (`bool`, *optional*, defaults to `False`):
146+
When True, ModernBertForMaskedLM keeps track of the logits' gradient when repadding for output. This only
147+
applies when using Flash Attention 2 with passed labels. Otherwise output logits always have a gradient.
144148
145149
Examples:
146150
@@ -196,6 +200,7 @@ def __init__(
196200
sparse_prediction=False,
197201
sparse_pred_ignore_index=-100,
198202
reference_compile=None,
203+
repad_logits_with_grad=False,
199204
**kwargs,
200205
):
201206
super().__init__(
@@ -235,6 +240,7 @@ def __init__(
235240
self.sparse_prediction = sparse_prediction
236241
self.sparse_pred_ignore_index = sparse_pred_ignore_index
237242
self.reference_compile = reference_compile
243+
self.repad_logits_with_grad = repad_logits_with_grad
238244

239245
if self.classifier_pooling not in ["cls", "mean"]:
240246
raise ValueError(
@@ -857,12 +863,14 @@ def _autoset_attn_implementation(
857863
):
858864
# If the user didn't specify anything, try to use flash_attention_2 if available.
859865
# Otherwise we fall back to the default SDPA -> Eager from the super() method.
866+
# ModernBert's FA2 implementation correctly handles non-fp16/bf16 dtypes, we don't
867+
# need the FA2 warning for non-fp16/bf16 dtypes so we set fp16 for the FA2 check.
860868
if config._attn_implementation_internal is None:
861869
config._attn_implementation_internal = "flash_attention_2"
862870
try:
863871
return cls._check_and_enable_flash_attn_2(
864872
config,
865-
torch_dtype=torch_dtype,
873+
torch_dtype=torch.float16,
866874
device_map=device_map,
867875
hard_check_only=False,
868876
check_device_map=check_device_map,
@@ -872,7 +880,7 @@ def _autoset_attn_implementation(
872880
return super()._autoset_attn_implementation(
873881
config,
874882
use_flash_attention_2=use_flash_attention_2,
875-
torch_dtype=torch_dtype,
883+
torch_dtype=torch.float16,
876884
device_map=device_map,
877885
check_device_map=check_device_map,
878886
)
@@ -897,6 +905,14 @@ def _maybe_set_compile(self):
897905
)
898906
self.config.reference_compile = False
899907

908+
if self.device.type == "cpu":
909+
if self.config.reference_compile:
910+
logger.warning_once(
911+
"Compiling the model with `torch.compile` and using a `torch.cpu` device is not supported. "
912+
"Falling back to non-compiled mode."
913+
)
914+
self.config.reference_compile = False
915+
900916
if self.config.reference_compile is None:
901917
self.config.reference_compile = is_triton_available()
902918

@@ -916,8 +932,8 @@ def resize_token_embeddings(self, *args, **kwargs):
916932
MODERNBERT_INPUTS_DOCSTRING = r"""
917933
Args:
918934
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
919-
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
920-
it.
935+
Indices of input sequence tokens in the vocabulary. With Flash Attention 2.0, padding will be ignored
936+
by default should you provide it.
921937
922938
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
923939
[`PreTrainedTokenizer.__call__`] for details.
@@ -943,7 +959,7 @@ def resize_token_embeddings(self, *args, **kwargs):
943959
sliding_window_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
944960
Mask to avoid performing attention on padding or far-away tokens. In ModernBert, only every few layers
945961
perform global attention, while the rest perform local attention. This mask is used to avoid attending to
946-
far-away tokens in the local attention layers.
962+
far-away tokens in the local attention layers when not using Flash Attention.
947963
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
948964
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
949965
config.n_positions - 1]`.
@@ -958,11 +974,11 @@ def resize_token_embeddings(self, *args, **kwargs):
958974
cu_seqlens (`torch.Tensor` of shape `(batch + 1,)`, *optional*):
959975
Cumulative sequence lengths of the input sequences. Used to index the unpadded tensors.
960976
max_seqlen (`int`, *optional*):
961-
Maximum sequence length in the batch. Used to pad the output tensors.
977+
Maximum sequence length in the batch excluding padding tokens. Used to unpad input_ids and pad output tensors.
962978
batch_size (`int`, *optional*):
963979
Batch size of the input sequences. Used to pad the output tensors.
964980
seq_len (`int`, *optional*):
965-
Sequence length of the input sequences. Used to pad the output tensors.
981+
Sequence length of the input sequences including padding tokens. Used to pad the output tensors.
966982
output_attentions (`bool`, *optional*):
967983
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
968984
tensors for more detail.
@@ -1281,8 +1297,9 @@ def forward(
12811297
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size)
12821298

12831299
if self.config._attn_implementation == "flash_attention_2":
1284-
with torch.no_grad():
1300+
with nullcontext() if self.config.repad_logits_with_grad or labels is None else torch.no_grad():
12851301
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
1302+
12861303
if not return_dict:
12871304
output = (logits,)
12881305
return ((loss,) + output) if loss is not None else output

0 commit comments

Comments
 (0)