Skip to content

Commit 42b8e79

Browse files
tomaarsenArthurZucker
authored andcommitted
ModernBert: reuse GemmaRotaryEmbedding via modular + Integration tests (#35459)
* Introduce 5 integration tests for the 4 model classes + torch export * ModernBert: reuse GemmaRotaryEmbedding via modular * Revert #35589, keep rope_kwargs; rely on them in modular_modernbert * Revert "Revert #35589, keep rope_kwargs; rely on them in modular_modernbert" This reverts commit 11b44b9. * Don't set rope_kwargs; override 'self.rope_init_fn' call instead
1 parent e39c9f7 commit 42b8e79

File tree

3 files changed

+178
-48
lines changed

3 files changed

+178
-48
lines changed

src/transformers/models/modernbert/modeling_modernbert.py

+42-14
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from ...activations import ACT2FN
3232
from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
3333
from ...modeling_outputs import BaseModelOutput, MaskedLMOutput, SequenceClassifierOutput, TokenClassifierOutput
34+
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
3435
from ...modeling_utils import PreTrainedModel
3536
from ...utils import (
3637
add_code_sample_docstrings,
@@ -241,30 +242,59 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
241242

242243

243244
class ModernBertRotaryEmbedding(nn.Module):
244-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
245+
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
245246
super().__init__()
247+
# BC: "rope_type" was originally "type"
248+
if hasattr(config, "rope_scaling") and config.rope_scaling is not None:
249+
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
250+
else:
251+
self.rope_type = "default"
252+
self.max_seq_len_cached = config.max_position_embeddings
253+
self.original_max_seq_len = config.max_position_embeddings
254+
255+
self.config = config
256+
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
257+
inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base)
258+
self.register_buffer("inv_freq", inv_freq, persistent=False)
259+
self.original_inv_freq = self.inv_freq
260+
261+
def _dynamic_frequency_update(self, position_ids, device):
262+
"""
263+
dynamic RoPE layers should recompute `inv_freq` in the following situations:
264+
1 - growing beyond the cached sequence length (allow scaling)
265+
2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
266+
"""
267+
seq_len = torch.max(position_ids) + 1
268+
if seq_len > self.max_seq_len_cached: # growth
269+
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, seq_len=seq_len)
270+
self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
271+
self.max_seq_len_cached = seq_len
246272

247-
self.dim = dim
248-
self.max_position_embeddings = max_position_embeddings
249-
self.base = base
250-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
251-
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
273+
if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
274+
self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
275+
self.max_seq_len_cached = self.original_max_seq_len
252276

253277
@torch.no_grad()
254-
def forward(self, x, position_ids, seq_len=None):
255-
# x: [bs, num_attention_heads, seq_len, head_size]
256-
self.inv_freq.to(x.device)
278+
def forward(self, x, position_ids):
279+
if "dynamic" in self.rope_type:
280+
self._dynamic_frequency_update(position_ids, device=x.device)
281+
282+
# Core RoPE block
257283
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
258284
position_ids_expanded = position_ids[:, None, :].float()
259-
# Force float32 since bfloat16 loses precision on long contexts
260-
# See https://github.com/huggingface/transformers/pull/29285
285+
# Force float32 (see https://github.com/huggingface/transformers/pull/29285)
261286
device_type = x.device.type
262287
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
263288
with torch.autocast(device_type=device_type, enabled=False):
264289
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
265290
emb = torch.cat((freqs, freqs), dim=-1)
266291
cos = emb.cos()
267292
sin = emb.sin()
293+
294+
# Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
295+
cos = cos * self.attention_scaling
296+
sin = sin * self.attention_scaling
297+
268298
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
269299

270300

@@ -468,9 +498,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
468498
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
469499
)
470500
else:
471-
self.rotary_emb = ModernBertRotaryEmbedding(
472-
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
473-
)
501+
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)
474502

475503
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
476504
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()

src/transformers/models/modernbert/modular_modernbert.py

+6-30
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
logging,
4242
)
4343
from ...utils.import_utils import is_triton_available
44-
from ..gemma.modeling_gemma import apply_rotary_pos_emb
44+
from ..gemma.modeling_gemma import GemmaRotaryEmbedding, apply_rotary_pos_emb
4545

4646

