Skip to content

Commit df3ff5a

Browse files
authored
Merge branch 'main' into cb-tp2
2 parents eee5471 + 2d6815e commit df3ff5a

21 files changed

Lines changed: 20 additions & 40 deletions

docs/source/en/attention_interface.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -130,18 +130,23 @@ model = AutoModelForImageTextToText.from_pretrained(
130130

131131
Customize or create new attention functions by adding them to the attention registry with [`AttentionInterface.register`]. Models use these functions through the `attn_implementation` argument.
132132

133-
This example customizes the attention function to print a statement for each layer.
133+
> [!WARNING]
134+
> Register a matching attention mask function when you register a custom attention function. If the custom `attn_implementation` name is not registered in [`AttentionMaskInterface`], Transformers skips mask creation and passes `attention_mask=None` to the attention layers. Your attention function must handle causal, padding, packing, or sliding-window constraints itself, or those constraints can be silently dropped.
135+
136+
This example customizes the attention function to print a statement for each layer. It keeps the mask in the original implementation by registering `masking_utils.sdpa_mask` as the attention mask function.
134137

135138
```python
136139
import torch
137-
from transformers import AutoModelForCausalLM, AttentionInterface
140+
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
138141
from transformers.integrations.sdpa_attention import sdpa_attention_forward
142+
from transformers.masking_utils import sdpa_mask
139143

140144
def my_new_sdpa(*args, **kwargs):
141145
print("I just entered the attention computation")
142146
return sdpa_attention_forward(*args, **kwargs)
143147

144148
AttentionInterface.register("my_new_sdpa", my_new_sdpa)
149+
AttentionMaskInterface.register("my_new_sdpa", sdpa_mask) # must have the same name as the registered attention function
145150

146151
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="my_new_sdpa")
147152
model(torch.ones(1, 5, dtype=int))
@@ -151,8 +156,9 @@ You can also add new arguments to the attention function. Models supporting [`At
151156

152157
```python
153158
import torch
154-
from transformers import AutoModelForCausalLM, AttentionInterface
159+
from transformers import AutoModelForCausalLM, AttentionInterface, AttentionMaskInterface
155160
from transformers.integrations.sdpa_attention import sdpa_attention_forward
161+
from transformers.masking_utils import sdpa_mask
156162

157163
def custom_attention(
158164
module: torch.nn.Module, # required arg
@@ -168,6 +174,7 @@ def custom_attention(
168174
return attn_output, attn_weights # attn_weights are optional here
169175

170176
AttentionInterface.register("custom", custom_attention)
177+
AttentionMaskInterface.register("custom", sdpa_mask) # to leave the existing mask untouched
171178

172179
model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation="custom")
173180
model(torch.ones(1, 5, dtype=int), a_new_kwargs=..., another_new_kwargs=...)

src/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3778,7 +3778,7 @@ def _prefill(
37783778
use_inputs_embeds = True
37793779
if (cache := model_kwargs.get("past_key_values")) is not None:
37803780
past_length = cache.get_seq_length()
3781-
# It will be sliced as input_embeds = inputs_embeds[:, -next_sequence_length:, :] in `prepare_inputs_for_generation`
3781+
# It will be sliced as inputs_embeds = inputs_embeds[:, -next_sequence_length:, :] in `prepare_inputs_for_generation`
37823782
if use_inputs_embeds:
37833783
next_sequence_length = model_kwargs["inputs_embeds"].shape[1] - past_length
37843784
else:

src/transformers/masking_utils.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from .cache_utils import Cache
2121
from .configuration_utils import PreTrainedConfig
2222
from .utils import is_torch_xpu_available, logging
23-
from .utils.deprecation import deprecate_kwarg
2423
from .utils.generic import GeneralInterface, is_flash_attention_requested
2524
from .utils.import_utils import is_torch_flex_attn_available, is_torch_greater_or_equal, is_tracing
2625

@@ -788,7 +787,6 @@ def find_packed_sequence_indices(position_ids: torch.Tensor) -> torch.Tensor | N
788787
return packed_sequence_mask
789788

790789

791-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
792790
def _preprocess_mask_arguments(
793791
config: PreTrainedConfig,
794792
inputs_embeds: torch.Tensor,
@@ -893,7 +891,6 @@ def _preprocess_mask_arguments(
893891
return False, attention_mask, packed_sequence_mask, q_length, kv_length, q_offset, kv_offset
894892

895893

896-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
897894
def create_causal_mask(
898895
config: PreTrainedConfig,
899896
inputs_embeds: torch.Tensor,
@@ -1019,7 +1016,6 @@ def create_causal_mask(
10191016
return causal_mask
10201017

10211018

1022-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
10231019
def create_bidirectional_mask(
10241020
config: PreTrainedConfig,
10251021
inputs_embeds: torch.Tensor,
@@ -1110,7 +1106,6 @@ def create_bidirectional_mask(
11101106
return attention_mask
11111107

11121108

1113-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
11141109
def create_sliding_window_causal_mask(
11151110
config: PreTrainedConfig,
11161111
inputs_embeds: torch.Tensor,
@@ -1237,7 +1232,6 @@ def create_sliding_window_causal_mask(
12371232
return causal_mask
12381233

12391234

1240-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
12411235
def create_bidirectional_sliding_window_mask(
12421236
config: PreTrainedConfig,
12431237
inputs_embeds: torch.Tensor,
@@ -1324,7 +1318,6 @@ def create_bidirectional_sliding_window_mask(
13241318
return attention_mask
13251319

13261320

1327-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
13281321
def create_chunked_causal_mask(
13291322
config: PreTrainedConfig,
13301323
inputs_embeds: torch.Tensor,
@@ -1453,7 +1446,6 @@ def create_chunked_causal_mask(
14531446
}
14541447

14551448

1456-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
14571449
def create_masks_for_generate(
14581450
config: PreTrainedConfig,
14591451
inputs_embeds: torch.Tensor,

src/transformers/models/bark/modeling_bark.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
is_torch_accelerator_available,
4040
logging,
4141
)
42-
from ...utils.deprecation import deprecate_kwarg
4342
from ..auto import AutoModel
4443
from .configuration_bark import (
4544
BarkCoarseConfig,
@@ -392,7 +391,6 @@ def get_input_embeddings(self):
392391
def set_input_embeddings(self, new_embeddings):
393392
self.input_embeds_layer = new_embeddings
394393

395-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
396394
@auto_docstring
397395
def forward(
398396
self,
@@ -990,7 +988,6 @@ def resize_token_embeddings(
990988

991989
return model_embeds
992990

993-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
994991
@auto_docstring
995992
def forward(
996993
self,

src/transformers/models/biogpt/modeling_biogpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,7 @@ def forward(
375375

376376
causal_mask = create_causal_mask(
377377
config=self.config,
378-
input_embeds=inputs_embeds,
378+
inputs_embeds=inputs_embeds,
379379
attention_mask=attention_mask,
380380
past_key_values=self_attn_cache,
381381
)

src/transformers/models/biogpt/modular_biogpt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -207,7 +207,7 @@ def forward(
207207

208208
causal_mask = create_causal_mask(
209209
config=self.config,
210-
input_embeds=inputs_embeds,
210+
inputs_embeds=inputs_embeds,
211211
attention_mask=attention_mask,
212212
past_key_values=self_attn_cache,
213213
)

src/transformers/models/blt/modeling_blt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@
3838
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
3939
from ...processing_utils import Unpack
4040
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41-
from ...utils.deprecation import deprecate_kwarg
4241
from ...utils.generic import maybe_autocast, merge_with_config_defaults
4342
from ...utils.output_capturing import OutputRecorder, capture_outputs
4443
from .configuration_blt import (
@@ -806,7 +805,6 @@ def __init__(self, config: BltGlobalTransformerConfig):
806805

807806
self.post_init()
808807

809-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
810808
def forward(
811809
self,
812810
inputs_embeds: torch.Tensor,

src/transformers/models/blt/modular_blt.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
3030
from ...processing_utils import Unpack
3131
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
32-
from ...utils.deprecation import deprecate_kwarg
3332
from ...utils.generic import maybe_autocast, merge_with_config_defaults
3433
from ...utils.output_capturing import OutputRecorder, capture_outputs
3534
from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401
@@ -740,7 +739,6 @@ def __init__(self, config: BltGlobalTransformerConfig):
740739

741740
self.post_init()
742741

743-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
744742
def forward(
745743
self,
746744
inputs_embeds: torch.Tensor,

src/transformers/models/distilbert/modeling_distilbert.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
auto_docstring,
4949
logging,
5050
)
51-
from ...utils.deprecation import deprecate_kwarg
5251
from ...utils.generic import can_return_tuple, merge_with_config_defaults
5352
from ...utils.output_capturing import capture_outputs
5453
from .configuration_distilbert import DistilBertConfig
@@ -92,7 +91,6 @@ def __init__(self, config: PreTrainedConfig):
9291
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
9392
)
9493

95-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
9694
def forward(
9795
self,
9896
input_ids: torch.Tensor,

src/transformers/models/gemma3/modeling_gemma3.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,6 @@
4343
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
4444
from ...processing_utils import Unpack
4545
from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check
46-
from ...utils.deprecation import deprecate_kwarg
4746
from ...utils.generic import maybe_autocast, merge_with_config_defaults
4847
from ...utils.output_capturing import capture_outputs
4948
from ..auto import AutoModel
@@ -1052,7 +1051,6 @@ def prepare_inputs_for_generation(
10521051
return model_inputs
10531052

10541053
@staticmethod
1055-
@deprecate_kwarg("input_embeds", version="5.6.0", new_name="inputs_embeds")
10561054
def create_masks_for_generate(
10571055
config: PreTrainedConfig,
10581056
inputs_embeds: torch.Tensor,

0 commit comments

Comments
 (0)