Skip to content
30 changes: 30 additions & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,13 @@
"vipllava": "llava",
"mistral3": "llava",
"pp_chart2table": "llava",
"voxtral": "qwen2_audio",
"voxtral_realtime": "qwen2_audio",
"audioflamingo3": "qwen2_audio",
"glmasr": "qwen2_audio",
"musicflamingo": "qwen2_audio",
"gemma3n_text": "qwen3_5_text",
"qwen3_5_moe_text": "qwen3_5_text",
"llava_next_video": "llava_next",
"llava_onevision": "llava_next",
# class-based mappings
Expand Down Expand Up @@ -401,6 +408,29 @@ def _build_checkpoint_conversion_mapping():
WeightRenaming(source_patterns=r"^vision_tower", target_patterns="model.vision_tower"),
WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
],
"qwen2_audio": [
WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
WeightRenaming(source_patterns=r"^audio_tower", target_patterns="model.audio_tower"),
WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
],
"granite_speech": [
WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
WeightRenaming(source_patterns=r"^encoder", target_patterns="model.encoder"),
WeightRenaming(source_patterns=r"^projector", target_patterns="model.projector"),
],
"vibevoice_asr": [
WeightRenaming(source_patterns=r"^language_model.model", target_patterns="model.language_model"),
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
WeightRenaming(
source_patterns=r"^acoustic_tokenizer_encoder", target_patterns="model.acoustic_tokenizer_encoder"
),
WeightRenaming(
source_patterns=r"^semantic_tokenizer_encoder", target_patterns="model.semantic_tokenizer_encoder"
),
WeightRenaming(source_patterns=r"^multi_modal_projector", target_patterns="model.multi_modal_projector"),
],
"llava_next": [
WeightRenaming(source_patterns=r"^language_model.lm_head", target_patterns="lm_head"),
WeightRenaming(source_patterns=r"^language_model", target_patterns="model.language_model"),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ class AudioFlamingo3Config(PreTrainedConfig):
audio_token_id: int = 151669
projector_hidden_act: str = "gelu"
projector_bias: bool = True
tie_word_embeddings: bool = True

def __post_init__(self, **kwargs):
if isinstance(self.audio_config, dict):
Expand Down
256 changes: 161 additions & 95 deletions src/transformers/models/audioflamingo3/modeling_audioflamingo3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

import math
from collections.abc import Callable
from dataclasses import dataclass

import torch
from torch import nn
Expand All @@ -31,13 +32,14 @@
from ...masking_utils import create_bidirectional_mask
from ...modeling_flash_attention_utils import FlashAttentionKwargs
from ...modeling_layers import GradientCheckpointingLayer
from ...modeling_outputs import BaseModelOutputWithPooling, CausalLMOutputWithPast
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
from ...processing_utils import Unpack
from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging, torch_compilable_check
from ...utils.deprecation import forward_base_model_attrs
from ...utils.generic import merge_with_config_defaults
from ...utils.output_capturing import capture_outputs
from ..auto import AutoModel, AutoModelForCausalLM
from ..auto import AutoModel
from .configuration_audioflamingo3 import AudioFlamingo3Config, AudioFlamingo3EncoderConfig


Expand Down Expand Up @@ -256,6 +258,42 @@ class AudioFlamingo3PreTrainedModel(PreTrainedModel):
_supports_sdpa = True


@dataclass
class AudioFlamingo3ModelOutputWithPast(BaseModelOutputWithPast):
r"""
audio_hidden_states (`torch.FloatTensor`, *optional*):
Projected audio hidden states.
"""

audio_hidden_states: torch.FloatTensor | None = None


@dataclass
@auto_docstring(
custom_intro="""
Base class for AudioFlamingo3 causal language model (or autoregressive) outputs.
"""
)
class AudioFlamingo3CausalLMOutputWithPast(ModelOutput):
r"""
loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
Language modeling loss (for next-token prediction).
logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
Prediction scores of the language modeling head.
past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
It is a [`~cache_utils.Cache`] instance.
audio_hidden_states (`torch.FloatTensor`, *optional*):
Hidden states of the audio encoder after projection.
"""

loss: torch.FloatTensor | None = None
logits: torch.FloatTensor | None = None
past_key_values: Cache | None = None
hidden_states: tuple[torch.FloatTensor] | None = None
attentions: tuple[torch.FloatTensor] | None = None
audio_hidden_states: torch.FloatTensor | None = None


@auto_docstring(
custom_intro="""
The audio model from AudioFlamingo3 without any head or projection on top.
Expand Down Expand Up @@ -403,23 +441,21 @@ def forward(self, audio_features):

@auto_docstring(
custom_intro="""
The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
The AudioFlamingo3 model (fine-tuned Whisper encoder, multi-modal projector, Qwen2 language model),
without a language modeling head.
"""
)
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
class AudioFlamingo3Model(AudioFlamingo3PreTrainedModel):
_supports_attention_backend = True
_tp_plan = None
_pp_plan = None
_keep_in_fp32_modules_strict = None

def __init__(self, config):
super().__init__(config)
self.vocab_size = config.text_config.vocab_size
self.audio_tower = AutoModel.from_config(config.audio_config)
self.language_model = AutoModelForCausalLM.from_config(config.text_config)
self.language_model = AutoModel.from_config(config.text_config)
self.multi_modal_projector = AudioFlamingo3MultiModalProjector(config)

# Initialize weights and apply final processing
self.post_init()

def get_input_embeddings(self):
Expand All @@ -428,18 +464,6 @@ def get_input_embeddings(self):
def set_input_embeddings(self, value):
self.language_model.set_input_embeddings(value)

def get_output_embeddings(self):
return self.language_model.get_output_embeddings()

def set_output_embeddings(self, new_embeddings):
self.language_model.set_output_embeddings(new_embeddings)

def set_decoder(self, decoder):
self.language_model.set_decoder(decoder)

def get_decoder(self):
return self.language_model.get_decoder()

@can_return_tuple
@auto_docstring(
custom_intro="This method is used to get the audio embeddings from input features (a log mel spectrogram), meaning inferring the audio encoder and the multi-modal projector."
Expand All @@ -452,11 +476,7 @@ def get_audio_features(
) -> tuple | BaseModelOutputWithPooling:
r"""
input_features (`torch.FloatTensor`):
Float values of mel features extracted from the raw speech waveform. Raw speech waveform can be
obtained by loading a `.flac` or `.wav` audio file into an array of type `list[float]` or a
`numpy.ndarray`, *e.g.* via the soundfile library (`pip install soundfile`). To prepare the array into
`input_features`, the [`AutoFeatureExtractor`] should be used for extracting the mel features, padding
and conversion into a tensor of type `torch.FloatTensor`. See [`~WhisperFeatureExtractor.__call__`]
Float values of mel features extracted from the raw speech waveform.
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padded feature indices.
"""
Expand Down Expand Up @@ -509,77 +529,17 @@ def forward(
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> CausalLMOutputWithPast:
) -> tuple | AudioFlamingo3ModelOutputWithPast:
r"""
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:

- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.

Example:

```python
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor

>>> model_id = "nvidia/audio-flamingo-3-hf"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")

>>> conversations = [
>>> [
>>> {
>>> "role": "user",
>>> "content": [
>>> {"type": "text", "text": "Transcribe the input speech."},
>>> {
>>> "type": "audio",
>>> "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/t_837b89f2-26aa-4ee2-bdf6-f73f0dd59b26.wav",
>>> },
>>> ],
>>> }
>>> ],
>>> [
>>> {
>>> "role": "user",
>>> "content": [
>>> {
>>> "type": "text",
>>> "text": "This track feels really peaceful and introspective. What elements make it feel so calming and meditative?",
>>> },
>>> {"type": "audio", "path": "https://huggingface.co/datasets/nvidia/AudioSkills/resolve/main/assets/FPSbCAANfbJLVSwD.mp3"},
>>> ],
>>> }
>>> ],
>>> ]

>>> inputs = processor.apply_chat_template(
>>> conversations,
>>> tokenize=True,
>>> add_generation_prompt=True,
>>> return_dict=True,
>>> ).to(model.device)

>>> outputs = model.generate(**inputs, max_new_tokens=500)

>>> decoded_outputs = processor.batch_decode(
>>> outputs[:, inputs["input_ids"].shape[1]:], skip_special_tokens=True
>>> )
>>> print(decoded_outputs)
["The spoken content of the audio is...", "The track's calming and meditative feel can be attributed to..."]
```"""

Mask to avoid performing attention on padding feature indices.
"""
if inputs_embeds is None:
inputs_embeds = self.get_input_embeddings()(input_ids)

audio_embeds = None
if input_features is not None and input_ids is not None:
audio_embeds = self.get_audio_features(input_features, input_features_mask, return_dict=True).pooler_output

Expand All @@ -589,17 +549,118 @@ def forward(
)
inputs_embeds = inputs_embeds.masked_scatter(special_audio_mask, audio_embeds.to(inputs_embeds.device))

outputs: CausalLMOutputWithPast = self.language_model(
outputs = self.language_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
labels=labels,
use_cache=use_cache,
logits_to_keep=logits_to_keep,
**kwargs,
)
return outputs

return AudioFlamingo3ModelOutputWithPast(
last_hidden_state=outputs.last_hidden_state,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
audio_hidden_states=audio_embeds,
)


@auto_docstring(
custom_intro="""
The AudioFlamingo3 model which consists of a fine-tuned Whisper encoder, a multi-modal projector and a Qwen2 language model.
"""
)
@forward_base_model_attrs(version="5.7")
class AudioFlamingo3ForConditionalGeneration(AudioFlamingo3PreTrainedModel, GenerationMixin):
_keep_in_fp32_modules_strict = None
_tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
_tp_plan = None
_pp_plan = None

def __init__(self, config):
super().__init__(config)
self.model = AudioFlamingo3Model(config)
self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
self.post_init()

def get_input_embeddings(self):
return self.model.get_input_embeddings()

def set_input_embeddings(self, value):
self.model.set_input_embeddings(value)

def get_output_embeddings(self) -> nn.Module:
return self.lm_head

def set_output_embeddings(self, new_embeddings):
self.lm_head = new_embeddings

def get_audio_features(self, input_features, input_features_mask, **kwargs):
return self.model.get_audio_features(input_features, input_features_mask, **kwargs)

@can_return_tuple
@auto_docstring
def forward(
self,
input_ids: torch.LongTensor | None = None,
input_features: torch.FloatTensor | None = None,
input_features_mask: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
position_ids: torch.LongTensor | None = None,
past_key_values: Cache | None = None,
inputs_embeds: torch.FloatTensor | None = None,
labels: torch.LongTensor | None = None,
use_cache: bool | None = None,
logits_to_keep: int | torch.Tensor = 0,
**kwargs: Unpack[TransformersKwargs],
) -> tuple | AudioFlamingo3CausalLMOutputWithPast:
r"""
input_features_mask (`torch.Tensor` of shape `(batch_size, feature_sequence_length)`):
Mask to avoid performing attention on padding feature indices.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss.

Example:

```python
>>> from transformers import AudioFlamingo3ForConditionalGeneration, AutoProcessor

>>> model_id = "nvidia/audio-flamingo-3-hf"
>>> processor = AutoProcessor.from_pretrained(model_id)
>>> model = AudioFlamingo3ForConditionalGeneration.from_pretrained(model_id, device_map="auto")
```"""
outputs = self.model(
input_ids=input_ids,
input_features=input_features,
input_features_mask=input_features_mask,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=inputs_embeds,
use_cache=use_cache,
**kwargs,
)

hidden_states = outputs.last_hidden_state
slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
logits = self.lm_head(hidden_states[:, slice_indices, :])

loss = None
if labels is not None:
loss = self.loss_function(
logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size, **kwargs
)

return AudioFlamingo3CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
audio_hidden_states=outputs.audio_hidden_states,
)

def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False, **kwargs):
input_features = kwargs.pop("input_features", None)
Expand All @@ -616,4 +677,9 @@ def prepare_inputs_for_generation(self, *args, is_first_iteration: bool = False,
return model_inputs


__all__ = ["AudioFlamingo3ForConditionalGeneration", "AudioFlamingo3PreTrainedModel", "AudioFlamingo3Encoder"]
__all__ = [
"AudioFlamingo3ForConditionalGeneration",
"AudioFlamingo3PreTrainedModel",
"AudioFlamingo3Encoder",
"AudioFlamingo3Model",
]
Loading
Loading