Skip to content

Commit 4b4b584

Browse files
Fix Sortformer tutorial issues and add InferenceSortformerStage benchmark (#1764)
* Fix Sortformer tutorial issues and add InferenceSortformerStage benchmark Signed-off-by: mmkrtchyan <mmkrtchyan@nvidia.com> * Address PR review feedback and update model to v2.1 Signed-off-by: mmkrtchyan <mmkrtchyan@nvidia.com> * Address second round of review feedback Signed-off-by: mmkrtchyan <mmkrtchyan@nvidia.com> * fixing some more issues Signed-off-by: mmkrtchyan <mmkrtchyan@nvidia.com> --------- Signed-off-by: mmkrtchyan <mmkrtchyan@nvidia.com> Co-authored-by: Sarah Yurick <53962159+sarahyurick@users.noreply.github.com>
1 parent f70af1f commit 4b4b584

5 files changed

Lines changed: 235 additions & 26 deletions

File tree

benchmarking/nightly-benchmark.yaml

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -631,6 +631,34 @@ entries:
631631
- metric: throughput_images_per_sec
632632
min_value: 3.0
633633

634+
- name: audio_sortformer_xenna
635+
enabled: false
636+
script: audio_sortformer_benchmark.py
637+
args: >-
638+
--benchmark-results-path={session_entry_dir}
639+
--manifest-path={datasets_path}/sortformer_diarization/manifest.jsonl
640+
--model-name=nvidia/diar_streaming_sortformer_4spk-v2.1
641+
--rttm-out-dir={session_entry_dir}/scratch/rttm
642+
timeout_s: 1800
643+
sink_data:
644+
- name: slack
645+
additional_metrics:
646+
- num_files_processed
647+
- throughput_files_per_sec
648+
- real_time_factor
649+
- total_segments_detected
650+
ping_on_failure:
651+
- U03C41SNADV # Aaftab V
652+
ray:
653+
num_cpus: 64
654+
num_gpus: 4
655+
enable_object_spilling: false
656+
requirements:
657+
- metric: is_success
658+
exact_value: true
659+
- metric: num_files_processed
660+
min_value: 1
661+
634662
- name: audio_fleurs_xenna
635663
enabled: true
636664
script: audio_fleurs_benchmark.py
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Audio Sortformer diarization benchmarking script.
16+
17+
This script runs Streaming Sortformer diarization benchmarks with
18+
comprehensive metrics collection including real-time factor (RTF),
19+
per-file segment counts, and throughput.
20+
"""
21+
22+
import argparse
23+
import time
24+
import traceback
25+
from typing import Any
26+
27+
from loguru import logger
28+
from utils import setup_executor, write_benchmark_results
29+
30+
from nemo_curator.pipeline import Pipeline
31+
from nemo_curator.stages.audio.alm.alm_manifest_reader import ALMManifestReader
32+
from nemo_curator.stages.audio.inference.sortformer import InferenceSortformerStage
33+
34+
35+
def _collect_diarization_metrics(tasks: list, elapsed_s: float) -> dict[str, Any]:
36+
"""Extract diarization-specific metrics from output tasks."""
37+
num_files = len(tasks) if tasks else 0
38+
total_audio_duration_s = 0.0
39+
total_segments = 0
40+
41+
for task in tasks or []:
42+
data = task.data if hasattr(task, "data") else {}
43+
total_audio_duration_s += float(data.get("duration", 0))
44+
segments = data.get("diar_segments", [])
45+
total_segments += len(segments)
46+
47+
throughput = num_files / elapsed_s if elapsed_s > 0 else 0.0
48+
rtf = elapsed_s / total_audio_duration_s if total_audio_duration_s > 0 else 0.0
49+
50+
return {
51+
"is_success": num_files > 0,
52+
"num_files_processed": num_files,
53+
"exec_time_s": round(elapsed_s, 2),
54+
"total_audio_duration_s": round(total_audio_duration_s, 2),
55+
"total_segments_detected": total_segments,
56+
"real_time_factor": round(rtf, 4),
57+
"throughput_files_per_sec": round(throughput, 4),
58+
}
59+
60+
61+
def run_audio_sortformer_benchmark(
62+
manifest_path: str,
63+
model_name: str,
64+
rttm_out_dir: str | None = None,
65+
executor: str = "xenna",
66+
**kwargs, # noqa: ARG001
67+
) -> dict[str, Any]:
68+
"""Run the audio Sortformer diarization benchmark and collect metrics."""
69+
logger.info("Starting audio Sortformer diarization benchmark")
70+
logger.info(f"Executor: {executor}")
71+
logger.info(f"Model: {model_name}")
72+
logger.info(f"Manifest: {manifest_path}")
73+
74+
executor_obj = setup_executor(executor)
75+
pipeline = Pipeline(
76+
name="audio_sortformer_diarization",
77+
description="Streaming Sortformer speaker diarization inference.",
78+
)
79+
80+
pipeline.add_stage(ALMManifestReader(manifest_path=manifest_path))
81+
pipeline.add_stage(
82+
InferenceSortformerStage(
83+
model_name=model_name,
84+
rttm_out_dir=rttm_out_dir,
85+
),
86+
)
87+
88+
t0 = time.perf_counter()
89+
results = pipeline.run(executor_obj)
90+
elapsed_s = time.perf_counter() - t0
91+
92+
metrics = _collect_diarization_metrics(results, elapsed_s)
93+
94+
logger.success(
95+
f"Benchmark completed: {metrics['num_files_processed']} files in {elapsed_s:.1f}s "
96+
f"(RTF={metrics['real_time_factor']:.3f}, {metrics['throughput_files_per_sec']:.2f} files/sec)"
97+
)
98+
99+
return {
100+
"params": {
101+
"executor": executor,
102+
"manifest_path": manifest_path,
103+
"model_name": model_name,
104+
"rttm_out_dir": rttm_out_dir,
105+
},
106+
"metrics": metrics,
107+
"tasks": results,
108+
}
109+
110+
111+
def main() -> int:
112+
parser = argparse.ArgumentParser(description="Audio Sortformer diarization benchmark for nightly benchmarking")
113+
parser.add_argument("--benchmark-results-path", required=True, help="Path to benchmark results")
114+
parser.add_argument("--manifest-path", required=True, help="Path to input JSONL manifest")
115+
parser.add_argument(
116+
"--model-name",
117+
default="nvidia/diar_streaming_sortformer_4spk-v2.1",
118+
help="HF Sortformer model id",
119+
)
120+
parser.add_argument("--executor", default="xenna", choices=["xenna", "ray_data"], help="Executor to use")
121+
parser.add_argument("--rttm-out-dir", default=None, help="Optional directory to write RTTM output files")
122+
123+
args = parser.parse_args()
124+
125+
logger.info("=== Audio Sortformer Diarization Benchmark Starting ===")
126+
logger.info(f"Arguments: {vars(args)}")
127+
128+
success_code = 1
129+
result_dict: dict[str, Any] = {
130+
"params": vars(args),
131+
"metrics": {
132+
"is_success": False,
133+
},
134+
"tasks": [],
135+
}
136+
try:
137+
result_dict.update(run_audio_sortformer_benchmark(**vars(args)))
138+
success_code = 0 if result_dict["metrics"]["is_success"] else 1
139+
except Exception as e:
140+
error_traceback = traceback.format_exc()
141+
logger.error(f"Benchmark failed: {e}")
142+
logger.debug(f"Full traceback:\n{error_traceback}")
143+
finally:
144+
write_benchmark_results(result_dict, args.benchmark_results_path)
145+
return success_code
146+
147+
148+
if __name__ == "__main__":
149+
raise SystemExit(main())

nemo_curator/stages/audio/inference/sortformer.py

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -87,17 +87,18 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]):
8787
8888
Uses the NeMo SortformerEncLabelModel for end-to-end neural speaker
8989
diarization with streaming support. See:
90-
https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2
90+
https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1
9191
9292
Args:
93-
model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2".
93+
model_name: Hugging Face model id. Defaults to "nvidia/diar_streaming_sortformer_4spk-v2.1".
9494
model_path: Local path to a .nemo checkpoint file; if set, takes precedence over model_name.
9595
cache_dir: Directory for caching downloaded model weights. Defaults to HF hub default.
9696
diar_model: Pre-loaded SortformerEncLabelModel; if provided, setup() is a no-op.
9797
filepath_key: Key in data for path to audio file. Defaults to "audio_filepath".
9898
diar_segments_key: Key in output data for diarization segments list. Defaults to "diar_segments".
9999
rttm_out_dir: Optional directory to write RTTM files. Defaults to None.
100100
chunk_len: Streaming chunk size in 80 ms frames. Defaults to 340 (~30.4 s latency).
101+
chunk_left_context: Left context frames. Defaults to 1.
101102
chunk_right_context: Right context frames. Defaults to 40.
102103
fifo_len: FIFO queue size in frames. Defaults to 40.
103104
spkcache_update_period: Speaker cache update period in frames. Defaults to 300.
@@ -106,14 +107,15 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]):
106107
name: Stage name. Defaults to "Sortformer_inference".
107108
"""
108109

109-
model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2"
110+
model_name: str = "nvidia/diar_streaming_sortformer_4spk-v2.1"
110111
model_path: str | None = None
111112
cache_dir: str | None = None
112113
diar_model: Any | None = None
113114
filepath_key: str = "audio_filepath"
114115
diar_segments_key: str = "diar_segments"
115116
rttm_out_dir: str | None = None
116117
chunk_len: int = 340
118+
chunk_left_context: int = 1
117119
chunk_right_context: int = 40
118120
fifo_len: int = 40
119121
spkcache_update_period: int = 300
@@ -126,42 +128,69 @@ class InferenceSortformerStage(ProcessingStage[AudioTask, AudioTask]):
126128
def setup_on_node(
127129
self, _node_info: NodeInfo | None = None, _worker_metadata: WorkerMetadata | None = None
128130
) -> None:
129-
"""Pre-download model weights on the node so actors load from cache."""
131+
"""Pre-download model weights on the node so workers load from cache."""
130132
if self.model_path is not None:
131133
return
132-
try:
133-
repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
134-
nemo_files = [f for f in os.listdir(repo_dir) if f.endswith(".nemo")]
135-
if nemo_files:
136-
self.model_path = os.path.join(repo_dir, nemo_files[0])
137-
else:
138-
logger.warning(f"No .nemo file found in {repo_dir}; setup() will fail")
139-
except Exception: # noqa: BLE001
140-
logger.info(f"Could not pre-cache {self.model_name}; actors will download on first use")
134+
snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
135+
136+
def _resolve_model_path(self) -> str:
137+
"""Resolve the path to the .nemo checkpoint from the HF cache."""
138+
if self.model_path is not None:
139+
return self.model_path
140+
repo_dir = snapshot_download(repo_id=self.model_name, cache_dir=self.cache_dir)
141+
nemo_files = sorted(f for f in os.listdir(repo_dir) if f.endswith(".nemo"))
142+
if not nemo_files:
143+
msg = f"No .nemo file found in {repo_dir} for model {self.model_name}"
144+
raise FileNotFoundError(msg)
145+
return os.path.join(repo_dir, nemo_files[0])
141146

142147
def setup(self, _worker_metadata: WorkerMetadata | None = None) -> None:
143148
"""Load Sortformer model from Hugging Face or a local .nemo file."""
144149
if self.diar_model is not None:
145150
self.diar_model.eval()
146151
self._configure_streaming()
152+
self._extend_pos_enc_for_long_audio()
147153
return
148154

155+
resolved_path = self._resolve_model_path()
149156
self.diar_model = SortformerEncLabelModel.restore_from(
150-
restore_path=self.model_path,
157+
restore_path=resolved_path,
151158
map_location="cuda",
152159
strict=False,
153160
)
154161

155162
self.diar_model.eval()
156163
self._configure_streaming()
164+
self._extend_pos_enc_for_long_audio()
165+
166+
def _extend_pos_enc_for_long_audio(self, max_len: int = 30000) -> None:
167+
"""Extend RelPositionalEncoding buffer to handle long audio files.
168+
169+
NeMo's streaming Sortformer initialises pos_enc sized for one chunk (~35
170+
conformer frames). Files longer than a few seconds overflow it at inference
171+
time. extend_pe() is a NeMo method that resizes the buffer safely — it just
172+
isn't called automatically. max_len=30000 covers ~1000 s at any subsampling.
173+
"""
174+
pos_enc = getattr(getattr(self.diar_model, "encoder", None), "pos_enc", None)
175+
if pos_enc is None or not hasattr(pos_enc, "extend_pe"):
176+
logger.warning("pos_enc not found or no extend_pe method — skipping extension")
177+
return
178+
params = next(self.diar_model.parameters())
179+
try:
180+
pos_enc.extend_pe(max_len, params.device, params.dtype)
181+
logger.info(f"Extended encoder pos_enc to max_len={max_len} for long-form audio")
182+
except Exception as e: # noqa: BLE001
183+
logger.warning(f"Could not extend pos_enc: {e}")
157184

158185
def _configure_streaming(self) -> None:
159186
"""Apply streaming configuration to the loaded model."""
160187
sm = self.diar_model.sortformer_modules
161188
sm.chunk_len = self.chunk_len
162189
sm.chunk_right_context = self.chunk_right_context
163190
sm.fifo_len = self.fifo_len
164-
sm.spkcache_update_period = self.spkcache_update_period
191+
sm.chunk_left_context = self.chunk_left_context
192+
if hasattr(sm, "spkcache_update_period"):
193+
sm.spkcache_update_period = self.spkcache_update_period
165194
sm.spkcache_len = self.spkcache_len
166195

167196
def inputs(self) -> tuple[list[str], list[str]]:
@@ -189,9 +218,7 @@ def process(self, task: AudioTask) -> AudioTask:
189218

190219
file_path = task.data[self.filepath_key]
191220
sess_name = task.data.get("session_name")
192-
resolved_sess_name = (
193-
sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0]
194-
)
221+
resolved_sess_name = sess_name if sess_name is not None else os.path.splitext(os.path.basename(file_path))[0]
195222

