Skip to content

Commit c66a379

Browse files
DongjiGaopzelaskocursoragent
authored
Add vLLM support for NeMo SpeechLM (#15520)
* Add vLLM plugin for NeMo Speech LLM inference Register NeMo Speech LM models into vLLM via the general_plugins entry point. Supports hybrid (NemotronH) and standard transformer (Qwen3) backbones. - NeMoSpeechLMHybridForConditionalGeneration: hybrid Mamba+MoE models - NeMoSpeechLMForConditionalGeneration: standard transformer models - NeMoSpeechLMStdForConditionalGeneration: legacy alias for standard - Audio preprocessing with automatic resampling to 16kHz mono - Thread-safe tokenizer patch for vLLM's concurrent encoding - Includes unit tests * vllm plugin: rename classes so standard owns the base name Swap the naming convention so it follows "unqualified base name = default variant, qualified name = specialization": NeMoSpeechLMForConditionalGeneration -> standard (Qwen3, Parakeet) NeMoSpeechLMHybridForConditionalGeneration -> hybrid Mamba+MoE (NemotronH) Previously the unqualified base name was the hybrid class, which made to_hf.py's arch auto-detection point non-hybrid checkpoints at the wrong implementation. Keep to_hf.py as the contract and rename the plugin classes to match. Legacy alias NeMoSpeechLMStdForConditionalGeneration now points at the new base-named class so checkpoints exported under the old name load. Made-with: Cursor * vllm plugin: drop NeMoSpeechLMStdForConditionalGeneration legacy alias No checkpoints in circulation use this name -- to_hf.py is the single source of truth for exported architecture names, and it only emits the two canonical names. Made-with: Cursor * vllm plugin: document that nemotron_v3 name is historical The package covers every SpeechLM backbone (Qwen3, NemotronH, ...); the folder name is a historical artifact from when the plugin started as a NemotronH-only experiment. Made-with: Cursor * vllm plugin config: tighten validation + document the quirks - Fail fast in `NeMoSpeechLMConfig.__init__` when the backbone config's `architectures` list isn't length-1: mixed or missing architectures currently route silently (mixed -> hybrid-if-any-match; missing -> treated as standard). A raised ValueError catches malformed ckpts at plugin load time instead of serving wrong weights. - Name the magic +10 on `text_config.vocab_size`: new constant `_SPEECHLM_EMBED_EXTRA_ROWS` with a block comment explaining it must match training-time vocab additions (audio locator + padding) so the embedding matrix in model.safetensors loads without shape mismatch. - Document the `architectures = ["NemotronHForCausalLM"]` normalization on hybrid backends (different checkpoints list different aliases; only the canonical name is in vLLM's registry). - Add a docstring on `__getattr__` explaining the guard list: prevents infinite recursion when plugin-specific fields are queried before `__init__` finishes, and prevents accidental delegation to same-named attributes on the wrapped `text_config`. - Drop the redundant `_ATTR_ALIASES` entry from the guard tuple: it starts with `_` so `startswith("_")` already catches it. Made-with: Cursor * vllm plugin: drop dead output_dim fallback in _load_nemo_perception Every training YAML in speechlm-2026h1/ sets perception.output_dim explicitly, so the 'if "output_dim" not in cfg' fallback never fired. Remove it and the now-unused output_dim parameter (plus the callsite's llm_hidden derivation). If a terse perception config lands here later, AudioPerceptionModule will fail on its own with a clearer error. Made-with: Cursor * vllm plugin: rename _pad_vocab_tensor -> _pad_to_vocab_size Verb-led name spells out what the helper does ('pad [the tensor] to [vocab_size]') instead of the ambiguous 'vocab tensor'. Pure rename, no behavior change. Made-with: Cursor * vllm plugin: add missing type hints per NeMo PR checklist Five signatures were missing hints and tripped the 'every exposed method needs Python 3 type hints' rule from the NeMo contributor checklist: _ensure_special_tokens, _init_perception, and the three Mamba-state classmethods. Uses PreTrainedTokenizerBase for the tokenizer, VllmConfig for vllm_config args, and Any for Mamba return types + _init_perception's config (NeMoSpeechLMConfig is same-package and brings import-cycle risk not worth the precision). Made-with: Cursor * vllm plugin: document _estimate_audio_tokens + add drift unit test The hand-rolled _estimate_audio_tokens function mirrors FastConformer's preprocessing chain (STFT + 3x Conv subsampling) but in pure Python to avoid ~90x tensor-ops overhead on the scheduler hotpath (measured 0.18 us vs 16 us per call via calc_length). Added: - Full docstring on _estimate_audio_tokens explaining what it mirrors, why it is hand-rolled, and a pointer to the drift test. - tests/collections/speechlm2/test_vllm_audio_token_estimator.py that asserts the estimator equals NeMo calc_length-based reference on 9 canonical audio lengths. Breaks when FastConformer's downsampling stack changes upstream, forcing a rewrite of the hand-rolled math. Made-with: Cursor * vllm plugin: hoist audios lookup out of get_replacement closure _get_prompt_updates's inner get_replacement closure previously re-ran mm_items.get_items('audio', ...) on every call. The lookup is O(1) and mm_items is already finalized at this point, so pulling it out once saves a redundant dict access per <|audio|> match and makes the closure body one line shorter. Pure cleanup, no behavior change. Made-with: Cursor * vllm plugin: validate + clean up placeholder/audio pairing _call_hf_processor silently accepted mismatches between the number of <|audio|> placeholders in the prompt and the number of audios in mm_data. The old loop processed first-N of whichever was shorter and left the surplus for a shape-mismatch crash deep in get_input_embeddings at forward time. Now: - Pre-loop length check: raises ValueError with a clear message when counts differ, so the error surfaces at the processor stage where the caller can see it. - Loop iterates ph_positions zipped with audios instead of walking all split parts and skipping text chunks; no audio_idx counter, no per- iteration <|audio|> branch, same behavior. - Short comment documents the positional pairing invariant. Made-with: Cursor * vllm plugin: fail loud when audio_signal_length is missing _parse_audio_input had two defensive branches that duplicated or contradicted pipeline behavior: - if audio_signal_length is None: [shape[-1]] * batch - elif not isinstance(..., Tensor): torch.tensor(...) The None branch was latently wrong. By the time execution reaches it, audio_signal has been zero-padded to max batch length via the list-stacking block above, so audio_signal.shape[-1] is the padded length, not the true audio length. Handing that to the perception encoder as input_signal_length means the encoder treats trailing zeros as real audio and emits extra output frames, silently breaking placeholder/feature alignment. In the real pipeline, _call_hf_processor always emits audio_signal_length as a 1D torch.Tensor of true per-audio lengths alongside audio_signal (both declared batched in _get_mm_fields_config), so neither branch is reachable. Replaced both with a single type check that raises ValueError when the invariant is violated. Made-with: Cursor * vllm plugin: use explicit params in _parse_audio_input _parse_audio_input had a **kwargs-only signature and popped audio fields inside the body. It mirrored vLLM's embed_multimodal(**kwargs) style but leaked the pattern into an internal helper that has a well-defined contract: exactly two inputs from the TensorSchema. Switched to explicit params (audio_signal, audio_signal_length) with **kwargs kept for forward-compat absorbing unexpected fields. Lets type checkers catch wrong-type callers and documents the contract in the signature itself. Made-with: Cursor * vllm plugin: document no-op device guard in _process_audio self.perception = self.perception.to(device) reads perception's own device and moves perception there -- always a no-op in the TP, PP, and single-GPU paths. Fragmented placement is the only case it would trigger, and there it silently moves all params to the first-by- iteration param's device, which is not controllable from the caller. Real device placement is established at init time by _mark_tower_model and declared structurally via get_mm_mapping. Added a short comment so future readers don't assume the line is doing real work and plan multi-GPU changes around it. Made-with: Cursor * vllm plugin: lift load_weights into base class with hooks Both NeMoSpeechLMHybridForConditionalGeneration and NeMoSpeechLMForConditionalGeneration had near-identical load_weights methods. The only real difference was one extra step for Standard: LoRA merge before HF-name rename. Refactored: - _NeMoSpeechLMBase.load_weights orchestrates the full pipeline (split -> perception load -> preprocess -> rename -> vLLM load). - _preprocess_llm_weights on base returns identity; Standard overrides to run _merge_lora_weights. Hybrid doesn't override. - _nemo_to_hf_llm_weights declared on base with NotImplementedError so a future subclass that forgets to override fails loudly with a clear message instead of AttributeError deep in load_weights. Subclasses now only hold the bits that differ (backbone-specific name mapping, LoRA merge). Future pipeline changes go in one place. Made-with: Cursor * vllm plugin: address CodeQL findings Four nit-level findings flagged by github-advanced-security on the latest PR push. Behavior unchanged. - __init__.py: add comment to the empty `except Exception: pass` around the NemotronH config patch — best-effort patch, silently skipped when the model class isn't reachable so other backbones still load. - model.py: drop redundant in-function `import re` in _normalize_lora_name (already imported at module top, line 31). - test_vllm_plugin.py: probe vLLM via importlib.util.find_spec instead of `import vllm` (CodeQL flags it even with `# noqa: F401`); drop stale `output_dim=256` kwarg from _load_nemo_perception call (parameter removed in 105a3dd). Signed-off-by: Dongji Gao <dongjig@nvidia.com> Made-with: Cursor * vllm plugin: rely on upstream tokenizer concurrency fix Remove the global HuggingFace fast-tokenizer monkey patch now that modern vLLM isolates multimodal tokenizer use, and keep the plugin compatible with the moved vLLM multimodal input type. Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: require exported SpeechLM config fields Fail fast when exported SpeechLM checkpoints omit fields that define the backbone, ASR source, prompt format, audio token, or pretrained-weight contract instead of silently falling back to local defaults. Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: simplify NemotronH architecture comment Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: harden config and registration tests Make plugin tests hermetic by mocking backbone config loading, skip estimator drift checks cleanly when vLLM is unavailable, and cover request-time invariants plus no tokenizer monkey patch behavior. Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: allow HF default config construction Allow HuggingFace to instantiate NeMoSpeechLMConfig without checkpoint fields for internal serialization while preserving validation for real exported SpeechLM configs. Signed-off-by: Dongji Gao <dongjig@nvidia.com> Made-with: Cursor * Apply isort and black reformatting Signed-off-by: DongjiGao <DongjiGao@users.noreply.github.com> * vllm plugin: simplify fake tokenizer callbacks Use the fake tokenizer class directly instead of wrapping it in no-op lambdas to satisfy CodeQL without changing test behavior. Signed-off-by: Dongji Gao <dongjig@nvidia.com> Made-with: Cursor * vllm plugin: rename nemotron_v3 to salm, collapse to single model class with backend composition Address Piotr's review (PR #15520): the nemotron_v3 folder name no longer matches the scope (the plugin handles both standard transformer and hybrid Mamba+MoE backbones), and the _NeMoSpeechLMBase + two-class inheritance pattern duplicates wiring across backbones. Rename the package to salm and replace the inheritance structure with composition: a single NeMoSpeechLMForConditionalGeneration class delegates LLM-specific work (architecture name, weight rename, optional LoRA merge, mamba state passthroughs) to a TransformerBackend or HybridBackend selected by make_backend(config). Single-class registration is feasible because vLLM's runtime ModelConfig.is_hybrid property uses text_config.layer_types as an escape hatch (the granite-4.0-micro path): we declare IsHybrid on the model class for the NemotronH backbone, and config.py populates layer_types=['attention']*N for transformer backbones so vLLM treats them as attention-only at runtime. There is no runtime isinstance(model, IsHybrid) check anywhere in vLLM that bypasses the property, so this collapses cleanly. salm/ now splits into: config.py NeMoSpeechLMConfig + the layer_types shim multimodal.py audio helpers and vLLM processor/info/dummy-inputs trio backends.py _BaseBackend, TransformerBackend, HybridBackend, make_backend model.py single NeMoSpeechLMForConditionalGeneration class __init__.py register() with one architecture name examples/speechlm2/to_hf.py emits the unified architecture name. Tests are updated for the new import paths and add coverage for layer_types shim wiring and make_backend dispatch (45 unit tests pass). Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Dongji Gao <dongjig@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com> Signed-off-by: Dongji Gao <dongjig@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com> * vllm plugin: avoid loading backbone config during registration Keep SALM plugin registration side-effect light by removing the NemotronH runtime monkey patch and relying on NeMoSpeechLMConfig to normalize rms_norm_eps on the wrapped backbone config. Add a registration test that fails if register() starts loading remote backbone configs again. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: simplify salm config comments Clarify the transformer layer_types shim and trim stale embedding-row commentary after review. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: normalize salm audio parser channels Ask vLLM's multimodal parser to reduce audio inputs to mono alongside the existing 16 kHz resampling, and pin that parser contract in the plugin tests. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: simplify salm model comments Remove stale inline commentary from the audio parsing path after review. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: remove unused audio typing import Drop the stale MultiModalEmbeddings re-export from audio.py; model.py imports the type directly where it is used. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: clarify transformer backend docs Remove incorrect Parakeet-TDT examples from SALM transformer-backend documentation and describe the path as decoder-only LLM backbones. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: remove audio duration limit from processing info Stop exposing a 40s max audio length through NeMoSpeechLMProcessingInfo; keep the finite 40s length only for vLLM dummy/profiling inputs. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: seed text config before base init Newer Transformers may call get_text_config during PretrainedConfig initialization, before the SALM wrapper has loaded the real backbone config. Seed an inert text_config first and keep the real checkpoint path unchanged. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> * vllm plugin: skip backend tests without vllm Backend selection imports salm.backends, which depends on vLLM symbols. Guard those tests on vLLM availability so CPU SpeechLM2 shards without vLLM skip them like the other plugin runtime tests. Made-with: Cursor Signed-off-by: Dongji Gao <dongjig@nvidia.com> --------- Signed-off-by: Dongji Gao <dongjig@nvidia.com> Signed-off-by: DongjiGao <DongjiGao@users.noreply.github.com> Co-authored-by: DongjiGao <DongjiGao@users.noreply.github.com> Co-authored-by: Piotr Żelasko <pzelasko@nvidia.com> Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent fb9a6c4 commit c66a379

11 files changed

Lines changed: 1702 additions & 12 deletions

File tree

examples/speechlm2/to_hf.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -107,11 +107,13 @@ def save_hf_checkpoint(model: torch.nn.Module, state_dict: dict, cfg: HfExportCo
107107
json.dump(config, f, indent=2)
108108

109109

110-
_HYBRID_ARCHITECTURES = {"NemotronHForCausalLM", "NemotronHybridForCausalLM"}
111-
112-
113110
def _detect_vllm_architecture(model_cfg: dict) -> str:
114-
"""Determine the vLLM plugin model class from the pretrained LLM backbone.
111+
"""Determine the vLLM plugin model class for the checkpoint.
112+
113+
The SALM plugin registers a single architecture name and selects between
114+
transformer and hybrid backends at instantiation time, so this function
115+
just verifies the backbone config is reachable and returns the unified
116+
name; the hybrid-vs-transformer split is handled inside the plugin.
115117
116118
Raises:
117119
ValueError: if the HF config can't be loaded or has no 'architectures'.
@@ -131,8 +133,6 @@ def _detect_vllm_architecture(model_cfg: dict) -> str:
131133
if not archs:
132134
raise ValueError(f"HF config for {pretrained_llm!r} has empty 'architectures'.")
133135

134-
if set(archs) & _HYBRID_ARCHITECTURES:
135-
return "NeMoSpeechLMHybridForConditionalGeneration"
136136
return "NeMoSpeechLMForConditionalGeneration"
137137

138138

nemo/collections/speechlm2/vllm/__init__.py

Whitespace-only changes.
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""vLLM plugin registration for NeMo Speech LM (SALM) models.
16+
17+
Registers ``NeMoSpeechLMConfig`` and the single
18+
``NeMoSpeechLMForConditionalGeneration`` model class with vLLM via the
19+
``vllm.general_plugins`` entry point.
20+
21+
A single model class covers every supported backbone family (standard
22+
decoder-only LLMs like Qwen3, hybrid Mamba+MoE like NemotronH).
23+
Backbone-specific behavior is selected at instantiation time.
24+
"""
25+
26+
_PKG = "nemo.collections.speechlm2.vllm.salm"
27+
28+
29+
def register():
30+
"""Register the NeMo Speech LM model and config with vLLM."""
31+
from transformers import AutoConfig
32+
33+
from nemo.collections.speechlm2.vllm.salm.config import NeMoSpeechLMConfig
34+
35+
AutoConfig.register("nemo_speechlm", NeMoSpeechLMConfig)
36+
37+
from vllm.transformers_utils.config import _CONFIG_REGISTRY
38+
39+
_CONFIG_REGISTRY["nemo_speechlm"] = NeMoSpeechLMConfig
40+
41+
from vllm.model_executor.models.registry import ModelRegistry
42+
43+
ModelRegistry.register_model(
44+
"NeMoSpeechLMForConditionalGeneration",
45+
f"{_PKG}.model:NeMoSpeechLMForConditionalGeneration",
46+
)
Lines changed: 269 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Audio-side plumbing for the NeMo Speech LM (SALM) vLLM plugin.
16+
17+
All audio handling lives here: helpers (perception loader, tokenizer special-token
18+
patcher, vocab-size padder), audio constants and TensorSchema, and the trio of
19+
classes that bind to vLLM's multimodal registry to drive prompt expansion and
20+
dummy-input generation. Backbone-agnostic; shared by both transformer and
21+
hybrid backends.
22+
23+
Public surface used by the rest of the package:
24+
25+
* ``_AUDIO_PLACEHOLDER`` -- the audio locator string vLLM emits during prompt
26+
rendering and the processor expands inline.
27+
* ``_load_nemo_perception``, ``_ensure_special_tokens``, ``_pad_to_vocab_size``
28+
-- small helpers reused at model init and weight load time.
29+
* ``NeMoSpeechLMAudioInputs`` -- vLLM ``TensorSchema`` describing the parsed
30+
audio tensors that flow into ``embed_multimodal``.
31+
* ``NeMoSpeechLMProcessingInfo`` / ``NeMoSpeechLMMultiModalProcessor`` /
32+
``NeMoSpeechLMDummyInputsBuilder`` -- the trio that vLLM's multimodal
33+
registry binds to the registered model class.
34+
"""
35+
36+
import re
37+
from collections.abc import Mapping
38+
from typing import Annotated, Literal
39+
40+
import torch
41+
from torch import nn
42+
from transformers import BatchFeature, PreTrainedTokenizerBase
43+
from vllm.config.multimodal import BaseDummyOptions
44+
45+
try:
46+
from vllm.inputs import MultiModalDataDict
47+
except ImportError:
48+
from vllm.multimodal.inputs import MultiModalDataDict
49+
50+
from vllm.multimodal.inputs import MultiModalFieldConfig, MultiModalKwargsItems
51+
from vllm.multimodal.parse import AudioProcessorItems, MultiModalDataItems, MultiModalDataParser
52+
from vllm.multimodal.processing import (
53+
BaseMultiModalProcessor,
54+
BaseProcessingInfo,
55+
PromptReplacement,
56+
PromptUpdate,
57+
PromptUpdateDetails,
58+
)
59+
from vllm.multimodal.processing.dummy_inputs import BaseDummyInputsBuilder
60+
from vllm.utils.tensor_schema import TensorSchema, TensorShape
61+
62+
from nemo.collections.speechlm2.vllm.salm.config import _AUDIO_PLACEHOLDER
63+
64+
_SAMPLING_RATE = 16000
65+
_AUDIO_CHANNELS = 1
66+
_DUMMY_AUDIO_DURATION_S = 40.0
67+
68+
69+
# ── Helpers ─────────────────────────────────────────────────────────
70+
71+
72+
def _ensure_special_tokens(tokenizer: PreTrainedTokenizerBase) -> None:
73+
special = [_AUDIO_PLACEHOLDER]
74+
existing = set(tokenizer.get_vocab().keys())
75+
to_add = [t for t in special if t not in existing]
76+
if to_add:
77+
tokenizer.add_special_tokens({"additional_special_tokens": to_add})
78+
79+
80+
def _load_nemo_perception(perception_cfg: dict) -> nn.Module:
81+
try:
82+
from omegaconf import DictConfig
83+
84+
from nemo.collections.speechlm2.modules import AudioPerceptionModule
85+
except ImportError as e:
86+
raise ImportError(
87+
"NeMo is required for the audio encoder. " "Install with: pip install nemo_toolkit[asr]"
88+
) from e
89+
90+
cfg = DictConfig(perception_cfg)
91+
perception = AudioPerceptionModule(cfg)
92+
perception.eval()
93+
return perception
94+
95+
96+
def _pad_to_vocab_size(tensor: torch.Tensor, target_vocab: int) -> torch.Tensor:
97+
if tensor.shape[0] < target_vocab:
98+
pad = torch.zeros(
99+
target_vocab - tensor.shape[0],
100+
*tensor.shape[1:],
101+
dtype=tensor.dtype,
102+
)
103+
tensor = torch.cat([tensor, pad], dim=0)
104+
return tensor
105+
106+
107+
# ── Multimodal contract types ───────────────────────────────────────
108+
109+
110+
class NeMoSpeechLMAudioInputs(TensorSchema):
111+
type: Literal["audio_features"] = "audio_features"
112+
audio_signal: Annotated[torch.Tensor | list[torch.Tensor], TensorShape("b", "t")]
113+
audio_signal_length: Annotated[torch.Tensor, TensorShape("b")]
114+
115+
116+
class NeMoSpeechLMProcessingInfo(BaseProcessingInfo):
117+
118+
def get_data_parser(self) -> MultiModalDataParser:
119+
return MultiModalDataParser(
120+
target_sr=_SAMPLING_RATE,
121+
target_channels=_AUDIO_CHANNELS,
122+
expected_hidden_size=self._get_expected_hidden_size(),
123+
)
124+
125+
def get_supported_mm_limits(self) -> Mapping[str, int | None]:
126+
return {"audio": 1}
127+
128+
@staticmethod
129+
def _estimate_audio_tokens(audio_length_samples: int) -> int:
130+
"""Predict the encoder's output frame count for an audio of N samples.
131+
132+
Mirrors the FastConformer preprocessing chain used by
133+
``AudioPerceptionModule``: STFT (n_fft=512, hop_length=160) followed
134+
by 3x Conv(kernel=3, stride=2) subsampling. Implemented as pure
135+
Python integer math instead of calling NeMo's ``calc_length`` so
136+
the scheduler hotpath avoids ~90x tensor-op overhead (measured
137+
0.18 us vs 16 us per call). If the encoder's downsampling stack
138+
ever changes upstream, the unit test at
139+
``tests/collections/speechlm2/test_vllm_audio_token_estimator.py``
140+
compares this function against ``calc_length`` on a canonical set
141+
of lengths and will fail, forcing a rewrite here.
142+
"""
143+
n_fft = 512
144+
hop_length = 160
145+
stft_pad = n_fft // 2
146+
fbank_len = (audio_length_samples + 2 * stft_pad - n_fft) // hop_length
147+
kernel, stride, repeat = 3, 2, 3
148+
add_pad = 1 + 1 - kernel
149+
length = float(fbank_len)
150+
for _ in range(repeat):
151+
length = (length + add_pad) / stride + 1.0
152+
return max(1, int(length))
153+
154+
155+
class NeMoSpeechLMMultiModalProcessor(
156+
BaseMultiModalProcessor[NeMoSpeechLMProcessingInfo],
157+
):
158+
159+
def _get_mm_fields_config(
160+
self,
161+
hf_inputs: BatchFeature,
162+
hf_processor_mm_kwargs: Mapping[str, object],
163+
) -> Mapping[str, MultiModalFieldConfig]:
164+
return dict(
165+
audio_signal=MultiModalFieldConfig.batched("audio"),
166+
audio_signal_length=MultiModalFieldConfig.batched("audio"),
167+
)
168+
169+
def _hf_processor_applies_updates(
170+
self,
171+
prompt_text: str,
172+
mm_items: MultiModalDataItems,
173+
hf_processor_mm_kwargs: Mapping[str, object],
174+
tokenization_kwargs: Mapping[str, object],
175+
) -> bool:
176+
return False
177+
178+
def _get_prompt_updates(
179+
self,
180+
mm_items: MultiModalDataItems,
181+
hf_processor_mm_kwargs: Mapping[str, object],
182+
out_mm_kwargs: MultiModalKwargsItems,
183+
) -> list[PromptUpdate]:
184+
audios = mm_items.get_items("audio", AudioProcessorItems)
185+
186+
def get_replacement(item_idx: int):
187+
audio = audios.get(item_idx)
188+
n_tokens = self.info._estimate_audio_tokens(audio.shape[-1])
189+
repl_full = _AUDIO_PLACEHOLDER * n_tokens
190+
return PromptUpdateDetails.select_text(repl_full, _AUDIO_PLACEHOLDER)
191+
192+
return [
193+
PromptReplacement(
194+
modality="audio",
195+
target=_AUDIO_PLACEHOLDER,
196+
replacement=get_replacement,
197+
)
198+
]
199+
200+
def _call_hf_processor(
201+
self,
202+
prompt: str,
203+
mm_data: Mapping[str, object],
204+
mm_kwargs: Mapping[str, object],
205+
tok_kwargs: Mapping[str, object],
206+
) -> BatchFeature:
207+
tokenizer = self.info.get_tokenizer()
208+
_ensure_special_tokens(tokenizer)
209+
mm_data = dict(mm_data)
210+
audios = mm_data.pop("audios", [])
211+
212+
if audios:
213+
audio_list: list[torch.Tensor] = []
214+
audio_lengths: list[int] = []
215+
parts = re.split(f"({re.escape(_AUDIO_PLACEHOLDER)})", prompt)
216+
# One placeholder is overwritten with one audio's encoder output
217+
# at forward time (positional pairing); counts must match or the
218+
# merge step in get_input_embeddings crashes / silently drops.
219+
ph_positions = [i for i, p in enumerate(parts) if p == _AUDIO_PLACEHOLDER]
220+
if len(ph_positions) != len(audios):
221+
raise ValueError(
222+
f"Prompt has {len(ph_positions)} "
223+
f"{_AUDIO_PLACEHOLDER!r} placeholders but "
224+
f"{len(audios)} audios were provided; counts must match."
225+
)
226+
for i, audio in zip(ph_positions, audios):
227+
audio_tensor = (
228+
audio if isinstance(audio, torch.Tensor) else torch.as_tensor(audio, dtype=torch.float32)
229+
)
230+
if audio_tensor.dim() > 1:
231+
audio_tensor = audio_tensor.squeeze()
232+
n_tokens = self.info._estimate_audio_tokens(audio_tensor.shape[-1])
233+
parts[i] = _AUDIO_PLACEHOLDER * n_tokens
234+
audio_list.append(audio_tensor)
235+
audio_lengths.append(audio_tensor.shape[-1])
236+
237+
prompt = "".join(parts)
238+
239+
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True)
240+
result = BatchFeature(dict(input_ids=[prompt_ids]), tensor_type="pt")
241+
242+
if audios:
243+
result["audio_signal"] = audio_list
244+
result["audio_signal_length"] = torch.tensor(audio_lengths)
245+
return result
246+
247+
248+
class NeMoSpeechLMDummyInputsBuilder(
249+
BaseDummyInputsBuilder[NeMoSpeechLMProcessingInfo],
250+
):
251+
252+
def get_dummy_mm_data(
253+
self,
254+
seq_len: int,
255+
mm_counts: Mapping[str, int],
256+
mm_options: Mapping[str, BaseDummyOptions],
257+
) -> MultiModalDataDict:
258+
num_audios = mm_counts.get("audio", 0)
259+
dummy_audio_len = int(_DUMMY_AUDIO_DURATION_S * _SAMPLING_RATE)
260+
return {
261+
"audio": self._get_dummy_audios(
262+
length=dummy_audio_len,
263+
num_audios=num_audios,
264+
)
265+
}
266+
267+
def get_dummy_text(self, mm_counts: Mapping[str, int]) -> str:
268+
num_audios = mm_counts.get("audio", 0)
269+
return "Transcribe the following: " + _AUDIO_PLACEHOLDER * num_audios

0 commit comments

Comments
 (0)