Skip to content

Commit 271e6bb

Browse files
committed
feat: s2t streaming transcription
1 parent 2145796 commit 271e6bb

4 files changed

Lines changed: 160 additions & 57 deletions

File tree

examples/speech-to-text/screens/SpeechToTextScreen.tsx

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useSpeechToText } from 'react-native-executorch';
1+
import { useSpeechToText, STREAMING_ACTION } from 'react-native-executorch';
22
import {
33
Text,
44
View,
@@ -51,7 +51,13 @@ export const SpeechToTextScreen = () => {
5151
sequence,
5252
error,
5353
transcribe,
54-
} = useSpeechToText({ modelName: 'moonshine', streamingConfig: 'balanced' });
54+
streamingTranscribe,
55+
} = useSpeechToText({
56+
modelName: 'moonshine',
57+
streamingConfig: 'balanced',
58+
windowSize: 5,
59+
overlapSeconds: 1.2,
60+
});
5561

5662
const loadAudio = async (url: string) => {
5763
const audioContext = new AudioContext({ sampleRate: 16e3 });
@@ -72,16 +78,18 @@ export const SpeechToTextScreen = () => {
7278
const onChunk = (data: string) => {
7379
const float32Chunk = float32ArrayFromPCMBinaryBuffer(data);
7480
audioBuffer.current?.push(...float32Chunk);
81+
streamingTranscribe(audioBuffer.current, STREAMING_ACTION.DATA);
7582
};
7683

7784
const handleRecordPress = async () => {
7885
if (isRecording) {
7986
LiveAudioStream.stop();
8087
setIsRecording(false);
81-
await transcribe(audioBuffer.current);
88+
await streamingTranscribe(audioBuffer.current, STREAMING_ACTION.STOP);
8289
audioBuffer.current = [];
8390
} else {
8491
setIsRecording(true);
92+
streamingTranscribe(audioBuffer.current, STREAMING_ACTION.START);
8593
startStreamingAudio(audioStreamOptions, onChunk);
8694
}
8795
};
@@ -133,6 +141,7 @@ export const SpeechToTextScreen = () => {
133141
setModalVisible(visible);
134142
if (audioUrl) {
135143
await transcribe(await loadAudio(audioUrl));
144+
setAudioUrl('');
136145
}
137146
}}
138147
onChangeText={setAudioUrl}
@@ -156,6 +165,7 @@ export const SpeechToTextScreen = () => {
156165
setModalVisible(true);
157166
} else {
158167
await transcribe(await loadAudio(audioUrl));
168+
setAudioUrl('');
159169
}
160170
}}
161171
>

src/constants/sttDefaults.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,9 @@ export const MODES = {
6666
overlapSeconds: 3,
6767
},
6868
};
69+
70+
export enum STREAMING_ACTION {
71+
START,
72+
DATA,
73+
STOP,
74+
}

src/controllers/SpeechToTextController.ts

Lines changed: 134 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import {
88
MODEL_CONFIGS,
99
ModelConfig,
1010
MODES,
11+
STREAMING_ACTION,
1112
} from '../constants/sttDefaults';
1213

