Skip to content

Commit c8ae48a

Browse files
buxukulinxiaodong
authored andcommitted
addon.node : support cancelling transcription via AbortSignal
1 parent 43d78af commit c8ae48a

3 files changed

Lines changed: 198 additions & 4 deletions

File tree

examples/addon.node/README.md

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,39 @@ Run the VAD example with performance comparison:
4444
node vad-example.js
4545
```
4646

47+
### Cancellation Usage
48+
49+
Run the cancellation example (cancels an in-flight transcription via `AbortSignal`):
50+
51+
```shell
52+
node cancel-example.js
53+
```
54+
55+
## Cancelling a transcription
56+
57+
An in-flight transcription can be cancelled by passing an `AbortSignal` as the `signal` parameter:
58+
59+
```javascript
60+
const ac = new AbortController();
61+
62+
const promise = whisperAsync({
63+
// ... other params ...
64+
signal: ac.signal,
65+
});
66+
67+
// cancel at any time
68+
ac.abort();
69+
70+
const result = await promise;
71+
// result.cancelled === true
72+
// result.transcription contains the segments transcribed before cancellation
73+
```
74+
75+
Cancellation is checked before each encoder run and before each ggml graph
76+
computation, so it usually takes effect within a fraction of a second.
77+
The promise resolves normally (it does not reject): `result.cancelled` is `true`
78+
and `result.transcription` contains the segments completed before the abort.
79+
4780
## Voice Activity Detection (VAD) Support
4881

4982
VAD can significantly improve transcription performance by only processing speech segments, which is especially beneficial for audio files with long periods of silence.
@@ -112,4 +145,5 @@ Both traditional whisper.cpp parameters and new VAD parameters are supported:
112145
- `comma_in_time`: Use comma in timestamps (default: true)
113146
- `print_progress`: Print progress info (default: false)
114147
- `progress_callback`: Progress callback function
148+
- `signal`: `AbortSignal` used to cancel the transcription (see above section)
115149
- VAD parameters (see above section)

examples/addon.node/addon.cpp

Lines changed: 49 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
#include "whisper.h"
66

7+
#include <atomic>
8+
#include <memory>
79
#include <string>
810
#include <thread>
911
#include <vector>
@@ -149,8 +151,9 @@ struct whisper_result {
149151

150152
class ProgressWorker : public Napi::AsyncWorker {
151153
public:
152-
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env)
153-
: Napi::AsyncWorker(callback), params(params), env(env) {
154+
ProgressWorker(Napi::Function& callback, whisper_params params, Napi::Function progress_callback, Napi::Env env,
155+
std::shared_ptr<std::atomic<bool>> is_aborted)
156+
: Napi::AsyncWorker(callback), params(params), env(env), is_aborted(std::move(is_aborted)) {
154157
// Create thread-safe function
155158
if (!progress_callback.IsEmpty()) {
156159
tsfn = Napi::ThreadSafeFunction::New(
@@ -185,6 +188,7 @@ class ProgressWorker : public Napi::AsyncWorker {
185188
}
186189

187190
Napi::Object returnObj = Napi::Object::New(Env());
191+
returnObj.Set("cancelled", Napi::Boolean::New(Env(), is_aborted->load()));
188192
if (!result.language.empty()) {
189193
returnObj.Set("language", Napi::String::New(Env(), result.language));
190194
}
@@ -217,6 +221,7 @@ class ProgressWorker : public Napi::AsyncWorker {
217221
whisper_result result;
218222
Napi::Env env;
219223
Napi::ThreadSafeFunction tsfn;
224+
std::shared_ptr<std::atomic<bool>> is_aborted;
220225

221226
// Custom run function with progress callback support
222227
int run_with_progress(whisper_params &params, whisper_result & result) {
@@ -344,6 +349,18 @@ class ProgressWorker : public Napi::AsyncWorker {
344349
};
345350
wparams.progress_callback_user_data = this;
346351

352+
// Cancellation support: checked before each encoder run (coarse)
353+
// and before each ggml graph computation (fine)
354+
wparams.encoder_begin_callback = [](struct whisper_context * /*ctx*/, struct whisper_state * /*state*/, void * user_data) {
355+
return !static_cast<std::atomic<bool>*>(user_data)->load();
356+
};
357+
wparams.encoder_begin_callback_user_data = is_aborted.get();
358+
359+
wparams.abort_callback = [](void * user_data) {
360+
return static_cast<std::atomic<bool>*>(user_data)->load();
361+
};
362+
wparams.abort_callback_user_data = is_aborted.get();
363+
347364
// Set VAD parameters
348365
wparams.vad = params.vad;
349366
wparams.vad_model_path = params.vad_model.c_str();
@@ -355,8 +372,16 @@ class ProgressWorker : public Napi::AsyncWorker {
355372
wparams.vad_params.speech_pad_ms = params.vad_speech_pad_ms;
356373
wparams.vad_params.samples_overlap = params.vad_samples_overlap;
357374

358-
if (whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors) != 0) {
375+
const int ret = whisper_full_parallel(ctx, wparams, pcmf32.data(), pcmf32.size(), params.n_processors);
376+
377+
if (is_aborted->load()) {
378+
// cancelled - keep the segments transcribed so far
379+
break;
380+
}
381+
382+
if (ret != 0) {
359383
fprintf(stderr, "failed to process audio\n");
384+
whisper_free(ctx);
360385
return 10;
361386
}
362387
}
@@ -538,9 +563,29 @@ Napi::Value whisper(const Napi::CallbackInfo& info) {
538563
params.vad_speech_pad_ms = vad_speech_pad_ms;
539564
params.vad_samples_overlap = vad_samples_overlap;
540565

566+
// Cancellation support: an AbortSignal can be passed via params.signal.
567+
// Its "abort" event sets a shared flag which is polled by the whisper.cpp
568+
// abort callbacks on the worker thread.
569+
auto is_aborted = std::make_shared<std::atomic<bool>>(false);
570+
if (whisper_params.Has("signal") && whisper_params.Get("signal").IsObject()) {
571+
Napi::Object signal = whisper_params.Get("signal").As<Napi::Object>();
572+
573+
if (signal.Get("aborted").ToBoolean().Value()) {
574+
is_aborted->store(true);
575+
} else if (signal.Has("addEventListener") && signal.Get("addEventListener").IsFunction()) {
576+
Napi::Function add_listener = signal.Get("addEventListener").As<Napi::Function>();
577+
Napi::Function on_abort = Napi::Function::New(env, [is_aborted](const Napi::CallbackInfo &) {
578+
is_aborted->store(true);
579+
});
580+
Napi::Object options = Napi::Object::New(env);
581+
options.Set("once", Napi::Boolean::New(env, true));
582+
add_listener.Call(signal, { Napi::String::New(env, "abort"), on_abort, options });
583+
}
584+
}
585+
541586
Napi::Function callback = info[1].As<Napi::Function>();
542587
// Create a new Worker class with progress callback support
543-
ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env);
588+
ProgressWorker* worker = new ProgressWorker(callback, params, progress_callback, env, is_aborted);
544589
worker->Queue();
545590
return env.Undefined();
546591
}
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
// Demonstrates cancelling an in-flight transcription via AbortSignal (params.signal).
2+
//
3+
// Usage: node cancel-example.js [--model=path/to/model.bin]
4+
5+
const path = require("path");
6+
const os = require("os");
7+
const { promisify } = require("util");
8+
9+
const isWindows = os.platform() === "win32";
10+
const buildPath = isWindows ? "../../build/bin/Release/addon.node" : "../../build/Release/addon.node";
11+
const { whisper } = require(path.join(__dirname, buildPath));
12+
13+
const whisperAsync = promisify(whisper);
14+
15+
const modelArg = process.argv.find((a) => a.startsWith("--model="));
16+
const model = modelArg
17+
? modelArg.slice("--model=".length)
18+
: path.join(__dirname, "../../models/ggml-base.en.bin");
19+
20+
// Long synthetic audio (tone + noise) so the transcription runs long enough
21+
// to be cancelled mid-flight.
22+
function syntheticAudio(seconds) {
23+
const n = 16000 * seconds;
24+
const pcm = new Float32Array(n);
25+
for (let i = 0; i < n; i++) {
26+
pcm[i] = 0.05 * Math.sin((2 * Math.PI * 440 * i) / 16000) + (Math.random() - 0.5) * 0.02;
27+
}
28+
return pcm;
29+
}
30+
31+
const baseParams = {
32+
language: "en",
33+
model,
34+
use_gpu: true,
35+
no_prints: true,
36+
no_timestamps: false,
37+
comma_in_time: false,
38+
};
39+
40+
async function cancelMidFlight() {
41+
console.log("--- test 1: cancel mid-transcription ---");
42+
const ac = new AbortController();
43+
const progressSeen = [];
44+
45+
const t0 = Date.now();
46+
const promise = whisperAsync({
47+
...baseParams,
48+
fname_inp: "",
49+
pcmf32: syntheticAudio(600),
50+
signal: ac.signal,
51+
progress_callback: (p) => {
52+
progressSeen.push(p);
53+
console.log(`progress: ${p}%`);
54+
if (!ac.signal.aborted) {
55+
console.log(">>> calling abort()");
56+
ac.abort();
57+
}
58+
},
59+
});
60+
61+
const result = await promise;
62+
const elapsed = Date.now() - t0;
63+
64+
console.log(`cancelled = ${result.cancelled}, segments = ${result.transcription.length}, elapsed = ${elapsed} ms`);
65+
if (result.cancelled !== true) throw new Error("FAIL: expected cancelled === true");
66+
if (progressSeen.includes(100)) throw new Error("FAIL: transcription ran to completion, was not cancelled");
67+
console.log("PASS\n");
68+
}
69+
70+
async function preAbortedSignal() {
71+
console.log("--- test 2: already-aborted signal ---");
72+
const ac = new AbortController();
73+
ac.abort();
74+
75+
const t0 = Date.now();
76+
const result = await whisperAsync({
77+
...baseParams,
78+
fname_inp: "",
79+
pcmf32: syntheticAudio(600),
80+
signal: ac.signal,
81+
});
82+
const elapsed = Date.now() - t0;
83+
84+
console.log(`cancelled = ${result.cancelled}, segments = ${result.transcription.length}, elapsed = ${elapsed} ms`);
85+
if (result.cancelled !== true) throw new Error("FAIL: expected cancelled === true");
86+
if (result.transcription.length !== 0) throw new Error("FAIL: expected no segments");
87+
console.log("PASS\n");
88+
}
89+
90+
async function normalRun() {
91+
console.log("--- test 3: normal run without signal (regression) ---");
92+
const t0 = Date.now();
93+
const result = await whisperAsync({
94+
...baseParams,
95+
fname_inp: path.join(__dirname, "../../samples/jfk.wav"),
96+
});
97+
const elapsed = Date.now() - t0;
98+
99+
const text = result.transcription.map((s) => s[2]).join(" ");
100+
console.log(`cancelled = ${result.cancelled}, segments = ${result.transcription.length}, elapsed = ${elapsed} ms`);
101+
console.log(`text: ${text.trim()}`);
102+
if (result.cancelled !== false) throw new Error("FAIL: expected cancelled === false");
103+
if (!text.toLowerCase().includes("ask not")) throw new Error("FAIL: unexpected transcription");
104+
console.log("PASS\n");
105+
}
106+
107+
(async () => {
108+
await cancelMidFlight();
109+
await preAbortedSignal();
110+
await normalRun();
111+
console.log("ALL TESTS PASSED");
112+
})().catch((err) => {
113+
console.error(err);
114+
process.exit(1);
115+
});

0 commit comments

Comments
 (0)