Skip to content

Commit b60dc95

Browse files
committed
PR_#integration-test from nithinraok/Curator nkoluguri/integration-test, validated improvements on top of the 4 PRs
Squash cherry-pick of integration-test's unique commits on top of #1853 + #1 + #3 + #1839: - 633acc7 FastText and Hallucination update → SelectBestPredictionStage: cross-model WER agreement. If both omni and ASR are flagged hallucinated but agree (WER ≤ 100 - min_agreement_pct, default 80%), keep omni and mark recovered — two independent models producing near-identical text is strong evidence the text is correct. → FastTextLIDStage: HuggingFace-format model loader, proper _predict() abstraction, source-tracked _skip_me ("Wrong language:{name}"). - 5fdfa0a additional notes key + skip writing keys after skip_me + pnc prompt + prefill caching → Models (qwen_omni, qwen_asr, qwen_text_llm): notes_key field for diagnostic info, vLLM enable_prefix_caching=True with xxhash. → text_filtering stages: skip writing output keys when skip_me is set. → New file: prompts/pnc_prompt.md. - 15424e3 updated prompt for ITN → Sharper ITN prompt (handles more conversion edge cases). - 0cf8e6c match max model len for ITN and PnC → Aligned ITN/PnC max_model_len (4096), max_num_seqs (16), gpu_memory_utilization (0.95). Wired ITN args through run_pipeline. - 7e32df1 add Qwen3ASR for all → Apply QwenASR recovery to all hallucination flags, not just specific patterns. WhisperHallucinationStage tweaks. - caccd37 Add min word count for FastText → Re-adds min_word_count=2 (FastText is unreliable on single-word inputs). Conflict resolution: - run_pipeline.py: kept multi-line argparse style (ours), kept --source_lang_key, adopted theirs' ITN stage construction (with new max_model_len/num_seqs/gpu_mem args). - fasttext_lid.py: took theirs' richer process logic (min_word_count check, per-sample expected language via source_lang_key, source-tracked _skip_me values). #NO_PR Signed-off-by: George Zelenfroynd <gzelenfroind@nvidia.com>
1 parent bddf57d commit b60dc95

15 files changed

Lines changed: 217 additions & 65 deletions

examples/audio/qwen_omni_inprocess/run_pipeline.py

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@
8282
from nemo_curator.stages.resources import Resources
8383

8484

85-
def _build_arg_parser() -> argparse.ArgumentParser:
85+
def _build_arg_parser() -> argparse.ArgumentParser: # noqa: PLR0915
8686
ap = argparse.ArgumentParser(description="QwenOmni in-process vLLM pipeline")
8787
ap.add_argument("--data_config", type=str, required=True, help="Granary YAML data config.")
8888
ap.add_argument("--corpus", type=str, nargs="*", default=None, help="Process only these corpora.")
@@ -108,7 +108,7 @@ def _build_arg_parser() -> argparse.ArgumentParser:
108108
tf.add_argument("--hall_phrases", type=str, required=True,
109109
help="Path to hallucination phrases text file.")
110110
tf.add_argument("--fasttext_model", type=str, default="lid.176.ftz",
111-
help="FastText LID model: local path or known name (lid.176.bin / lid.176.ftz).")
111+
help="FastText LID model: HuggingFace repo ID, local path, or known name (lid.176.bin / lid.176.ftz).")
112112
tf.add_argument("--regex_yaml", type=str, required=True,
113113
help="Path to regex substitution rules YAML.")
114114
tf.add_argument("--target_lang", type=str, default="en",
@@ -170,11 +170,17 @@ def _build_arg_parser() -> argparse.ArgumentParser:
170170
help="TP size for ITN model (None = auto-detect).")
171171
itn.add_argument("--itn_max_output_tokens", type=int, default=4096,
172172
help="Max tokens to generate per ITN sample.")
173+
itn.add_argument("--itn_max_model_len", type=int, default=4096,
174+
help="Max context length for ITN vLLM engine.")
175+
itn.add_argument("--itn_max_num_seqs", type=int, default=16,
176+
help="Max concurrent sequences for ITN vLLM engine.")
177+
itn.add_argument("--itn_gpu_memory_utilization", type=float, default=0.95,
178+
help="Fraction of GPU memory for ITN vLLM engine.")
173179
itn.add_argument("--itn_no_validation", action="store_true", help="Disable ITN output validation.")
174180
return ap
175181

176182

177-
def main() -> None:
183+
def main() -> None: # noqa: C901
178184
args = _build_arg_parser().parse_args()
179185

