Skip to content

Commit 79441a7

Browse files
committed
PR_#1 from mmkrtchyan/qwen-asr-on-filter-pipeline, QwenASR hallucination recovery and unified _skip_me tracking
Adds optional QwenASR (0.6B) re-transcription stage that recovers samples flagged as hallucinated by QwenOmni. Unifies _skip_me field with source tracking (Hallucination:WhisperHallucination_omni, Recovered:QwenASR, Wrong language:FastTextLID, etc.). Routes the best prediction to downstream stages via SelectBestPredictionStage. - nemo_curator/models/qwen_asr.py — QwenASR model wrapper (qwen_asr lib + vLLM) - nemo_curator/stages/audio/inference/qwen_asr.py — conditional QwenASR stage - nemo_curator/stages/audio/text_filtering/select_best_prediction.py — picker - keep_waveform flag on QwenOmni for downstream audio access - Skip empty text in hallucination detection - Enabled via --asr_model_id (otherwise no-op) Squash cherry-pick of nune-tadevosyan#1. Conflict resolution: - run_pipeline.py: rebuilt to integrate both QwenASR (this PR) and ITN (PR #3) stage chains. Used PR_#1's structure as base, added ITN imports, CLI flags, prompt loading, and conditional ITNRestorationStage append after PnC. - fasttext_lid.py: kept PR_#1's source-tracked "Empty text:{self.name}" (matches the unified _skip_me convention introduced here). #NO_PR Signed-off-by: George Zelenfroynd <gzelenfroind@nvidia.com>
1 parent 42faf7e commit 79441a7

7 files changed

Lines changed: 631 additions & 214 deletions

File tree

examples/audio/qwen_omni_inprocess/run_pipeline.py

Lines changed: 212 additions & 205 deletions
Large diffs are not rendered by default.