196223
all_segments = self.diarize([file_path])
197224
segments = all_segments[0]

tutorials/audio/callhome_diar/README.md

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
# Speaker Diarization on CallHome English with NeMo Curator
22

3-
This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER).
3+
This tutorial runs [Streaming Sortformer](https://huggingface.co/nvidia/diar_streaming_sortformer_4spk-v2.1) speaker diarization on the [CallHome English](https://catalog.ldc.upenn.edu/LDC97S42) dataset using NeMo Curator's `InferenceSortformerStage`, then evaluates Diarization Error Rate (DER).
44

55
Inference runs in parallel via `Pipeline` + `XennaExecutor` for high throughput.
66

77
## Prerequisites
88

99
- Python 3.10+
1010
- NeMo Curator installed (see [installation guide](https://docs.nvidia.com/nemo/curator/latest/admin/installation.html))
11-
- [`sox`](https://sox.sourceforge.net/) command-line tool (for stereo-to-mono conversion; install via `apt install sox`, `brew install sox`, or `conda install -c conda-forge sox`)
11+
- [`ffmpeg`](https://ffmpeg.org/) command-line tool (for stereo-to-mono conversion; pre-installed in the NeMo Curator container)
1212
- CallHome English dataset with `.wav` files and `eng/*.cha` ground-truth annotations
1313

1414
### Dataset layout
@@ -51,7 +51,7 @@ Key arguments:
5151
| `--output-dir` | `output` | Root for RTTM files, results JSON, and checkpoints |
5252
| `--collar` | `0.25` | Collar tolerance (seconds) for DER scoring |
5353
| `--clean` | off | Remove entire output directory before re-running |
54-
| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2` | Hugging Face model id |
54+
| `--model` | `nvidia/diar_streaming_sortformer_4spk-v2.1` | Hugging Face model id |
5555

5656
### Streaming configuration
5757

@@ -67,7 +67,7 @@ All values are in **80 ms frames**. Override via `--chunk-len`, `--chunk-right-c
6767
## What the script does
6868

6969
1. **File discovery (`CallHomeReaderStage`)** — Scans the dataset directory for WAV files with matching `.cha` annotations, skipping already-processed files. Emits one `AudioTask` per file.
70-
2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `sox` so the model sees both speakers.
70+
2. **Mono conversion (`EnsureMonoStage`)** — CallHome WAVs are stereo (one channel per speaker). This stage downmixes to mono 16 kHz via `ffmpeg` so the model sees both speakers.
7171
3. **Diarization inference (`InferenceSortformerStage`)** — Runs Streaming Sortformer on each mono file. Also writes RTTM files to `<output-dir>/rttm/`.
7272
4. **DER evaluation (`DERComputationStage`)** — Compares predicted segments against CHA ground truth. Scoring is restricted to the UEM region (min/max annotated timestamps from CHA) with a configurable collar tolerance (default 0.25 s).
7373

@@ -102,7 +102,7 @@ pipeline = Pipeline(
102102
stages=[
103103
MyAudioReaderStage(data_dir="/path/to/audio"), # your reader stage
104104
InferenceSortformerStage(
105-
model_name="nvidia/diar_streaming_sortformer_4spk-v2",
105+
model_name="nvidia/diar_streaming_sortformer_4spk-v2.1",
106106
rttm_out_dir="./rttm",
107107
),
108108
],
@@ -122,3 +122,4 @@ results = pipeline.run(executor=XennaExecutor())
122122
- Maximum 4 speakers per recording
123123
- Trained primarily on English speech
124124
- Performance may degrade on noisy or very long recordings
125+
- Audio must be mono 16 kHz; running on raw stereo or narrow-band (8 kHz) files without proper conversion will produce very high false-alarm rates

0 commit comments

Comments
 (0)