180186
prompt = args.prompt
@@ -259,7 +265,6 @@ def main() -> None:
259265
batch_size=args.asr_batch_size,
260266
gpu_memory_utilization=args.asr_gpu_memory_utilization,
261267
max_new_tokens=args.asr_max_new_tokens,
262-
run_only_if_key="_skip_me",
263268
),
264269
WhisperHallucinationStage(
265270
name="WhisperHallucination_asr",
@@ -320,18 +325,19 @@ def main() -> None:
320325
])
321326

322327
if args.enable_itn:
323-
stages.append(
324-
ITNRestorationStage(
325-
model_id=args.itn_model_id,
326-
prompt_text=itn_prompt_text,
327-
text_key=args.itn_text_key or ("pnc_text" if not args.skip_pnc else "abbreviated_text"),
328-
output_text_key=args.itn_output_key,
329-
tensor_parallel_size=args.itn_tensor_parallel_size,
330-
max_output_tokens=args.itn_max_output_tokens,
331-
batch_size=args.itn_batch_size,
332-
enable_validation=not args.itn_no_validation,
333-
)
334-
)
328+
stages.append(ITNRestorationStage(
329+
model_id=args.itn_model_id,
330+
prompt_text=itn_prompt_text,
331+
text_key=args.itn_text_key or ("pnc_text" if not args.skip_pnc else "abbreviated_text"),
332+
output_text_key=args.itn_output_key,
333+
tensor_parallel_size=args.itn_tensor_parallel_size,
334+
max_output_tokens=args.itn_max_output_tokens,
335+
max_model_len=args.itn_max_model_len,
336+
max_num_seqs=args.itn_max_num_seqs,
337+
gpu_memory_utilization=args.itn_gpu_memory_utilization,
338+
batch_size=args.itn_batch_size,
339+
enable_validation=not args.itn_no_validation,
340+
))
335341

336342
stages.append(ShardedManifestWriterStage(output_dir=args.output_dir))
337343

nemo_curator/models/qwen_asr.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,12 @@ def _patch_transformers_compat() -> None:
8383
sig = inspect.signature(original)
8484
params = list(sig.parameters.values())
8585
if params and params[0].name == "func":
86-
def compat_check_model_inputs(*args, **kwargs):
86+
def compat_check_model_inputs(*args): # noqa: ANN202
8787
if args and callable(args[0]):
8888
return original(args[0])
8989
return original
9090
transformers.check_model_inputs = compat_check_model_inputs
91-
except Exception: # noqa: BLE001
91+
except Exception: # noqa: BLE001, S110
9292
pass
9393

9494
def setup(self) -> None:
@@ -114,6 +114,8 @@ def setup(self) -> None:
114114
max_new_tokens=self.max_new_tokens,
115115
trust_remote_code=True,
116116
enforce_eager=True,
117+
enable_prefix_caching=True,
118+
prefix_caching_hash_algo="xxhash",
117119
)
118120

119121
logger.info("QwenASR model loaded")

nemo_curator/models/qwen_omni.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,8 @@ def setup(self) -> None:
116116
max_num_seqs=self.max_num_seqs,
117117
max_model_len=self.max_model_len,
118118
seed=1234,
119+
enable_prefix_caching=True,
120+
prefix_caching_hash_algo="xxhash",
119121
)
120122

121123
from transformers import Qwen3OmniMoeProcessor

nemo_curator/models/qwen_text_llm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def setup(self) -> None:
122122
max_num_seqs=self.max_num_seqs,
123123
max_model_len=self.max_model_len,
124124
seed=1234,
125+
enable_prefix_caching=True,
126+
prefix_caching_hash_algo="xxhash",
125127
)
126128

127129
self._sampling_params = SamplingParams(

nemo_curator/stages/audio/text_filtering/abbreviation_concat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def outputs(self) -> tuple[list[str], list[str]]:
237237
def _process_single(self, task: AudioTask) -> AudioTask:
238238
skip = task.data.get(self.skip_me_key, "")
239239
if skip:
240-
task.data.setdefault(self.output_text_key, task.data.get(self.text_key, ""))
240+
task.data.setdefault(self.output_text_key, "")
241241
task.data.setdefault(self.abbreviations_key, [])
242242
return task
243243

nemo_curator/stages/audio/text_filtering/fasttext_lid.py

Lines changed: 71 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,26 @@
2929
}
3030
_DEFAULT_CACHE_DIR = os.path.expanduser("~/.cache/nemo_curator/fasttext")
3131