4747
if is_flash_attn_2_available():
@@ -504,32 +504,10 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
504504
return self.Wo(self.drop(self.act(input) * gate))
505505

506506

507-
class ModernBertRotaryEmbedding(nn.Module):
508-
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
509-
super().__init__()
510-
511-
self.dim = dim
512-
self.max_position_embeddings = max_position_embeddings
513-
self.base = base
514-
inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float() / self.dim))
515-
self.register_buffer("inv_freq", tensor=inv_freq, persistent=False)
516-
517-
@torch.no_grad()
518-
def forward(self, x, position_ids, seq_len=None):
519-
# x: [bs, num_attention_heads, seq_len, head_size]
520-
self.inv_freq.to(x.device)
521-
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
522-
position_ids_expanded = position_ids[:, None, :].float()
523-
# Force float32 since bfloat16 loses precision on long contexts
524-
# See https://github.com/huggingface/transformers/pull/29285
525-
device_type = x.device.type
526-
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
527-
with torch.autocast(device_type=device_type, enabled=False):
528-
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
529-
emb = torch.cat((freqs, freqs), dim=-1)
530-
cos = emb.cos()
531-
sin = emb.sin()
532-
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
507+
class ModernBertRotaryEmbedding(GemmaRotaryEmbedding):
508+
def __init__(self, config: ModernBertConfig, dim: int, base: float, device: Optional[torch.device] = None):
509+
super().__init__(self, config=config, device=device)
510+
inv_freq, self.attention_scaling = self.rope_init_fn(None, device, dim=dim, base=base)
533511

534512

535513
def eager_attention_forward(
@@ -698,9 +676,7 @@ def __init__(self, config: ModernBertConfig, layer_id: Optional[int] = None):
698676
dim=self.head_dim, max_seqlen=max_position_embeddings, base=rope_theta
699677
)
700678
else:
701-
self.rotary_emb = ModernBertRotaryEmbedding(
702-
dim=self.head_dim, max_position_embeddings=max_position_embeddings, base=rope_theta
703-
)
679+
self.rotary_emb = ModernBertRotaryEmbedding(config=config, dim=self.head_dim, base=rope_theta)
704680

705681
self.Wo = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
706682
self.out_drop = nn.Dropout(config.attention_dropout) if config.attention_dropout > 0.0 else nn.Identity()

tests/models/modernbert/test_modeling_modernbert.py

+130-4
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,9 @@
1616
import unittest
1717

1818
import pytest
19+
from packaging import version
1920

