Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions benchmarking/nightly-benchmark.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,34 @@ entries:
- metric: throughput_images_per_sec
min_value: 3.0

- name: audio_sortformer_xenna
enabled: false
script: audio_sortformer_benchmark.py
args: >-
--benchmark-results-path={session_entry_dir}
--manifest-path={datasets_path}/sortformer_diarization/manifest.jsonl
--model-name=nvidia/diar_streaming_sortformer_4spk-v2.1
--rttm-out-dir={session_entry_dir}/scratch/rttm
timeout_s: 1800
sink_data:
- name: slack
additional_metrics:
- num_files_processed
- throughput_files_per_sec
- real_time_factor
- total_segments_detected
ping_on_failure:
- U03C41SNADV # Aaftab V
ray:
num_cpus: 64
num_gpus: 4
enable_object_spilling: false
requirements:
- metric: is_success
exact_value: true
- metric: num_files_processed
min_value: 1

- name: audio_fleurs_xenna
enabled: true
script: audio_fleurs_benchmark.py
Expand Down
149 changes: 149 additions & 0 deletions benchmarking/scripts/audio_sortformer_benchmark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
Comment thread
melllinia marked this conversation as resolved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Audio Sortformer diarization benchmarking script.

This script runs Streaming Sortformer diarization benchmarks with
comprehensive metrics collection including real-time factor (RTF),
per-file segment counts, and throughput.
"""

import argparse
import time
import traceback
from typing import Any

from loguru import logger
from utils import setup_executor, write_benchmark_results

from nemo_curator.pipeline import Pipeline
from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader
from nemo_curator.stages.audio.inference.sortformer import InferenceSortformerStage


def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any]:
"""Extract diarization-specific metrics from output tasks."""
num_files = len(tasks) if tasks else 0
total_audio_duration_s = 0.0
total_segments = 0

for task in tasks or []:
data = task.data if hasattr(task, "data") else {}
total_audio_duration_s += float(data.get("duration", 0))
segments = data.get("diar_segments", [])
total_segments += len(segments)

throughput = num_files / elapsed_s if elapsed_s > 0 else 0.0
rtf = elapsed_s / total_audio_duration_s if total_audio_duration_s > 0 else 0.0

return {
"is_success": num_files > 0,
"num_files_processed": num_files,
"exec_time_s": round(elapsed_s, 2),
"total_audio_duration_s": round(total_audio_duration_s, 2),
"total_segments_detected": total_segments,
"real_time_factor": round(rtf, 4),
"throughput_files_per_sec": round(throughput, 4),
}
Comment thread
melllinia marked this conversation as resolved.


def run_audio_sortformer_benchmark(
manifest_path: str,
model_name: str,
rttm_out_dir: str | None = None,
executor: str = "xenna",
**kwargs, # noqa: ARG001
) -> dict[str, Any]:
"""Run the audio Sortformer diarization benchmark and collect metrics."""
logger.info("Starting audio Sortformer diarization benchmark")
logger.info(f"Executor: {executor}")
logger.info(f"Model: {model_name}")
logger.info(f"Manifest: {manifest_path}")

executor_obj = setup_executor(executor)
pipeline = Pipeline(
name="audio_sortformer_diarization",
description="Streaming Sortformer speaker diarization inference.",
)

pipeline.add_stage(ALMManifestReader(manifest_path=manifest_path))
pipeline.add_stage(
InferenceSortformerStage(
model_name=model_name,
rttm_out_dir=rttm_out_dir,
),
)

t0 = time.perf_counter()
results = pipeline.run(executor_obj)
elapsed_s = time.perf_counter() - t0

metrics = _collect_diarization_metrics(results, elapsed_s)

logger.success(
f"Benchmark completed: {metrics['num_files_processed']} files in {elapsed_s:.1f}s "
f"(RTF={metrics['real_time_factor']:.3f}, {metrics['throughput_files_per_sec']:.2f} files/sec)"
)

return {
Comment thread
melllinia marked this conversation as resolved.
"params": {
"executor": executor,
"manifest_path": manifest_path,
"model_name": model_name,
"rttm_out_dir": rttm_out_dir,
},
"metrics": metrics,
"tasks": results,
}


def main() -> int:
parser = argparse.ArgumentParser(description="Audio Sortformer diarization benchmark for nightly benchmarking")
parser.add_argument("--benchmark-results-path", required=True, help="Path to benchmark results")
parser.add_argument("--manifest-path", required=True, help="Path to input JSONL manifest")
parser.add_argument(
"--model-name",
default="nvidia/diar_streaming_sortformer_4spk-v2.1",
help="HF Sortformer model id",
)
parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use")
parser.add_argument("--rttm-out-dir", default=None, help="Optional directory to write RTTM output files")

args = parser.parse_args()

logger.info("=== Audio Sortformer Diarization Benchmark Starting ===")
logger.info(f"Arguments: {vars(args)}")

success_code = 1
result_dict: dict[str, Any] = {
"params": vars(args),
"metrics": {
"is_success": False,
},
"tasks": [],
}
try:
result_dict.update(run_audio_sortformer_benchmark(**vars(args)))
success_code = 0 if result_dict["metrics"]["is_success"] else 1
except Exception as e:
error_traceback = traceback.format_exc()
logger.error(f"Benchmark failed: {e}")
logger.debug(f"Full traceback:\n{error_traceback}")
finally:
write_benchmark_results(result_dict, args.benchmark_results_path)
return success_code


if __name__ == "__main__":
raise SystemExit(main())
63 changes: 45 additions & 18 deletions nemo_curator/stages/audio/inference/sortformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,17 +87,18 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]):

Uses the NeMo SortformerEncLabelModel for end-to-end neural speaker
diarization with streaming support. See:
https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1

Args:
model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2".
model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2.1".
model_path: Local path to a .nemo checkpoint file; if set, takes precedence over model_name.
cache_dir: Directory for caching downloaded model weights. Defaults to HF hub default.
diar_model: Pre-loaded SortformerEncLabelModel; if provided, setup() is a no-op.
filepath_key: Key in data for path to audio file. Defaults to "audio_filepath".
diar_segments_key: Key in output data for diarization segments list. Defaults to "diar_segments".
rttm_out_dir: Optional directory to write RTTM files. Defaults to None.
chunk_len: Streaming chunk size in 80 ms frames. Defaults to 340 (~30.4 s latency).
chunk_left_context: Left context frames. Defaults to 1.
chunk_right_context: Right context frames. Defaults to 40.
fifo_len: FIFO queue size in frames. Defaults to 40.
spkcache_update_period: Speaker cache update period in frames. Defaults to 300.
Expand All @@ -106,14 +107,15 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]):
name: Stage name. Defaults to "Sortformer_inference".
"""