32+
_ISO639_3_TO_1: dict[str, str] = {
33+
"afr": "af", "amh": "am", "ara": "ar", "asm": "as", "aze": "az",
34+
"bel": "be", "ben": "bn", "bos": "bs", "bul": "bg", "cat": "ca",
35+
"ces": "cs", "cym": "cy", "dan": "da", "deu": "de", "ell": "el",
36+
"eng": "en", "est": "et", "eus": "eu", "fas": "fa", "fin": "fi",
37+
"fra": "fr", "gle": "ga", "glg": "gl", "guj": "gu", "hau": "ha",
38+
"heb": "he", "hin": "hi", "hrv": "hr", "hun": "hu", "hye": "hy",
39+
"ibo": "ig", "ind": "id", "isl": "is", "ita": "it", "jav": "jv",
40+
"jpn": "ja", "kan": "kn", "kat": "ka", "khm": "km", "kor": "ko",
41+
"lao": "lo", "lav": "lv", "lit": "lt", "mal": "ml", "mar": "mr",
42+
"mkd": "mk", "mon": "mn", "msa": "ms", "mya": "my", "nep": "ne",
43+
"nld": "nl", "nob": "nb", "nor": "no", "ori": "or", "pan": "pa",
44+
"pol": "pl", "por": "pt", "ron": "ro", "rus": "ru", "sin": "si",
45+
"slk": "sk", "slv": "sl", "som": "so", "spa": "es", "sqi": "sq",
46+
"srp": "sr", "sun": "su", "swa": "sw", "swe": "sv", "tam": "ta",
47+
"tel": "te", "tgl": "tl", "tha": "th", "tur": "tr", "ukr": "uk",
48+
"urd": "ur", "vie": "vi", "xho": "xh", "yor": "yo", "zho": "zh",
49+
"zul": "zu",
50+
}
51+
3252

