Skip to content

Commit bddf57d

Browse files
committed
PR_#1839 from nithinraok/add-ml-prompt, multilingual verbatim transcription prompt for Qwen3-Omni
Adds language-agnostic single-turn ASR pseudo-labeling prompt for non-English audio. Unlike the English two-turn flow (transcription + disfluency followup), this prompt combines transcription and verbatim fidelity into one instruction, making the followup turn unnecessary for ML languages. - examples/audio/qwen_omni_inprocess/prompts/ml_qwen3_omni_disfluency_asr.md (uses {language} placeholder) - nemo_curator/models/qwen_omni.py: _resolve_prompt() helper + thread language through _build_messages, _build_turn2_messages, _prepare_single, _prepare_batch, _prepare_turn2_single, _prepare_turn2_batch, generate() - nemo_curator/stages/audio/inference/qwen_omni.py: source_lang_key field pulls per-sample language from manifest and passes to model.generate() - examples/audio/qwen_omni_inprocess/run_pipeline.py: --source_lang_key CLI Surgical squash cherry-pick of #1839 (additive bits only). Skipped FastTextLIDStage source_lang_key (would conflict with PR #1's source-tracking refactor) and initialize_fields drop (already handled). #NO_PR Signed-off-by: George Zelenfroynd <gzelenfroind@nvidia.com>
1 parent 79441a7 commit bddf57d

4 files changed

Lines changed: 61 additions & 20 deletions

File tree

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Transcribe the {language} audio into text exactly as the speaker says it. Write numbers as spoken words.