nemo_curator/models/qwen_asr.py

Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Qwen3-ASR model wrapper for in-process vLLM inference.
16+
17+
Uses the ``qwen_asr`` library which wraps vLLM internally and exposes a
18+
high-level ``transcribe()`` API that accepts in-memory numpy waveforms.
19+
"""
20+
21+
from __future__ import annotations
22+
23+
import gc
24+
from typing import TYPE_CHECKING, Any
25+
26+
from loguru import logger
27+
28+
from nemo_curator.models.base import ModelInterface
29+
30+
if TYPE_CHECKING:
31+
import numpy as np
32+
33+
_QWEN3_ASR_MODEL_ID = "Qwen/Qwen3-ASR-0.6B"
34+
35+
36+
class QwenASR(ModelInterface):
37+
"""Qwen3-ASR model via the ``qwen_asr`` library with vLLM backend.
38+
39+
Audio is accepted as in-memory numpy arrays (mono, any sample rate).
40+
The ``qwen_asr`` library handles resampling to 16 kHz, chunking long
41+
audio, and batched vLLM inference internally.
42+
"""
43+
44+
def __init__(
45+
self,
46+
model_id: str = _QWEN3_ASR_MODEL_ID,
47+
language: str | None = None,
48+
gpu_memory_utilization: float = 0.7,
49+
max_new_tokens: int = 4096,
50+
max_inference_batch_size: int = 128,
51+
):
52+
self.model_id = model_id
53+
self.language = language
54+
self.gpu_memory_utilization = gpu_memory_utilization
55+
self.max_new_tokens = max_new_tokens
56+
self.max_inference_batch_size = max_inference_batch_size
57+
58+
self._model: Any = None
59+
60+
@property
61+
def model_id_names(self) -> list[str]:
62+
return [self.model_id]
63+
64+
# ------------------------------------------------------------------
65+
# Lifecycle
66+
# ------------------------------------------------------------------
67+
68+
@staticmethod
69+
def _patch_transformers_compat() -> None:
70+
"""Patch transformers.check_model_inputs for qwen-asr compatibility.
71+
72+
Newer transformers changed check_model_inputs from a decorator factory
73+
(called with parentheses) to a plain decorator. The qwen-asr package
74+
uses the old ``@check_model_inputs()`` syntax which breaks on newer
75+
versions. This wraps it to accept both styles.
76+
"""
77+
try:
78+
import transformers
79+
original = getattr(transformers, "check_model_inputs", None)
80+
if original is None:
81+
return
82+
import inspect
83+
sig = inspect.signature(original)
84+
params = list(sig.parameters.values())
85+
if params and params[0].name == "func":
86+
def compat_check_model_inputs(*args, **kwargs):
87+
if args and callable(args[0]):
88+
return original(args[0])
89+
return original
90+
transformers.check_model_inputs = compat_check_model_inputs
91+
except Exception: # noqa: BLE001
92+
pass
93+
94+
def setup(self) -> None:
95+
self._patch_transformers_compat()
96+
97+
try:
98+
from qwen_asr import Qwen3ASRModel
99+
except ImportError:
100+
msg = "qwen_asr is required for QwenASR. Install it: pip install qwen-asr[vllm]"
101+
raise ImportError(msg) from None
102+
103+
logger.info(
104+
f"Loading QwenASR model={self.model_id} "
105+
f"gpu_mem={self.gpu_memory_utilization} "
106+
f"max_new_tokens={self.max_new_tokens} "
107+
f"max_batch={self.max_inference_batch_size}"
108+
)
109+
110+
self._model = Qwen3ASRModel.LLM(
111+
model=self.model_id,
112+
gpu_memory_utilization=self.gpu_memory_utilization,
113+
max_inference_batch_size=self.max_inference_batch_size,
114+
max_new_tokens=self.max_new_tokens,
115+
trust_remote_code=True,
116+
enforce_eager=True,
117+
)
118+
119+
logger.info("QwenASR model loaded")
120+
121+
def teardown(self) -> None:
122+
del self._model
123+
self._model = None
124+
gc.collect()
125+
try:
126+
import torch
127+
128+
torch.cuda.empty_cache()
129+
except Exception: # noqa: BLE001, S110
130+
pass
131+
132+
# ------------------------------------------------------------------
133+
# Generation
134+
# ------------------------------------------------------------------
135+
136+
def generate(
137+
self,
138+
waveforms: list[np.ndarray],
139+
sample_rates: list[int],
140+
contexts: list[str] | None = None,
141+
) -> tuple[list[str], list[str]]:
142+
"""Run batched ASR inference on in-memory audio waveforms.
143+
144+
Args:
145+
waveforms: List of 1-D mono numpy float32 arrays.
146+
sample_rates: Corresponding sample rates for each waveform.
147+
contexts: Optional per-sample instruction strings for
148+
``with_instruction`` mode.
149+
150+
Returns:
151+
``(texts, languages)`` -- transcribed text and detected
152+
language for each input.
153+
"""
154+
if self._model is None:
155+
msg = "Model not initialized. Call setup() first."
156+
raise RuntimeError(msg)
157+
158+
audio_inputs: list[tuple[np.ndarray, int]] = list(
159+
zip(waveforms, sample_rates, strict=True)
160+
)
161+
162+
kwargs: dict[str, Any] = {
163+
"audio": audio_inputs,
164+
"language": self.language,
165+
}
166+
if contexts is not None:
167+
kwargs["context"] = contexts
168+
169+
results = self._model.transcribe(**kwargs)
170+
171+
texts = [getattr(r, "text", str(r)) for r in results]
172+
languages = [getattr(r, "language", "") or (self.language or "") for r in results]
173+
174+
return texts, languages
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from __future__ import annotations
16+
17+
from dataclasses import dataclass, field
18+
from typing import TYPE_CHECKING
19+
20+
from loguru import logger
21+
22+
from nemo_curator.models.qwen_asr import QwenASR
23+
from nemo_curator.stages.base import ProcessingStage
24+
from nemo_curator.stages.resources import Resources
25+
from nemo_curator.tasks import AudioTask
26+
27+
if TYPE_CHECKING:
28+
from nemo_curator.backends.base import NodeInfo, WorkerMetadata
29+
30+
31+
@dataclass
32+
class InferenceQwenASRStage(ProcessingStage[AudioTask, AudioTask]):
33+
"""Audio inference using Qwen3-ASR via the ``qwen_asr`` library (vLLM backend).
34+
35+
Expects each ``AudioTask.data`` to carry:
36+
37+
- ``waveform``: 1-D mono numpy float32 array (any sample rate)
38+
- ``sample_rate``: int
39+
40+
When ``run_only_if_key`` is set, the stage only runs inference on
41+
tasks where ``task.data[run_only_if_key]`` starts with
42+
``run_only_if_prefix`` (default ``"Hallucination"``). Non-matching
43+
tasks pass through unchanged.
44+
45+
Args:
46+
model_id: HuggingFace model identifier or local path.
47+
language: Language hint (e.g. ``"English"``).
48+
pred_text_key: Key where the predicted text is stored.
49+
language_key: Key where the detected language is stored.
50+
run_only_if_key: If set, only run inference on tasks where
51+
``task.data[run_only_if_key]`` starts with ``run_only_if_prefix``.
52+
gpu_memory_utilization: Fraction of GPU memory vLLM may use.
53+
max_new_tokens: Maximum tokens to generate per sample.
54+
max_inference_batch_size: Batch size for internal vLLM batching.
55+
"""
56+
57+
name: str = "QwenASR_inference"
58+
model_id: str = "Qwen/Qwen3-ASR-0.6B"
59+
language: str | None = None
60+
waveform_key: str = "waveform"
61+
sample_rate_key: str = "sample_rate"
62+
pred_text_key: str = "qwen3_asr_prediction"
63+
language_key: str = "qwen3_asr_language"
64+
context_key: str | None = None
65+
run_only_if_key: str | None = None
66+
run_only_if_prefix: str = "Hallucination"
67+
gpu_memory_utilization: float = 0.7
68+
max_new_tokens: int = 4096
69+
max_inference_batch_size: int = 128
70+
resources: Resources = field(default_factory=lambda: Resources(gpus=1.0))
71+
batch_size: int = 128
72+
73+
def __post_init__(self) -> None:
74+
self._model: QwenASR | None = None
75+
76+
def _create_model(self) -> QwenASR:
77+
return QwenASR(
78+
model_id=self.model_id,
79+
language=self.language,
80+
gpu_memory_utilization=self.gpu_memory_utilization,
81+
max_new_tokens=self.max_new_tokens,
82+
max_inference_batch_size=self.max_inference_batch_size,
83+
)
84+
85+
# ------------------------------------------------------------------
86+
# Lifecycle
87+
# ------------------------------------------------------------------
88+
89+
def setup_on_node(
90+
self,
91+
_node_info: NodeInfo | None = None,
92+
_worker_metadata: WorkerMetadata | None = None,
93+
) -> None:
94+
self._model = self._create_model()
95+
self._model.setup()
96+
logger.info("QwenASR model ready on node")
97+
98+
def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None:
99+
if self._model is None:
100+
self._model = self._create_model()
101+
self._model.setup()
102+
103+
def teardown(self) -> None:
104+
if self._model is not None:
105+
self._model.teardown()
106+
self._model = None
107+
108+
# ------------------------------------------------------------------
109+
# I/O contract
110+
# ------------------------------------------------------------------
111+
112+
def inputs(self) -> tuple[list[str], list[str]]:
113+
return [], [self.waveform_key, self.sample_rate_key]
114+
115+
def outputs(self) -> tuple[list[str], list[str]]:
116+
return [], [self.pred_text_key, self.language_key]
117+
118+
# ------------------------------------------------------------------
119+
# Processing
120+
# ------------------------------------------------------------------
121+
122+
def process(self, task: AudioTask) -> AudioTask:
123+
msg = "InferenceQwenASRStage only supports process_batch"
124+
raise NotImplementedError(msg)
125+
126+
def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
127+
if not tasks:
128+
return []
129+
130+
if self._model is None:
131+
msg = "Model not initialized — setup() was not called"
132+
raise RuntimeError(msg)
133+
134+
for task in tasks:
135+
task.data.setdefault(self.pred_text_key, "")
136+
task.data.setdefault(self.language_key, "")
137+
138+
if self.run_only_if_key:
139+
run_indices = [
140+
i for i, t in enumerate(tasks)
141+
if str(t.data.get(self.run_only_if_key, "")).startswith(self.run_only_if_prefix)
142+
]
143+
else:
144+
run_indices = list(range(len(tasks)))
145+
146+
if not run_indices:
147+
for task in tasks:
148+
task.data.pop(self.waveform_key, None)
149+
logger.info(f"QwenASR: skipped entire batch of {len(tasks)} (none matched run_only_if_key)")
150+
return tasks
151+
152+
waveforms = [tasks[i].data[self.waveform_key] for i in run_indices]
153+
sample_rates = [tasks[i].data[self.sample_rate_key] for i in run_indices]
154+
contexts = (
155+
[tasks[i].data.get(self.context_key, "") for i in run_indices]
156+
if self.context_key else None
157+
)
158+
159+
pred_texts, languages = self._model.generate(waveforms, sample_rates, contexts)
160+
161+
for idx, pred, lang in zip(run_indices, pred_texts, languages, strict=True):
162+
tasks[idx].data[self.pred_text_key] = pred
163+
tasks[idx].data[self.language_key] = lang
164+
165+
for task in tasks:
166+
task.data.pop(self.waveform_key, None)
167+
168+
skipped = len(tasks) - len(run_indices)
169+
logger.info(f"QwenASR: generated {len(run_indices)} predictions, skipped {skipped}")
170+
return tasks

nemo_curator/stages/audio/inference/qwen_omni.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ class InferenceQwenOmniStage(ProcessingStage[AudioTask, AudioTask]):
8181
temperature: float = 0.0
8282
top_k: int = 1
8383
prep_workers: int = 8
84+
keep_waveform: bool = False
8485
resources: Resources = field(default_factory=lambda: Resources(gpus=1.0))
8586
batch_size: int = 32
8687

@@ -172,7 +173,8 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
172173
task.data[self.pred_text_key] = pred
173174
if self.followup_prompt:
174175
task.data[self.disfluency_text_key] = disfl
175-
task.data.pop(self.waveform_key, None)
176+
if not self.keep_waveform:
177+
task.data.pop(self.waveform_key, None)
176178

177179
logger.info(f"QwenOmni: generated {len(pred_texts)} predictions (turn2={bool(self.followup_prompt)})")
178180
return tasks

nemo_curator/stages/audio/text_filtering/fasttext_lid.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def _process_single(self, task: AudioTask) -> AudioTask:
105105
text = text.strip().replace("\n", " ")
106106
if not text:
107107
if not task.data[self.skip_me_key]:
108-
task.data[self.skip_me_key] = "Empty text"
108+
task.data[self.skip_me_key] = f"Empty text:{self.name}"
109109
return task
110110
result_str = self._lid.score_document(text)
111111
score_list = eval(result_str) # noqa: S307 — output of our own FastText model

0 commit comments

Comments
 (0)