3353
@dataclass
3454
class FastTextLIDStage(ProcessingStage[AudioTask, AudioTask]):
@@ -43,21 +63,29 @@ class FastTextLIDStage(ProcessingStage[AudioTask, AudioTask]):
4363
4464
An already non-empty ``skip_me`` value is never overwritten.
4565
66+
Texts with fewer than ``min_word_count`` words are passed through
67+
without LID filtering because FastText confidence is unreliable on
68+
very short inputs (especially single words).
69+
4670
``model_path`` can be:
71+
- A HuggingFace Hub repo ID (e.g.
72+
``facebook/fasttext-language-identification``), which is downloaded
73+
via ``huggingface_hub``.
4774
- An absolute path to a local ``.bin`` or ``.ftz`` file.
48-
- A known model name (``lid.176.bin`` or ``lid.176.ftz``), which is
75+
- A legacy model name (``lid.176.bin`` or ``lid.176.ftz``), which is
4976
downloaded to ``~/.cache/nemo_curator/fasttext/`` on first use.
5077
"""
5178

5279
model_path: str = ""
5380
target_lang: str = "en"
5481
min_lang_prob: float = 0.8
82+
min_word_count: int = 2
5583
text_key: str = "pred_text"
5684
skip_me_key: str = "_skip_me"
5785
name: str = "FastTextLID"
5886
resources: Resources = field(default_factory=lambda: Resources(cpus=1.0))
5987

60-
_lid: Any = field(default=None, init=False, repr=False)
88+
_model: Any = field(default=None, init=False, repr=False)
6189

6290
def __post_init__(self) -> None:
6391
if not self.model_path:
@@ -76,18 +104,39 @@ def _resolve_model_path(self) -> str:
76104
logger.info(f"FastTextLIDStage: downloading {self.model_path} from {url}")
77105
urllib.request.urlretrieve(url, cache_path) # noqa: S310
78106
return cache_path
107+
if "/" in self.model_path:
108+
try:
109+
from huggingface_hub import hf_hub_download
110+
111+
return hf_hub_download(repo_id=self.model_path, filename="model.bin")
112+
except Exception as exc:
113+
msg = f"Failed to download '{self.model_path}' from HuggingFace Hub: {exc}"
114+
raise ValueError(msg) from exc
79115
msg = (
80-
f"model_path '{self.model_path}' is not a valid file path and not a known model name. "
81-
f"Known names: {list(_FASTTEXT_MODEL_URLS)}"
116+
f"model_path '{self.model_path}' is not a valid file path, a known model name, "
117+
f"or a HuggingFace repo ID. Known names: {list(_FASTTEXT_MODEL_URLS)}"
82118
)
83119
raise ValueError(msg)
84120

121+
@staticmethod
122+
def _parse_label(raw_label: str) -> str:
123+
"""Extract a 2-letter ISO 639-1 language code from a fasttext label.
124+
125+
Handles both the legacy format (``__label__en``) and the HuggingFace
126+
``facebook/fasttext-language-identification`` format
127+
(``__label__eng_Latn``).
128+
"""
129+
lang_part = raw_label.replace("__label__", "")
130+
if "_" in lang_part:
131+
iso3 = lang_part.split("_", 1)[0]
132+
return _ISO639_3_TO_1.get(iso3, iso3).lower()
133+
return lang_part.lower()
134+
85135
def setup(self, _worker_metadata: object | None = None) -> None:
86-
from nemo_curator.stages.text.filters.fasttext.fasttext_filters import FastTextLangId
136+
import fasttext
87137

88138
resolved = self._resolve_model_path()
89-
self._lid = FastTextLangId(model_path=resolved, min_langid_score=0.0)
90-
self._lid.load_model()
139+
self._model = fasttext.load_model(resolved)
91140
logger.info(f"FastTextLIDStage: loaded model from {resolved}")
92141

93142
def inputs(self) -> tuple[list[str], list[str]]:
@@ -96,6 +145,10 @@ def inputs(self) -> tuple[list[str], list[str]]:
96145
def outputs(self) -> tuple[list[str], list[str]]:
97146
return [], [self.skip_me_key]
98147

148+
def _predict(self, text: str) -> tuple[str, float]:
149+
labels, scores = self._model.predict([text], k=1)
150+
return self._parse_label(labels[0][0]), scores[0][0].item()
151+
99152
def _process_single(self, task: AudioTask) -> AudioTask:
100153
if task.data.get(self.skip_me_key, ""):
101154
return task
@@ -107,19 +160,21 @@ def _process_single(self, task: AudioTask) -> AudioTask:
107160
if not task.data[self.skip_me_key]:
108161
task.data[self.skip_me_key] = f"Empty text:{self.name}"
109162
return task
110-
result_str = self._lid.score_document(text)
111-
score_list = eval(result_str) # noqa: S307 — output of our own FastText model
112-
prob = float(score_list[0])
113-
lang = str(score_list[1]).lower()
163+
if len(text.split()) < self.min_word_count:
164+
return task
165+
lang, prob = self._predict(text)
166+
expected = self.target_lang
167+
if self.source_lang_key and self.source_lang_key in task.data:
168+
expected = task.data[self.source_lang_key]
114169
if not task.data[self.skip_me_key]:
115-
if lang != self.target_lang.lower():
116-
task.data[self.skip_me_key] = "Wrong language"
170+
if lang != expected.lower():
171+
task.data[self.skip_me_key] = f"Wrong language:{self.name}"
117172
elif prob < self.min_lang_prob:
118-
task.data[self.skip_me_key] = "Low probability of language"
173+
task.data[self.skip_me_key] = f"Low probability of language:{self.name}"
119174
return task
120175

121176
def process(self, task: AudioTask) -> AudioTask:
122-
if self._lid is None:
177+
if self._model is None:
123178
logger.warning(
124179
f"FastTextLIDStage ({self.name}): setup() was not called before process(). "
125180
"Calling setup() now — check that your executor invokes setup() on each worker."
@@ -128,7 +183,7 @@ def process(self, task: AudioTask) -> AudioTask:
128183
return self._process_single(task)
129184

130185
def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
131-
if self._lid is None:
186+
if self._model is None:
132187
logger.warning(
133188
f"FastTextLIDStage ({self.name}): setup() was not called before process_batch(). "
134189
"Calling setup() now — check that your executor invokes setup() on each worker."

nemo_curator/stages/audio/text_filtering/initialize_fields.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from dataclasses import dataclass, field # noqa: I001
15+
from dataclasses import dataclass, field
1616

1717
from nemo_curator.stages.base import ProcessingStage
1818
from nemo_curator.stages.resources import Resources

nemo_curator/stages/audio/text_filtering/itn_restoration.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -196,9 +196,9 @@ class ITNRestorationStage(ProcessingStage[AudioTask, AudioTask]):
196196
itn_filtered_key: str = "itn_filtered"
197197
enable_validation: bool = True
198198
tensor_parallel_size: int | None = None
199-
max_output_tokens: int = 4096
200-
max_model_len: int = 32768
201-
max_num_seqs: int = 256
199+
max_output_tokens: int = 512
200+
max_model_len: int = 4096
201+
max_num_seqs: int = 16
202202
gpu_memory_utilization: float = 0.95
203203
kv_cache_dtype: str = "fp8"
204204
resources: Resources = field(default_factory=lambda: Resources(gpus=1.0))
@@ -354,7 +354,10 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
354354
for i, task in enumerate(tasks):
355355
text = task.data.get(self.text_key, "")
356356
skip = task.data.get(self.skip_me_key, "")
357-
if not text or not text.strip() or skip:
357+
if skip:
358+
task.data[self.output_text_key] = ""
359+
continue
360+
if not text or not text.strip():
358361
task.data[self.output_text_key] = text
359362
continue
360363
valid_indices.append(i)

nemo_curator/stages/audio/text_filtering/pnc_content_guard.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ def outputs(self) -> tuple[list[str], list[str]]:
6161

6262
def _process_single(self, task: AudioTask) -> AudioTask:
6363
if task.data.get(self.skip_me_key, ""):
64+
task.data.setdefault(self.pnc_text_key, "")
6465
task.data.setdefault(self.rejected_text_key, "")
6566
return task
6667
original = task.data.get(self.text_key, "")

nemo_curator/stages/audio/text_filtering/pnc_restoration.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from __future__ import annotations
1616

1717
from dataclasses import dataclass, field
18+
from pathlib import Path
1819
from typing import TYPE_CHECKING
1920

2021
from loguru import logger
@@ -27,6 +28,8 @@
2728
from nemo_curator.stages.resources import Resources
2829
from nemo_curator.tasks import AudioTask
2930

31+
_DEFAULT_PNC_PROMPT_PATH = Path(__file__).resolve().parent / "prompts" / "pnc_prompt.md"
32+
3033

3134
@dataclass
3235
class PnCRestorationStage(ProcessingStage[AudioTask, AudioTask]):
@@ -82,13 +85,11 @@ class PnCRestorationStage(ProcessingStage[AudioTask, AudioTask]):
8285
'Answer only "yes" or "no".\n\n'
8386
"Text: {text}"
8487
)
85-
pnc_prompt: str = (
86-
"Restore proper punctuation and capitalization to the following text. "
87-
"Output only the corrected text, nothing else.\n\nText: {text}"
88-
)
88+
pnc_prompt: str | None = None
89+
pnc_prompt_file: str | None = None
8990
system_prompt: str | None = None
90-
max_model_len: int = 8192
91-
max_num_seqs: int = 64
91+
max_model_len: int = 4096
92+
max_num_seqs: int = 16
9293
gpu_memory_utilization: float = 0.95
9394
tensor_parallel_size: int | None = None
9495
max_output_tokens: int = 512
@@ -98,17 +99,25 @@ class PnCRestorationStage(ProcessingStage[AudioTask, AudioTask]):
9899
batch_size: int = 64
99100
resources: Resources = field(default_factory=lambda: Resources(gpus=1.0))
100101

102+
def _resolve_pnc_prompt(self) -> str:
103+
if self.pnc_prompt:
104+
return self.pnc_prompt
105+
path = Path(self.pnc_prompt_file) if self.pnc_prompt_file else _DEFAULT_PNC_PROMPT_PATH
106+
logger.info("PnCRestoration: loading prompt from {}", path)
107+
return path.read_text(encoding="utf-8").strip()
108+
101109
def __post_init__(self) -> None:
102110
self._model: QwenTextLLM | None = None
103111
tp = self.tensor_parallel_size
104112
if tp and tp > 0:
105113
self.resources = Resources(gpus=float(tp))
106114

107115
def _create_model(self) -> QwenTextLLM:
116+
pnc_prompt = self._resolve_pnc_prompt()
108117
return QwenTextLLM(
109118
model_id=self.model_id,
110119
completeness_prompt=self.completeness_prompt,
111-
pnc_prompt=self.pnc_prompt,
120+
pnc_prompt=pnc_prompt,
112121
system_prompt=self.system_prompt,
113122
max_model_len=self.max_model_len,
114123
max_num_seqs=self.max_num_seqs,
@@ -175,7 +184,9 @@ def process_batch(self, tasks: list[AudioTask]) -> list[AudioTask]:
175184
for i, task in enumerate(tasks):
176185
skip = task.data.get(self.skip_me_key, "")
177186
text = task.data.get(self.text_key, "")
178-
if skip or not text.strip():
187+
if skip:
188+
task.data[self.output_text_key] = ""
189+
elif not text.strip():
179190
task.data[self.output_text_key] = text
180191
else:
181192
eligible_indices.append(i)

0 commit comments

Comments
 (0)