88 MODEL_CONFIGS ,
99 ModelConfig ,
1010 MODES ,
11+ STREAMING_ACTION ,
1112} from '../constants/sttDefaults' ;
1213
1314const longCommonInfPref = ( seq1 : number [ ] , seq2 : number [ ] ) => {
@@ -46,6 +47,11 @@ export class SpeechToTextController {
4647 public isReady = false ;
4748 public isGenerating = false ;
4849 private modelName ! : 'moonshine' | 'whisper' ;
50+ private seqs : number [ ] [ ] = [ ] ;
51+ private prevSeq : number [ ] = [ ] ;
52+ private streamWaveform : number [ ] = [ ] ;
53+ private isDecodingChunk = false ;
54+ private numberOfDecodedChunks = 0 ;
4955
5056 // tokenizer tokens to string mapping used for decoding sequence
5157 private tokenMapping ! : { [ key : number ] : string } ;
@@ -190,6 +196,44 @@ export class SpeechToTextController {
190196 }
191197 }
192198
199+ private async decodeChunk ( chunk : number [ ] ) : Promise < number [ ] > {
200+ let lastToken = this . config . tokenizer . sos ;
201+ let seq = [ lastToken ] ;
202+ let prevSeqTokenIdx = 0 ;
203+ try {
204+ await this . nativeModule . encode ( chunk ) ;
205+ } catch ( error ) {
206+ this . onErrorCallback ?.( `Encode ${ error } ` ) ;
207+ return [ ] ;
208+ }
209+ while ( lastToken !== this . config . tokenizer . eos ) {
210+ try {
211+ lastToken = await this . nativeModule . decode ( seq ) ;
212+ } catch ( error ) {
213+ this . onErrorCallback ?.( `Decode ${ error } ` ) ;
214+ return [ ] ;
215+ }
216+ seq = [ ...seq , lastToken ] ;
217+ if (
218+ this . seqs . length > 0 &&
219+ seq . length < this . seqs . at ( - 1 ) ! . length &&
220+ seq . length % 3 !== 0
221+ ) {
222+ this . prevSeq = [ ...this . prevSeq , this . seqs . at ( - 1 ) ! [ prevSeqTokenIdx ++ ] ! ] ;
223+ this . decodedTranscribeCallback ( this . prevSeq ) ;
224+ }
225+ }
226+ return seq ;
227+ }
228+
229+ private handleOverlaps ( seqs : number [ ] [ ] ) : number [ ] {
230+ const maxInd = longCommonInfPref ( seqs . at ( - 2 ) ! , seqs . at ( - 1 ) ! ) ;
231+ let finalSeq = [ ...this . sequence , ...seqs . at ( - 2 ) ! . slice ( 0 , maxInd ) ] ;
232+ this . sequence = finalSeq ;
233+ this . decodedTranscribeCallback ( finalSeq ) ;
234+ return finalSeq ;
235+ }
236+
193237 public async transcribe ( waveform : number [ ] ) : Promise < string > {
194238 if ( ! this . isReady ) {
195239 this . onErrorCallback ?.( new Error ( 'Model is not yet ready' ) ) ;
@@ -200,90 +244,126 @@ export class SpeechToTextController {
200244 return '' ;
201245 }
202246 this . onErrorCallback ?.( undefined ) ;
247+ this . decodedTranscribeCallback ( [ ] ) ;
203248 this . isGeneratingCallback ( true ) ;
204249
205250 this . sequence = [ ] ;
206-
207- if ( ! waveform ) {
208- this . isGeneratingCallback ( false ) ;
209-
210- this . onErrorCallback ?.(
211- new Error (
212- `Nothing to transcribe, perhaps you forgot to call this.loadAudio().`
213- )
214- ) ;
215- }
216-
251+ this . seqs = [ ] ;
252+ this . prevSeq = [ ] ;
217253 this . chunkWaveform ( waveform ) ;
218254
219- let seqs : number [ ] [ ] = [ ] ;
220- let prevseq : number [ ] = [ ] ;
221255 for ( let chunkId = 0 ; chunkId < this . chunks . length ; chunkId ++ ) {
222- let lastToken = this . config . tokenizer . sos ;
223- let prevSeqTokenIdx = 0 ;
256+ let seq = await this . decodeChunk ( this . chunks ! . at ( chunkId ) ! ) ;
224257 let finalSeq : number [ ] = [ ] ;
225- let seq = [ lastToken ] ;
226- try {
227- await this . nativeModule . encode ( this . chunks ! . at ( chunkId ) ! ) ;
228- } catch ( error ) {
229- this . onErrorCallback ?.( `Encode ${ error } ` ) ;
230- return '' ;
231- }
232- while ( lastToken !== this . config . tokenizer . eos ) {
233- try {
234- lastToken = await this . nativeModule . decode ( seq ) ;
235- } catch ( error ) {
236- this . onErrorCallback ?.( `Decode ${ error } ` ) ;
237- return '' ;
238- }
239- seq = [ ...seq , lastToken ] ;
240- if (
241- seqs . length > 0 &&
242- seq . length < seqs . at ( - 1 ) ! . length &&
243- seq . length % 3 !== 0
244- ) {
245- prevseq = [ ...prevseq , seqs . at ( - 1 ) ! [ prevSeqTokenIdx ++ ] ! ] ;
246- this . decodedTranscribeCallback ( prevseq ) ;
247- }
248- }
249-
250258 if ( this . chunks . length === 1 ) {
251259 finalSeq = seq ;
252260 this . sequence = finalSeq ;
253261 this . decodedTranscribeCallback ( finalSeq ) ;
254262 break ;
255263 }
256264 // remove sos/eos token and 3 additional ones
257- if ( seqs . length === 0 ) {
258- seqs = [ seq . slice ( 0 , - 4 ) ] ;
259- } else if ( seqs . length === this . chunks . length - 1 ) {
260- seqs = [ ...seqs , seq . slice ( 4 ) ] ;
265+ if ( this . seqs . length === 0 ) {
266+ this . seqs = [ seq . slice ( 0 , - 4 ) ] ;
267+ } else if ( this . seqs . length === this . chunks . length - 1 ) {
268+ this . seqs = [ ... this . seqs , seq . slice ( 4 ) ] ;
261269 } else {
262- seqs = [ ...seqs , seq . slice ( 4 , - 4 ) ] ;
270+ this . seqs = [ ... this . seqs , seq . slice ( 4 , - 4 ) ] ;
263271 }
264- if ( seqs . length < 2 ) {
272+ if ( this . seqs . length < 2 ) {
265273 continue ;
266274 }
267275
268- const maxInd = longCommonInfPref ( seqs . at ( - 2 ) ! , seqs . at ( - 1 ) ! ) ;
269- finalSeq = [ ...this . sequence , ...seqs . at ( - 2 ) ! . slice ( 0 , maxInd ) ] ;
270- this . sequence = finalSeq ;
271- this . decodedTranscribeCallback ( finalSeq ) ;
272- prevseq = finalSeq ;
276+ this . prevSeq = this . handleOverlaps ( this . seqs ) ;
273277
274278 //last sequence processed
275- if ( seqs . length === this . chunks . length ) {
276- finalSeq = [ ...this . sequence , ...seqs . at ( - 1 ) ! ] ;
279+ if ( this . seqs . length === this . chunks . length ) {
280+ finalSeq = [ ...this . sequence , ...this . seqs . at ( - 1 ) ! ] ;
277281 this . sequence = finalSeq ;
278282 this . decodedTranscribeCallback ( finalSeq ) ;
279- prevseq = finalSeq ;
283+ this . prevSeq = finalSeq ;
280284 }
281285 }
282286 const decodedSeq = this . decodeSeq ( this . sequence ) ;
283287 this . isGeneratingCallback ( false ) ;
284288 return decodedSeq ;
285289 }
286290
291+ public async streamingTranscribe (
292+ waveform : number [ ] ,
293+ streamAction : STREAMING_ACTION
294+ ) : Promise < string > {
295+ if ( ! this . isReady ) {
296+ this . onErrorCallback ?.( new Error ( 'Model is not yet ready' ) ) ;
297+ return '' ;
298+ }
299+ this . onErrorCallback ?.( undefined ) ;
300+
301+ if ( streamAction == STREAMING_ACTION . START ) {
302+ this . sequence = [ ] ;
303+ this . seqs = [ ] ;
304+ this . streamWaveform = [ ] ;
305+ this . prevSeq = [ ] ;
306+ this . numberOfDecodedChunks = 0 ;
307+ this . decodedTranscribeCallback ( [ ] ) ;
308+ this . isGeneratingCallback ( true ) ;
309+ }
310+ this . streamWaveform = [ ...this . streamWaveform , ...waveform ] ;
311+ this . chunkWaveform ( this . streamWaveform ) ;
312+ if ( ! this . isDecodingChunk && streamAction != 2 ) {
313+ this . isDecodingChunk = true ;
314+ while (
315+ this . chunks . at ( this . numberOfDecodedChunks ) ?. length ==
316+ 2 * this . overlapSeconds + this . windowSize ||
317+ ( this . numberOfDecodedChunks == 0 &&
318+ this . chunks . at ( this . numberOfDecodedChunks ) ?. length ==
319+ this . windowSize + this . overlapSeconds )
320+ ) {
321+ let seq = await this . decodeChunk (
322+ this . chunks . at ( this . numberOfDecodedChunks ) !
323+ ) ;
324+ // remove sos/eos token and 3 additional ones
325+ if ( this . numberOfDecodedChunks == 0 ) {
326+ this . seqs = [ seq . slice ( 0 , - 4 ) ] ;
327+ } else {
328+ this . seqs = [ ...this . seqs , seq . slice ( 4 , - 4 ) ] ;
329+ this . prevSeq = this . handleOverlaps ( this . seqs ) ;
330+ }
331+ this . numberOfDecodedChunks ++ ;
332+ if ( this . seqs . length < 2 ) {
333+ continue ;
334+ }
335+ }
336+ this . isDecodingChunk = false ;
337+ }
338+ while (
339+ this . numberOfDecodedChunks < this . chunks . length &&
340+ streamAction == STREAMING_ACTION . STOP
341+ ) {
342+ let seq = await this . decodeChunk (
343+ this . chunks . at ( this . numberOfDecodedChunks ) !
344+ ) ;
345+ if ( this . numberOfDecodedChunks == 0 ) {
346+ this . sequence = seq ;
347+ this . decodedTranscribeCallback ( seq ) ;
348+ this . isGeneratingCallback ( false ) ;
349+ break ;
350+ }
351+ //last sequence processed
352+ if ( this . numberOfDecodedChunks == this . chunks . length - 1 ) {
353+ let finalSeq = [ ...this . sequence , ...seq ] ;
354+ this . sequence = finalSeq ;
355+ this . decodedTranscribeCallback ( finalSeq ) ;
356+ this . isGeneratingCallback ( false ) ;
357+ } else {
358+ this . seqs = [ ...this . seqs , seq . slice ( 4 , - 4 ) ] ;
359+ this . handleOverlaps ( this . seqs ) ;
360+ }
361+ this . numberOfDecodedChunks ++ ;
362+ }
363+ const decodedSeq = this . decodeSeq ( this . sequence ) ;
364+ return decodedSeq ;
365+ }
366+
287367 public decodeSeq ( seq ?: number [ ] ) : string {
288368 if ( ! this . modelName ) {
289369 this . onErrorCallback ?.(
0 commit comments