model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"
model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2.1"
model_path: str | None = None
cache_dir: str | None = None
diar_model: Any | None = None
filepath_key: str = "audio_filepath"
diar_segments_key: str = "diar_segments"
rttm_out_dir: str | None = None
chunk_len: int = 340
chunk_left_context: int = 1
chunk_right_context: int = 40
fifo_len: int = 40
spkcache_update_period: int = 300
Expand All @@ -126,42 +128,69 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]):
def setup_on_node(
self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None
) -> None:
"""Pre-download model weights on the node so actors load from cache."""
"""Pre-download model weights on the node so workers load from cache."""
if self.model_path is not None:
return
try:
repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")]
if nemo_files:
self.model_path = os.path.join(repo_dir, nemo_files[0])
else:
logger.warning(f"No .nemo file found in {repo_dir}; setup() will fail")
except Exception: # noqa: BLE001
logger.info(f"Could not pre-cache {self.model_name}; actors will download on first use")
snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)

def _resolve_model_path(self) -> str:
"""Resolve the path to the .nemo checkpoint from the HF cache."""
if self.model_path is not None:
return self.model_path
repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
nemo_files = sorted(f for f in os.listdir(repo_dir) if f.endswith(".nemo"))
if not nemo_files:
msg = f"No .nemo file found in {repo_dir} for model {self.model_name}"
raise FileNotFoundError(msg)
return os.path.join(repo_dir, nemo_files[0])

def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None:
"""Load Sortformer model from Hugging Face or a local .nemo file."""
if self.diar_model is not None:
self.diar_model.eval()
self._configure_streaming()
self._extend_pos_enc_for_long_audio()
return

resolved_path = self._resolve_model_path()
self.diar_model = SortformerEncLabelModel.restore_from(
restore_path=self.model_path,
restore_path=resolved_path,
map_location="cuda",
strict=False,
)

self.diar_model.eval()
self._configure_streaming()
self._extend_pos_enc_for_long_audio()

