Skip to content

Commit 59e28c3

Browse files
CyrilvallezArthurZucker
authored andcommitted
Fix flex_attention in training mode (#35605)
* fix flex * add test * style
1 parent 7cf6230 commit 59e28c3

File tree

2 files changed

+14
-1
lines changed

2 files changed

+14
-1
lines changed

src/transformers/integrations/flex_attention.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def causal_mod(score, b, h, q_idx, kv_idx):
2727
if softcap is not None:
2828
score = softcap * torch.tanh(score / softcap)
2929
if causal_mask is not None:
30-
score += causal_mask[b][0][q_idx][kv_idx]
30+
score = score + causal_mask[b][0][q_idx][kv_idx]
3131
return score
3232

3333
attn_output, attention_weights = flex_attention(

tests/test_modeling_common.py

+13
Original file line numberDiff line numberDiff line change
@@ -4790,6 +4790,19 @@ def test_forward_with_num_logits_to_keep(self):
47904790
# Assert the last tokens are actually the same (except for the natural fluctuation due to order of FP ops)
47914791
self.assertTrue(torch.allclose(all_logits[:, -1:, :], last_token_logits, atol=1e-5))
47924792

4793+
@require_torch_gpu
4794+
def test_flex_attention_with_grads(self):
4795+
for model_class in self.all_model_classes:
4796+
if not model_class._supports_flex_attn:
4797+
self.skipTest(reason="This model does not support flex attention")
4798+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
4799+
config._attn_implementation = "flex_attention"
4800+
model = model_class(config).to(device=torch_device, dtype=torch.float16)
4801+
self.assertTrue(model.config._attn_implementation == "flex_attention")
4802+
4803+
# If this does not raise an error, the test passes (see https://github.com/huggingface/transformers/pull/35605)
4804+
_ = model(inputs_dict["input_ids"].to(torch_device))
4805+
47934806

47944807
global_rng = random.Random()
47954808

0 commit comments

Comments
 (0)