Skip to content

Commit cdf2ea4

Browse files
committed
Add stable timestamps module and verification scripts
1 parent 21411d8 commit cdf2ea4

15 files changed

+2824
-18
lines changed

examples/cli/cli.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ struct whisper_params {
105105

106106
// Voice Activity Detection (VAD) parameters
107107
bool vad = false;
108+
bool stable_timestamps = false;
108109
std::string vad_model = "";
109110
float vad_threshold = 0.5f;
110111
int vad_min_speech_duration_ms = 250;
@@ -210,6 +211,7 @@ static bool whisper_params_parse(int argc, char ** argv, whisper_params & params
210211
else if ( arg == "--grammar-penalty") { params.grammar_penalty = std::stof(ARGV_NEXT); }
211212
// Voice Activity Detection (VAD)
212213
else if ( arg == "--vad") { params.vad = true; }
214+
else if ( arg == "--stable-timestamps") { params.stable_timestamps = true; }
213215
else if (arg == "-vm" || arg == "--vad-model") { params.vad_model = ARGV_NEXT; }
214216
else if (arg == "-vt" || arg == "--vad-threshold") { params.vad_threshold = std::stof(ARGV_NEXT); }
215217
else if (arg == "-vspd" || arg == "--vad-min-speech-duration-ms") { params.vad_min_speech_duration_ms = std::stoi(ARGV_NEXT); }
@@ -293,6 +295,7 @@ static void whisper_print_usage(int /*argc*/, char ** argv, const whisper_params
293295
// Voice Activity Detection (VAD) parameters
294296
fprintf(stderr, "\nVoice Activity Detection (VAD) options:\n");
295297
fprintf(stderr, " --vad [%-7s] enable Voice Activity Detection (VAD)\n", params.vad ? "true" : "false");
298+
fprintf(stderr, " --stable-timestamps [%-7s] enable stable timestamps (requires --vad-model)\n", params.stable_timestamps ? "true" : "false");
296299
fprintf(stderr, " -vm FNAME, --vad-model FNAME [%-7s] VAD model path\n", params.vad_model.c_str());
297300
fprintf(stderr, " -vt N, --vad-threshold N [%-7.2f] VAD threshold for speech recognition\n", params.vad_threshold);
298301
fprintf(stderr, " -vspd N, --vad-min-speech-duration-ms N [%-7d] VAD min speech duration (0.0-1.0)\n", params.vad_min_speech_duration_ms);
@@ -1002,6 +1005,12 @@ int main(int argc, char ** argv) {
10021005
exit(0);
10031006
}
10041007

1008+
if (params.stable_timestamps && params.vad_model.empty()) {
1009+
fprintf(stderr, "error: --stable-timestamps requires --vad-model\n");
1010+
whisper_print_usage(argc, argv, params);
1011+
return 2;
1012+
}
1013+
10051014
if (params.no_prints) {
10061015
whisper_log_set(cb_log_disable, NULL);
10071016
}
@@ -1211,6 +1220,7 @@ int main(int argc, char ** argv) {
12111220

12121221
wparams.suppress_nst = params.suppress_nst;
12131222

1223+
wparams.stable_timestamps = params.stable_timestamps;
12141224
wparams.vad = params.vad;
12151225
wparams.vad_model_path = params.vad_model.c_str();
12161226

include/whisper.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,10 @@ extern "C" {
583583
size_t i_start_rule;
584584
float grammar_penalty;
585585

586+
// Stable timestamps - snap word boundaries to speech edges using VAD
587+
// Requires vad_model_path to be set. Forces vad=true, token_timestamps=true, max_initial_ts=0.
588+
bool stable_timestamps;
589+
586590
// Voice Activity Detection (VAD) params
587591
bool vad; // Enable VAD
588592
const char * vad_model_path; // Path to VAD model
Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
# Stable Timestamps - How stable-ts Works
2+
3+
Reference repo: https://github.com/jianfch/stable-ts
4+
5+
## Overview
6+
7+
stable-ts improves Whisper's word-level timestamps with near-zero performance cost. The core idea: Whisper gives rough timestamps, then stable-ts clips them to where sound actually exists. No model weights are changed.
8+
9+
## The 5 Mechanisms
10+
11+
### 1. Post-Hoc Silence Snapping (main workhorse, always on)
12+
13+
**Files:** `stabilization/__init__.py`, `stabilization/nonvad.py`, `stabilization/silero_vad.py`
14+
15+
After Whisper produces timestamps, every word boundary is checked against a silence map and snapped to speech edges.
16+
17+
**Silence map construction (non-VAD mode):**
18+
1. `abs(waveform)` -> normalize by 99.9th percentile
19+
2. Interpolate down to one value per audio token position (320 samples per token at 16kHz)
20+
3. Average-pool with kernel size 5 (reflection padding) to smooth
21+
4. Quantize: `mask = (mask * 20).round()` -> anything rounding to 0 = silent
22+
5. Convert boolean mask to start/end silence timing arrays
23+
24+
**Snapping logic:**
25+
- If word.start falls inside silence -> move start to silence_end
26+
- If word.end falls inside silence -> move end to silence_start
27+
- If silence is contained within a word -> snap the boundary with less "error" (ratio of overshoot vs silence duration, threshold 10%)
28+
- First word in segment: prefer keeping end (snap start forward)
29+
- Last word in segment: prefer keeping start (snap end backward)
30+
- Minimum word duration is enforced during snapping
31+
32+
### 2. Better Cross-Attention / DTW Alignment
33+
34+
**File:** `timing.py`
35+
36+
Three improvements to how word timestamps are extracted from cross-attention:
37+
38+
**a) Gap padding:**
39+
Prepend `" ..."` tokens before each segment's tokens in DTW. This absorbs early cross-attention energy that would otherwise cause timestamps to start too early.
40+
41+
**b) Dynamic head selection (`dynamic_heads`):**
42+
Instead of hardcoded `model.alignment_heads`, score ALL attention heads by how monotonically their peaks track the DTW path. Select best k (default 6) per token. Can run multiple iterations where each pass refines head selection using previous DTW result.
43+
44+
**c) `max_initial_timestamp=None`:**
45+
Vanilla Whisper forces first timestamp <= 1s. stable-ts removes this constraint so speech starting later in a 30s chunk isn't forced early.
46+
47+
**d) New alignment algorithm (`aligner='new'`, from arxiv:2509.09987):**
48+
Score all (layer, head) pairs by column-norm and row-norm of attention matrix. Select top-k (default 20) globally, normalize each by column norm, average, then DTW.
49+
50+
### 3. Constrained Decoding (opt-in, off by default)
51+
52+
**File:** `decode.py`
53+
54+
Subclasses Whisper's `DecodingTask`. During token sampling, timestamp tokens corresponding to silent audio regions are set to `-inf`. The decoder literally cannot predict a timestamp in silence.
55+
56+
```
57+
ts_logits[:, ts_token_mask] = -inf
58+
```
59+
60+
Controlled by `suppress_ts_tokens=True` (defaults to `False`).
61+
62+
Also caches audio features across temperature fallbacks (vanilla Whisper re-encodes mel each time).
63+
64+
### 4. Binary-Search Refinement (optional, expensive)
65+
66+
**File:** `non_whisper/refinement.py`
67+
68+
Called explicitly via `model.refine()`. For each word boundary:
69+
1. Progressively mute audio inward from the boundary
70+
2. Run inference, monitor token probability
71+
3. If probability holds -> mute more (boundary can be tighter)
72+
4. If probability drops -> restore (speech is there)
73+
5. Binary search converges to latest-possible-start / earliest-possible-end
74+
75+
Precision ~0.1s default. Runs inference dozens of times per word - slow but optional.
76+
77+
### 5. Hallucination Filtering
78+
79+
**File:** `whisper_word_level/original_whisper.py`
80+
81+
- Segments with >50% zero-duration words -> discarded
82+
- Segments below avg probability threshold -> discarded
83+
- Entirely silent 30s chunks -> skipped without running decoder
84+
- Long silence gaps within chunks -> audio truncated to prevent hallucinated text
85+
- Punctuation-only segments -> deleted
86+
87+
## Cost Summary
88+
89+
| Mechanism | Speed Cost | Always On? | Benefit |
90+
|-----------|-----------|------------|---------|
91+
| Silence snapping | ~0 | Yes | 60% of improvement |
92+
| Better DTW (gap padding, dynamic heads) | ~0 | Yes | 20% of improvement |
93+
| Hallucination filtering | ~0 | Yes | Cleaner output |
94+
| Constrained decoding | ~0 | No (opt-in) | Prevents silent timestamps |
95+
| Binary-search refinement | Very high | No (explicit call) | Tightest possible boundaries |
96+
97+
## What to Port to whisper.cpp
98+
99+
**Priority 1 (easy, high impact):** Post-hoc silence snapping. ~100 lines of C. No model changes needed. Just audio analysis + timestamp adjustment on existing output.
100+
101+
**Priority 2 (medium effort):** Gap padding in DTW step. Requires touching `whisper_exp_compute_token_level_timestamps()`.
102+
103+
**Priority 3 (medium effort):** Dynamic attention head selection. whisper.cpp already extracts cross-attention for DTW. Need to expose all heads and score them.
104+
105+
**Priority 4 (low priority):** Constrained decoding. Invasive to sampling loop.
106+
107+
**Priority 5 (skip):** Binary-search refinement. Too expensive, wrong fit for whisper.cpp's use case.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
# Stable Timestamps - How whisper.cpp Works (Relevant Internals)
2+
3+
## Codebase Structure
4+
5+
- `include/whisper.h` (741 lines) -- Public C API
6+
- `src/whisper.cpp` (9016 lines) -- Entire implementation in one file
7+
- `src/whisper-arch.h` -- Tensor name maps (encoder/decoder/VAD)
8+
- `ggml/` -- Tensor library backend
9+
- `examples/cli/cli.cpp` -- Main CLI
10+
11+
## Key Data Structures (all in `src/whisper.cpp`)
12+
13+
### Token Data (`whisper_token_data`, whisper.h:131)
14+
```c
15+
id, tid (timestamp token), p (probability), plog, pt (timestamp prob),
16+
ptsum (sum of timestamp probs), t0/t1, t_dtw, vlen (voice length)
17+
```
18+
19+
### Segment (`whisper_segment`, line 460)
20+
```c
21+
t0, t1, text, no_speech_prob, tokens (vector<whisper_token_data>)
22+
```
23+
24+
### State (`whisper_state`, line 834)
25+
Holds: `mel`, `kv_self/kv_cross`, `decoders[8]`, `result_all` (segments), `energy` (PCM signal energy), `aheads_masks`, `aheads_cross_QKs`, `vad_context/segments/mapping`
26+
27+
## Decoding Pipeline
28+
29+
Entry: **`whisper_full_with_state()`** at line 6805
30+
31+
1. **PCM -> Mel** (line 6818): `whisper_pcm_to_mel_with_state()` -- FFT + mel filterbank, 80 bands, hop=160 (10ms/frame)
32+
2. **Signal energy** (line 6847): `get_signal_energy(samples, n_samples, 32)` -- smoothed abs amplitude for token timestamps
33+
3. **Main loop** (line 7012): `while(true)` over 30s chunks, advancing by `seek`
34+
4. **Encoder** (line 7033): `whisper_encode_internal()` -- conv + encoder + cross-attn KV cache
35+
5. **Prompt setup** (line 7098-7157): `[<prev>] + past + [<sot>] + [<lang>] + [<transcribe>]`
36+
6. **Token-by-token** (line 7197): `for (i = 0; i < n_max; ++i)` where `n_max = n_text_ctx/2 - 4`
37+
38+
### Logit Processing -- `whisper_process_logits()` at line 6155
39+
40+
This is WHERE ALL LOGIT FILTERING HAPPENS:
41+
42+
- **Line 6232**: `logits_filter_callback` -- user-supplied callback (external injection point)
43+
- **Line 6268-6308**: Timestamp pairing constraints (must come in pairs, must increase)
44+
- **Line 6291-6298**: `max_initial_ts` constraint -- limits first timestamp to <= 1.0s
45+
- **stable-ts removes this** by setting it to `None`
46+
- whisper.cpp param: `params.max_initial_ts` (default 1.0f, line 5950)
47+
- **Line 6300-6308**: Increasing timestamp enforcement via `decoder.seek_delta/2`
48+
- **Line 6314-6365**: Force timestamp when `sum(ts_probs) > max(text_probs)`
49+
50+
**INJECTION POINT for constrained decoding:** Between lines 6300-6308 (after increasing-ts check), add `logits[token_beg + t] = -INFINITY` for silent positions. Or use the existing `logits_filter_callback` externally.
51+
52+
### Sampling -- `whisper_sample_token()` at line 6438
53+
Greedy: argmax. Also computes `tid` (best timestamp), `pt` (timestamp prob), `ptsum` (sum timestamp probs).
54+
55+
## Word-Level Timestamps
56+
57+
### Method 1: Non-DTW (simpler, existing)
58+
59+
**`whisper_exp_compute_token_level_timestamps()`** at line 8433
60+
61+
1. Uses `state.energy` (smoothed PCM amplitude)
62+
2. Confident timestamps from `token.tid` when `pt > thold_pt && ptsum > thold_ptsum`
63+
3. Fills gaps by proportional splitting based on `vlen`
64+
4. **Energy-based refinement** (lines 8563-8631): Expands/contracts token boundaries using signal energy. This is a PRIMITIVE form of silence snapping already present -- but crude.
65+
66+
### Method 2: DTW (experimental, more accurate)
67+
68+
**`whisper_exp_compute_token_level_timestamps_dtw()`** at line 8815
69+
70+
1. Build token sequence: `[sot] + [lang] + [no_timestamps] + all_text_tokens + [eot]`
71+
2. Full decoder pass with `save_alignment_heads_QKs=true`
72+
3. Copy cross-attention QKs to CPU: shape `[n_tokens, n_audio_tokens, n_heads]`
73+
4. Normalize (line 8907)
74+
5. Median filter width 7 over audio dimension (line 8914)
75+
6. **Mean across heads** (line 8919) -- all selected heads weighted equally
76+
7. Scale by -1 (line 8920)
77+
8. Standard DTW + backtrace via `dtw_and_backtrace()` (line 8690)
78+
9. Assign timestamps from DTW path (lines 8940-8963)
79+
80+
**IMPORTANT:** DTW does NOT work with `flash_attn=true` (line 3708-3710) because flash attention doesn't expose intermediate attention weights.
81+
82+
Called at lines 7725-7728 after all segments created for a 30s window.
83+
84+
### Alignment Heads -- Hardcoded (lines 384-409)
85+
86+
```c
87+
static const whisper_ahead g_aheads_large_v3[] = {
88+
{7,0}, {10,17}, {12,18}, {13,12}, {16,1}, {17,14}, {19,11}, {21,4}, {24,1}, {25,6}
89+
};
90+
static const whisper_ahead g_aheads_large_v3_turbo[] = {
91+
{2,4}, {2,11}, {3,3}, {3,6}, {3,11}, {3,14}
92+
};
93+
```
94+
95+
Selected via `get_alignment_heads_by_layer()` (line 8666). Modes: preset-specific, N-top-most layers, or custom user-provided heads.
96+
97+
Masks built in `aheads_masks_init()` (line 1160), used during decoder graph construction at lines 2720-2734 in the cross-attention block.
98+
99+
### WHERE TO ADD IMPROVEMENTS:
100+
101+
**Gap padding:** In DTW function at line 8843-8860 when building token sequence. Insert `" ..."` tokens after `no_timestamps` but before text tokens. Adjust `sot_sequence_length`.
102+
103+
**Dynamic head selection:** At line 8919 (currently takes mean). Instead: score each head for monotonicity, select top-k, then average only those. Would need to expose all heads first (currently only preset heads captured).
104+
105+
## VAD Support (Already Exists!)
106+
107+
whisper.cpp has full Silero-style neural VAD:
108+
109+
- **`whisper_vad()`** at line 6621 -- called from `whisper_full()` when `params.vad == true`
110+
- Strips silence, concatenates speech segments with overlap
111+
- Builds `vad_mapping_table` to remap timestamps back to original audio
112+
- **Per-frame speech probabilities** available via `whisper_vad_probs()` API
113+
- Params: `threshold`, `min_speech_duration_ms`, `min_silence_duration_ms`, etc.
114+
115+
This is relevant because: we could use the existing VAD probabilities as input for the silence mask instead of building our own loudness-based detector (or offer both options like stable-ts).
116+
117+
## Segment Creation & Output
118+
119+
### How Segments Are Created (lines 7616-7718)
120+
1. Scan tokens for timestamp tokens (`id > whisper_token_beg()`)
121+
2. Text between timestamps -> segment with `t0`, `t1`, text, tokens
122+
3. Pushed to `result_all`
123+
4. If `token_timestamps == true`: per-segment token timestamps computed
124+
5. If DTW enabled: DTW timestamps computed per-window after all segments
125+
126+
### WHERE TO HOOK POST-HOC SNAPPING:
127+
128+
**Option A -- Internal:** After DTW (line 7735) or after non-DTW token timestamps (lines 7663/7708), iterate all segments and snap word boundaries to speech edges using silence mask.
129+
130+
**Option B -- End of pipeline:** Before `whisper_full_with_state()` returns (line 7753), as a final pass over all `result_all`.
131+
132+
**Option C -- New public API:** `whisper_snap_timestamps(ctx, state)` that callers invoke after `whisper_full()`. Cleanest, non-invasive.
133+
134+
## Existing Energy-Based "Snapping" (Primitive)
135+
136+
Lines 8563-8631 in `whisper_exp_compute_token_level_timestamps()`:
137+
- Computes energy sum in token's time range
138+
- Expands/contracts boundaries based on energy threshold
139+
- Already exists but is crude compared to stable-ts
140+
141+
## Key Constants
142+
143+
| Constant | Value | Meaning |
144+
|----------|-------|---------|
145+
| `WHISPER_SAMPLE_RATE` | 16000 | Hz |
146+
| `WHISPER_HOP_LENGTH` | 160 | samples per mel frame = 10ms |
147+
| `WHISPER_CHUNK_SIZE` | 30 | seconds per chunk |
148+
| `WHISPER_N_FFT` | 400 | FFT window size |
149+
| Audio token resolution | 320 samples = 20ms | Each audio ctx position |
150+
| Timestamp token resolution | 20ms | Each increment of timestamp token |
151+
| `n_audio_ctx` | 1500 | Audio tokens per 30s chunk |
152+
| `n_text_ctx` | 448 | Max text tokens |
153+
154+
## Public API Surface (relevant)
155+
156+
```c
157+
// After transcription:
158+
whisper_full_n_segments(ctx)
159+
whisper_full_get_segment_t0/t1(ctx, i) // centiseconds (1 = 10ms)
160+
whisper_full_get_segment_text(ctx, i)
161+
whisper_full_n_tokens(ctx, i)
162+
whisper_full_get_token_data(ctx, i, j) // -> whisper_token_data
163+
whisper_full_get_segment_no_speech_prob(ctx, i)
164+
165+
// Params:
166+
params.token_timestamps // enable non-DTW word timestamps
167+
params.max_initial_ts // default 1.0s (stable-ts sets to 0)
168+
params.logits_filter_callback // can inject custom logit filters externally
169+
ctx_params.dtw_token_timestamps // enable DTW mode
170+
ctx_params.dtw_aheads_preset // which alignment heads
171+
params.vad // enable built-in VAD
172+
```

0 commit comments

Comments
 (0)