def _extend_pos_enc_for_long_audio(self, max_len: int = 30000) -> None:
"""Extend RelPositionalEncoding buffer to handle long audio files.

NeMo's streaming Sortformer initialises pos_enc sized for one chunk (~35
conformer frames). Files longer than a few seconds overflow it at inference
time. extend_pe() is a NeMo method that resizes the buffer safely — it just
isn't called automatically. max_len=30000 covers ~1000 s at any subsampling.
"""
pos_enc = getattr(getattr(self.diar_model, "encoder", None), "pos_enc", None)
if pos_enc is None or not hasattr(pos_enc, "extend_pe"):
logger.warning("pos_enc not found or no extend_pe method — skipping extension")
return
params = next(self.diar_model.parameters())
try:
pos_enc.extend_pe(max_len, params.device, params.dtype)
logger.info(f"Extended encoder pos_enc to max_len={max_len} for long-form audio")
except Exception as e: # noqa: BLE001
logger.warning(f"Could not extend pos_enc: {e}")

def _configure_streaming(self) -> None:
"""Apply streaming configuration to the loaded model."""
sm = self.diar_model.sortformer_modules
sm.chunk_len = self.chunk_len
sm.chunk_right_context = self.chunk_right_context
sm.fifo_len = self.fifo_len
sm.spkcache_update_period = self.spkcache_update_period
sm.chunk_left_context = self.chunk_left_context
if hasattr(sm, "spkcache_update_period"):
sm.spkcache_update_period = self.spkcache_update_period
sm.spkcache_len = self.spkcache_len
Comment on lines -164 to 194
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following up on some greptile comments, why is there a hasattr guard for spkcache_update_period? Also should spkcache_len have a guard too?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Older NeMo versions don't have spkcache_update_period on SortformerModules, without the guard it crashes.


def inputs(self) -> tuple[list[str], list[str]]:
Expand Down Expand Up @@ -189,9 +218,7 @@ def process(self, task: AudioTask) -> AudioTask:

file_path = task.data[self.filepath_key]
sess_name = task.data.get("session_name")
resolved_sess_name = (
sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0]
)
resolved_sess_name = sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0]

all_segments = self.diarize([file_path])
segments = all_segments[0]
Expand Down
11 changes: 6 additions & 5 deletions tutorials/audio/callhome_diar/README.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
# Speaker Diarization on CallHome English with NeMo Curator

This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER).
This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER).

Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput.

## Prerequisites

- Python 3.10+
- NeMo Curator installed (see [installation guide](https://docs.nvidia.com/nemo/curator/latest/admin/installation.html))
- [`sox`](https://sox.sourceforge.net/) command-line tool (for stereo-to-mono conversion; install via `apt install sox`, `brew install sox`, or `conda install -c conda-forge sox`)
- [`ffmpeg`](https://ffmpeg.org/) command-line tool (for stereo-to-mono conversion; pre-installed in the NeMo Curator container)
- CallHome English dataset with `.wav` files and `eng/*.cha` ground-truth annotations

### Dataset layout
Expand Down Expand Up @@ -51,7 +51,7 @@ Key arguments:
| `--output-dir` | `output` | Root for RTTM files, results JSON, and checkpoints |
| `--collar` | `0.25` | Collar tolerance (seconds) for DER scoring |
| `--clean` | off | Remove entire output directory before re-running |
| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2` | Hugging Face model id |
| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2.1` | Hugging Face model id |

### Streaming configuration

Expand All @@ -67,7 +67,7 @@ All values are in **80 ms frames**. Override via `--chunk-len`, `--chunk-right-c
## What the script does

1. **File discovery (`CallHomeReaderStage`)** — Scans the dataset directory for WAV files with matching `.cha` annotations, skipping already-processed files. Emits one `AudioTask` per file.
2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `sox` so the model sees both speakers.
2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `ffmpeg` so the model sees both speakers.
3. **Diarization inference (`InferenceSortformerStage`)** — Runs Streaming Sortformer on each mono file. Also writes RTTM files to `<output-dir>/rttm/`.
4. **DER evaluation (`DERComputationStage`)** — Compares predicted segments against CHA ground truth. Scoring is restricted to the UEM region (min/max annotated timestamps from CHA) with a configurable collar tolerance (default 0.25 s).

Expand Down Expand Up @@ -102,7 +102,7 @@ pipeline = Pipeline(
stages=[
MyAudioReaderStage(data_dir="/path/to/audio"), # your reader stage
InferenceSortformerStage(
model_name="nvidia/diar_streaming_sortformer_4spk-v2",
model_name="nvidia/diar_streaming_sortformer_4spk-v2.1",
rttm_out_dir="./rttm",
),
],
Expand All @@ -122,3 +122,4 @@ results = pipeline.run(executor=XennaExecutor())
- Maximum 4 speakers per recording
- Trained primarily on English speech
- Performance may degrade on noisy or very long recordings
- Audio must be mono 16 kHz; running on raw stereo or narrow-band (8 kHz) files without proper conversion will produce very high false-alarm rates
Loading
Loading