Skip to content

Commit b62201a

Browse files
chmjkbclaude
andcommitted
chore: migrate SpeechToTextModule to factory pattern, add SpeechToTextModelName type
- Add SpeechToTextModelName union type - Add modelName to SpeechToTextModelConfig - SpeechToTextModule: private constructor, fromModelName, fromCustomModel - useSpeechToText: use factory, add model.modelName to deps Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
1 parent 19e9a17 commit b62201a

3 files changed

Lines changed: 176 additions & 67 deletions

File tree

packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,41 +24,53 @@ export const useSpeechToText = ({
2424
const [isReady, setIsReady] = useState(false);
2525
const [isGenerating, setIsGenerating] = useState(false);
2626
const [downloadProgress, setDownloadProgress] = useState(0);
27-
28-
const [moduleInstance, _] = useState(() => new SpeechToTextModule());
27+
const [moduleInstance, setModuleInstance] =
28+
useState<SpeechToTextModule | null>(null);
2929

3030
useEffect(() => {
3131
if (preventLoad) return;
32-
let isMounted = true;
3332

34-
(async () => {
35-
setDownloadProgress(0);
36-
setError(null);
37-
try {
38-
setIsReady(false);
39-
await moduleInstance.load(
40-
{
41-
isMultilingual: model.isMultilingual,
42-
encoderSource: model.encoderSource,
43-
decoderSource: model.decoderSource,
44-
tokenizerSource: model.tokenizerSource,
45-
},
46-
(progress) => {
47-
if (isMounted) setDownloadProgress(progress);
48-
}
49-
);
50-
if (isMounted) setIsReady(true);
51-
} catch (err) {
52-
if (isMounted) setError(parseUnknownError(err));
33+
let active = true;
34+
setDownloadProgress(0);
35+
setError(null);
36+
setIsReady(false);
37+
38+
SpeechToTextModule.fromModelName(
39+
{
40+
modelName: model.modelName,
41+
isMultilingual: model.isMultilingual,
42+
encoderSource: model.encoderSource,
43+
decoderSource: model.decoderSource,
44+
tokenizerSource: model.tokenizerSource,
45+
},
46+
(p) => {
47+
if (active) setDownloadProgress(p);
5348
}
54-
})();
49+
)
50+
.then((mod) => {
51+
if (!active) {
52+
mod.delete();
53+
return;
54+
}
55+
setModuleInstance((prev) => {
56+
prev?.delete();
57+
return mod;
58+
});
59+
setIsReady(true);
60+
})
61+
.catch((err) => {
62+
if (active) setError(parseUnknownError(err));
63+
});
5564

5665
return () => {
57-
isMounted = false;
58-
moduleInstance.delete();
66+
active = false;
67+
setModuleInstance((prev) => {
68+
prev?.delete();
69+
return null;
70+
});
5971
};
6072
}, [
61-
moduleInstance,
73+
model.modelName,
6274
model.isMultilingual,
6375
model.encoderSource,
6476
model.decoderSource,
@@ -71,7 +83,7 @@ export const useSpeechToText = ({
7183
waveform: Float32Array,
7284
options: DecodingOptions = {}
7385
): Promise<TranscriptionResult> => {
74-
if (!isReady) {
86+
if (!isReady || !moduleInstance) {
7587
throw new RnExecutorchError(
7688
RnExecutorchErrorCode.ModuleNotLoaded,
7789
'The model is currently not loaded. Please load the model before calling this function.'
@@ -103,7 +115,7 @@ export const useSpeechToText = ({
103115
void,
104116
unknown
105117
> {
106-
if (!isReady) {
118+
if (!isReady || !moduleInstance) {
107119
throw new RnExecutorchError(
108120
RnExecutorchErrorCode.ModuleNotLoaded,
109121
'The model is currently not loaded. Please load the model before calling this function.'
@@ -131,17 +143,44 @@ export const useSpeechToText = ({
131143

132144
const streamInsert = useCallback(
133145
(waveform: Float32Array) => {
134-
if (!isReady) return;
146+
if (!isReady || !moduleInstance) return;
135147
moduleInstance.streamInsert(waveform);
136148
},
137149
[isReady, moduleInstance]
138150
);
139151

140152
const streamStop = useCallback(() => {
141-
if (!isReady) return;
153+
if (!isReady || !moduleInstance) return;
142154
moduleInstance.streamStop();
143155
}, [isReady, moduleInstance]);
144156

157+
const encode = useCallback(
158+
(waveform: Float32Array): Promise<Float32Array> => {
159+
if (!moduleInstance)
160+
throw new RnExecutorchError(
161+
RnExecutorchErrorCode.ModuleNotLoaded,
162+
'The model is currently not loaded. Please load the model before calling this function.'
163+
);
164+
return moduleInstance.encode(waveform);
165+
},
166+
[moduleInstance]
167+
);
168+
169+
const decode = useCallback(
170+
(
171+
tokens: Int32Array,
172+
encoderOutput: Float32Array
173+
): Promise<Float32Array> => {
174+
if (!moduleInstance)
175+
throw new RnExecutorchError(
176+
RnExecutorchErrorCode.ModuleNotLoaded,
177+
'The model is currently not loaded. Please load the model before calling this function.'
178+
);
179+
return moduleInstance.decode(tokens, encoderOutput);
180+
},
181+
[moduleInstance]
182+
);
183+
145184
return {
146185
error,
147186
isReady,
@@ -151,7 +190,7 @@ export const useSpeechToText = ({
151190
stream,
152191
streamInsert,
153192
streamStop,
154-
encode: moduleInstance.encode.bind(moduleInstance),
155-
decode: moduleInstance.decode.bind(moduleInstance),
193+
encode,
194+
decode,
156195
};
157196
};

packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts

Lines changed: 85 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import {
22
DecodingOptions,
33
SpeechToTextModelConfig,
4+
SpeechToTextModelName,
45
TranscriptionResult,
56
} from '../../types/stt';
67
import { ResourceFetcher } from '../../utils/ResourceFetcher';
8+
import { ResourceSource } from '../../types/common';
79
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
810
import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils';
911
import { Logger } from '../../common/Logger';
@@ -17,50 +19,98 @@ export class SpeechToTextModule {
1719
private nativeModule: any;
1820
private modelConfig!: SpeechToTextModelConfig;
1921

22+
private constructor() {}
23+
2024
/**
21-
* Loads the model specified by the config object.
22-
* `onDownloadProgressCallback` allows you to monitor the current progress of the model download.
25+
* Creates a Speech to Text instance for a built-in model.
26+
*
27+
* @param namedSources - Configuration object containing model name, sources, and multilingual flag.
28+
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
29+
* @returns A Promise resolving to a `SpeechToTextModule` instance.
2330
*
24-
* @param model - Configuration object containing model sources.
25-
* @param onDownloadProgressCallback - Optional callback to monitor download progress.
31+
* @example
32+
* ```ts
33+
* import { SpeechToTextModule, WHISPER_TINY_EN } from 'react-native-executorch';
34+
* const stt = await SpeechToTextModule.fromModelName(WHISPER_TINY_EN);
35+
* ```
2636
*/
27-
public async load(
28-
model: SpeechToTextModelConfig,
29-
onDownloadProgressCallback: (progress: number) => void = () => {}
30-
) {
37+
static async fromModelName(
38+
namedSources: SpeechToTextModelConfig,
39+
onDownloadProgress: (progress: number) => void = () => {}
40+
): Promise<SpeechToTextModule> {
41+
const instance = new SpeechToTextModule();
3142
try {
32-
this.modelConfig = model;
43+
await instance.internalLoad(namedSources, onDownloadProgress);
44+
return instance;
45+
} catch (error) {
46+
Logger.error('Load failed:', error);
47+
throw parseUnknownError(error);
48+
}
49+
}
3350

34-
const tokenizerLoadPromise = ResourceFetcher.fetch(
35-
undefined,
36-
model.tokenizerSource
37-
);
38-
const encoderDecoderPromise = ResourceFetcher.fetch(
39-
onDownloadProgressCallback,
40-
model.encoderSource,
41-
model.decoderSource
42-
);
43-
const [tokenizerSources, encoderDecoderResults] = await Promise.all([
44-
tokenizerLoadPromise,
45-
encoderDecoderPromise,
46-
]);
47-
const encoderSource = encoderDecoderResults?.[0];
48-
const decoderSource = encoderDecoderResults?.[1];
49-
if (!encoderSource || !decoderSource || !tokenizerSources) {
50-
throw new RnExecutorchError(
51-
RnExecutorchErrorCode.DownloadInterrupted,
52-
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
53-
);
54-
}
55-
this.nativeModule = await global.loadSpeechToText(
51+
/**
52+
* Creates a Speech to Text instance with user-provided model binaries.
53+
* Use this when working with a custom-exported STT model.
54+
* Internally uses `'custom'` as the model name for telemetry.
55+
*
56+
* @param encoderSource - A fetchable resource pointing to the encoder model binary.
57+
* @param decoderSource - A fetchable resource pointing to the decoder model binary.
58+
* @param tokenizerSource - A fetchable resource pointing to the tokenizer file.
59+
* @param isMultilingual - Whether the model supports multiple languages.
60+
* @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1.
61+
* @returns A Promise resolving to a `SpeechToTextModule` instance.
62+
*/
63+
static fromCustomModel(
64+
encoderSource: ResourceSource,
65+
decoderSource: ResourceSource,
66+
tokenizerSource: ResourceSource,
67+
isMultilingual: boolean,
68+
onDownloadProgress: (progress: number) => void = () => {}
69+
): Promise<SpeechToTextModule> {
70+
return SpeechToTextModule.fromModelName(
71+
{
72+
modelName: 'custom' as SpeechToTextModelName,
5673
encoderSource,
5774
decoderSource,
58-
tokenizerSources[0]!
75+
tokenizerSource,
76+
isMultilingual,
77+
},
78+
onDownloadProgress
79+
);
80+
}
81+
82+
private async internalLoad(
83+
model: SpeechToTextModelConfig,
84+
onDownloadProgressCallback: (progress: number) => void = () => {}
85+
) {
86+
this.modelConfig = model;
87+
88+
const tokenizerLoadPromise = ResourceFetcher.fetch(
89+
undefined,
90+
model.tokenizerSource
91+
);
92+
const encoderDecoderPromise = ResourceFetcher.fetch(
93+
onDownloadProgressCallback,
94+
model.encoderSource,
95+
model.decoderSource
96+
);
97+
const [tokenizerSources, encoderDecoderResults] = await Promise.all([
98+
tokenizerLoadPromise,
99+
encoderDecoderPromise,
100+
]);
101+
const encoderSource = encoderDecoderResults?.[0];
102+
const decoderSource = encoderDecoderResults?.[1];
103+
if (!encoderSource || !decoderSource || !tokenizerSources) {
104+
throw new RnExecutorchError(
105+
RnExecutorchErrorCode.DownloadInterrupted,
106+
'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.'
59107
);
60-
} catch (error) {
61-
Logger.error('Load failed:', error);
62-
throw parseUnknownError(error);
63108
}
109+
this.nativeModule = await global.loadSpeechToText(
110+
encoderSource,
111+
decoderSource,
112+
tokenizerSources[0]!
113+
);
64114
}
65115

66116
/**

packages/react-native-executorch/src/types/stt.ts

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,20 @@
11
import { ResourceSource } from './common';
22
import { RnExecutorchError } from '../errors/errorUtils';
33

4+
/**
5+
* Union of all built-in Speech-to-Text model names.
6+
*
7+
* @category Types
8+
*/
9+
export type SpeechToTextModelName =
10+
| 'whisper-tiny-en'
11+
| 'whisper-tiny-en-quantized'
12+
| 'whisper-base-en'
13+
| 'whisper-small-en'
14+
| 'whisper-tiny'
15+
| 'whisper-base'
16+
| 'whisper-small';
17+
418
/**
519
* Configuration for Speech to Text model.
620
*
@@ -261,6 +275,12 @@ export interface TranscriptionResult {
261275
* @category Types
262276
*/
263277
export interface SpeechToTextModelConfig {
278+
/**
279+
* The built-in model name (e.g. `'whisper-tiny-en'`). Used for telemetry and hook reload triggers.
280+
* Pass one of the pre-built STT constants (e.g. `WHISPER_TINY_EN`) to populate all required fields.
281+
*/
282+
modelName: SpeechToTextModelName;
283+
264284
/**
265285
* A boolean flag indicating whether the model supports multiple languages.
266286
*/

0 commit comments

Comments
 (0)