Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
151 changes: 151 additions & 0 deletions benchmarking/scripts/audio_sortformer_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
Comment thread
melllinia marked this conversation as resolved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Audio Sortformer diarization benchmarking script.

This script runs Streaming Sortformer diarization benchmarks with
comprehensive metrics collection including real-time factor (RTF),
per-file segment counts, and throughput.
"""

import argparse
import time
import traceback
from pathlib import Path
from typing import Any

from loguru import logger
from utils import setup_executor, write_benchmark_results

from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader
from nemo_curator.stages.audio.inference.sortformer import InferenceSortformerStage
from nemo_curator.stages.resources import Resources


def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any]:
"""Extract diarization-specific metrics from output tasks."""
num_files = len(tasks) if tasks else 0
total_audio_duration_s = 0.0
total_segments = 0

for task in tasks or []:
data = task.data if hasattr(task, "data") else {}
total_audio_duration_s += float(data.get("duration", 0))
segments = data.get("diar_segments", [])
total_segments += len(segments)

throughput = num_files / elapsed_s if elapsed_s > 0 else 0.0
rtf = elapsed_s / total_audio_duration_s if total_audio_duration_s > 0 else 0.0

return {
"is_success": True,
"num_files_processed": num_files,
"exec_time_s": round(elapsed_s, 2),
"total_audio_duration_s": round(total_audio_duration_s, 2),
"total_segments_detected": total_segments,
"real_time_factor": round(rtf, 4),
"throughput_files_per_sec": round(throughput, 4),
}
Comment thread
melllinia marked this conversation as resolved.


def run_audio_sortformer_benchmark( # noqa: PLR0913
benchmark_results_path: str,
manifest_path: str,
model_name: str,
gpus: int,
rttm_out_dir: str | None = None,
executor: str = "xenna",
**kwargs, # noqa: ARG001
) -> dict[str, Any]:
"""Run the audio Sortformer diarization benchmark and collect metrics."""
benchmark_results_path = Path(benchmark_results_path)

logger.info("Starting audio Sortformer diarization benchmark")
logger.info(f"Executor: {executor}")
logger.info(f"Model: {model_name}")
logger.info(f"Manifest: {manifest_path}")
logger.info(f"GPUs: {gpus}")

executor_obj = setup_executor(executor)
pipeline = Pipeline(
name="audio_sortformer_diarization",
description="Streaming Sortformer speaker diarization inference.",
)

pipeline.add_stage(ALMManifestReader(manifest_path=manifest_path))
pipeline.add_stage(
InferenceSortformerStage(
model_name=model_name,
rttm_out_dir=rttm_out_dir,
).with_(resources=Resources(gpus=gpus)),
Comment thread
melllinia marked this conversation as resolved.
Outdated
)

t0 = time.perf_counter()
results = pipeline.run(executor_obj)
elapsed_s = time.perf_counter() - t0

metrics = _collect_diarization_metrics(results, elapsed_s)

logger.success(
f"Benchmark completed: {metrics['num_files_processed']} files in {elapsed_s:.1f}s "
f"(RTF={metrics['real_time_factor']:.3f}, {metrics['throughput_files_per_sec']:.2f} files/sec)"
)

return {
Comment thread
melllinia marked this conversation as resolved.
"metrics": metrics,
"tasks": results,
}


def main() -> int:
parser = argparse.ArgumentParser(description="Audio Sortformer diarization benchmark for nightly benchmarking")
parser.add_argument("--benchmark-results-path", required=True, help="Path to benchmark results")
parser.add_argument("--manifest-path", required=True, help="Path to input JSONL manifest")
parser.add_argument(
"--model-name",
default="nvidia/diar_streaming_sortformer_4spk-v2",
help="HF Sortformer model id",
)
parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use")
parser.add_argument("--gpus", type=int, default=1, help="Number of GPUs to use")
parser.add_argument("--rttm-out-dir", default=None, help="Optional directory to write RTTM output files")

args = parser.parse_args()

logger.info("=== Audio Sortformer Diarization Benchmark Starting ===")
logger.info(f"Arguments: {vars(args)}")

success_code = 1
result_dict: dict[str, Any] = {
"params": vars(args),
"metrics": {
"is_success": False,
},
"tasks": [],
}
try:
result_dict.update(run_audio_sortformer_benchmark(**vars(args)))
success_code = 0 if result_dict["metrics"]["is_success"] else 1
except Exception as e:
error_traceback = traceback.format_exc()
logger.error(f"Benchmark failed: {e}")
logger.debug(f"Full traceback:\n{error_traceback}")
finally:
write_benchmark_results(result_dict, args.benchmark_results_path)
return success_code


if __name__ == "__main__":
raise SystemExit(main())
27 changes: 16 additions & 11 deletions nemo_curator/stages/audio/inference/sortformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,14 +130,20 @@ def setup_on_node(
if self.model_path is not None:
return
try:
repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")]
if nemo_files:
self.model_path = os.path.join(repo_dir, nemo_files[0])
else:
logger.warning(f"No .nemo file found in {repo_dir}; setup() will fail")
snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
except Exception: # noqa: BLE001
logger.info(f"Could not pre-cache {self.model_name}; actors will download on first use")
logger.info(f"Could not pre-cache {self.model_name}; workers will download on first use")

def _resolve_model_path(self) -> str:
"""Resolve the path to the .nemo checkpoint, downloading if needed."""
if self.model_path is not None:
return self.model_path
repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
Comment thread
melllinia marked this conversation as resolved.
nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")]
if not nemo_files:
msg = f"No .nemo file found in {repo_dir} for model {self.model_name}"
raise FileNotFoundError(msg)
return os.path.join(repo_dir, nemo_files[0])
Comment thread
melllinia marked this conversation as resolved.
Outdated

def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None:
"""Load Sortformer model from Hugging Face or a local .nemo file."""
Expand All @@ -146,8 +152,9 @@ def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None:
self._configure_streaming()
return

resolved_path = self._resolve_model_path()
self.diar_model = SortformerEncLabelModel.restore_from(
restore_path=self.model_path,
restore_path=resolved_path,
map_location="cuda",
strict=False,
)
Expand Down Expand Up @@ -189,9 +196,7 @@ def process(self, task: AudioTask) -> AudioTask:

file_path = task.data[self.filepath_key]
sess_name = task.data.get("session_name")
resolved_sess_name = (
sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0]
)
resolved_sess_name = sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0]

all_segments = self.diarize([file_path])
segments = all_segments[0]
Expand Down
5 changes: 3 additions & 2 deletions tutorials/audio/callhome_diar/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput.

- Python 3.10+
- NeMo Curator installed (see [installation guide](https://docs.nvidia.com/nemo/curator/latest/admin/installation.html))
- [`sox`](https://sox.sourceforge.net/) command-line tool (for stereo-to-mono conversion; install via `apt install sox`, `brew install sox`, or `conda install -c conda-forge sox`)
- [`ffmpeg`](https://ffmpeg.org/) command-line tool (for stereo-to-mono conversion; pre-installed in the NeMo Curator container)
- CallHome English dataset with `.wav` files and `eng/*.cha` ground-truth annotations

### Dataset layout
Expand Down Expand Up @@ -67,7 +67,7 @@ All values are in **80 ms frames**. Override via `--chunk-len`, `--chunk-right-c
## What the script does

1. **File discovery (`CallHomeReaderStage`)** — Scans the dataset directory for WAV files with matching `.cha` annotations, skipping already-processed files. Emits one `AudioTask` per file.
2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `sox` so the model sees both speakers.
2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `ffmpeg` so the model sees both speakers.
3. **Diarization inference (`InferenceSortformerStage`)** — Runs Streaming Sortformer on each mono file. Also writes RTTM files to `<output-dir>/rttm/`.
4. **DER evaluation (`DERComputationStage`)** — Compares predicted segments against CHA ground truth. Scoring is restricted to the UEM region (min/max annotated timestamps from CHA) with a configurable collar tolerance (default 0.25 s).

Expand Down Expand Up @@ -116,3 +116,4 @@ results = pipeline.run(executor=XennaExecutor())
- Maximum 4 speakers per recording
- Trained primarily on English speech
- Performance may degrade on noisy or very long recordings
- Audio must be mono 16 kHz; running on raw stereo or narrow-band (8 kHz) files without proper conversion will produce very high false-alarm rates
8 changes: 6 additions & 2 deletions tutorials/audio/callhome_diar/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from collections import Counter
from dataclasses import dataclass
from pathlib import Path
from typing import Any

from loguru import logger

Expand Down Expand Up @@ -144,6 +145,9 @@ def inputs(self) -> tuple[list[str], list[str]]:
def outputs(self) -> tuple[list[str], list[str]]:
return ["data"], [self.filepath_key]

def xenna_stage_spec(self) -> dict[str, Any]:
return {"num_workers_per_node": 1}

def process(self, task: _EmptyTask) -> list[AudioTask]: # noqa: ARG002
cha_path = Path(self.cha_dir)
done = {p.stem for p in Path(self.rttm_out_dir).glob("*.rttm")} if self.rttm_out_dir else set()
Expand All @@ -164,7 +168,7 @@ def process(self, task: _EmptyTask) -> list[AudioTask]: # noqa: ARG002

@dataclass
class EnsureMonoStage(ProcessingStage[AudioTask, AudioTask]):
"""Downmix stereo WAVs to mono 16 kHz via sox."""
"""Downmix stereo WAVs to mono 16 kHz via ffmpeg."""

mono_dir: str = "mono"
filepath_key: str = "audio_filepath"
Expand All @@ -182,7 +186,7 @@ def _ensure_mono(self, wav_path: str) -> str:
return mono_path
os.makedirs(self.mono_dir, exist_ok=True)
subprocess.run( # noqa: S603
["sox", wav_path, "-c", "1", "-r", "16000", mono_path], # noqa: S607
["ffmpeg", "-i", wav_path, "-ac", "1", "-ar", "16000", "-y", mono_path], # noqa: S607
check=True,
capture_output=True,
)
Expand Down
Loading