diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index e7937fed254f..a120e15ac355 100755 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -86,6 +86,13 @@ "vipllava": "llava", "mistral3": "llava", "pp_chart2table": "llava", + "voxtral": "qwen2_audio", + "voxtral_realtime": "qwen2_audio", + "audioflamingo3": "qwen2_audio", + "glmasr": "qwen2_audio", + "musicflamingo": "qwen2_audio", + "gemma3n_text": "qwen3_5_text", + "qwen3_5_moe_text": "qwen3_5_text", "llava_next_video": "llava_next", "llava_onevision": "llava_next", # class-based mappings @@ -401,6 +408,29 @@ def _build_checkpoint_conversion_mapping(): WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"), WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), ], + "qwen2_audio": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], + "granite_speech": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming(source_patterns=r"^encoder", target_patterns="model.encoder"), + WeightRenaming(source_patterns=r"^projector", target_patterns="model.projector"), + ], + "vibevoice_asr": [ + WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"), + WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), + WeightRenaming( + source_patterns=r"^acoustic_tokenizer_encoder", target_patterns="model.acoustic_tokenizer_encoder" + ), + WeightRenaming( + source_patterns=r"^semantic_tokenizer_encoder", target_patterns="model.semantic_tokenizer_encoder" + ), + WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"), + ], "llava_next": [ WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"), WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"), diff --git a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py index 096e263d856d..cd81ef805205 100644 --- a/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/configuration_audioflamingo3.py @@ -100,6 +100,7 @@ class AudioFlamingo3Config(PreTrainedConfig): audio_token_id: int = 151669 projector_hidden_act: str = "gelu" projector_bias: bool = True + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py index 6f18fcc437ad..f4b8f79bca3c 100644 --- a/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modeling_audioflamingo3.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -31,13 +32,14 @@ from ...masking_utils import create_bidirectional_mask from ...modeling_flash_attention_utils import FlashAttentionKwargs from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig @@ -256,6 +258,42 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel): _supports_sdpa = True +@dataclass +class AudioFlamingo3ModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The audio model from AudioFlamingo3 without any head or projection on top. @@ -403,23 +441,21 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel): _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -428,18 +464,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -452,11 +476,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -509,77 +529,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | AudioFlamingo3ModelOutputWithPast: r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -589,17 +549,118 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +@forward_base_model_attrs(version="5.7") +class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = None + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -616,4 +677,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py index dfb2c1f54d35..f823bf929321 100644 --- a/src/transformers/models/audioflamingo3/modular_audioflamingo3.py +++ b/src/transformers/models/audioflamingo3/modular_audioflamingo3.py @@ -13,22 +13,30 @@ # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn from ...activations import ACT2FN from ...cache_utils import Cache from ...masking_utils import create_bidirectional_mask -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPooling, ModelOutput from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..qwen2_audio.modeling_qwen2_audio import ( Qwen2AudioEncoder, Qwen2AudioPreTrainedModel, ) -from ..voxtral.modeling_voxtral import VoxtralForConditionalGeneration, VoxtralMultiModalProjector +from ..voxtral.modeling_voxtral import ( + VoxtralForConditionalGeneration, + VoxtralModel, + VoxtralModelOutputWithPast, + VoxtralMultiModalProjector, +) from ..whisper.modeling_whisper import WhisperAttention, WhisperEncoderLayer from .configuration_audioflamingo3 import AudioFlamingo3Config @@ -48,6 +56,37 @@ class AudioFlamingo3PreTrainedModel(Qwen2AudioPreTrainedModel): pass +@dataclass +class AudioFlamingo3ModelOutputWithPast(VoxtralModelOutputWithPast): + pass + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for AudioFlamingo3 causal language model (or autoregressive) outputs. + """ +) +class AudioFlamingo3CausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The audio model from AudioFlamingo3 without any head or projection on top. @@ -138,10 +177,11 @@ def __init__(self, config: AudioFlamingo3Config): @auto_docstring( custom_intro=""" - The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. """ ) -class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): +class AudioFlamingo3Model(VoxtralModel): _supports_attention_backend = True _tp_plan = None _pp_plan = None @@ -162,11 +202,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -195,77 +231,17 @@ def forward( position_ids: torch.LongTensor | None = None, past_key_values: Cache | None = None, inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ): r""" input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): - Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: - - - 1 for tokens that are **not masked**, - - 0 for tokens that are **masked**. - labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor - - >>> model_id = "nvidia/audio-flamingo-3-hf" - >>> processor = AutoProcessor.from_pretrained(model_id) - >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") - - >>> conversations = [ - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> {"type": "text", "text": "Transcribe the input speech."}, - >>> { - >>> "type": "audio", - >>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav", - >>> }, - >>> ], - >>> } - >>> ], - >>> [ - >>> { - >>> "role": "user", - >>> "content": [ - >>> { - >>> "type": "text", - >>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?", - >>> }, - >>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"}, - >>> ], - >>> } - >>> ], - >>> ] - - >>> inputs = processor.apply_chat_template( - >>> conversations, - >>> tokenize=True, - >>> add_generation_prompt=True, - >>> return_dict=True, - >>> ).to(model.device) - - >>> outputs = model.generate(**inputs, max_new_tokens=500) - - >>> decoded_outputs = processor.batch_decode( - >>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True - >>> ) - >>> print(decoded_outputs) - ["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."] - ```""" - + Mask to avoid performing attention on padding feature indices. + """ if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_embeds = None if input_features is not None and input_ids is not None: audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output @@ -275,17 +251,104 @@ def forward( ) inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - outputs: CausalLMOutputWithPast = self.language_model( + outputs = self.language_model( inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + return AudioFlamingo3ModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model. + """ +) +@forward_base_model_attrs(version="5.7") +class AudioFlamingo3ForConditionalGeneration(VoxtralForConditionalGeneration): + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.model = AudioFlamingo3Model(config) + self.post_init() + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | AudioFlamingo3CausalLMOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. + + Example: + + ```python + >>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor + + >>> model_id = "nvidia/audio-flamingo-3-hf" + >>> processor = AutoProcessor.from_pretrained(model_id) + >>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto") + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return AudioFlamingo3CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -302,4 +365,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"] +__all__ = [ + "AudioFlamingo3ForConditionalGeneration", + "AudioFlamingo3PreTrainedModel", + "AudioFlamingo3Encoder", + "AudioFlamingo3Model", +] diff --git a/src/transformers/models/auto/modeling_auto.py b/src/transformers/models/auto/modeling_auto.py index 2202cc773db0..7bced9ad55eb 100644 --- a/src/transformers/models/auto/modeling_auto.py +++ b/src/transformers/models/auto/modeling_auto.py @@ -52,7 +52,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("aria", "AriaModel"), ("aria_text", "AriaTextModel"), ("audio-spectrogram-transformer", "ASTModel"), - ("audioflamingo3", "AudioFlamingo3ForConditionalGeneration"), + ("audioflamingo3", "AudioFlamingo3Model"), ("audioflamingo3_encoder", "AudioFlamingo3Encoder"), ("autoformer", "AutoformerModel"), ("aya_vision", "AyaVisionModel"), @@ -200,7 +200,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("glm_ocr", "GlmOcrModel"), ("glm_ocr_text", "GlmOcrTextModel"), ("glm_ocr_vision", "GlmOcrVisionModel"), - ("glmasr", "GlmAsrForConditionalGeneration"), + ("glmasr", "GlmAsrModel"), ("glmasr_encoder", "GlmAsrEncoder"), ("glpn", "GLPNModel"), ("got_ocr2", "GotOcr2Model"), @@ -214,6 +214,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("gptj", "GPTJModel"), ("granite", "GraniteModel"), ("granite4_vision", "Granite4VisionModel"), + ("granite_speech", "GraniteSpeechModel"), ("granite_speech", "GraniteSpeechForConditionalGeneration"), ("granitemoe", "GraniteMoeModel"), ("granitemoehybrid", "GraniteMoeHybridModel"), @@ -318,7 +319,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("mpt", "MptModel"), ("mra", "MraModel"), ("mt5", "MT5Model"), - ("musicflamingo", "MusicFlamingoForConditionalGeneration"), + ("musicflamingo", "MusicFlamingoModel"), ("musicgen", "MusicgenModel"), ("musicgen_melody", "MusicgenMelodyModel"), ("mvp", "MvpModel"), @@ -377,6 +378,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("qwen2", "Qwen2Model"), ("qwen2_5_vl", "Qwen2_5_VLModel"), ("qwen2_5_vl_text", "Qwen2_5_VLTextModel"), + ("qwen2_audio", "Qwen2AudioModel"), ("qwen2_audio_encoder", "Qwen2AudioEncoder"), ("qwen2_moe", "Qwen2MoeModel"), ("qwen2_vl", "Qwen2VLModel"), @@ -472,7 +474,7 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("vibevoice_acoustic_tokenizer", "VibeVoiceAcousticTokenizerModel"), ("vibevoice_acoustic_tokenizer_decoder", "VibeVoiceAcousticTokenizerDecoderModel"), ("vibevoice_acoustic_tokenizer_encoder", "VibeVoiceAcousticTokenizerEncoderModel"), - ("vibevoice_asr", "VibeVoiceAsrForConditionalGeneration"), + ("vibevoice_asr", "VibeVoiceAsrModel"), ("video_llama_3", "VideoLlama3Model"), ("video_llama_3_vision", "VideoLlama3VisionModel"), ("video_llava", "VideoLlavaModel"), @@ -488,9 +490,9 @@ class _BaseModelWithGenerate(PreTrainedModel, GenerationMixin): ("vits", "VitsModel"), ("vivit", "VivitModel"), ("vjepa2", "VJEPA2Model"), - ("voxtral", "VoxtralForConditionalGeneration"), + ("voxtral", "VoxtralModel"), ("voxtral_encoder", "VoxtralEncoder"), - ("voxtral_realtime", "VoxtralRealtimeForConditionalGeneration"), + ("voxtral_realtime", "VoxtralRealtimeModel"), ("voxtral_realtime_encoder", "VoxtralRealtimeEncoder"), ("voxtral_realtime_text", "VoxtralRealtimeTextModel"), ("wav2vec2", "Wav2Vec2Model"), diff --git a/src/transformers/models/glmasr/configuration_glmasr.py b/src/transformers/models/glmasr/configuration_glmasr.py index c3d320bb1db4..c89379ead3e7 100644 --- a/src/transformers/models/glmasr/configuration_glmasr.py +++ b/src/transformers/models/glmasr/configuration_glmasr.py @@ -101,6 +101,7 @@ class GlmAsrConfig(PreTrainedConfig): text_config: dict | PreTrainedConfig | None = None audio_token_id: int = 59260 projector_hidden_act: str = "gelu" + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/glmasr/modeling_glmasr.py b/src/transformers/models/glmasr/modeling_glmasr.py index f2c68e56df71..fa4a50e97555 100644 --- a/src/transformers/models/glmasr/modeling_glmasr.py +++ b/src/transformers/models/glmasr/modeling_glmasr.py @@ -19,6 +19,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from typing import Optional from ...activations import ACT2FN @@ -26,14 +27,20 @@ from ...generation import GenerationMixin from ...integrations import use_kernelized_func from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + CausalLMOutputWithPast, + ModelOutput, +) from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_glmasr import GlmAsrConfig, GlmAsrEncoderConfig @@ -349,25 +356,32 @@ def forward(self, audio_features): return hidden_states +@dataclass +class GlmAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + @auto_docstring( custom_intro=""" The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class GlmAsrModel(GlmAsrPreTrainedModel): _supports_attention_backend = True _tp_plan = None _pp_plan = None + _keep_in_fp32_modules_strict = None def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = GlmAsrMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -376,18 +390,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="Compute audio embeddings from log-mel input features using the audio encoder and multi-modal projector." @@ -400,11 +402,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padded feature indices. """ @@ -450,6 +448,114 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | GlmAsrModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return GlmAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for GlmAsr causal language model (or autoregressive) outputs. + """ +) +class GlmAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss (for next-token prediction). + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores of the language modeling head. + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Hidden states of the audio encoder after projection. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class GlmAsrForConditionalGeneration(GlmAsrPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = None + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + _tp_plan = None + _pp_plan = None + + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, input_features, input_features_mask, **kwargs): + return self.model.get_audio_features(input_features, input_features_mask, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -494,30 +600,36 @@ def forward( >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) >>> print(decoded_outputs) ```""" - - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: CausalLMOutputWithPast = self.language_model( - inputs_embeds=inputs_embeds, + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + input_features_mask=input_features_mask, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, - labels=labels, + inputs_embeds=inputs_embeds, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return GlmAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, + ) def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs): input_features = kwargs.pop("input_features", None) @@ -534,4 +646,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, return model_inputs -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrPreTrainedModel"] +__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrModel", "GlmAsrPreTrainedModel"] diff --git a/src/transformers/models/glmasr/modular_glmasr.py b/src/transformers/models/glmasr/modular_glmasr.py index 2c6085eb3a18..d836d89b5625 100644 --- a/src/transformers/models/glmasr/modular_glmasr.py +++ b/src/transformers/models/glmasr/modular_glmasr.py @@ -25,10 +25,12 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torch_available, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs from ..audioflamingo3.modeling_audioflamingo3 import ( AudioFlamingo3ForConditionalGeneration, + AudioFlamingo3Model, AudioFlamingo3MultiModalProjector, AudioFlamingo3PreTrainedModel, ) @@ -356,7 +358,7 @@ def __init__(self, config: GlmAsrConfig): The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. """ ) -class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +class GlmAsrModel(AudioFlamingo3Model): _supports_attention_backend = True @can_return_tuple @@ -387,6 +389,19 @@ def get_audio_features( return audio_outputs + +@auto_docstring( + custom_intro=""" + The GlmAsr model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class GlmAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): + def __init__(self, config): + super().__init__(config) + self.model = GlmAsrModel(config) + self.post_init() + def forward( self, input_ids: torch.LongTensor | None = None, @@ -442,4 +457,10 @@ def forward( ) -__all__ = ["GlmAsrEncoder", "GlmAsrForConditionalGeneration", "GlmAsrProcessor", "GlmAsrPreTrainedModel"] +__all__ = [ + "GlmAsrEncoder", + "GlmAsrForConditionalGeneration", + "GlmAsrModel", + "GlmAsrProcessor", + "GlmAsrPreTrainedModel", +] diff --git a/src/transformers/models/musicflamingo/configuration_musicflamingo.py b/src/transformers/models/musicflamingo/configuration_musicflamingo.py index 7eff8861558a..a733f73004c5 100644 --- a/src/transformers/models/musicflamingo/configuration_musicflamingo.py +++ b/src/transformers/models/musicflamingo/configuration_musicflamingo.py @@ -66,6 +66,7 @@ class MusicFlamingoConfig(PreTrainedConfig): audio_token_id: int = 151669 projector_hidden_act: str = "gelu" projector_bias: bool = True + tie_word_embeddings: bool = False audio_bos_token_id: int = 151670 audio_eos_token_id: int = 151671 diff --git a/src/transformers/models/musicflamingo/modeling_musicflamingo.py b/src/transformers/models/musicflamingo/modeling_musicflamingo.py index a9e05470662d..5ceed90170ed 100644 --- a/src/transformers/models/musicflamingo/modeling_musicflamingo.py +++ b/src/transformers/models/musicflamingo/modeling_musicflamingo.py @@ -20,6 +20,7 @@ # limitations under the License. from collections.abc import Callable +from dataclasses import dataclass from math import pi from typing import Optional @@ -29,12 +30,12 @@ from ...activations import ACT2FN from ...cache_utils import Cache from ...generation import GenerationMixin -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, CausalLMOutputWithPast from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available, torch_compilable_check -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_musicflamingo import MusicFlamingoConfig @@ -150,6 +151,16 @@ def _init_weights(self, module): init.copy_(module.position_angles, buffer_value) +@dataclass +class MusicFlamingoModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + class MusicFlamingoMultiModalProjector(nn.Module): """ Audio adaptor (small MLP) that projects MusicFlamingoEncoder features @@ -173,6 +184,134 @@ def forward(self, audio_features): return hidden_states +@auto_docstring( + custom_intro=""" + The MusicFlamingo model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model), + without a language modeling head. + """ +) +class MusicFlamingoModel(MusicFlamingoPreTrainedModel): + _supports_attention_backend = True + _tp_plan = None + _pp_plan = None + _keep_in_fp32_modules_strict = None + + def __init__(self, config): + super().__init__(config) + self.audio_tower = AutoModel.from_config(config.audio_config) + self.language_model = AutoModel.from_config(config.text_config) + self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + + @can_return_tuple + @auto_docstring( + custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." + ) + def get_audio_features( + self, + input_features: torch.FloatTensor, + input_features_mask: torch.Tensor, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | BaseModelOutputWithPooling: + r""" + input_features (`torch.FloatTensor`): + Float values of mel features extracted from the raw speech waveform. + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padded feature indices. + """ + + audio_output = self.audio_tower( + input_features, input_features_mask=input_features_mask, return_dict=True, **kwargs + ) + audio_embeds = self.multi_modal_projector(audio_output.last_hidden_state) + + # Mask according to the audio tower output lengths, accounting for both conv downsampling and final avg pooling + input_lengths = input_features_mask.sum(-1).to(torch.long) + _, post_lengths = self.audio_tower._get_feat_extract_output_lengths(input_lengths) + valid_mask = torch.arange(audio_embeds.shape[1], device=post_lengths.device)[None, :] < post_lengths[:, None] + audio_output.pooler_output = audio_embeds[valid_mask.to(audio_embeds.device)] + + return audio_output + + def get_placeholder_mask( + self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor + ): + """ + Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is + equal to the length of multimodal features. If the lengths are different, an error is raised. + """ + if input_ids is None: + special_audio_mask = inputs_embeds == self.get_input_embeddings()( + torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + ) + special_audio_mask = special_audio_mask.all(-1) + else: + special_audio_mask = input_ids == self.config.audio_token_id + + n_audio_tokens = special_audio_mask.sum() + n_audio_features = audio_features.shape[0] + special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) + torch_compilable_check( + inputs_embeds[special_audio_mask].numel() == audio_features.numel(), + f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + ) + return special_audio_mask + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + input_features_mask: torch.Tensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | MusicFlamingoModelOutputWithPast: + r""" + input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. + """ + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + use_cache=use_cache, + **kwargs, + ) + + return MusicFlamingoModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + def rotate_half(x): x = x.reshape(*x.shape[:-1], -1, 2) x1, x2 = x.unbind(dim=-1) @@ -200,38 +339,28 @@ def apply_rotary_time_emb(hidden_states, cos, sin): ) class MusicFlamingoForConditionalGeneration(MusicFlamingoPreTrainedModel, GenerationMixin): _keep_in_fp32_modules_strict = None - _supports_attention_backend = True + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} _tp_plan = None _pp_plan = None def __init__(self, config: MusicFlamingoConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.multi_modal_projector = MusicFlamingoMultiModalProjector(config) + self.model = MusicFlamingoModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) self.pos_emb = MusicFlamingoRotaryEmbedding(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): - return self.language_model.get_input_embeddings() + return self.model.get_input_embeddings() def set_input_embeddings(self, value): - self.language_model.set_input_embeddings(value) + self.model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() + def get_output_embeddings(self) -> nn.Module: + return self.lm_head def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() + self.lm_head = new_embeddings @can_return_tuple @auto_docstring( @@ -269,30 +398,6 @@ def get_audio_features( return audio_output - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - @can_return_tuple @auto_docstring def forward( @@ -442,4 +547,4 @@ def _build_audio_timestamps( return window_indices.unsqueeze(1) * max_post_length * audio_embed_frame_step + frame_offsets -__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoPreTrainedModel"] +__all__ = ["MusicFlamingoForConditionalGeneration", "MusicFlamingoModel", "MusicFlamingoPreTrainedModel"] diff --git a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py index 6aec9eace900..749f24123bb4 100644 --- a/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/configuration_qwen2_audio.py @@ -98,6 +98,7 @@ class Qwen2AudioConfig(PreTrainedConfig): audio_config: dict | PreTrainedConfig | None = None text_config: dict | PreTrainedConfig | None = None audio_token_index: int = 151646 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py index cbdf67cfd29d..f8079c133d28 100644 --- a/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py +++ b/src/transformers/models/qwen2_audio/modeling_qwen2_audio.py @@ -25,19 +25,44 @@ from ...generation import GenerationMixin from ...masking_utils import create_bidirectional_mask from ...modeling_layers import GradientCheckpointingLayer -from ...modeling_outputs import BaseModelOutput, ModelOutput +from ...modeling_outputs import BaseModelOutput, BaseModelOutputWithPast, ModelOutput from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, is_torchdynamo_compiling, logging, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import can_return_tuple, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_qwen2_audio import Qwen2AudioConfig, Qwen2AudioEncoderConfig logger = logging.get_logger(__name__) +@auto_docstring( + custom_intro=""" + Base class for Qwen2Audio outputs, with hidden states and attentions. + """ +) +class Qwen2AudioModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache). + + Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see + `past_key_values` input) to speed up sequential decoding. + attention_mask (`torch.FloatTensor`, *optional*): + Attention mask, potentially updated by the audio merging logic so that audio tokens are unmasked. + labels (`torch.LongTensor`, *optional*): + Labels, potentially re-aligned by the legacy audio merging logic. Returned so the language-modeling + head can compute the loss against the expanded sequence. + """ + + attention_mask: torch.FloatTensor | None = None + labels: torch.LongTensor | None = None + + +@dataclass @auto_docstring( custom_intro=""" Base class for Qwen2Audio causal language model (or autoregressive) outputs. @@ -394,17 +419,17 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The QWEN2AUDIO model which consists of a audio backbone and a language model. + The Qwen2Audio model which consists of an audio backbone and a language model, without a language modeling head. """ ) -class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): +class Qwen2AudioModel(Qwen2AudioPreTrainedModel): def __init__(self, config: Qwen2AudioConfig): super().__init__(config) self.audio_tower = AutoModel.from_config(config.audio_config) # Usually a `Qwen2AudioEncoder` instance self.multi_modal_projector = Qwen2AudioMultiModalProjector(config) self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.pad_token_id = ( self.config.text_config.pad_token_id if self.config.text_config.pad_token_id is not None else -1 @@ -428,18 +453,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - def _merge_input_ids_with_audio_features( self, audio_features, num_audio_tokens, inputs_embeds, input_ids, attention_mask, labels ): @@ -651,7 +664,7 @@ def forward( labels: torch.LongTensor | None = None, use_cache: bool | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | Qwen2AudioCausalLMOutputWithPast: + ) -> tuple | Qwen2AudioModelOutputWithPast: r""" feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: @@ -659,32 +672,9 @@ def forward( - 1 for tokens that are **not masked**, - 0 for tokens that are **masked**. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): - Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., - config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored - (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. - - Example: - - ```python - >>> from io import BytesIO - >>> from urllib.request import urlopen - >>> import librosa - >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration - - >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") - >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") - - >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" - >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" - >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) - - >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(**inputs, max_length=30) - >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Generate the caption in English: Glass is breaking." - ```""" + Labels kept in the signature for the legacy merge path that may re-align them with audio tokens. + The loss is not computed here; `Qwen2AudioForConditionalGeneration` is responsible for that. + """ target_device = self.audio_tower.device @@ -767,7 +757,116 @@ def forward( **kwargs, ) - logits = outputs.logits + return Qwen2AudioModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + attention_mask=attention_mask, + labels=labels, + ) + + +@auto_docstring( + custom_intro=""" + The QWEN2AUDIO model which consists of an audio backbone and a language model. + """ +) +@forward_base_model_attrs(version="5.7") +class Qwen2AudioForConditionalGeneration(Qwen2AudioPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: Qwen2AudioConfig): + super().__init__(config) + self.model = Qwen2AudioModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + @property + def padding_side(self): + return self.model.padding_side + + @padding_side.setter + def padding_side(self, padding_side: str): + self.model.padding_side = padding_side + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + feature_attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | Qwen2AudioCausalLMOutputWithPast: + r""" + feature_attention_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`): + Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`: + + - 1 for tokens that are **not masked**, + - 0 for tokens that are **masked**. + labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): + Labels for computing the masked language modeling loss. Indices should either be in `[0, ..., + config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored + (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`. + + Example: + + ```python + >>> from io import BytesIO + >>> from urllib.request import urlopen + >>> import librosa + >>> from transformers import AutoProcessor, Qwen2AudioForConditionalGeneration + + >>> model = Qwen2AudioForConditionalGeneration.from_pretrained("Qwen/Qwen2-Audio-7B") + >>> processor = AutoProcessor.from_pretrained("Qwen/Qwen2-Audio-7B") + + >>> prompt = "<|audio_bos|><|AUDIO|><|audio_eos|>Generate the caption in English:" + >>> url = "https://qianwen-res.oss-cn-beijing.aliyuncs.com/Qwen2-Audio/audio/glass-breaking-151256.mp3" + >>> audio, _ = librosa.load(BytesIO(urlopen(url).read()), sr=self.processor.feature_extractor.sampling_rate) + + >>> inputs = processor(text=prompt, audio=audio, return_tensors="pt") + + >>> # Generate + >>> generate_ids = model.generate(**inputs, max_length=30) + >>> processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] + "Generate the caption in English: Glass is breaking." + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + feature_attention_mask=feature_attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + labels=labels, + use_cache=use_cache, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + logits = self.lm_head(hidden_states) + attention_mask = outputs.attention_mask + labels = outputs.labels if outputs.labels is not None else labels loss = None if labels is not None: @@ -809,4 +908,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder"] +__all__ = ["Qwen2AudioForConditionalGeneration", "Qwen2AudioPreTrainedModel", "Qwen2AudioEncoder", "Qwen2AudioModel"] diff --git a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py index a673a5845871..4d56a948eda1 100644 --- a/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/configuration_vibevoice_asr.py @@ -75,6 +75,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 + tie_word_embeddings: bool = False def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): diff --git a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py index b66dd15b2cb1..0a412957819b 100644 --- a/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modeling_vibevoice_asr.py @@ -17,6 +17,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -25,17 +27,12 @@ from ...cache_utils import Cache from ...generation import GenerationMixin from ...integrations import use_kernel_forward_from_hub -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput from ...modeling_utils import PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - torch_compilable_check, -) -from ..auto import AutoModel, AutoModelForCausalLM +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling +from ...utils.deprecation import forward_base_model_attrs +from ..auto import AutoModel from .configuration_vibevoice_asr import VibeVoiceAsrConfig @@ -255,27 +252,71 @@ def _init_weights(self, module): init.constant_(module.ffn_gamma, self.config.layer_scale_init_value) +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. """ ) -class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. + """ +) +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): _supports_attention_backend = True - _tp_plan = None - _pp_plan = None def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.language_model = AutoModelForCausalLM.from_config(config.text_config) - self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - - # Initialize weights and apply final processing + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) self.post_init() + # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze + # their parameters so grad-checkpointing and training sanity checks don't flag them. + for p in self.acoustic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + for p in self.semantic_tokenizer_encoder.parameters(): + p.requires_grad_(False) def get_input_embeddings(self): return self.language_model.get_input_embeddings() @@ -283,18 +324,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") def get_audio_features( @@ -303,17 +332,15 @@ def get_audio_features( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> tuple | BaseModelOutputWithPooling: + ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -358,7 +385,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -369,29 +395,86 @@ def get_audio_features( return BaseModelOutputWithPooling(last_hidden_state=acoustic_latents, pooler_output=combined_features) - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrModelOutputWithPast: + r""" + padding_mask (): + + acoustic_tokenizer_chunk_size (): + """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_values is not None and input_ids is not None: + audio_embeds = self.get_audio_features( + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + ).pooler_output + + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, ) - return special_audio_mask + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) @can_return_tuple @auto_docstring @@ -404,14 +487,15 @@ def forward( input_values: torch.FloatTensor | None = None, padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. + Size of audio chunks processed by the acoustic and semantic tokenizers. Example: @@ -421,33 +505,35 @@ def forward( >>> model_id = "microsoft/VibeVoice-ASR-HF" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_values is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_values=input_values, - padding_mask=padding_mask, - acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, - ).pooler_output + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -468,4 +554,4 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg return model_inputs -__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrPreTrainedModel"] +__all__ = ["VibeVoiceAsrForConditionalGeneration", "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel"] diff --git a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py index 5fb92a1d4f1b..a0dbcb158268 100644 --- a/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py +++ b/src/transformers/models/vibevoice_asr/modular_vibevoice_asr.py @@ -11,16 +11,23 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from dataclasses import dataclass + import torch from huggingface_hub.dataclasses import strict from torch import nn from ...cache_utils import Cache from ...configuration_utils import PreTrainedConfig -from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast +from ...generation import GenerationMixin +from ...modeling_outputs import ( + BaseModelOutputWithPast, + BaseModelOutputWithPooling, + ModelOutput, +) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging -from ..audioflamingo3.modeling_audioflamingo3 import AudioFlamingo3ForConditionalGeneration +from ...utils.deprecation import forward_base_model_attrs from ..auto import CONFIG_MAPPING, AutoConfig, AutoModel from ..qwen2.modeling_qwen2 import Qwen2RMSNorm from ..vibevoice_acoustic_tokenizer.modeling_vibevoice_acoustic_tokenizer import ( @@ -82,6 +89,7 @@ class VibeVoiceAsrConfig(PreTrainedConfig): audio_bos_token_id: int = 151646 audio_eos_token_id: int = 151647 acoustic_tokenizer_chunk_size: int = 1440000 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.acoustic_tokenizer_encoder_config, dict): @@ -161,19 +169,77 @@ class VibeVoiceAsrPreTrainedModel(VibeVoiceAcousticTokenizerPreTrainedModel): _supports_sdpa = True +@dataclass @auto_docstring( custom_intro=""" - The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + Base class for VibeVoice ASR outputs, with hidden states and attentions. + """ +) +class VibeVoiceAsrModelOutputWithPast(BaseModelOutputWithPast): + r""" + past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`): + It is a [`~cache_utils.Cache`] instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + + +@dataclass +@auto_docstring( + custom_intro=""" + Base class for VibeVoice ASR causal language model outputs. + """ +) +class VibeVoiceAsrCausalLMOutputWithPast(ModelOutput): + r""" + loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided): + Language modeling loss. + logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`): + Prediction scores. + past_key_values (`Cache`, *optional*): + Cache instance. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + loss: torch.FloatTensor | None = None + logits: torch.FloatTensor | None = None + past_key_values: Cache | None = None + hidden_states: tuple[torch.FloatTensor] | None = None + attentions: tuple[torch.FloatTensor] | None = None + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model (acoustic tokenizer + semantic tokenizer + multi-modal projector + language model), + without a language modeling head. """ ) -class VibeVoiceAsrForConditionalGeneration(AudioFlamingo3ForConditionalGeneration): +class VibeVoiceAsrModel(VibeVoiceAsrPreTrainedModel): _supports_attention_backend = True def __init__(self, config: VibeVoiceAsrConfig): super().__init__(config) self.acoustic_tokenizer_encoder = AutoModel.from_config(config.acoustic_tokenizer_encoder_config) self.semantic_tokenizer_encoder = AutoModel.from_config(config.semantic_tokenizer_encoder_config) - del self.audio_tower + self.multi_modal_projector = VibeVoiceAsrMultiModalProjector(config) + self.language_model = AutoModel.from_config(config.text_config) + self.post_init() + # Acoustic/semantic tokenizers are run under no_grad in `get_audio_features`; freeze + # their parameters so grad-checkpointing and training sanity checks don't flag them. + for p in self.acoustic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + for p in self.semantic_tokenizer_encoder.parameters(): + p.requires_grad_(False) + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) @can_return_tuple @auto_docstring(custom_intro="Encode audio into embeddings that can be used by the language model.") @@ -186,14 +252,12 @@ def get_audio_features( ): r""" input_values (`torch.FloatTensor` of shape `(batch_size, num_samples)`): - Input audio tensor. Audio should be sampled at 24kHz. + Input audio tensor sampled at 24kHz. padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks to process at once through the tokenizers. Defaults to `config.acoustic_tokenizer_chunk_size`, - but can be modified to fit the available memory. + Size of audio chunks to process at once through the tokenizers. """ - if acoustic_tokenizer_chunk_size is None: acoustic_tokenizer_chunk_size = self.config.acoustic_tokenizer_chunk_size else: @@ -238,7 +302,6 @@ def get_audio_features( combined_features = self.multi_modal_projector(acoustic_latents, semantic_latents) if padding_mask is not None: - # Adjust padding mask according to tokenizer compression num_audio_tokens = torch.ceil( padding_mask.sum(dim=-1) / self.config.acoustic_tokenizer_encoder_config.hop_length ).to(torch.int64) @@ -261,13 +324,89 @@ def forward( padding_mask: torch.BoolTensor | None = None, acoustic_tokenizer_chunk_size: int | None = None, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | VibeVoiceAsrModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_values is not None and input_ids is not None: + audio_embeds = self.get_audio_features( + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + ).pooler_output + + audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) + inputs_embeds = inputs_embeds.masked_scatter( + audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + ) + + outputs = self.language_model( + inputs_embeds=inputs_embeds, + attention_mask=attention_mask, + past_key_values=past_key_values, + **kwargs, + ) + + return VibeVoiceAsrModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The VibeVoice ASR model with pre-trained acoustic tokenizers and a language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VibeVoiceAsrForConditionalGeneration(VibeVoiceAsrPreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config: VibeVoiceAsrConfig): + super().__init__(config) + self.model = VibeVoiceAsrModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + attention_mask: torch.Tensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + input_values: torch.FloatTensor | None = None, + padding_mask: torch.BoolTensor | None = None, + acoustic_tokenizer_chunk_size: int | None = None, + labels: torch.LongTensor | None = None, + logits_to_keep: int | torch.Tensor = 0, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VibeVoiceAsrCausalLMOutputWithPast: r""" padding_mask (`torch.BoolTensor` of shape `(batch_size, sequence_length)`, *optional*): Mask to avoid performing operations on padding feature indices. acoustic_tokenizer_chunk_size (`int`, *optional*): - Size of audio chunks processed by the acoustic and semantic tokenizers. Defaults to - `config.acoustic_tokenizer_chunk_size`, but can be modified to fit the available memory. + Size of audio chunks processed by the acoustic and semantic tokenizers. Example: @@ -277,33 +416,35 @@ def forward( >>> model_id = "microsoft/VibeVoice-ASR-HF" >>> processor = AutoProcessor.from_pretrained(model_id) >>> model = VibeVoiceAsrForConditionalGeneration.from_pretrained(model_id, dtype="auto", device_map="auto") - - >>> inputs = processor.apply_transcription_request("https://huggingface.co/datasets/hf-internal-testing/dummy-audio-samples/resolve/main/bcn_weather.mp3") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - >>> outputs = model.generate(**inputs) - - >>> decoded_outputs = processor.batch_decode(outputs[:, inputs.input_ids.shape[1] :], skip_special_tokens=True) - >>> print(decoded_outputs) ```""" + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + input_values=input_values, + padding_mask=padding_mask, + acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, + **kwargs, + ) - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_values is not None and input_ids is not None: - audio_embeds = self.get_audio_features( - input_values=input_values, - padding_mask=padding_mask, - acoustic_tokenizer_chunk_size=acoustic_tokenizer_chunk_size, - ).pooler_output + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) - # Replace text-audio token placeholders with audio embeddings - audio_token_mask = (input_ids == self.config.audio_token_id).unsqueeze(-1) - inputs_embeds = inputs_embeds.masked_scatter( - audio_token_mask.to(inputs_embeds.device), audio_embeds.to(inputs_embeds.device) + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs ) - return self.language_model( - inputs_embeds=inputs_embeds, attention_mask=attention_mask, past_key_values=past_key_values, **kwargs + return VibeVoiceAsrCausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=outputs.audio_hidden_states, ) def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwargs): @@ -327,5 +468,6 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration=False, **kwarg __all__ = [ "VibeVoiceAsrConfig", "VibeVoiceAsrForConditionalGeneration", + "VibeVoiceAsrModel", "VibeVoiceAsrPreTrainedModel", ] diff --git a/src/transformers/models/voxtral/configuration_voxtral.py b/src/transformers/models/voxtral/configuration_voxtral.py index 2ecbedfc1a9e..b476d80dd976 100644 --- a/src/transformers/models/voxtral/configuration_voxtral.py +++ b/src/transformers/models/voxtral/configuration_voxtral.py @@ -110,6 +110,7 @@ class VoxtralConfig(PreTrainedConfig): text_config: dict | PreTrainedConfig | None = None audio_token_id: int | None = None projector_hidden_act: str = "gelu" + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/voxtral/modeling_voxtral.py b/src/transformers/models/voxtral/modeling_voxtral.py index 54466321b79e..15260a14936a 100644 --- a/src/transformers/models/voxtral/modeling_voxtral.py +++ b/src/transformers/models/voxtral/modeling_voxtral.py @@ -21,6 +21,7 @@ import math from collections.abc import Callable +from dataclasses import dataclass import torch from torch import nn @@ -33,9 +34,10 @@ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from .configuration_voxtral import VoxtralConfig, VoxtralEncoderConfig @@ -359,22 +361,33 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + audio_hidden_states: torch.FloatTensor | None = None + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -383,18 +396,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -442,6 +443,81 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -456,7 +532,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | CausalLMOutputWithPast: r""" Example: @@ -490,29 +566,34 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -529,4 +610,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral/modular_voxtral.py b/src/transformers/models/voxtral/modular_voxtral.py index 02e8e2806a0f..31c7193d71f8 100644 --- a/src/transformers/models/voxtral/modular_voxtral.py +++ b/src/transformers/models/voxtral/modular_voxtral.py @@ -13,6 +13,8 @@ # limitations under the License. +from dataclasses import dataclass + import torch from torch import nn @@ -26,9 +28,10 @@ ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, torch_compilable_check +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel, AutoModelForCausalLM +from ..auto import AutoModel from ..qwen2_audio.modeling_qwen2_audio import ( Qwen2AudioAttention, Qwen2AudioEncoder, @@ -128,22 +131,33 @@ def forward(self, audio_features): return hidden_states +@dataclass @auto_docstring( custom_intro=""" - The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + Base class for Voxtral outputs, with hidden states and attentions. """ ) -class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = ["embed_positions"] +class VoxtralModelOutputWithPast(BaseModelOutputWithPast): + r""" + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states. + """ + + audio_hidden_states: torch.FloatTensor | None = None + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model, + without a language modeling head. + """ +) +class VoxtralModel(VoxtralPreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = AutoModelForCausalLM.from_config(config.text_config) + self.language_model = AutoModel.from_config(config.text_config) self.multi_modal_projector = VoxtralMultiModalProjector(config) - - # Initialize weights and apply final processing self.post_init() def get_input_embeddings(self): @@ -152,18 +166,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -211,6 +213,81 @@ def get_placeholder_mask( ) return special_audio_mask + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + use_cache: bool | None = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralModelOutputWithPast: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings()(input_ids) + + audio_embeds = None + if input_features is not None and input_ids is not None: + audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output + + # replace text-audio token placeholders with audio embeddings + special_audio_mask = self.get_placeholder_mask( + input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds + ) + inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) + + outputs: BaseModelOutputWithPast = self.language_model( + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + **kwargs, + ) + + return VoxtralModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + audio_hidden_states=audio_embeds, + ) + + +@auto_docstring( + custom_intro=""" + The Voxtral model, which consists of Whisper encoder, a multi-modal projector and a Llama language model. + """ +) +@forward_base_model_attrs(version="5.7") +class VoxtralForConditionalGeneration(VoxtralPreTrainedModel, GenerationMixin): + _keep_in_fp32_modules_strict = ["embed_positions"] + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + @can_return_tuple @auto_docstring def forward( @@ -225,7 +302,7 @@ def forward( use_cache: bool | None = None, logits_to_keep: int | torch.Tensor = 0, **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: + ) -> tuple | CausalLMOutputWithPast: r""" Example: @@ -259,29 +336,34 @@ def forward( >>> processor.batch_decode(outputs[:, inputs.input_ids.shape[1]:], skip_special_tokens=True) ["This audio is a humorous conversation between two friends, likely in English, where one of them is trying to figure out what the other's tattoo says."] ```""" - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings()(input_ids) - - if input_features is not None and input_ids is not None: - audio_embeds = self.get_audio_features(input_features, return_dict=True).pooler_output - - # replace text-audio token placeholders with audio embeddings - special_audio_mask = self.get_placeholder_mask( - input_ids, inputs_embeds=inputs_embeds, audio_features=audio_embeds - ) - inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device)) - - outputs: BaseModelOutputWithPast = self.language_model( + outputs = self.model( + input_ids=input_ids, + input_features=input_features, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, **kwargs, ) - return outputs + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) def prepare_inputs_for_generation(self, *args, **kwargs): # Overwritten -- we should not pass input_features when we are in cached decoding stage @@ -298,4 +380,4 @@ def prepare_inputs_for_generation(self, *args, **kwargs): return model_inputs -__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralForConditionalGeneration"] +__all__ = ["VoxtralPreTrainedModel", "VoxtralEncoder", "VoxtralModel", "VoxtralForConditionalGeneration"] diff --git a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py index b0227b418771..a1593dfbcdca 100644 --- a/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/configuration_voxtral_realtime.py @@ -170,6 +170,7 @@ class VoxtralRealtimeConfig(PreTrainedConfig): audio_length_per_tok: int = 8 default_num_delay_tokens: int = 6 downsample_factor: int = 4 + tie_word_embeddings: bool = True def __post_init__(self, **kwargs): if isinstance(self.audio_config, dict): diff --git a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py index dbecd9a6f530..31b3850c591e 100644 --- a/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modeling_voxtral_realtime.py @@ -39,17 +39,10 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel from ...processing_utils import Unpack -from ...utils import ( - TransformersKwargs, - auto_docstring, - can_return_tuple, - is_torchdynamo_compiling, - logging, - torch_compilable_check, -) +from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import maybe_autocast, merge_with_config_defaults from ...utils.output_capturing import capture_outputs -from ..auto import AutoModel from .configuration_voxtral_realtime import ( VoxtralRealtimeConfig, VoxtralRealtimeEncoderConfig, @@ -125,6 +118,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -487,6 +498,7 @@ class VoxtralRealtimePreTrainedModel(PreTrainedModel): _supports_attention_backend = True # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -827,80 +839,6 @@ def forward( ) -@auto_docstring -class VoxtralRealtimeTextForCausalLM(VoxtralRealtimeTextPreTrainedModel, GenerationMixin): - _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"} - _tp_plan = {"lm_head": "colwise_gather_output"} - _pp_plan = {"lm_head": (["hidden_states"], ["logits"])} - - def __init__(self, config): - super().__init__(config) - self.model = VoxtralRealtimeTextModel(config) - self.vocab_size = config.vocab_size - self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) - - # Initialize weights and apply final processing - self.post_init() - - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -935,17 +873,15 @@ def forward(self, audio_features): @auto_docstring( custom_intro=""" - The VoxtralRealtime model, which consists of Whisper encoder, a multi-modal projector and a LLama language model. + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. """ ) -class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): - _keep_in_fp32_modules_strict = None - +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.vocab_size = config.text_config.vocab_size - self.audio_tower = AutoModel.from_config(config.audio_config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) @@ -958,18 +894,6 @@ def get_input_embeddings(self): def set_input_embeddings(self, value): self.language_model.set_input_embeddings(value) - def get_output_embeddings(self): - return self.language_model.get_output_embeddings() - - def set_output_embeddings(self, new_embeddings): - self.language_model.set_output_embeddings(new_embeddings) - - def set_decoder(self, decoder): - self.language_model.set_decoder(decoder) - - def get_decoder(self): - return self.language_model.get_decoder() - @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -985,11 +909,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -1014,30 +934,6 @@ def get_audio_features( return audio_outputs - def get_placeholder_mask( - self, input_ids: torch.LongTensor, inputs_embeds: torch.FloatTensor, audio_features: torch.FloatTensor - ): - """ - Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is - equal to the length of multimodal features. If the lengths are different, an error is raised. - """ - if input_ids is None: - special_audio_mask = inputs_embeds == self.get_input_embeddings()( - torch.tensor(self.config.audio_token_id, dtype=torch.long, device=inputs_embeds.device) - ) - special_audio_mask = special_audio_mask.all(-1) - else: - special_audio_mask = input_ids == self.config.audio_token_id - - n_audio_tokens = special_audio_mask.sum() - n_audio_features = audio_features.shape[0] - special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device) - torch_compilable_check( - inputs_embeds[special_audio_mask].numel() == audio_features.numel(), - f"Audio features and audio tokens do not match, tokens: {n_audio_tokens}, features: {n_audio_features}", - ) - return special_audio_mask - @can_return_tuple @auto_docstring def forward( @@ -1051,43 +947,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -1097,6 +970,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -1106,7 +981,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -1125,25 +1001,141 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +@forward_base_model_attrs(version="5.7") +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @property + def audio_tower(self): + return self.model.audio_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -1339,4 +1331,9 @@ def _prepare_generated_length( return generation_config -__all__ = ["VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel"] +__all__ = [ + "VoxtralRealtimeForConditionalGeneration", + "VoxtralRealtimeEncoder", + "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", +] diff --git a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py index edad37679927..d82c6417cb20 100644 --- a/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py +++ b/src/transformers/models/voxtral_realtime/modular_voxtral_realtime.py @@ -31,18 +31,17 @@ from ...models.mistral.modeling_mistral import ( MistralAttention, MistralDecoderLayer, - MistralForCausalLM, MistralMLP, MistralModel, MistralRMSNorm, ) from ...models.voxtral.modeling_voxtral import ( - VoxtralForConditionalGeneration, VoxtralMultiModalProjector, VoxtralPreTrainedModel, ) from ...processing_utils import Unpack from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging +from ...utils.deprecation import forward_base_model_attrs from ...utils.generic import merge_with_config_defaults from ...utils.output_capturing import capture_outputs from .configuration_voxtral_realtime import VoxtralRealtimeEncoderConfig @@ -116,6 +115,24 @@ class VoxtralRealtimeEncoderOutput(BaseModelOutputWithPast): padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None +@dataclass +class VoxtralRealtimeModelOutputWithPast(BaseModelOutputWithPast): + r""" + Args: + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the audio encoder + that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + audio_hidden_states (`torch.FloatTensor`, *optional*): + Projected audio hidden states before they are added to the text embeddings. + """ + + encoder_past_key_values: Cache | None = None + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None + audio_hidden_states: torch.FloatTensor | None = None + + @dataclass class VoxtralRealtimeCausalLMOutputWithPast(CausalLMOutputWithPast): r""" @@ -255,6 +272,7 @@ def forward( class VoxtralRealtimePreTrainedModel(VoxtralPreTrainedModel, PreTrainedModel): # TODO: @eustlb, this should be enabled soon _can_compile_fullgraph = False + _keep_in_fp32_modules_strict = None @torch.no_grad() def _init_weights(self, module): @@ -436,66 +454,6 @@ def __init__(self, config): self.rotary_emb = VoxtralRealtimeRotaryEmbedding(config=config) -class VoxtralRealtimeTextForCausalLM(MistralForCausalLM): - @can_return_tuple - @auto_docstring - def forward( - self, - input_ids: torch.LongTensor | None = None, - attention_mask: torch.Tensor | None = None, - position_ids: torch.LongTensor | None = None, - past_key_values: Cache | None = None, - inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, - use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, - **kwargs: Unpack[TransformersKwargs], - ) -> CausalLMOutputWithPast: - r""" - Example: - - ```python - >>> from transformers import AutoTokenizer, VoxtralRealtimeTextForCausalLM - - >>> model = VoxtralRealtimeTextForCausalLM.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - >>> tokenizer = AutoTokenizer.from_pretrained("mistralai/Voxtral-Mini-4B-Realtime-2602") - - >>> prompt = "Hey, are you conscious? Can you talk to me?" - >>> inputs = tokenizer(prompt, return_tensors="pt") - - >>> # Generate - >>> generate_ids = model.generate(inputs.input_ids, max_length=30) - >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] - "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you." - ```""" - outputs: BaseModelOutputWithPast = self.model( - input_ids=input_ids, - attention_mask=attention_mask, - position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=inputs_embeds, - use_cache=use_cache, - **kwargs, - ) - - hidden_states = outputs.last_hidden_state - # Only compute necessary logits, and do not upcast them to float if we are not computing the loss - slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep - logits = self.lm_head(hidden_states[:, slice_indices, :]) - - loss = None - if labels is not None: - loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.vocab_size, **kwargs) - - return CausalLMOutputWithPast( - loss=loss, - logits=logits, - past_key_values=outputs.past_key_values, - hidden_states=outputs.hidden_states, - attentions=outputs.attentions, - ) - - class VoxtralRealtimeTimeEmbedding(nn.Module): """Sinusoidal Embedding for encoding time""" @@ -520,14 +478,29 @@ def __init__(self, config): ) -class VoxtralRealtimeForConditionalGeneration(VoxtralForConditionalGeneration, GenerationMixin): - _keep_in_fp32_modules_strict = None - +@auto_docstring( + custom_intro=""" + The VoxtralRealtime model, which consists of a streaming Whisper-style encoder, a multi-modal projector, + a Mistral-based language model and a time embedding, without a language modeling head. + """ +) +class VoxtralRealtimeModel(VoxtralRealtimePreTrainedModel): def __init__(self, config): super().__init__(config) - self.language_model = VoxtralRealtimeTextForCausalLM(config.text_config) + self.audio_tower = VoxtralRealtimeEncoder(config.audio_config) + self.language_model = VoxtralRealtimeTextModel(config.text_config) + self.multi_modal_projector = VoxtralRealtimeMultiModalProjector(config) self.time_embedding = VoxtralRealtimeTimeEmbedding(config.text_config.hidden_size) + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.language_model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.language_model.set_input_embeddings(value) + @can_return_tuple @auto_docstring( custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector." @@ -543,11 +516,7 @@ def get_audio_features( ) -> tuple | BaseModelOutputWithPooling: r""" input_features (`torch.FloatTensor`): - Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be - obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a - `numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into - `input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding - and conversion into a tensor of type `torch.FloatTensor`. See [`~VoxtralRealtimeFeatureExtractor.__call__`] + Float values of mel features extracted from the raw speech waveform. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): @@ -585,43 +554,20 @@ def forward( padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, inputs_embeds: torch.FloatTensor | None = None, encoder_inputs_embeds: torch.FloatTensor | None = None, - labels: torch.LongTensor | None = None, use_cache: bool | None = None, - logits_to_keep: int | torch.Tensor = 0, num_delay_tokens: int | torch.Tensor = None, **kwargs: Unpack[TransformersKwargs], - ) -> VoxtralRealtimeCausalLMOutputWithPast: + ) -> tuple | VoxtralRealtimeModelOutputWithPast: r""" encoder_past_key_values (`Cache`, *optional*): - Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder. padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): Cache for padding in convolutional layers to maintain state across streaming chunks. encoder_inputs_embeds (`torch.FloatTensor`, *optional*): Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. num_delay_tokens (`int` or `torch.Tensor`, *optional*): - Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. - - Example: - - ```python - >>> import torch - >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor - >>> from datasets import load_dataset - - >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" - - >>> processor = AutoProcessor.from_pretrained(repo_id) - >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") - - >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") - >>> audio = ds[0]["audio"]["array"] - - >>> inputs = processor(audio, return_tensors="pt") - >>> inputs = inputs.to(model.device, dtype=model.dtype) - - >>> outputs = model.generate(**inputs) - >>> processor.batch_decode(outputs, skip_special_tokens=True) - ```""" + Number of delay tokens used when preparing inputs. + """ if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") @@ -631,6 +577,8 @@ def forward( if inputs_embeds is None: inputs_embeds = self.get_input_embeddings()(input_ids) + audio_outputs = None + audio_embeds = None if input_features is not None or encoder_inputs_embeds is not None: audio_outputs = self.get_audio_features( input_features=input_features, @@ -640,7 +588,8 @@ def forward( use_cache=use_cache, return_dict=True, ) - inputs_embeds += audio_outputs.pooler_output.to(inputs_embeds.device) + audio_embeds = audio_outputs.pooler_output + inputs_embeds = inputs_embeds + audio_embeds.to(inputs_embeds.device) if num_delay_tokens is None: num_delay_tokens = self.config.default_num_delay_tokens @@ -659,25 +608,141 @@ def forward( t_cond = self.time_embedding(time_tensor) t_cond = t_cond[None, ...] # broadcastable to batch size - outputs: CausalLMOutputWithPast = self.language_model( + outputs: BaseModelOutputWithPast = self.language_model( attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, - labels=labels, use_cache=use_cache, - logits_to_keep=logits_to_keep, t_cond=t_cond, **kwargs, ) + + return VoxtralRealtimeModelOutputWithPast( + last_hidden_state=outputs.last_hidden_state, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + encoder_past_key_values=audio_outputs.past_key_values + if (audio_outputs is not None and use_cache) + else None, + padding_cache=audio_outputs.padding_cache if (audio_outputs is not None and use_cache) else None, + audio_hidden_states=audio_embeds, + ) + + +@forward_base_model_attrs(version="5.7") +class VoxtralRealtimeForConditionalGeneration(VoxtralRealtimePreTrainedModel, GenerationMixin): + _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"} + + def __init__(self, config): + super().__init__(config) + self.model = VoxtralRealtimeModel(config) + self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False) + self.post_init() + + def get_input_embeddings(self): + return self.model.get_input_embeddings() + + def set_input_embeddings(self, value): + self.model.set_input_embeddings(value) + + def get_output_embeddings(self) -> nn.Module: + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def get_audio_features(self, *args, **kwargs): + return self.model.get_audio_features(*args, **kwargs) + + @property + def audio_tower(self): + return self.model.audio_tower + + @can_return_tuple + @auto_docstring + def forward( + self, + input_ids: torch.LongTensor | None = None, + input_features: torch.FloatTensor | None = None, + attention_mask: torch.Tensor | None = None, + position_ids: torch.LongTensor | None = None, + past_key_values: Cache | None = None, + encoder_past_key_values: Cache | None = None, + padding_cache: VoxtralRealtimeConv1dPaddingCache | None = None, + inputs_embeds: torch.FloatTensor | None = None, + encoder_inputs_embeds: torch.FloatTensor | None = None, + labels: torch.LongTensor | None = None, + use_cache: bool | None = None, + logits_to_keep: int | torch.Tensor = 0, + num_delay_tokens: int | torch.Tensor = None, + **kwargs: Unpack[TransformersKwargs], + ) -> tuple | VoxtralRealtimeCausalLMOutputWithPast: + r""" + encoder_past_key_values (`Cache`, *optional*): + Pre-computed hidden-states (key and value in the self-attention blocks) for the encoder that can be used to speed up sequential decoding. + padding_cache (`VoxtralRealtimeConv1dPaddingCache`, *optional*): + Cache for padding in convolutional layers to maintain state across streaming chunks. + encoder_inputs_embeds (`torch.FloatTensor`, *optional*): + Optionally, instead of passing `input_features` you can choose to directly pass an embedded representation for the encoder. + num_delay_tokens (`int` or `torch.Tensor`, *optional*): + Number of delay tokens used when preparing inputs, see [`~VoxtralRealtimeProcessor`] for more details. + + Example: + + ```python + >>> import torch + >>> from transformers import VoxtralRealtimeForConditionalGeneration, AutoProcessor + >>> from datasets import load_dataset + + >>> repo_id = "mistralai/Voxtral-Mini-4B-Realtime-2602" + + >>> processor = AutoProcessor.from_pretrained(repo_id) + >>> model = VoxtralRealtimeForConditionalGeneration.from_pretrained(repo_id, dtype=torch.bfloat16, device_map="auto") + + >>> ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation") + >>> audio = ds[0]["audio"]["array"] + + >>> inputs = processor(audio, return_tensors="pt") + >>> inputs = inputs.to(model.device, dtype=model.dtype) + + >>> outputs = model.generate(**inputs) + >>> processor.batch_decode(outputs, skip_special_tokens=True) + ```""" + outputs = self.model( + input_ids=input_ids, + input_features=input_features, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_values=past_key_values, + encoder_past_key_values=encoder_past_key_values, + padding_cache=padding_cache, + inputs_embeds=inputs_embeds, + encoder_inputs_embeds=encoder_inputs_embeds, + use_cache=use_cache, + num_delay_tokens=num_delay_tokens, + **kwargs, + ) + + hidden_states = outputs.last_hidden_state + slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep + logits = self.lm_head(hidden_states[:, slice_indices, :]) + + loss = None + if labels is not None: + loss = self.loss_function( + logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs + ) + return VoxtralRealtimeCausalLMOutputWithPast( - loss=outputs.loss, - logits=outputs.logits, + loss=loss, + logits=logits, past_key_values=outputs.past_key_values, hidden_states=outputs.hidden_states, attentions=outputs.attentions, - encoder_past_key_values=audio_outputs.past_key_values if use_cache else None, - padding_cache=audio_outputs.padding_cache if use_cache else None, + encoder_past_key_values=outputs.encoder_past_key_values, + padding_cache=outputs.padding_cache, ) def prepare_inputs_for_generation( @@ -705,7 +770,7 @@ def _prepare_model_inputs( bos_token_id: torch.Tensor | None = None, model_kwargs: dict[str, torch.Tensor] | None = None, ) -> tuple[torch.Tensor, str | None, dict[str, torch.Tensor]]: - inputs, input_name, model_kwargs = GenerationMixin._prepare_model_inputs(inputs, bos_token_id, model_kwargs) + inputs, input_name, model_kwargs = super()._prepare_model_inputs(inputs, bos_token_id, model_kwargs) input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): @@ -725,7 +790,7 @@ def _has_unfinished_sequences(self, this_peer_finished: bool, synced_gpus: bool, if getattr(self, "_stream_exhausted", False): self._stream_exhausted = False return False - return GenerationMixin._has_unfinished_sequences(this_peer_finished, synced_gpus, device) + return super()._has_unfinished_sequences(this_peer_finished, synced_gpus, device) def _update_model_kwargs_for_generation( self, @@ -734,7 +799,7 @@ def _update_model_kwargs_for_generation( is_encoder_decoder: bool = False, num_new_tokens: int = 1, ): - model_kwargs = GenerationMixin._update_model_kwargs_for_generation( + model_kwargs = super()._update_model_kwargs_for_generation( outputs, model_kwargs, is_encoder_decoder, num_new_tokens ) @@ -761,7 +826,7 @@ def _prepare_cache_for_generation( batch_size: int, max_cache_length: int, ): - GenerationMixin._prepare_cache_for_generation( + super()._prepare_cache_for_generation( generation_config, model_kwargs, generation_mode, batch_size, max_cache_length ) @@ -815,7 +880,7 @@ def _prepare_generation_config( generation_config is not None and generation_config.max_new_tokens is not None ) - generation_config, model_kwargs = GenerationMixin._prepare_generation_config(generation_config, **kwargs) + generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs) input_features = model_kwargs.get("input_features") if input_features is not None and not isinstance(input_features, GeneratorType): @@ -854,7 +919,7 @@ def _prepare_generated_length( if getattr(generation_config, "_voxtral_set_max_length", False): has_default_max_length = False - generation_config = GenerationMixin._prepare_generated_length( + generation_config = super()._prepare_generated_length( generation_config, has_default_max_length, has_default_min_length, @@ -877,4 +942,5 @@ def _prepare_generated_length( "VoxtralRealtimeForConditionalGeneration", "VoxtralRealtimeEncoder", "VoxtralRealtimePreTrainedModel", + "VoxtralRealtimeModel", ] diff --git a/src/transformers/utils/deprecation.py b/src/transformers/utils/deprecation.py index db0e67325d78..7091bfa6759a 100644 --- a/src/transformers/utils/deprecation.py +++ b/src/transformers/utils/deprecation.py @@ -173,3 +173,59 @@ def wrapped_func(*args, **kwargs): return wrapped_func return wrapper + + +def forward_base_model_attrs(version: str): + """ + Class decorator that forwards attribute access to the base model (`self.`) + when the attribute is not found on the instance directly, and warns that direct access on the + outer class is deprecated. + + Intended for backward compatibility during refactors that move submodules from the outer + `*ForConditionalGeneration` class down to the inner base model — e.g. `model.language_model` + becoming `model.model.language_model`. + + Apply only to the outer wrapper class (the `*ForConditionalGeneration`), not to the inner + base model itself. The decorator relies on `base_model_prefix` being set on the class (which + `PreTrainedModel` subclasses always do). + + Args: + version (`str`): + The Transformers version in which direct access will be removed (e.g. `"5.7"`). + """ + + def decorator(cls): + # Resolve the inherited __getattr__ (typically nn.Module's, which looks up + # submodules/parameters/buffers) so we can delegate to it without recursing. + inherited_getattr = cls.__getattr__ + + def __getattr__(self, name): + # First, the normal nn.Module lookup (submodules, parameters, buffers). + try: + return inherited_getattr(self, name) + except AttributeError: + pass + # Only forward public attributes to the base model — private names are + # framework internals (e.g. `_is_hf_initialized`) and shouldn't warn. + if name.startswith("_"): + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + prefix = type(self).base_model_prefix + try: + base = inherited_getattr(self, prefix) + except AttributeError: + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + if hasattr(base, name): + if not is_torchdynamo_compiling(): + warnings.warn( + f"Accessing `{name}` directly on `{type(self).__name__}` is deprecated and " + f"will be removed in Transformers v{version}. Use `.{prefix}.{name}` instead.", + FutureWarning, + stacklevel=2, + ) + return getattr(base, name) + raise AttributeError(f"{type(self).__name__!r} object has no attribute {name!r}") + + cls.__getattr__ = __getattr__ + return cls + + return decorator diff --git a/tests/alm_tester.py b/tests/alm_tester.py index c34d4d45524c..96d527aef5aa 100644 --- a/tests/alm_tester.py +++ b/tests/alm_tester.py @@ -13,7 +13,6 @@ # limitations under the License. import copy -import unittest from inspect import signature from .multimodal_tester import MultiModalModelTest, MultiModalModelTester @@ -153,11 +152,6 @@ class ALMModelTest(MultiModalModelTest): - `pipeline_model_mapping`: Override if not using default from model_tester """ - # TODO: @eustlb, remove this once #45534 is merged - @unittest.skip("Audio-LMs have no separate base model without a head.") - def test_model_base_model_prefix(self): - pass - def test_mismatching_num_audio_tokens(self): """ Tests that ALMs throw an error with explicit message saying what is wrong