examples/audio/qwen_omni_inprocess/run_pipeline.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,9 @@ def _build_arg_parser() -> argparse.ArgumentParser:
113113
help="Path to regex substitution rules YAML.")
114114
tf.add_argument("--target_lang", type=str, default="en",
115115
help="Expected language code for LID filtering.")
116+
tf.add_argument("--source_lang_key", type=str, default="",
117+
help="Per-sample language key in manifest for {language} prompt substitution. "
118+
"Empty (default) disables per-sample language threading.")
116119
tf.add_argument("--min_lang_prob", type=float, default=0.8,
117120
help="Minimum FastText language probability to keep an entry.")
118121
tf.add_argument("--unique_words_threshold", type=float, default=0.4,
@@ -224,6 +227,7 @@ def main() -> None:
224227
max_num_seqs=args.max_num_seqs,
225228
gpu_memory_utilization=args.gpu_memory_utilization,
226229
prep_workers=args.prep_workers,
230+
source_lang_key=args.source_lang_key,
227231
pred_text_key="qwen3_prediction_s1",
228232
disfluency_text_key="qwen3_prediction_s2",
229233
keep_waveform=bool(args.asr_model_id),

nemo_curator/models/qwen_omni.py

Lines changed: 51 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -161,49 +161,66 @@ def _resample(waveform: np.ndarray, orig_sr: int, target_sr: int = _QWEN_SAMPLE_
161161

162162
return librosa.resample(waveform, orig_sr=orig_sr, target_sr=target_sr)
163163

164-
def _build_messages(self, waveform: np.ndarray) -> list[dict[str, Any]]:
165-
"""Build Turn 1 chat messages with an in-memory waveform (numpy array at 16 kHz)."""
164+
def _resolve_prompt(self, template: str, language: str | None) -> str:
165+
"""Replace ``{language}`` placeholder if *language* is provided."""
166+
if language and template and "{language}" in template:
167+
return template.replace("{language}", language)
168+
return template
169+
170+
def _build_messages(self, waveform: np.ndarray, language: str | None = None) -> list[dict[str, Any]]:
171+
"""Build Turn 1 chat messages with an in-memory waveform (numpy array at 16 kHz).
172+
173+
Prompts may contain a ``{language}`` placeholder which is replaced
174+
with *language* (e.g., ``"French"``) when provided.
175+
"""
176+
prompt = self._resolve_prompt(self.prompt_text, language)
166177
messages: list[dict[str, Any]] = []
167178
if self.system_prompt:
168-
messages.append({"role": "system", "content": [{"type": "text", "text": self.system_prompt}]})
179+
sys_prompt = self._resolve_prompt(self.system_prompt, language)
180+
messages.append({"role": "system", "content": [{"type": "text", "text": sys_prompt}]})
169181
messages.append({
170182
"role": "user",
171183
"content": [
172-
{"type": "text", "text": self.prompt_text},
184+
{"type": "text", "text": prompt},
173185
{"type": "audio", "audio": waveform},
174186
],
175187
})
176188
return messages
177189

178-
def _build_turn2_messages(self, waveform: np.ndarray, pred_text: str) -> list[dict[str, Any]]:
179-
"""Build Turn 2 messages: full Turn 1 conversation history + follow-up promt."""
190+
def _build_turn2_messages(
191+
self, waveform: np.ndarray, pred_text: str, language: str | None = None,
192+
) -> list[dict[str, Any]]:
193+
"""Build Turn 2 messages: full Turn 1 conversation history + follow-up prompt."""
194+
prompt = self._resolve_prompt(self.prompt_text, language)
195+
followup = self._resolve_prompt(self.followup_prompt, language)
180196
messages: list[dict[str, Any]] = []
181197
if self.system_prompt:
182-
messages.append({"role": "system", "content": [{"type": "text", "text": self.system_prompt}]})
198+
sys_prompt = self._resolve_prompt(self.system_prompt, language)
199+
messages.append({"role": "system", "content": [{"type": "text", "text": sys_prompt}]})
183200
messages.append({
184201
"role": "user",
185202
"content": [
186-
{"type": "text", "text": self.prompt_text},
203+
{"type": "text", "text": prompt},
187204
{"type": "audio", "audio": waveform},
188205
],
189206
})
190207
messages.append({"role": "assistant", "content": [{"type": "text", "text": pred_text}]})
191208
messages.append({
192209
"role": "user",
193210
"content": [
194-
{"type": "text", "text": self.followup_prompt},
211+
{"type": "text", "text": followup},
195212
],
196213
})
197214
return messages
198215

199216
def _prepare_single(
200-
self, waveform: np.ndarray, sample_rate: int,
217+
self, waveform: np.ndarray, sample_rate: int, language: str | None = None,
201218
) -> tuple[dict[str, Any], np.ndarray] | None:
202219
from qwen_omni_utils import process_mm_info
203220

204221
try:
205222
waveform_16k = self._resample(waveform, sample_rate)
206-
messages = self._build_messages(waveform_16k)
223+
messages = self._build_messages(waveform_16k, language)
207224
text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
208225
audios, images, videos = process_mm_info(messages, use_audio_in_video=False)
209226
except Exception: # noqa: BLE001
@@ -227,18 +244,23 @@ def _prepare_batch(
227244
self,
228245
waveforms: list[np.ndarray],
229246
sample_rates: list[int],
247+
languages: list[str | None] | None = None,
230248
) -> list[tuple[dict[str, Any], np.ndarray] | None]:
249+
langs = languages if languages is not None else [None] * len(waveforms)
231250
if self._prep_pool is None:
232-
return [self._prepare_single(w, sr) for w, sr in zip(waveforms, sample_rates, strict=False)]
233-
return list(self._prep_pool.map(self._prepare_single, waveforms, sample_rates))
251+
return [
252+
self._prepare_single(w, sr, lang)
253+
for w, sr, lang in zip(waveforms, sample_rates, langs, strict=False)
254+
]
255+
return list(self._prep_pool.map(self._prepare_single, waveforms, sample_rates, langs))
234256

235257
def _prepare_turn2_single(
236-
self, waveform_16k: np.ndarray, pred_text: str,
258+
self, waveform_16k: np.ndarray, pred_text: str, language: str | None = None,
237259
) -> dict[str, Any] | None:
238260
from qwen_omni_utils import process_mm_info
239261

240262
try:
241-
messages = self._build_turn2_messages(waveform_16k, pred_text)
263+
messages = self._build_turn2_messages(waveform_16k, pred_text, language)
242264
text = self._processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
243265
audios, images, videos = process_mm_info(messages, use_audio_in_video=False)
244266
except Exception: # noqa: BLE001
@@ -262,13 +284,15 @@ def _prepare_turn2_batch(
262284
self,
263285
waveforms_16k: list[np.ndarray],
264286
pred_texts: list[str],
287+
languages: list[str | None] | None = None,
265288
) -> list[dict[str, Any] | None]:
289+
langs = languages if languages is not None else [None] * len(waveforms_16k)
266290
if self._prep_pool is None:
267291
return [
268-
self._prepare_turn2_single(w, pt)
269-
for w, pt in zip(waveforms_16k, pred_texts, strict=False)
292+
self._prepare_turn2_single(w, pt, lang)
293+
for w, pt, lang in zip(waveforms_16k, pred_texts, langs, strict=False)
270294
]
271-
return list(self._prep_pool.map(self._prepare_turn2_single, waveforms_16k, pred_texts))
295+
return list(self._prep_pool.map(self._prepare_turn2_single, waveforms_16k, pred_texts, langs))
272296

273297
# ------------------------------------------------------------------
274298
# Generation
@@ -278,6 +302,7 @@ def generate(
278302
self,
279303
waveforms: list[np.ndarray],
280304
sample_rates: list[int],
305+
languages: list[str | None] | None = None,
281306
) -> tuple[list[str], list[str]]:
282307
"""Run batched two-turn inference on in-memory audio waveforms.
283308
@@ -288,6 +313,9 @@ def generate(
288313
Args:
289314
waveforms: List of 1-D mono numpy float32 arrays.
290315
sample_rates: Corresponding sample rates for each waveform.
316+
languages: Optional per-sample language strings for ``{language}``
317+
placeholder substitution in prompts. Length must match
318+
``waveforms``. Pass ``None`` (default) to skip substitution.
291319
292320
Returns:
293321
``(pred_texts, disfluency_texts)`` — one string per input for
@@ -301,7 +329,7 @@ def generate(
301329
n = len(waveforms)
302330

303331
# -- Turn 1 ----------------------------------------------------------
304-
prepared = self._prepare_batch(waveforms, sample_rates)
332+
prepared = self._prepare_batch(waveforms, sample_rates, languages)
305333
valid_indices = [i for i, p in enumerate(prepared) if p is not None]
306334
valid_inputs = [prepared[i][0] for i in valid_indices]
307335
waveforms_16k: dict[int, np.ndarray] = {i: prepared[i][1] for i in valid_indices}
@@ -327,9 +355,13 @@ def generate(
327355
if not t2_indices:
328356
return pred_texts, [""] * n
329357

358+
t2_languages = (
359+
[languages[i] for i in t2_indices] if languages is not None else None
360+
)
330361
t2_prepared = self._prepare_turn2_batch(
331362
[waveforms_16k[i] for i in t2_indices],
332363
[pred_texts[i] for i in t2_indices],
364+
t2_languages,
333365
)
334366

335367
t2_valid = [(i, p) for i, p in zip(t2_indices, t2_prepared, strict=False) if p is not None]

nemo_curator/stages/audio/inference/qwen_omni.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ class InferenceQwenOmniStage(ProcessingStage[AudioTask, AudioTask]):
7171
system_prompt: str | None = None
7272
waveform_key: str = "waveform"
7373
sample_rate_key: str = "sample_rate"
74+
source_lang_key: str = ""
7475
pred_text_key: str = "qwen3_prediction_s1"
7576
disfluency_text_key: str = "qwen3_prediction_s2"
7677
max_model_len: int = 32768
@@ -166,8 +167,11 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
166167

167168
waveforms = [t.data[self.waveform_key] for t in tasks]
168169
sample_rates = [t.data[self.sample_rate_key] for t in tasks]
170+
languages: list[str | None] | None = None
171+
if self.source_lang_key:
172+
languages = [t.data.get(self.source_lang_key) or None for t in tasks]
169173

170-
pred_texts, disfluency_texts = self._model.generate(waveforms, sample_rates)
174+
pred_texts, disfluency_texts = self._model.generate(waveforms, sample_rates, languages)
171175

172176
for task, pred, disfl in zip(tasks, pred_texts, disfluency_texts, strict=True):
173177
task.data[self.pred_text_key] = pred

0 commit comments

Comments
 (0)