Skip to content

Commit 4a736a5

Browse files
committed
feat: add char decoding to speechToTextController.ts
1 parent 343bb02 commit 4a736a5

1 file changed

Lines changed: 54 additions & 40 deletions

File tree

src/controllers/SpeechToTextController.ts

Lines changed: 54 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,22 @@ import {
1010
MODEL_CONFIGS,
1111
ModelConfig,
1212
} from '../constants/sttDefaults';
13+
import { unicodeToBytes } from '../utils/tokenizerUtils';
1314

1415
const longCommonInfPref = (seq1: number[], seq2: number[]) => {
1516
let maxInd = 0;
1617
let maxLength = 0;
1718

1819
for (let i = 0; i < seq1.length; i++) {
1920
let j = 0;
20-
let hamming_dist = 0;
21+
let hammingDist = 0;
2122
while (
2223
j < seq2.length &&
2324
i + j < seq1.length &&
24-
(seq1[i + j] === seq2[j] || hamming_dist < HAMMING_DIST_THRESHOLD)
25+
(seq1[i + j] === seq2[j] || hammingDist < HAMMING_DIST_THRESHOLD)
2526
) {
2627
if (seq1[i + j] !== seq2[j]) {
27-
hamming_dist++;
28+
hammingDist++;
2829
}
2930
j++;
3031
}
@@ -54,6 +55,8 @@ export class SpeechToTextController {
5455

5556
// tokenizer tokens to string mapping used for decoding sequence
5657
private tokenMapping!: { [key: number]: string };
58+
private textDecoder: any;
59+
private charDecoder: { [key: string]: number };
5760

5861
// User callbacks
5962
private decodedTranscribeCallback: (sequence: number[]) => void;
@@ -93,6 +96,8 @@ export class SpeechToTextController {
9396
this.isGenerating = isGenerating;
9497
isGeneratingCallback?.(isGenerating);
9598
};
99+
this.charDecoder = unicodeToBytes();
100+
this.textDecoder = new TextDecoder('utf-8', { fatal: false });
96101
this.onErrorCallback = onErrorCallback;
97102
this.audioContext = new AudioContext({ sampleRate: SAMPLE_RATE });
98103
this.nativeModule = new _SpeechToTextModule();
@@ -205,7 +210,6 @@ export class SpeechToTextController {
205210

206211
if (!waveform) {
207212
this.isGeneratingCallback(false);
208-
209213
this.onErrorCallback?.(
210214
new Error(
211215
`Nothing to transcribe, perhaps you forgot to call this.loadAudio().`
@@ -216,49 +220,51 @@ export class SpeechToTextController {
216220
this.chunkWaveform(waveform);
217221

218222
let seqs: number[][] = [];
219-
let prevseq: number[] = [];
220-
for (let chunk_id = 0; chunk_id < this.chunks.length; chunk_id++) {
221-
let last_token = this.config.tokenizer.sos;
222-
let prev_seq_token_idx = 0;
223-
let final_seq: number[] = [];
224-
let seq = [last_token];
225-
let enc_output;
223+
let prevSeq: number[] = [];
224+
for (let chunkId = 0; chunkId < this.chunks.length; chunkId++) {
225+
let lastToken = this.config.tokenizer.bos;
226+
let prevSeqTokenIdx = 0;
227+
let finalSeq: number[] = [];
228+
let seq = [lastToken];
229+
let encoderOutput;
226230
try {
227-
enc_output = await this.nativeModule.encode(this.chunks!.at(chunk_id)!);
231+
encoderOutput = await this.nativeModule.encode(
232+
this.chunks!.at(chunkId)!
233+
);
228234
} catch (error) {
229235
this.onErrorCallback?.(`Encode ${error}`);
230236
return '';
231237
}
232-
while (last_token !== this.config.tokenizer.eos) {
238+
while (lastToken !== this.config.tokenizer.eos) {
233239
let output;
234240
try {
235-
output = await this.nativeModule.decode(seq, [enc_output]);
241+
output = await this.nativeModule.decode(seq, [encoderOutput]);
236242
} catch (error) {
237243
this.onErrorCallback?.(`Decode ${error}`);
238244
return '';
239245
}
240246
if (typeof output === 'number') {
241-
last_token = output;
247+
lastToken = output;
242248
} else {
243-
last_token = output[output.length - 1];
249+
lastToken = output[output.length - 1];
244250
}
245-
seq = [...seq, last_token];
251+
seq.push(lastToken);
246252
if (
247253
seqs.length > 0 &&
248254
seq.length < seqs.at(-1)!.length &&
249255
seq.length % 3 !== 0
250256
) {
251-
prevseq = [...prevseq, seqs.at(-1)![prev_seq_token_idx++]!];
252-
this.decodedTranscribeCallback(prevseq);
257+
prevSeq = [...prevSeq, seqs.at(-1)![prevSeqTokenIdx++]!];
258+
this.decodedTranscribeCallback(prevSeq);
253259
}
254260
}
255261
if (this.chunks.length === 1) {
256-
final_seq = seq;
257-
this.sequence = final_seq;
258-
this.decodedTranscribeCallback(final_seq);
262+
finalSeq = seq;
263+
this.sequence = finalSeq;
264+
this.decodedTranscribeCallback(finalSeq);
259265
break;
260266
}
261-
// remove sos/eos token and 3 additional ones
267+
// remove bos/eos token and 3 additional ones
262268
if (seqs.length === 0) {
263269
seqs = [seq.slice(0, -4)];
264270
} else if (
@@ -274,25 +280,25 @@ export class SpeechToTextController {
274280
}
275281

276282
const maxInd = longCommonInfPref(seqs.at(-2)!, seqs.at(-1)!);
277-
final_seq = [...this.sequence, ...seqs.at(-2)!.slice(0, maxInd)];
278-
this.sequence = final_seq;
279-
this.decodedTranscribeCallback(final_seq);
280-
prevseq = final_seq;
283+
finalSeq = [...this.sequence, ...seqs.at(-2)!.slice(0, maxInd)];
284+
this.sequence = finalSeq;
285+
this.decodedTranscribeCallback(finalSeq);
286+
prevSeq = finalSeq;
281287

282-
//last sequence processed
288+
// last sequence processed
283289
if (seqs.length === Math.ceil(waveform.length / this.windowSize)) {
284-
final_seq = [...this.sequence, ...seqs.at(-1)!];
285-
this.sequence = final_seq;
286-
this.decodedTranscribeCallback(final_seq);
287-
prevseq = final_seq;
290+
finalSeq = [...this.sequence, ...seqs.at(-1)!];
291+
this.sequence = finalSeq;
292+
this.decodedTranscribeCallback(finalSeq);
293+
prevSeq = finalSeq;
288294
}
289295
}
290296
const decodedSeq = this.decodeSeq(this.sequence);
291297
this.isGeneratingCallback(false);
292298
return decodedSeq;
293299
}
294300

295-
public decodeSeq(seq?: number[]): string {
301+
private decodeSeq(seq?: number[]): string {
296302
if (!this.modelName) {
297303
this.onErrorCallback?.(
298304
new Error('Model is not loaded, call `loadModel` first')
@@ -302,14 +308,22 @@ export class SpeechToTextController {
302308
this.onErrorCallback?.(undefined);
303309
if (!seq) seq = this.sequence;
304310

305-
return seq
311+
const decodedSeq = seq
306312
.filter(
307-
(token) =>
308-
token !== this.config.tokenizer.eos &&
309-
token !== this.config.tokenizer.sos
313+
(tokenId) =>
314+
tokenId !== this.config.tokenizer.eos &&
315+
tokenId !== this.config.tokenizer.bos
310316
)
311-
.map((token) => this.tokenMapping[token])
312-
.join('')
313-
.replaceAll(this.config.tokenizer.special_char, ' ');
317+
.map((tokenId) => this.tokenMapping[tokenId])
318+
.join('');
319+
320+
let byteArray = Array.from(decodedSeq).map(
321+
(char) => this.charDecoder[char]
322+
);
323+
const text = this.textDecoder.decode(
324+
new Uint8Array(byteArray as number[]),
325+
{ stream: false }
326+
);
327+
return text;
314328
}
315329
}

0 commit comments

Comments
 (0)