From 81e54cb6b2254d76045241a7b4e345a4d2f7675a Mon Sep 17 00:00:00 2001 From: mmkrtchyan Date: Wed, 8 Apr 2026 17:51:02 +0400 Subject: [PATCH 1/4] Fix Sortformer tutorial issues and add InferenceSortformerStage benchmark Signed-off-by: mmkrtchyan --- .../scripts/audio_sortformer_benchmark.py | 151 ++++++++++++++++++ .../stages/audio/inference/sortformer.py | 27 ++-- tutorials/audio/callhome_diar/README.md | 5 +- tutorials/audio/callhome_diar/run.py | 8 +- 4 files changed, 176 insertions(+), 15 deletions(-) create mode 100644 benchmarking/scripts/audio_sortformer_benchmark.py diff --git a/benchmarking/scripts/audio_sortformer_benchmark.py b/benchmarking/scripts/audio_sortformer_benchmark.py new file mode 100644 index 0000000000..1c9f38b7bc --- /dev/null +++ b/benchmarking/scripts/audio_sortformer_benchmark.py @@ -0,0 +1,151 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio Sortformer diarization benchmarking script. + +This script runs Streaming Sortformer diarization benchmarks with +comprehensive metrics collection including real-time factor (RTF), +per-file segment counts, and throughput. +""" + +import argparse +import time +import traceback +from pathlib import Path +from typing import Any + +from loguru import logger +from utils import setup_executor, write_benchmark_results + +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader +from nemo_curator.stages.audio.inference.sortformer import InferenceSortformerStage +from nemo_curator.stages.resources import Resources + + +def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any]: + """Extract diarization-specific metrics from output tasks.""" + num_files = len(tasks) if tasks else 0 + total_audio_duration_s = 0.0 + total_segments = 0 + + for task in tasks or []: + data = task.data if hasattr(task, "data") else {} + total_audio_duration_s += float(data.get("duration", 0)) + segments = data.get("diar_segments", []) + total_segments += len(segments) + + throughput = num_files / elapsed_s if elapsed_s > 0 else 0.0 + rtf = elapsed_s / total_audio_duration_s if total_audio_duration_s > 0 else 0.0 + + return { + "is_success": True, + "num_files_processed": num_files, + "exec_time_s": round(elapsed_s, 2), + "total_audio_duration_s": round(total_audio_duration_s, 2), + "total_segments_detected": total_segments, + "real_time_factor": round(rtf, 4), + "throughput_files_per_sec": round(throughput, 4), + } + + +def run_audio_sortformer_benchmark( # noqa: PLR0913 + benchmark_results_path: str, + manifest_path: str, + model_name: str, + gpus: int, + rttm_out_dir: str | None = None, + executor: str = "xenna", + **kwargs, # noqa: ARG001 +) -> dict[str, Any]: + """Run the audio Sortformer diarization benchmark and collect metrics.""" + benchmark_results_path = Path(benchmark_results_path) + + logger.info("Starting audio Sortformer diarization benchmark") + logger.info(f"Executor: {executor}") + logger.info(f"Model: {model_name}") + logger.info(f"Manifest: {manifest_path}") + logger.info(f"GPUs: {gpus}") + + executor_obj = setup_executor(executor) + pipeline = Pipeline( + name="audio_sortformer_diarization", + description="Streaming Sortformer speaker diarization inference.", + ) + + pipeline.add_stage(ALMManifestReader(manifest_path=manifest_path)) + pipeline.add_stage( + InferenceSortformerStage( + model_name=model_name, + rttm_out_dir=rttm_out_dir, + ).with_(resources=Resources(gpus=gpus)), + ) + + t0 = time.perf_counter() + results = pipeline.run(executor_obj) + elapsed_s = time.perf_counter() - t0 + + metrics = _collect_diarization_metrics(results, elapsed_s) + + logger.success( + f"Benchmark completed: {metrics['num_files_processed']} files in {elapsed_s:.1f}s " + f"(RTF={metrics['real_time_factor']:.3f}, {metrics['throughput_files_per_sec']:.2f} files/sec)" + ) + + return { + "metrics": metrics, + "tasks": results, + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="Audio Sortformer diarization benchmark for nightly benchmarking") + parser.add_argument("--benchmark-results-path", required=True, help="Path to benchmark results") + parser.add_argument("--manifest-path", required=True, help="Path to input JSONL manifest") + parser.add_argument( + "--model-name", + default="nvidia/diar_streaming_sortformer_4spk-v2", + help="HF Sortformer model id", + ) + parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use") + parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use") + parser.add_argument("--rttm-out-dir", default=None, help="Optional directory to write RTTM output files") + + args = parser.parse_args() + + logger.info("=== Audio Sortformer Diarization Benchmark Starting ===") + logger.info(f"Arguments: {vars(args)}") + + success_code = 1 + result_dict: dict[str, Any] = { + "params": vars(args), + "metrics": { + "is_success": False, + }, + "tasks": [], + } + try: + result_dict.update(run_audio_sortformer_benchmark(**vars(args))) + success_code = 0 if result_dict["metrics"]["is_success"] else 1 + except Exception as e: + error_traceback = traceback.format_exc() + logger.error(f"Benchmark failed: {e}") + logger.debug(f"Full traceback:\n{error_traceback}") + finally: + write_benchmark_results(result_dict, args.benchmark_results_path) + return success_code + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/nemo_curator/stages/audio/inference/sortformer.py b/nemo_curator/stages/audio/inference/sortformer.py index 7f33c5f31a..f68a630dec 100644 --- a/nemo_curator/stages/audio/inference/sortformer.py +++ b/nemo_curator/stages/audio/inference/sortformer.py @@ -130,14 +130,20 @@ def setup_on_node( if self.model_path is not None: return try: - repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) - nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")] - if nemo_files: - self.model_path = os.path.join(repo_dir, nemo_files[0]) - else: - logger.warning(f"No .nemo file found in {repo_dir}; setup() will fail") + snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) except Exception: # noqa: BLE001 - logger.info(f"Could not pre-cache {self.model_name}; actors will download on first use") + logger.info(f"Could not pre-cache {self.model_name}; workers will download on first use") + + def _resolve_model_path(self) -> str: + """Resolve the path to the .nemo checkpoint, downloading if needed.""" + if self.model_path is not None: + return self.model_path + repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) + nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")] + if not nemo_files: + msg = f"No .nemo file found in {repo_dir} for model {self.model_name}" + raise FileNotFoundError(msg) + return os.path.join(repo_dir, nemo_files[0]) def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: """Load Sortformer model from Hugging Face or a local .nemo file.""" @@ -146,8 +152,9 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: self._configure_streaming() return + resolved_path = self._resolve_model_path() self.diar_model = SortformerEncLabelModel.restore_from( - restore_path=self.model_path, + restore_path=resolved_path, map_location="cuda", strict=False, ) @@ -189,9 +196,7 @@ def process(self, task: AudioTask) -> AudioTask: file_path = task.data[self.filepath_key] sess_name = task.data.get("session_name") - resolved_sess_name = ( - sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0] - ) + resolved_sess_name = sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0] all_segments = self.diarize([file_path]) segments = all_segments[0] diff --git a/tutorials/audio/callhome_diar/README.md b/tutorials/audio/callhome_diar/README.md index 7194f77601..1c68c0f8ea 100644 --- a/tutorials/audio/callhome_diar/README.md +++ b/tutorials/audio/callhome_diar/README.md @@ -8,7 +8,7 @@ Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput. - Python 3.10+ - NeMo Curator installed (see [installation guide](https://docs.nvidia.com/nemo/curator/latest/admin/installation.html)) -- [`sox`](https://sox.sourceforge.net/) command-line tool (for stereo-to-mono conversion; install via `apt install sox`, `brew install sox`, or `conda install -c conda-forge sox`) +- [`ffmpeg`](https://ffmpeg.org/) command-line tool (for stereo-to-mono conversion; pre-installed in the NeMo Curator container) - CallHome English dataset with `.wav` files and `eng/*.cha` ground-truth annotations ### Dataset layout @@ -67,7 +67,7 @@ All values are in **80 ms frames**. Override via `--chunk-len`, `--chunk-right-c ## What the script does 1. **File discovery (`CallHomeReaderStage`)** — Scans the dataset directory for WAV files with matching `.cha` annotations, skipping already-processed files. Emits one `AudioTask` per file. -2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `sox` so the model sees both speakers. +2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `ffmpeg` so the model sees both speakers. 3. **Diarization inference (`InferenceSortformerStage`)** — Runs Streaming Sortformer on each mono file. Also writes RTTM files to `/rttm/`. 4. **DER evaluation (`DERComputationStage`)** — Compares predicted segments against CHA ground truth. Scoring is restricted to the UEM region (min/max annotated timestamps from CHA) with a configurable collar tolerance (default 0.25 s). @@ -116,3 +116,4 @@ results = pipeline.run(executor=XennaExecutor()) - Maximum 4 speakers per recording - Trained primarily on English speech - Performance may degrade on noisy or very long recordings +- Audio must be mono 16 kHz; running on raw stereo or narrow-band (8 kHz) files without proper conversion will produce very high false-alarm rates diff --git a/tutorials/audio/callhome_diar/run.py b/tutorials/audio/callhome_diar/run.py index efb85e4569..297e7cbc93 100644 --- a/tutorials/audio/callhome_diar/run.py +++ b/tutorials/audio/callhome_diar/run.py @@ -34,6 +34,7 @@ from collections import Counter from dataclasses import dataclass from pathlib import Path +from typing import Any from loguru import logger @@ -144,6 +145,9 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return ["data"], [self.filepath_key] + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers_per_node": 1} + def process(self, task: _EmptyTask) -> list[AudioTask]: # noqa: ARG002 cha_path = Path(self.cha_dir) done = {p.stem for p in Path(self.rttm_out_dir).glob("*.rttm")} if self.rttm_out_dir else set() @@ -164,7 +168,7 @@ def process(self, task: _EmptyTask) -> list[AudioTask]: # noqa: ARG002 @dataclass class EnsureMonoStage(ProcessingStage[AudioTask, AudioTask]): - """Downmix stereo WAVs to mono 16 kHz via sox.""" + """Downmix stereo WAVs to mono 16 kHz via ffmpeg.""" mono_dir: str = "mono" filepath_key: str = "audio_filepath" @@ -182,7 +186,7 @@ def _ensure_mono(self, wav_path: str) -> str: return mono_path os.makedirs(self.mono_dir, exist_ok=True) subprocess.run( # noqa: S603 - ["sox", wav_path, "-c", "1", "-r", "16000", mono_path], # noqa: S607 + ["ffmpeg", "-i", wav_path, "-ac", "1", "-ar", "16000", "-y", mono_path], # noqa: S607 check=True, capture_output=True, ) From fd875760df764150ae54bd5dfa81727ea47edef4 Mon Sep 17 00:00:00 2001 From: mmkrtchyan Date: Tue, 21 Apr 2026 20:08:27 +0400 Subject: [PATCH 2/4] Address PR review feedback and update model to v2.1 Signed-off-by: mmkrtchyan --- benchmarking/nightly-benchmark.yaml | 29 ++++++++++++ .../scripts/audio_sortformer_benchmark.py | 11 ++++- .../stages/audio/inference/sortformer.py | 45 ++++++++++++++++--- tutorials/audio/callhome_diar/README.md | 6 +-- tutorials/audio/callhome_diar/run.py | 2 +- 5 files changed, 80 insertions(+), 13 deletions(-) diff --git a/benchmarking/nightly-benchmark.yaml b/benchmarking/nightly-benchmark.yaml index ab0b66f3e0..607719036e 100644 --- a/benchmarking/nightly-benchmark.yaml +++ b/benchmarking/nightly-benchmark.yaml @@ -631,6 +631,35 @@ entries: - metric: throughput_images_per_sec min_value: 3.0 + - name: audio_sortformer_xenna + enabled: false + script: audio_sortformer_benchmark.py + args: >- + --benchmark-results-path={session_entry_dir} + --manifest-path={datasets_path}/sortformer_diarization/manifest.jsonl + --model-name=nvidia/diar_streaming_sortformer_4spk-v2.1 + --rttm-out-dir={session_entry_dir}/scratch/rttm + --gpus=1 + timeout_s: 1800 + sink_data: + - name: slack + additional_metrics: + - num_files_processed + - throughput_files_per_sec + - real_time_factor + - total_segments_detected + ping_on_failure: + - U03C41SNADV # Aaftab V + ray: + num_cpus: 64 + num_gpus: 4 + enable_object_spilling: false + requirements: + - metric: is_success + exact_value: true + - metric: num_files_processed + min_value: 1 + - name: audio_fleurs_xenna enabled: true script: audio_fleurs_benchmark.py diff --git a/benchmarking/scripts/audio_sortformer_benchmark.py b/benchmarking/scripts/audio_sortformer_benchmark.py index 1c9f38b7bc..922ba38014 100644 --- a/benchmarking/scripts/audio_sortformer_benchmark.py +++ b/benchmarking/scripts/audio_sortformer_benchmark.py @@ -50,7 +50,7 @@ def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any rtf = elapsed_s / total_audio_duration_s if total_audio_duration_s > 0 else 0.0 return { - "is_success": True, + "is_success": num_files > 0, "num_files_processed": num_files, "exec_time_s": round(elapsed_s, 2), "total_audio_duration_s": round(total_audio_duration_s, 2), @@ -104,6 +104,13 @@ def run_audio_sortformer_benchmark( # noqa: PLR0913 ) return { + "params": { + "executor": executor, + "manifest_path": manifest_path, + "model_name": model_name, + "gpus": gpus, + "rttm_out_dir": rttm_out_dir, + }, "metrics": metrics, "tasks": results, } @@ -115,7 +122,7 @@ def main() -> int: parser.add_argument("--manifest-path", required=True, help="Path to input JSONL manifest") parser.add_argument( "--model-name", - default="nvidia/diar_streaming_sortformer_4spk-v2", + default="nvidia/diar_streaming_sortformer_4spk-v2.1", help="HF Sortformer model id", ) parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use") diff --git a/nemo_curator/stages/audio/inference/sortformer.py b/nemo_curator/stages/audio/inference/sortformer.py index f68a630dec..4117ba9663 100644 --- a/nemo_curator/stages/audio/inference/sortformer.py +++ b/nemo_curator/stages/audio/inference/sortformer.py @@ -87,10 +87,10 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): Uses the NeMo SortformerEncLabelModel for end-to-end neural speaker diarization with streaming support. See: - https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 + https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1 Args: - model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2". + model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2.1". model_path: Local path to a .nemo checkpoint file; if set, takes precedence over model_name. cache_dir: Directory for caching downloaded model weights. Defaults to HF hub default. diar_model: Pre-loaded SortformerEncLabelModel; if provided, setup() is a no-op. @@ -98,15 +98,17 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): diar_segments_key: Key in output data for diarization segments list. Defaults to "diar_segments". rttm_out_dir: Optional directory to write RTTM files. Defaults to None. chunk_len: Streaming chunk size in 80 ms frames. Defaults to 340 (~30.4 s latency). + chunk_left_context: Left context frames. Defaults to 1. chunk_right_context: Right context frames. Defaults to 40. fifo_len: FIFO queue size in frames. Defaults to 40. spkcache_update_period: Speaker cache update period in frames. Defaults to 300. spkcache_len: Speaker cache size in frames. Defaults to 188. inference_batch_size: Batch size passed to diarize(). Defaults to 1. + batch_duration: Maximum total audio duration (seconds) per lhotse batch. Defaults to 100000. name: Stage name. Defaults to "Sortformer_inference". """ - model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2" + model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2.1" model_path: str | None = None cache_dir: str | None = None diar_model: Any | None = None @@ -114,11 +116,13 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): diar_segments_key: str = "diar_segments" rttm_out_dir: str | None = None chunk_len: int = 340 + chunk_left_context: int = 1 chunk_right_context: int = 40 fifo_len: int = 40 spkcache_update_period: int = 300 spkcache_len: int = 188 inference_batch_size: int = 1 + batch_duration: int = 100000 name: str = "Sortformer_inference" batch_size: int = 1 resources: Resources = field(default_factory=lambda: Resources(cpus=1.0, gpu_memory_gb=8.0)) @@ -130,7 +134,7 @@ def setup_on_node( if self.model_path is not None: return try: - snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) + self._cached_repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) except Exception: # noqa: BLE001 logger.info(f"Could not pre-cache {self.model_name}; workers will download on first use") @@ -138,8 +142,10 @@ def _resolve_model_path(self) -> str: """Resolve the path to the .nemo checkpoint, downloading if needed.""" if self.model_path is not None: return self.model_path - repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) - nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")] + repo_dir = getattr(self, "_cached_repo_dir", None) or snapshot_download( + repo_id=self.model_name, cache_dir=self.cache_dir + ) + nemo_files = sorted(f for f in os.listdir(repo_dir) if f.endswith(".nemo")) if not nemo_files: msg = f"No .nemo file found in {repo_dir} for model {self.model_name}" raise FileNotFoundError(msg) @@ -150,6 +156,7 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: if self.diar_model is not None: self.diar_model.eval() self._configure_streaming() + self._extend_pos_enc_for_long_audio() return resolved_path = self._resolve_model_path() @@ -161,6 +168,28 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: self.diar_model.eval() self._configure_streaming() + self._extend_pos_enc_for_long_audio() + + def _extend_pos_enc_for_long_audio(self, max_len: int = 30000) -> None: + """Extend RelPositionalEncoding buffer to handle long audio files. + + NeMo's streaming Sortformer initialises pos_enc sized for one chunk (~35 + conformer frames). Files longer than a few seconds overflow it at inference + time. extend_pe() is a NeMo method that resizes the buffer safely — it just + isn't called automatically. max_len=30000 covers ~1000 s at any subsampling. + """ + import torch + + pos_enc = getattr(getattr(self.diar_model, "encoder", None), "pos_enc", None) + if pos_enc is None or not hasattr(pos_enc, "extend_pe"): + logger.warning("pos_enc not found or no extend_pe method — skipping extension") + return + device = next(self.diar_model.parameters()).device + try: + pos_enc.extend_pe(max_len, device, torch.float32) + logger.info(f"Extended encoder pos_enc to max_len={max_len} for long-form audio") + except Exception as e: # noqa: BLE001 + logger.warning(f"Could not extend pos_enc: {e}") def _configure_streaming(self) -> None: """Apply streaming configuration to the loaded model.""" @@ -168,7 +197,9 @@ def _configure_streaming(self) -> None: sm.chunk_len = self.chunk_len sm.chunk_right_context = self.chunk_right_context sm.fifo_len = self.fifo_len - sm.spkcache_update_period = self.spkcache_update_period + sm.chunk_left_context = self.chunk_left_context + if hasattr(sm, "spkcache_update_period"): + sm.spkcache_update_period = self.spkcache_update_period sm.spkcache_len = self.spkcache_len def inputs(self) -> tuple[list[str], list[str]]: diff --git a/tutorials/audio/callhome_diar/README.md b/tutorials/audio/callhome_diar/README.md index 1c68c0f8ea..0bf5665120 100644 --- a/tutorials/audio/callhome_diar/README.md +++ b/tutorials/audio/callhome_diar/README.md @@ -1,6 +1,6 @@ # Speaker Diarization on CallHome English with NeMo Curator -This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER). +This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER). Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput. @@ -51,7 +51,7 @@ Key arguments: | `--output-dir` | `output` | Root for RTTM files, results JSON, and checkpoints | | `--collar` | `0.25` | Collar tolerance (seconds) for DER scoring | | `--clean` | off | Remove entire output directory before re-running | -| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2` | Hugging Face model id | +| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2.1` | Hugging Face model id | ### Streaming configuration @@ -102,7 +102,7 @@ pipeline = Pipeline( stages=[ MyAudioReaderStage(data_dir="/path/to/audio"), # your reader stage InferenceSortformerStage( - model_name="nvidia/diar_streaming_sortformer_4spk-v2", + model_name="nvidia/diar_streaming_sortformer_4spk-v2.1", rttm_out_dir="./rttm", ), ], diff --git a/tutorials/audio/callhome_diar/run.py b/tutorials/audio/callhome_diar/run.py index 297e7cbc93..ed19d8dfbc 100644 --- a/tutorials/audio/callhome_diar/run.py +++ b/tutorials/audio/callhome_diar/run.py @@ -107,7 +107,7 @@ def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Sortformer diarization on CallHome English + DER evaluation.") p.add_argument("--data-dir", type=Path, required=True, help="CallHome-eng0 dataset root.") p.add_argument("--output-dir", type=Path, default=Path("output"), help="Root directory for all outputs.") - p.add_argument("--model", default="nvidia/diar_streaming_sortformer_4spk-v2", help="HF Sortformer model id.") + p.add_argument("--model", default="nvidia/diar_streaming_sortformer_4spk-v2.1", help="HF Sortformer model id.") p.add_argument("--collar", type=float, default=COLLAR, help="Collar tolerance (seconds).") p.add_argument("--clean", action="store_true", help="Remove entire output directory before running.") p.add_argument("--chunk-len", type=int, default=340, help="Streaming chunk size in 80ms frames.") From 5c3be571d5667cae7a150c37045f7cbfa1ae2f86 Mon Sep 17 00:00:00 2001 From: mmkrtchyan Date: Thu, 23 Apr 2026 14:54:47 +0400 Subject: [PATCH 3/4] Address second round of review feedback Signed-off-by: mmkrtchyan --- benchmarking/nightly-benchmark.yaml | 1 - .../scripts/audio_sortformer_benchmark.py | 9 ++------- .../stages/audio/inference/sortformer.py | 18 +++++------------- 3 files changed, 7 insertions(+), 21 deletions(-) diff --git a/benchmarking/nightly-benchmark.yaml b/benchmarking/nightly-benchmark.yaml index 607719036e..b924248e7c 100644 --- a/benchmarking/nightly-benchmark.yaml +++ b/benchmarking/nightly-benchmark.yaml @@ -639,7 +639,6 @@ entries: --manifest-path={datasets_path}/sortformer_diarization/manifest.jsonl --model-name=nvidia/diar_streaming_sortformer_4spk-v2.1 --rttm-out-dir={session_entry_dir}/scratch/rttm - --gpus=1 timeout_s: 1800 sink_data: - name: slack diff --git a/benchmarking/scripts/audio_sortformer_benchmark.py b/benchmarking/scripts/audio_sortformer_benchmark.py index 922ba38014..873e7ac5c4 100644 --- a/benchmarking/scripts/audio_sortformer_benchmark.py +++ b/benchmarking/scripts/audio_sortformer_benchmark.py @@ -31,7 +31,6 @@ from nemo_curator.pipeline import Pipeline from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader from nemo_curator.stages.audio.inference.sortformer import InferenceSortformerStage -from nemo_curator.stages.resources import Resources def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any]: @@ -60,11 +59,10 @@ def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any } -def run_audio_sortformer_benchmark( # noqa: PLR0913 +def run_audio_sortformer_benchmark( benchmark_results_path: str, manifest_path: str, model_name: str, - gpus: int, rttm_out_dir: str | None = None, executor: str = "xenna", **kwargs, # noqa: ARG001 @@ -76,7 +74,6 @@ def run_audio_sortformer_benchmark( # noqa: PLR0913 logger.info(f"Executor: {executor}") logger.info(f"Model: {model_name}") logger.info(f"Manifest: {manifest_path}") - logger.info(f"GPUs: {gpus}") executor_obj = setup_executor(executor) pipeline = Pipeline( @@ -89,7 +86,7 @@ def run_audio_sortformer_benchmark( # noqa: PLR0913 InferenceSortformerStage( model_name=model_name, rttm_out_dir=rttm_out_dir, - ).with_(resources=Resources(gpus=gpus)), + ), ) t0 = time.perf_counter() @@ -108,7 +105,6 @@ def run_audio_sortformer_benchmark( # noqa: PLR0913 "executor": executor, "manifest_path": manifest_path, "model_name": model_name, - "gpus": gpus, "rttm_out_dir": rttm_out_dir, }, "metrics": metrics, @@ -126,7 +122,6 @@ def main() -> int: help="HF Sortformer model id", ) parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use") - parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use") parser.add_argument("--rttm-out-dir", default=None, help="Optional directory to write RTTM output files") args = parser.parse_args() diff --git a/nemo_curator/stages/audio/inference/sortformer.py b/nemo_curator/stages/audio/inference/sortformer.py index 4117ba9663..45ae4eee9f 100644 --- a/nemo_curator/stages/audio/inference/sortformer.py +++ b/nemo_curator/stages/audio/inference/sortformer.py @@ -18,6 +18,7 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any +import torch from huggingface_hub import snapshot_download from loguru import logger from nemo.collections.asr.models import SortformerEncLabelModel @@ -104,7 +105,6 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): spkcache_update_period: Speaker cache update period in frames. Defaults to 300. spkcache_len: Speaker cache size in frames. Defaults to 188. inference_batch_size: Batch size passed to diarize(). Defaults to 1. - batch_duration: Maximum total audio duration (seconds) per lhotse batch. Defaults to 100000. name: Stage name. Defaults to "Sortformer_inference". """ @@ -122,7 +122,6 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): spkcache_update_period: int = 300 spkcache_len: int = 188 inference_batch_size: int = 1 - batch_duration: int = 100000 name: str = "Sortformer_inference" batch_size: int = 1 resources: Resources = field(default_factory=lambda: Resources(cpus=1.0, gpu_memory_gb=8.0)) @@ -130,21 +129,16 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): def setup_on_node( self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None ) -> None: - """Pre-download model weights on the node so actors load from cache.""" + """Pre-download model weights on the node so workers load from cache.""" if self.model_path is not None: return - try: - self._cached_repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) - except Exception: # noqa: BLE001 - logger.info(f"Could not pre-cache {self.model_name}; workers will download on first use") + snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) def _resolve_model_path(self) -> str: - """Resolve the path to the .nemo checkpoint, downloading if needed.""" + """Resolve the path to the .nemo checkpoint from the HF cache.""" if self.model_path is not None: return self.model_path - repo_dir = getattr(self, "_cached_repo_dir", None) or snapshot_download( - repo_id=self.model_name, cache_dir=self.cache_dir - ) + repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) nemo_files = sorted(f for f in os.listdir(repo_dir) if f.endswith(".nemo")) if not nemo_files: msg = f"No .nemo file found in {repo_dir} for model {self.model_name}" @@ -178,8 +172,6 @@ def _extend_pos_enc_for_long_audio(self, max_len: int = 30000) -> None: time. extend_pe() is a NeMo method that resizes the buffer safely — it just isn't called automatically. max_len=30000 covers ~1000 s at any subsampling. """ - import torch - pos_enc = getattr(getattr(self.diar_model, "encoder", None), "pos_enc", None) if pos_enc is None or not hasattr(pos_enc, "extend_pe"): logger.warning("pos_enc not found or no extend_pe method — skipping extension") From 246067a49c763d659697986b6bf2577a017fb525 Mon Sep 17 00:00:00 2001 From: mmkrtchyan Date: Thu, 23 Apr 2026 21:35:31 +0400 Subject: [PATCH 4/4] fixing some more issues Signed-off-by: mmkrtchyan --- benchmarking/scripts/audio_sortformer_benchmark.py | 4 ---- nemo_curator/stages/audio/inference/sortformer.py | 5 ++--- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/benchmarking/scripts/audio_sortformer_benchmark.py b/benchmarking/scripts/audio_sortformer_benchmark.py index 873e7ac5c4..11edfedbf0 100644 --- a/benchmarking/scripts/audio_sortformer_benchmark.py +++ b/benchmarking/scripts/audio_sortformer_benchmark.py @@ -22,7 +22,6 @@ import argparse import time import traceback -from pathlib import Path from typing import Any from loguru import logger @@ -60,7 +59,6 @@ def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any def run_audio_sortformer_benchmark( - benchmark_results_path: str, manifest_path: str, model_name: str, rttm_out_dir: str | None = None, @@ -68,8 +66,6 @@ def run_audio_sortformer_benchmark( **kwargs, # noqa: ARG001 ) -> dict[str, Any]: """Run the audio Sortformer diarization benchmark and collect metrics.""" - benchmark_results_path = Path(benchmark_results_path) - logger.info("Starting audio Sortformer diarization benchmark") logger.info(f"Executor: {executor}") logger.info(f"Model: {model_name}") diff --git a/nemo_curator/stages/audio/inference/sortformer.py b/nemo_curator/stages/audio/inference/sortformer.py index 45ae4eee9f..3f88b64e1b 100644 --- a/nemo_curator/stages/audio/inference/sortformer.py +++ b/nemo_curator/stages/audio/inference/sortformer.py @@ -18,7 +18,6 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any -import torch from huggingface_hub import snapshot_download from loguru import logger from nemo.collections.asr.models import SortformerEncLabelModel @@ -176,9 +175,9 @@ def _extend_pos_enc_for_long_audio(self, max_len: int = 30000) -> None: if pos_enc is None or not hasattr(pos_enc, "extend_pe"): logger.warning("pos_enc not found or no extend_pe method — skipping extension") return - device = next(self.diar_model.parameters()).device + params = next(self.diar_model.parameters()) try: - pos_enc.extend_pe(max_len, device, torch.float32) + pos_enc.extend_pe(max_len, params.device, params.dtype) logger.info(f"Extended encoder pos_enc to max_len={max_len} for long-form audio") except Exception as e: # noqa: BLE001 logger.warning(f"Could not extend pos_enc: {e}")