Skip to content

Commit ac51a36

Browse files
reslove code review items:
- Revert the attention_sink assertion for multimodal models - Parameterize the audio encoder input shape - Fix comment typos
1 parent 2e4c9bc commit ac51a36

6 files changed

Lines changed: 24 additions & 7 deletions

File tree

examples/qualcomm/oss_scripts/llama/dataset.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,19 @@ def _build_audio_dataset(
9090
wav, sr = soundfile.read(audio_path, always_2d=False)
9191
wav = torch.from_numpy(wav).float().unsqueeze(0) # [1, T]
9292

93+
# Pad to fixed length so input_features has shape [1, n_bins, input_dim]
94+
hop_length = processor.audio_processor.melspec_kwargs["hop_length"]
95+
target_raw_length = (config.n_bins * 2 - 1) * hop_length
96+
pad_size = target_raw_length - wav.shape[-1]
97+
if pad_size > 0:
98+
wav = torch.nn.functional.pad(wav, (0, pad_size))
99+
elif pad_size < 0:
100+
suggested_n_bins = (wav.shape[-1] // hop_length + 1) // 2
101+
raise ValueError(
102+
f"Audio length ({wav.shape[-1]} samples) exceeds target ({target_raw_length} samples) "
103+
f"derived from n_bins={config.n_bins}. Set n_bins >= {suggested_n_bins} in the config to avoid information loss."
104+
)
105+
93106
# Process audio with text prompt using HuggingFace processor
94107
input_features = processor(prompt, wav, return_tensors="pt").input_features
95108
dataset.append((input_features,))

examples/qualcomm/oss_scripts/llama/encoder/encoder_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,11 @@ class AudioModalityConfig(MultiModalityConfig):
5353
"""
5454

5555
audio_seq_len: int
56+
n_bins: int
5657
audio_url: str
5758

5859
def create_encoder(self, config):
59-
return self.encoder_class(config)
60+
return self.encoder_class(config, n_bins=self.n_bins)
6061

6162

6263
@dataclass(init=False, frozen=True)
@@ -92,6 +93,7 @@ class GraniteSpeechEncoder(AudioModalityConfig):
9293

9394
encoder_class = GraniteSpeechCTCEncoderWrapper
9495
audio_seq_len = 171
96+
n_bins = 844
9597
audio_url = "https://huggingface.co/ibm-granite/granite-speech-3.3-2b/resolve/main/10226_10111_000000.wav?download=true"
9698
quant_recipe = GraniteSpeechEncoderQuantRecipe
9799
num_sharding = 8

examples/qualcomm/oss_scripts/llama/llama.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ def export_llama(args) -> None:
635635
)
636636
# TODO: Implement attention sink support for multimodal models (vision/audio).
637637
assert (
638-
is_multimodal or args.use_attention_sink is None
638+
not is_multimodal or args.use_attention_sink is None
639639
), "Multimodal models currently do not support attention sink feature."
640640

641641
if args.pre_gen_pte:

examples/qualcomm/oss_scripts/llama/model/audio_encoder.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def forward(
123123

124124

125125
class GraniteSpeechCTCEncoderWrapper(nn.Module):
126-
def __init__(self, config: GraniteSpeechConfig):
126+
def __init__(self, config: GraniteSpeechConfig, n_bins: int):
127127
super().__init__()
128128
self.encoder = GraniteSpeechCTCEncoder(config.encoder_config)
129129
self.projector = GraniteSpeechEncoderProjector(config)
@@ -145,9 +145,11 @@ def __init__(self, config: GraniteSpeechConfig):
145145
)
146146

147147
self.config = config
148+
self.n_bins = n_bins
148149

149150
def get_example_inputs(self):
150-
return (torch.randn((1, 844, 160), dtype=torch.float32),)
151+
input_dim = self.config.encoder_config.input_dim
152+
return (torch.randn((1, self.n_bins, input_dim), dtype=torch.float32),)
151153

152154
def forward(self, hidden_states: torch.Tensor):
153155
encoder_embeds = self.encoder(hidden_states)

examples/qualcomm/oss_scripts/llama/runner/multimodal_runner/multimodal_embedding_merger.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ class MultimodalEmbeddingMerger {
4545
int32_t embedding_dim_;
4646
int32_t total_tokens_{0};
4747

48-
// merged embeddings are holded in this vector.
48+
// merged embeddings are held in this vector.
4949
std::vector<float> embeddings_;
5050
std::array<executorch::aten::TensorImpl::SizesType, 3> sizes_{};
5151
};

examples/qualcomm/oss_scripts/llama/tokenizer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -184,8 +184,8 @@ def prepare_messages(self, prompts: List[str]): # noqa: C901
184184

185185
audio_paths = self.control_args.audio_path
186186
if hasattr(self.config, AUDIO_ENCODER):
187-
# Load image from user-specified path (URL or local file)
188-
# fall back to the default image URL if no image is provided.
187+
# Load audio from user-specified path (URL or local file)
188+
# fall back to the default audio URL if no audio is provided.
189189
if not audio_paths:
190190
audio_paths = [getattr(self.config, AUDIO_ENCODER).audio_url]
191191
warnings.warn(

0 commit comments

Comments
 (0)