diff --git a/benchmarking/nightly-benchmark.yaml b/benchmarking/nightly-benchmark.yaml index 39f918e8c2..8e0636a1c7 100644 --- a/benchmarking/nightly-benchmark.yaml +++ b/benchmarking/nightly-benchmark.yaml @@ -631,6 +631,34 @@ entries: - metric: throughput_images_per_sec min_value: 3.0 + - name: audio_sortformer_xenna + enabled: false + script: audio_sortformer_benchmark.py + args: >- + --benchmark-results-path={session_entry_dir} + --manifest-path={datasets_path}/sortformer_diarization/manifest.jsonl + --model-name=nvidia/diar_streaming_sortformer_4spk-v2.1 + --rttm-out-dir={session_entry_dir}/scratch/rttm + timeout_s: 1800 + sink_data: + - name: slack + additional_metrics: + - num_files_processed + - throughput_files_per_sec + - real_time_factor + - total_segments_detected + ping_on_failure: + - U03C41SNADV # Aaftab V + ray: + num_cpus: 64 + num_gpus: 4 + enable_object_spilling: false + requirements: + - metric: is_success + exact_value: true + - metric: num_files_processed + min_value: 1 + - name: audio_fleurs_xenna enabled: true script: audio_fleurs_benchmark.py diff --git a/benchmarking/scripts/audio_sortformer_benchmark.py b/benchmarking/scripts/audio_sortformer_benchmark.py new file mode 100644 index 0000000000..11edfedbf0 --- /dev/null +++ b/benchmarking/scripts/audio_sortformer_benchmark.py @@ -0,0 +1,149 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Audio Sortformer diarization benchmarking script. + +This script runs Streaming Sortformer diarization benchmarks with +comprehensive metrics collection including real-time factor (RTF), +per-file segment counts, and throughput. +""" + +import argparse +import time +import traceback +from typing import Any + +from loguru import logger +from utils import setup_executor, write_benchmark_results + +from nemo_curator.pipeline import Pipeline +from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader +from nemo_curator.stages.audio.inference.sortformer import InferenceSortformerStage + + +def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any]: + """Extract diarization-specific metrics from output tasks.""" + num_files = len(tasks) if tasks else 0 + total_audio_duration_s = 0.0 + total_segments = 0 + + for task in tasks or []: + data = task.data if hasattr(task, "data") else {} + total_audio_duration_s += float(data.get("duration", 0)) + segments = data.get("diar_segments", []) + total_segments += len(segments) + + throughput = num_files / elapsed_s if elapsed_s > 0 else 0.0 + rtf = elapsed_s / total_audio_duration_s if total_audio_duration_s > 0 else 0.0 + + return { + "is_success": num_files > 0, + "num_files_processed": num_files, + "exec_time_s": round(elapsed_s, 2), + "total_audio_duration_s": round(total_audio_duration_s, 2), + "total_segments_detected": total_segments, + "real_time_factor": round(rtf, 4), + "throughput_files_per_sec": round(throughput, 4), + } + + +def run_audio_sortformer_benchmark( + manifest_path: str, + model_name: str, + rttm_out_dir: str | None = None, + executor: str = "xenna", + **kwargs, # noqa: ARG001 +) -> dict[str, Any]: + """Run the audio Sortformer diarization benchmark and collect metrics.""" + logger.info("Starting audio Sortformer diarization benchmark") + logger.info(f"Executor: {executor}") + logger.info(f"Model: {model_name}") + logger.info(f"Manifest: {manifest_path}") + + executor_obj = setup_executor(executor) + pipeline = Pipeline( + name="audio_sortformer_diarization", + description="Streaming Sortformer speaker diarization inference.", + ) + + pipeline.add_stage(ALMManifestReader(manifest_path=manifest_path)) + pipeline.add_stage( + InferenceSortformerStage( + model_name=model_name, + rttm_out_dir=rttm_out_dir, + ), + ) + + t0 = time.perf_counter() + results = pipeline.run(executor_obj) + elapsed_s = time.perf_counter() - t0 + + metrics = _collect_diarization_metrics(results, elapsed_s) + + logger.success( + f"Benchmark completed: {metrics['num_files_processed']} files in {elapsed_s:.1f}s " + f"(RTF={metrics['real_time_factor']:.3f}, {metrics['throughput_files_per_sec']:.2f} files/sec)" + ) + + return { + "params": { + "executor": executor, + "manifest_path": manifest_path, + "model_name": model_name, + "rttm_out_dir": rttm_out_dir, + }, + "metrics": metrics, + "tasks": results, + } + + +def main() -> int: + parser = argparse.ArgumentParser(description="Audio Sortformer diarization benchmark for nightly benchmarking") + parser.add_argument("--benchmark-results-path", required=True, help="Path to benchmark results") + parser.add_argument("--manifest-path", required=True, help="Path to input JSONL manifest") + parser.add_argument( + "--model-name", + default="nvidia/diar_streaming_sortformer_4spk-v2.1", + help="HF Sortformer model id", + ) + parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use") + parser.add_argument("--rttm-out-dir", default=None, help="Optional directory to write RTTM output files") + + args = parser.parse_args() + + logger.info("=== Audio Sortformer Diarization Benchmark Starting ===") + logger.info(f"Arguments: {vars(args)}") + + success_code = 1 + result_dict: dict[str, Any] = { + "params": vars(args), + "metrics": { + "is_success": False, + }, + "tasks": [], + } + try: + result_dict.update(run_audio_sortformer_benchmark(**vars(args))) + success_code = 0 if result_dict["metrics"]["is_success"] else 1 + except Exception as e: + error_traceback = traceback.format_exc() + logger.error(f"Benchmark failed: {e}") + logger.debug(f"Full traceback:\n{error_traceback}") + finally: + write_benchmark_results(result_dict, args.benchmark_results_path) + return success_code + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/nemo_curator/stages/audio/inference/sortformer.py b/nemo_curator/stages/audio/inference/sortformer.py index 7f33c5f31a..3f88b64e1b 100644 --- a/nemo_curator/stages/audio/inference/sortformer.py +++ b/nemo_curator/stages/audio/inference/sortformer.py @@ -87,10 +87,10 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): Uses the NeMo SortformerEncLabelModel for end-to-end neural speaker diarization with streaming support. See: - https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2 + https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1 Args: - model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2". + model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2.1". model_path: Local path to a .nemo checkpoint file; if set, takes precedence over model_name. cache_dir: Directory for caching downloaded model weights. Defaults to HF hub default. diar_model: Pre-loaded SortformerEncLabelModel; if provided, setup() is a no-op. @@ -98,6 +98,7 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): diar_segments_key: Key in output data for diarization segments list. Defaults to "diar_segments". rttm_out_dir: Optional directory to write RTTM files. Defaults to None. chunk_len: Streaming chunk size in 80 ms frames. Defaults to 340 (~30.4 s latency). + chunk_left_context: Left context frames. Defaults to 1. chunk_right_context: Right context frames. Defaults to 40. fifo_len: FIFO queue size in frames. Defaults to 40. spkcache_update_period: Speaker cache update period in frames. Defaults to 300. @@ -106,7 +107,7 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): name: Stage name. Defaults to "Sortformer_inference". """ - model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2" + model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2.1" model_path: str | None = None cache_dir: str | None = None diar_model: Any | None = None @@ -114,6 +115,7 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): diar_segments_key: str = "diar_segments" rttm_out_dir: str | None = None chunk_len: int = 340 + chunk_left_context: int = 1 chunk_right_context: int = 40 fifo_len: int = 40 spkcache_update_period: int = 300 @@ -126,34 +128,59 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]): def setup_on_node( self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None ) -> None: - """Pre-download model weights on the node so actors load from cache.""" + """Pre-download model weights on the node so workers load from cache.""" if self.model_path is not None: return - try: - repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) - nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")] - if nemo_files: - self.model_path = os.path.join(repo_dir, nemo_files[0]) - else: - logger.warning(f"No .nemo file found in {repo_dir}; setup() will fail") - except Exception: # noqa: BLE001 - logger.info(f"Could not pre-cache {self.model_name}; actors will download on first use") + snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) + + def _resolve_model_path(self) -> str: + """Resolve the path to the .nemo checkpoint from the HF cache.""" + if self.model_path is not None: + return self.model_path + repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir) + nemo_files = sorted(f for f in os.listdir(repo_dir) if f.endswith(".nemo")) + if not nemo_files: + msg = f"No .nemo file found in {repo_dir} for model {self.model_name}" + raise FileNotFoundError(msg) + return os.path.join(repo_dir, nemo_files[0]) def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None: """Load Sortformer model from Hugging Face or a local .nemo file.""" if self.diar_model is not None: self.diar_model.eval() self._configure_streaming() + self._extend_pos_enc_for_long_audio() return + resolved_path = self._resolve_model_path() self.diar_model = SortformerEncLabelModel.restore_from( - restore_path=self.model_path, + restore_path=resolved_path, map_location="cuda", strict=False, ) self.diar_model.eval() self._configure_streaming() + self._extend_pos_enc_for_long_audio() + + def _extend_pos_enc_for_long_audio(self, max_len: int = 30000) -> None: + """Extend RelPositionalEncoding buffer to handle long audio files. + + NeMo's streaming Sortformer initialises pos_enc sized for one chunk (~35 + conformer frames). Files longer than a few seconds overflow it at inference + time. extend_pe() is a NeMo method that resizes the buffer safely — it just + isn't called automatically. max_len=30000 covers ~1000 s at any subsampling. + """ + pos_enc = getattr(getattr(self.diar_model, "encoder", None), "pos_enc", None) + if pos_enc is None or not hasattr(pos_enc, "extend_pe"): + logger.warning("pos_enc not found or no extend_pe method — skipping extension") + return + params = next(self.diar_model.parameters()) + try: + pos_enc.extend_pe(max_len, params.device, params.dtype) + logger.info(f"Extended encoder pos_enc to max_len={max_len} for long-form audio") + except Exception as e: # noqa: BLE001 + logger.warning(f"Could not extend pos_enc: {e}") def _configure_streaming(self) -> None: """Apply streaming configuration to the loaded model.""" @@ -161,7 +188,9 @@ def _configure_streaming(self) -> None: sm.chunk_len = self.chunk_len sm.chunk_right_context = self.chunk_right_context sm.fifo_len = self.fifo_len - sm.spkcache_update_period = self.spkcache_update_period + sm.chunk_left_context = self.chunk_left_context + if hasattr(sm, "spkcache_update_period"): + sm.spkcache_update_period = self.spkcache_update_period sm.spkcache_len = self.spkcache_len def inputs(self) -> tuple[list[str], list[str]]: @@ -189,9 +218,7 @@ def process(self, task: AudioTask) -> AudioTask: file_path = task.data[self.filepath_key] sess_name = task.data.get("session_name") - resolved_sess_name = ( - sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0] - ) + resolved_sess_name = sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0] all_segments = self.diarize([file_path]) segments = all_segments[0] diff --git a/tutorials/audio/callhome_diar/README.md b/tutorials/audio/callhome_diar/README.md index bc6782cbb9..fc841de49c 100644 --- a/tutorials/audio/callhome_diar/README.md +++ b/tutorials/audio/callhome_diar/README.md @@ -1,6 +1,6 @@ # Speaker Diarization on CallHome English with NeMo Curator -This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER). +This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER). Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput. @@ -8,7 +8,7 @@ Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput. - Python 3.10+ - NeMo Curator installed (see [installation guide](https://docs.nvidia.com/nemo/curator/latest/admin/installation.html)) -- [`sox`](https://sox.sourceforge.net/) command-line tool (for stereo-to-mono conversion; install via `apt install sox`, `brew install sox`, or `conda install -c conda-forge sox`) +- [`ffmpeg`](https://ffmpeg.org/) command-line tool (for stereo-to-mono conversion; pre-installed in the NeMo Curator container) - CallHome English dataset with `.wav` files and `eng/*.cha` ground-truth annotations ### Dataset layout @@ -51,7 +51,7 @@ Key arguments: | `--output-dir` | `output` | Root for RTTM files, results JSON, and checkpoints | | `--collar` | `0.25` | Collar tolerance (seconds) for DER scoring | | `--clean` | off | Remove entire output directory before re-running | -| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2` | Hugging Face model id | +| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2.1` | Hugging Face model id | ### Streaming configuration @@ -67,7 +67,7 @@ All values are in **80 ms frames**. Override via `--chunk-len`, `--chunk-right-c ## What the script does 1. **File discovery (`CallHomeReaderStage`)** — Scans the dataset directory for WAV files with matching `.cha` annotations, skipping already-processed files. Emits one `AudioTask` per file. -2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `sox` so the model sees both speakers. +2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `ffmpeg` so the model sees both speakers. 3. **Diarization inference (`InferenceSortformerStage`)** — Runs Streaming Sortformer on each mono file. Also writes RTTM files to `/rttm/`. 4. **DER evaluation (`DERComputationStage`)** — Compares predicted segments against CHA ground truth. Scoring is restricted to the UEM region (min/max annotated timestamps from CHA) with a configurable collar tolerance (default 0.25 s). @@ -102,7 +102,7 @@ pipeline = Pipeline( stages=[ MyAudioReaderStage(data_dir="/path/to/audio"), # your reader stage InferenceSortformerStage( - model_name="nvidia/diar_streaming_sortformer_4spk-v2", + model_name="nvidia/diar_streaming_sortformer_4spk-v2.1", rttm_out_dir="./rttm", ), ], @@ -122,3 +122,4 @@ results = pipeline.run(executor=XennaExecutor()) - Maximum 4 speakers per recording - Trained primarily on English speech - Performance may degrade on noisy or very long recordings +- Audio must be mono 16 kHz; running on raw stereo or narrow-band (8 kHz) files without proper conversion will produce very high false-alarm rates diff --git a/tutorials/audio/callhome_diar/run.py b/tutorials/audio/callhome_diar/run.py index efb85e4569..ed19d8dfbc 100644 --- a/tutorials/audio/callhome_diar/run.py +++ b/tutorials/audio/callhome_diar/run.py @@ -34,6 +34,7 @@ from collections import Counter from dataclasses import dataclass from pathlib import Path +from typing import Any from loguru import logger @@ -106,7 +107,7 @@ def parse_args() -> argparse.Namespace: p = argparse.ArgumentParser(description="Sortformer diarization on CallHome English + DER evaluation.") p.add_argument("--data-dir", type=Path, required=True, help="CallHome-eng0 dataset root.") p.add_argument("--output-dir", type=Path, default=Path("output"), help="Root directory for all outputs.") - p.add_argument("--model", default="nvidia/diar_streaming_sortformer_4spk-v2", help="HF Sortformer model id.") + p.add_argument("--model", default="nvidia/diar_streaming_sortformer_4spk-v2.1", help="HF Sortformer model id.") p.add_argument("--collar", type=float, default=COLLAR, help="Collar tolerance (seconds).") p.add_argument("--clean", action="store_true", help="Remove entire output directory before running.") p.add_argument("--chunk-len", type=int, default=340, help="Streaming chunk size in 80ms frames.") @@ -144,6 +145,9 @@ def inputs(self) -> tuple[list[str], list[str]]: def outputs(self) -> tuple[list[str], list[str]]: return ["data"], [self.filepath_key] + def xenna_stage_spec(self) -> dict[str, Any]: + return {"num_workers_per_node": 1} + def process(self, task: _EmptyTask) -> list[AudioTask]: # noqa: ARG002 cha_path = Path(self.cha_dir) done = {p.stem for p in Path(self.rttm_out_dir).glob("*.rttm")} if self.rttm_out_dir else set() @@ -164,7 +168,7 @@ def process(self, task: _EmptyTask) -> list[AudioTask]: # noqa: ARG002 @dataclass class EnsureMonoStage(ProcessingStage[AudioTask, AudioTask]): - """Downmix stereo WAVs to mono 16 kHz via sox.""" + """Downmix stereo WAVs to mono 16 kHz via ffmpeg.""" mono_dir: str = "mono" filepath_key: str = "audio_filepath" @@ -182,7 +186,7 @@ def _ensure_mono(self, wav_path: str) -> str: return mono_path os.makedirs(self.mono_dir, exist_ok=True) subprocess.run( # noqa: S603 - ["sox", wav_path, "-c", "1", "-r", "16000", mono_path], # noqa: S607 + ["ffmpeg", "-i", wav_path, "-ac", "1", "-ar", "16000", "-y", mono_path], # noqa: S607 check=True, capture_output=True, )