20-
from transformers import ModernBertConfig, is_torch_available
21+
from transformers import AutoTokenizer, ModernBertConfig, is_torch_available
2122
from transformers.models.auto import get_values
2223
from transformers.testing_utils import (
2324
CaptureLogger,
@@ -362,6 +363,131 @@ def test_flash_attn_2_conversion(self):
362363

363364
@require_torch
364365
class ModernBertModelIntegrationTest(unittest.TestCase):
365-
"""
366-
These still need to be written, once public models are available.
367-
"""
366+
@slow
367+
def test_inference_masked_lm(self):
368+
if version.parse(torch.__version__) < version.parse("2.4.0"):
369+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
370+
371+
model = ModernBertForMaskedLM.from_pretrained(
372+
"answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa"
373+
)
374+
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
375+
376+
inputs = tokenizer("Hello World!", return_tensors="pt")
377+
with torch.no_grad():
378+
output = model(**inputs)[0]
379+
expected_shape = torch.Size((1, 5, 50368))
380+
self.assertEqual(output.shape, expected_shape)
381+
382+
# compare the actual values for a slice.
383+
expected_slice = torch.tensor(
384+
[[[3.8387, -0.2017, 12.2839], [3.6300, 0.6869, 14.7123], [-5.1137, -3.8122, 11.9874]]]
385+
)
386+
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
387+
388+
@slow
389+
def test_inference_no_head(self):
390+
if version.parse(torch.__version__) < version.parse("2.4.0"):
391+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
392+
393+
model = ModernBertModel.from_pretrained(
394+
"answerdotai/ModernBERT-base", reference_compile=False, attn_implementation="sdpa"
395+
)
396+
tokenizer = AutoTokenizer.from_pretrained("answerdotai/ModernBERT-base")
397+
398+
inputs = tokenizer("Hello World!", return_tensors="pt")
399+
with torch.no_grad():
400+
output = model(**inputs)[0]
401+
expected_shape = torch.Size((1, 5, 768))
402+
self.assertEqual(output.shape, expected_shape)
403+
404+
# compare the actual values for a slice.
405+
expected_slice = torch.tensor(
406+
[[[0.3151, -0.6417, -0.7027], [-0.7834, -1.5810, 0.4576], [1.0614, -0.7268, -0.0871]]]
407+
)
408+
self.assertTrue(torch.allclose(output[:, :3, :3], expected_slice, atol=1e-4))
409+
410+
@slow
411+
def test_inference_token_classification(self):
412+
if version.parse(torch.__version__) < version.parse("2.4.0"):
413+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
414+
415+
model = ModernBertForTokenClassification.from_pretrained(
416+
"hf-internal-testing/tiny-random-ModernBertForTokenClassification",
417+
reference_compile=False,
418+
attn_implementation="sdpa",
419+
)
420+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-ModernBertForTokenClassification")
421+
422+
inputs = tokenizer("Hello World!", return_tensors="pt")
423+
with torch.no_grad():
424+
output = model(**inputs)[0]
425+
expected_shape = torch.Size((1, 5, 2))
426+
self.assertEqual(output.shape, expected_shape)
427+
428+
expected = torch.tensor(
429+
[[[2.0159, 4.6569], [-0.9430, 3.1595], [-3.8770, 3.2653], [1.5752, 4.5167], [-1.6939, 1.2524]]]
430+
)
431+
self.assertTrue(torch.allclose(output, expected, atol=1e-4))
432+
433+
@slow
434+
def test_inference_sequence_classification(self):
435+
if version.parse(torch.__version__) < version.parse("2.4.0"):
436+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
437+
438+
model = ModernBertForSequenceClassification.from_pretrained(
439+
"hf-internal-testing/tiny-random-ModernBertForSequenceClassification",
440+
reference_compile=False,
441+
attn_implementation="sdpa",
442+
)
443+
tokenizer = AutoTokenizer.from_pretrained(
444+
"hf-internal-testing/tiny-random-ModernBertForSequenceClassification"
445+
)
446+
447+
inputs = tokenizer("Hello World!", return_tensors="pt")
448+
with torch.no_grad():
449+
output = model(**inputs)[0]
450+
expected_shape = torch.Size((1, 2))
451+
self.assertEqual(output.shape, expected_shape)
452+
453+
expected = torch.tensor([[1.6466, 4.5662]])
454+
self.assertTrue(torch.allclose(output, expected, atol=1e-4))
455+
456+
@slow
457+
def test_export(self):
458+
if version.parse(torch.__version__) < version.parse("2.4.0"):
459+
self.skipTest(reason="This test requires torch >= 2.4 to run.")
460+
461+
bert_model = "answerdotai/ModernBERT-base"
462+
device = "cpu"
463+
attn_implementation = "sdpa"
464+
max_length = 512
465+
466+
tokenizer = AutoTokenizer.from_pretrained(bert_model)
467+
inputs = tokenizer(
468+
"the man worked as a [MASK].",
469+
return_tensors="pt",
470+
padding="max_length",
471+
max_length=max_length,
472+
)
473+
474+
model = ModernBertForMaskedLM.from_pretrained(
475+
bert_model,
476+
device_map=device,
477+
attn_implementation=attn_implementation,
478+
)
479+
480+
logits = model(**inputs).logits
481+
eg_predicted_mask = tokenizer.decode(logits[0, 6].topk(5).indices)
482+
self.assertEqual(eg_predicted_mask.split(), ["lawyer", "mechanic", "teacher", "doctor", "waiter"])
483+
484+
exported_program = torch.export.export(
485+
model,
486+
args=(inputs["input_ids"],),
487+
kwargs={"attention_mask": inputs["attention_mask"]},
488+
strict=True,
489+
)
490+
491+
result = exported_program.module().forward(inputs["input_ids"], inputs["attention_mask"])
492+
ep_predicted_mask = tokenizer.decode(result.logits[0, 6].topk(5).indices)
493+
self.assertEqual(eg_predicted_mask, ep_predicted_mask)

0 commit comments

Comments
 (0)