@@ -10,21 +10,22 @@ import {
1010 MODEL_CONFIGS ,
1111 ModelConfig ,
1212} from '../constants/sttDefaults' ;
13+ import { unicodeToBytes } from '../utils/tokenizerUtils' ;
1314
1415const longCommonInfPref = ( seq1 : number [ ] , seq2 : number [ ] ) => {
1516 let maxInd = 0 ;
1617 let maxLength = 0 ;
1718
1819 for ( let i = 0 ; i < seq1 . length ; i ++ ) {
1920 let j = 0 ;
20- let hamming_dist = 0 ;
21+ let hammingDist = 0 ;
2122 while (
2223 j < seq2 . length &&
2324 i + j < seq1 . length &&
24- ( seq1 [ i + j ] === seq2 [ j ] || hamming_dist < HAMMING_DIST_THRESHOLD )
25+ ( seq1 [ i + j ] === seq2 [ j ] || hammingDist < HAMMING_DIST_THRESHOLD )
2526 ) {
2627 if ( seq1 [ i + j ] !== seq2 [ j ] ) {
27- hamming_dist ++ ;
28+ hammingDist ++ ;
2829 }
2930 j ++ ;
3031 }
@@ -54,6 +55,8 @@ export class SpeechToTextController {
5455
5556 // tokenizer tokens to string mapping used for decoding sequence
5657 private tokenMapping ! : { [ key : number ] : string } ;
58+ private textDecoder : any ;
59+ private charDecoder : { [ key : string ] : number } ;
5760
5861 // User callbacks
5962 private decodedTranscribeCallback : ( sequence : number [ ] ) => void ;
@@ -93,6 +96,8 @@ export class SpeechToTextController {
9396 this . isGenerating = isGenerating ;
9497 isGeneratingCallback ?.( isGenerating ) ;
9598 } ;
99+ this . charDecoder = unicodeToBytes ( ) ;
100+ this . textDecoder = new TextDecoder ( 'utf-8' , { fatal : false } ) ;
96101 this . onErrorCallback = onErrorCallback ;
97102 this . audioContext = new AudioContext ( { sampleRate : SAMPLE_RATE } ) ;
98103 this . nativeModule = new _SpeechToTextModule ( ) ;
@@ -205,7 +210,6 @@ export class SpeechToTextController {
205210
206211 if ( ! waveform ) {
207212 this . isGeneratingCallback ( false ) ;
208-
209213 this . onErrorCallback ?.(
210214 new Error (
211215 `Nothing to transcribe, perhaps you forgot to call this.loadAudio().`
@@ -216,49 +220,51 @@ export class SpeechToTextController {
216220 this . chunkWaveform ( waveform ) ;
217221
218222 let seqs : number [ ] [ ] = [ ] ;
219- let prevseq : number [ ] = [ ] ;
220- for ( let chunk_id = 0 ; chunk_id < this . chunks . length ; chunk_id ++ ) {
221- let last_token = this . config . tokenizer . sos ;
222- let prev_seq_token_idx = 0 ;
223- let final_seq : number [ ] = [ ] ;
224- let seq = [ last_token ] ;
225- let enc_output ;
223+ let prevSeq : number [ ] = [ ] ;
224+ for ( let chunkId = 0 ; chunkId < this . chunks . length ; chunkId ++ ) {
225+ let lastToken = this . config . tokenizer . bos ;
226+ let prevSeqTokenIdx = 0 ;
227+ let finalSeq : number [ ] = [ ] ;
228+ let seq = [ lastToken ] ;
229+ let encoderOutput ;
226230 try {
227- enc_output = await this . nativeModule . encode ( this . chunks ! . at ( chunk_id ) ! ) ;
231+ encoderOutput = await this . nativeModule . encode (
232+ this . chunks ! . at ( chunkId ) !
233+ ) ;
228234 } catch ( error ) {
229235 this . onErrorCallback ?.( `Encode ${ error } ` ) ;
230236 return '' ;
231237 }
232- while ( last_token !== this . config . tokenizer . eos ) {
238+ while ( lastToken !== this . config . tokenizer . eos ) {
233239 let output ;
234240 try {
235- output = await this . nativeModule . decode ( seq , [ enc_output ] ) ;
241+ output = await this . nativeModule . decode ( seq , [ encoderOutput ] ) ;
236242 } catch ( error ) {
237243 this . onErrorCallback ?.( `Decode ${ error } ` ) ;
238244 return '' ;
239245 }
240246 if ( typeof output === 'number' ) {
241- last_token = output ;
247+ lastToken = output ;
242248 } else {
243- last_token = output [ output . length - 1 ] ;
249+ lastToken = output [ output . length - 1 ] ;
244250 }
245- seq = [ ... seq , last_token ] ;
251+ seq . push ( lastToken ) ;
246252 if (
247253 seqs . length > 0 &&
248254 seq . length < seqs . at ( - 1 ) ! . length &&
249255 seq . length % 3 !== 0
250256 ) {
251- prevseq = [ ...prevseq , seqs . at ( - 1 ) ! [ prev_seq_token_idx ++ ] ! ] ;
252- this . decodedTranscribeCallback ( prevseq ) ;
257+ prevSeq = [ ...prevSeq , seqs . at ( - 1 ) ! [ prevSeqTokenIdx ++ ] ! ] ;
258+ this . decodedTranscribeCallback ( prevSeq ) ;
253259 }
254260 }
255261 if ( this . chunks . length === 1 ) {
256- final_seq = seq ;
257- this . sequence = final_seq ;
258- this . decodedTranscribeCallback ( final_seq ) ;
262+ finalSeq = seq ;
263+ this . sequence = finalSeq ;
264+ this . decodedTranscribeCallback ( finalSeq ) ;
259265 break ;
260266 }
261- // remove sos /eos token and 3 additional ones
267+ // remove bos /eos token and 3 additional ones
262268 if ( seqs . length === 0 ) {
263269 seqs = [ seq . slice ( 0 , - 4 ) ] ;
264270 } else if (
@@ -274,25 +280,25 @@ export class SpeechToTextController {
274280 }
275281
276282 const maxInd = longCommonInfPref ( seqs . at ( - 2 ) ! , seqs . at ( - 1 ) ! ) ;
277- final_seq = [ ...this . sequence , ...seqs . at ( - 2 ) ! . slice ( 0 , maxInd ) ] ;
278- this . sequence = final_seq ;
279- this . decodedTranscribeCallback ( final_seq ) ;
280- prevseq = final_seq ;
283+ finalSeq = [ ...this . sequence , ...seqs . at ( - 2 ) ! . slice ( 0 , maxInd ) ] ;
284+ this . sequence = finalSeq ;
285+ this . decodedTranscribeCallback ( finalSeq ) ;
286+ prevSeq = finalSeq ;
281287
282- //last sequence processed
288+ // last sequence processed
283289 if ( seqs . length === Math . ceil ( waveform . length / this . windowSize ) ) {
284- final_seq = [ ...this . sequence , ...seqs . at ( - 1 ) ! ] ;
285- this . sequence = final_seq ;
286- this . decodedTranscribeCallback ( final_seq ) ;
287- prevseq = final_seq ;
290+ finalSeq = [ ...this . sequence , ...seqs . at ( - 1 ) ! ] ;
291+ this . sequence = finalSeq ;
292+ this . decodedTranscribeCallback ( finalSeq ) ;
293+ prevSeq = finalSeq ;
288294 }
289295 }
290296 const decodedSeq = this . decodeSeq ( this . sequence ) ;
291297 this . isGeneratingCallback ( false ) ;
292298 return decodedSeq ;
293299 }
294300
295- public decodeSeq ( seq ?: number [ ] ) : string {
301+ private decodeSeq ( seq ?: number [ ] ) : string {
296302 if ( ! this . modelName ) {
297303 this . onErrorCallback ?.(
298304 new Error ( 'Model is not loaded, call `loadModel` first' )
@@ -302,14 +308,22 @@ export class SpeechToTextController {
302308 this . onErrorCallback ?.( undefined ) ;
303309 if ( ! seq ) seq = this . sequence ;
304310
305- return seq
311+ const decodedSeq = seq
306312 . filter (
307- ( token ) =>
308- token !== this . config . tokenizer . eos &&
309- token !== this . config . tokenizer . sos
313+ ( tokenId ) =>
314+ tokenId !== this . config . tokenizer . eos &&
315+ tokenId !== this . config . tokenizer . bos
310316 )
311- . map ( ( token ) => this . tokenMapping [ token ] )
312- . join ( '' )
313- . replaceAll ( this . config . tokenizer . special_char , ' ' ) ;
317+ . map ( ( tokenId ) => this . tokenMapping [ tokenId ] )
318+ . join ( '' ) ;
319+
320+ let byteArray = Array . from ( decodedSeq ) . map (
321+ ( char ) => this . charDecoder [ char ]
322+ ) ;
323+ const text = this . textDecoder . decode (
324+ new Uint8Array ( byteArray as number [ ] ) ,
325+ { stream : false }
326+ ) ;
327+ return text ;
314328 }
315329}
0 commit comments