1314
const longCommonInfPref = (seq1: number[], seq2: number[]) => {
@@ -46,6 +47,11 @@ export class SpeechToTextController {
4647
public isReady = false;
4748
public isGenerating = false;
4849
private modelName!: 'moonshine' | 'whisper';
50+
private seqs: number[][] = [];
51+
private prevSeq: number[] = [];
52+
private streamWaveform: number[] = [];
53+
private isDecodingChunk = false;
54+
private numberOfDecodedChunks = 0;
4955

5056
// tokenizer tokens to string mapping used for decoding sequence
5157
private tokenMapping!: { [key: number]: string };
@@ -190,6 +196,44 @@ export class SpeechToTextController {
190196
}
191197
}
192198

199+
private async decodeChunk(chunk: number[]): Promise<number[]> {
200+
let lastToken = this.config.tokenizer.sos;
201+
let seq = [lastToken];
202+
let prevSeqTokenIdx = 0;
203+
try {
204+
await this.nativeModule.encode(chunk);
205+
} catch (error) {
206+
this.onErrorCallback?.(`Encode ${error}`);
207+
return [];
208+
}
209+
while (lastToken !== this.config.tokenizer.eos) {
210+
try {
211+
lastToken = await this.nativeModule.decode(seq);
212+
} catch (error) {
213+
this.onErrorCallback?.(`Decode ${error}`);
214+
return [];
215+
}
216+
seq = [...seq, lastToken];
217+
if (
218+
this.seqs.length > 0 &&
219+
seq.length < this.seqs.at(-1)!.length &&
220+
seq.length % 3 !== 0
221+
) {
222+
this.prevSeq = [...this.prevSeq, this.seqs.at(-1)![prevSeqTokenIdx++]!];
223+
this.decodedTranscribeCallback(this.prevSeq);
224+
}
225+
}
226+
return seq;
227+
}
228+
229+
private handleOverlaps(seqs: number[][]): number[] {
230+
const maxInd = longCommonInfPref(seqs.at(-2)!, seqs.at(-1)!);
231+
let finalSeq = [...this.sequence, ...seqs.at(-2)!.slice(0, maxInd)];
232+
this.sequence = finalSeq;
233+
this.decodedTranscribeCallback(finalSeq);
234+
return finalSeq;
235+
}
236+
193237
public async transcribe(waveform: number[]): Promise<string> {
194238
if (!this.isReady) {
195239
this.onErrorCallback?.(new Error('Model is not yet ready'));
@@ -200,90 +244,126 @@ export class SpeechToTextController {
200244
return '';
201245
}
202246
this.onErrorCallback?.(undefined);
247+
this.decodedTranscribeCallback([]);
203248
this.isGeneratingCallback(true);
204249

205250
this.sequence = [];
206-
207-
if (!waveform) {
208-
this.isGeneratingCallback(false);
209-
210-
this.onErrorCallback?.(
211-
new Error(
212-
`Nothing to transcribe, perhaps you forgot to call this.loadAudio().`
213-
)
214-
);
215-
}
216-
251+
this.seqs = [];
252+
this.prevSeq = [];
217253
this.chunkWaveform(waveform);
218254

219-
let seqs: number[][] = [];
220-
let prevseq: number[] = [];
221255
for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
222-
let lastToken = this.config.tokenizer.sos;
223-
let prevSeqTokenIdx = 0;
256+
let seq = await this.decodeChunk(this.chunks!.at(chunkId)!);
224257
let finalSeq: number[] = [];
225-
let seq = [lastToken];
226-
try {
227-
await this.nativeModule.encode(this.chunks!.at(chunkId)!);
228-
} catch (error) {
229-
this.onErrorCallback?.(`Encode ${error}`);
230-
return '';
231-
}
232-
while (lastToken !== this.config.tokenizer.eos) {
233-
try {
234-
lastToken = await this.nativeModule.decode(seq);
235-
} catch (error) {
236-
this.onErrorCallback?.(`Decode ${error}`);
237-
return '';
238-
}
239-
seq = [...seq, lastToken];
240-
if (
241-
seqs.length > 0 &&
242-
seq.length < seqs.at(-1)!.length &&
243-
seq.length % 3 !== 0
244-
) {
245-
prevseq = [...prevseq, seqs.at(-1)![prevSeqTokenIdx++]!];
246-
this.decodedTranscribeCallback(prevseq);
247-
}
248-
}
249-
250258
if (this.chunks.length === 1) {
251259
finalSeq = seq;
252260
this.sequence = finalSeq;
253261
this.decodedTranscribeCallback(finalSeq);
254262
break;
255263
}
256264
// remove sos/eos token and 3 additional ones
257-
if (seqs.length === 0) {
258-
seqs = [seq.slice(0, -4)];
259-
} else if (seqs.length === this.chunks.length - 1) {
260-
seqs = [...seqs, seq.slice(4)];
265+
if (this.seqs.length === 0) {
266+
this.seqs = [seq.slice(0, -4)];
267+
} else if (this.seqs.length === this.chunks.length - 1) {
268+
this.seqs = [...this.seqs, seq.slice(4)];
261269
} else {
262-
seqs = [...seqs, seq.slice(4, -4)];
270+
this.seqs = [...this.seqs, seq.slice(4, -4)];
263271
}
264-
if (seqs.length < 2) {
272+
if (this.seqs.length < 2) {
265273
continue;
266274
}
267275

268-
const maxInd = longCommonInfPref(seqs.at(-2)!, seqs.at(-1)!);
269-
finalSeq = [...this.sequence, ...seqs.at(-2)!.slice(0, maxInd)];
270-
this.sequence = finalSeq;
271-
this.decodedTranscribeCallback(finalSeq);
272-
prevseq = finalSeq;
276+
this.prevSeq = this.handleOverlaps(this.seqs);
273277

274278
//last sequence processed
275-
if (seqs.length === this.chunks.length) {
276-
finalSeq = [...this.sequence, ...seqs.at(-1)!];
279+
if (this.seqs.length === this.chunks.length) {
280+
finalSeq = [...this.sequence, ...this.seqs.at(-1)!];
277281
this.sequence = finalSeq;
278282
this.decodedTranscribeCallback(finalSeq);
279-
prevseq = finalSeq;
283+
this.prevSeq = finalSeq;
280284
}
281285
}
282286
const decodedSeq = this.decodeSeq(this.sequence);
283287
this.isGeneratingCallback(false);
284288
return decodedSeq;
285289
}
286290

291+
public async streamingTranscribe(
292+
waveform: number[],
293+
streamAction: STREAMING_ACTION
294+
): Promise<string> {
295+
if (!this.isReady) {
296+
this.onErrorCallback?.(new Error('Model is not yet ready'));
297+
return '';
298+
}
299+
this.onErrorCallback?.(undefined);
300+
301+
if (streamAction == STREAMING_ACTION.START) {
302+
this.sequence = [];
303+
this.seqs = [];
304+
this.streamWaveform = [];
305+
this.prevSeq = [];
306+
this.numberOfDecodedChunks = 0;
307+
this.decodedTranscribeCallback([]);
308+
this.isGeneratingCallback(true);
309+
}
310+
this.streamWaveform = [...this.streamWaveform, ...waveform];
311+
this.chunkWaveform(this.streamWaveform);
312+
if (!this.isDecodingChunk && streamAction != 2) {
313+
this.isDecodingChunk = true;
314+
while (
315+
this.chunks.at(this.numberOfDecodedChunks)?.length ==
316+
2 * this.overlapSeconds + this.windowSize ||
317+
(this.numberOfDecodedChunks == 0 &&
318+
this.chunks.at(this.numberOfDecodedChunks)?.length ==
319+
this.windowSize + this.overlapSeconds)
320+
) {
321+
let seq = await this.decodeChunk(
322+
this.chunks.at(this.numberOfDecodedChunks)!
323+
);
324+
// remove sos/eos token and 3 additional ones
325+
if (this.numberOfDecodedChunks == 0) {
326+
this.seqs = [seq.slice(0, -4)];
327+
} else {
328+
this.seqs = [...this.seqs, seq.slice(4, -4)];
329+
this.prevSeq = this.handleOverlaps(this.seqs);
330+
}
331+
this.numberOfDecodedChunks++;
332+
if (this.seqs.length < 2) {
333+
continue;
334+
}
335+
}
336+
this.isDecodingChunk = false;
337+
}
338+
while (
339+
this.numberOfDecodedChunks < this.chunks.length &&
340+
streamAction == STREAMING_ACTION.STOP
341+
) {
342+
let seq = await this.decodeChunk(
343+
this.chunks.at(this.numberOfDecodedChunks)!
344+
);
345+
if (this.numberOfDecodedChunks == 0) {
346+
this.sequence = seq;
347+
this.decodedTranscribeCallback(seq);
348+
this.isGeneratingCallback(false);
349+
break;
350+
}
351+
//last sequence processed
352+
if (this.numberOfDecodedChunks == this.chunks.length - 1) {
353+
let finalSeq = [...this.sequence, ...seq];
354+
this.sequence = finalSeq;
355+
this.decodedTranscribeCallback(finalSeq);
356+
this.isGeneratingCallback(false);
357+
} else {
358+
this.seqs = [...this.seqs, seq.slice(4, -4)];
359+
this.handleOverlaps(this.seqs);
360+
}
361+
this.numberOfDecodedChunks++;
362+
}
363+
const decodedSeq = this.decodeSeq(this.sequence);
364+
return decodedSeq;
365+
}
366+
287367
public decodeSeq(seq?: number[]): string {
288368
if (!this.modelName) {
289369
this.onErrorCallback?.(

src/hooks/natural_language_processing/useSpeechToText.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { useEffect, useState } from 'react';
22
import { SpeechToTextController } from '../../controllers/SpeechToTextController';
33
import { ResourceSource } from '../../types/common';
4+
import { STREAMING_ACTION } from '../../constants/sttDefaults';
45

56
interface SpeechToTextModule {
67
isReady: boolean;
@@ -12,6 +13,10 @@ interface SpeechToTextModule {
1213
transcribe: (
1314
input: number[]
1415
) => ReturnType<SpeechToTextController['transcribe']>;
16+
streamingTranscribe: (
17+
input: number[],
18+
streamAction: STREAMING_ACTION
19+
) => ReturnType<SpeechToTextController['streamingTranscribe']>;
1520
}
1621

1722
export const useSpeechToText = ({
@@ -77,5 +82,7 @@ export const useSpeechToText = ({
7782
sequence,
7883
error,
7984
transcribe: (waveform: number[]) => model.transcribe(waveform),
85+
streamingTranscribe: (waveform: number[], streamAction: STREAMING_ACTION) =>
86+
model.streamingTranscribe(waveform, streamAction),
8087
};
8188
};

0 commit comments

Comments
 (0)