diff --git a/.cspell-wordlist.txt b/.cspell-wordlist.txt index b9179fa09c..7746a87919 100644 --- a/.cspell-wordlist.txt +++ b/.cspell-wordlist.txt @@ -128,4 +128,15 @@ detr metaprogramming ktlint lefthook -espeak \ No newline at end of file +espeak +NCHW +həlˈO +wˈɜɹld +mˈæn +dˈʌzᵊnt +tɹˈʌst +hɪmsˈɛlf +nˈɛvəɹ +ɹˈiᵊli +ˈɛniwˌʌn +ˈɛls diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md index 5394f9a3d7..280509f7ac 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/LLMModule.md @@ -15,14 +15,13 @@ TypeScript API implementation of the [useLLM](../../03-hooks/01-natural-language ```typescript import { LLMModule, LLAMA3_2_1B_QLORA } from 'react-native-executorch'; -// Creating an instance -const llm = new LLMModule({ - tokenCallback: (token) => console.log(token), - messageHistoryCallback: (messages) => console.log(messages), -}); - -// Loading the model -await llm.load(LLAMA3_2_1B_QLORA, (progress) => console.log(progress)); +// Creating an instance and loading the model +const llm = await LLMModule.fromModelName( + LLAMA3_2_1B_QLORA, + (progress) => console.log(progress), + (token) => console.log(token), + (messages) => console.log(messages), +); // Running the model - returns the generated response const response = await llm.sendMessage('Hello, World!'); @@ -41,30 +40,26 @@ All methods of `LLMModule` are explained in details here: [LLMModule API Referen ## Loading the model -To create a new instance of `LLMModule`, use the [constructor](../../06-api-reference/classes/LLMModule.md#constructor) with optional callbacks: - -- [`tokenCallback`](../../06-api-reference/classes/LLMModule.md#tokencallback) - Function called on every generated token. - -- [`messageHistoryCallback`](../../06-api-reference/classes/LLMModule.md#messagehistorycallback) - Function called on every finished message. - -Then, to load the model, use the [`load`](../../06-api-reference/classes/LLMModule.md#load) method. It accepts an object with the following fields: +Use the static [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname) factory method: -- [`model`](../../06-api-reference/classes/LLMModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/LLMModule.md#modelsource) - The location of the used model. - - - [`tokenizerSource`](../../06-api-reference/classes/LLMModule.md#tokenizersource) - The location of the used tokenizer. - - - [`tokenizerConfigSource`](../../06-api-reference/classes/LLMModule.md#tokenizerconfigsource) - The location of the used tokenizer config. +```typescript +const llm = await LLMModule.fromModelName( + LLAMA3_2_3B, // model config constant + onDownloadProgress, // optional, progress 0–1 + tokenCallback, // optional, called on every token + messageHistoryCallback // optional, called when generation finishes +); +``` -- [`onDownloadProgressCallback`](../../06-api-reference/classes/LLMModule.md#ondownloadprogresscallback) - Callback to track download progress. +The model config object contains `modelSource`, `tokenizerSource`, `tokenizerConfigSource`, and optional `capabilities`. Pass one of the built-in constants (e.g. `LLAMA3_2_3B`) or construct it manually. -This method returns a promise, which can resolve to an error or void. +This method returns a promise resolving to an `LLMModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Listening for download progress -To subscribe to the download progress event, you can pass the [`onDownloadProgressCallback`](../../06-api-reference/classes/LLMModule.md#ondownloadprogresscallback) function to the [`load`](../../06-api-reference/classes/LLMModule.md#load) method. This function is called whenever the download progress changes. +To subscribe to the download progress event, you can pass the `onDownloadProgress` callback as the second argument to [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname). This function is called whenever the download progress changes. ## Running the model @@ -116,25 +111,26 @@ To configure model (i.e. change system prompt, load initial conversation history ## Vision-Language Models (VLM) -Some models support multimodal input — text and images together. To use them, pass `capabilities` in the model object when calling [`load`](../../06-api-reference/classes/LLMModule.md#load): +Some models support multimodal input — text and images together. To use them, pass `capabilities` in the model object when calling [`fromModelName`](../../06-api-reference/classes/LLMModule.md#frommodelname): ```typescript import { LLMModule, LFM2_VL_1_6B_QUANTIZED } from 'react-native-executorch'; -const llm = new LLMModule({ - tokenCallback: (token) => console.log(token), -}); - -await llm.load(LFM2_VL_1_6B_QUANTIZED); +const llm = await LLMModule.fromModelName( + LFM2_VL_1_6B_QUANTIZED, + undefined, + (token) => console.log(token) +); ``` The `capabilities` field is already set on the model constant. You can also construct the model object explicitly: ```typescript -await llm.load({ - modelSource: '...', - tokenizerSource: '...', - tokenizerConfigSource: '...', +const llm = await LLMModule.fromModelName({ + modelName: 'lfm2.5-vl-1.6b-quantized', + modelSource: require('./path/to/model.pte'), + tokenizerSource: require('./path/to/tokenizer.json'), + tokenizerConfigSource: require('./path/to/tokenizer_config.json'), capabilities: ['vision'], }); ``` @@ -161,6 +157,27 @@ const chat: Message[] = [ const response = await llm.generate(chat); ``` +## Using a custom model + +Use [`fromCustomModel`](../../06-api-reference/classes/LLMModule.md#fromcustommodel) to load your own exported LLM instead of a built-in preset: + +```typescript +import { LLMModule } from 'react-native-executorch'; + +const llm = await LLMModule.fromCustomModel( + 'https://example.com/model.pte', + 'https://example.com/tokenizer.json', + 'https://example.com/tokenizer_config.json', + (progress) => console.log(progress), + (token) => console.log(token), + (messages) => console.log(messages) +); +``` + +### Required model contract + +The `.pte` model binary must be exported following the [ExecuTorch LLM export process](https://docs.pytorch.org/executorch/1.1/llm/export-llm.html). The native runner expects the standard ExecuTorch text-generation interface — KV-cache management, prefill/decode phases, and logit sampling are all handled by the runtime. + ## Deleting the model from memory To delete the model from memory, you can use the [`delete`](../../06-api-reference/classes/LLMModule.md#delete) method. diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md index 5d2351e66e..d4d8897e7c 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/SpeechToTextModule.md @@ -14,10 +14,12 @@ TypeScript API implementation of the [useSpeechToText](../../03-hooks/01-natural ```typescript import { SpeechToTextModule, WHISPER_TINY_EN } from 'react-native-executorch'; -const model = new SpeechToTextModule(); -await model.load(WHISPER_TINY_EN, (progress) => { - console.log(progress); -}); +const model = await SpeechToTextModule.fromModelName( + WHISPER_TINY_EN, + (progress) => { + console.log(progress); + } +); // Standard transcription (returns string) const text = await model.transcribe(waveform); @@ -40,18 +42,17 @@ All methods of `SpeechToTextModule` are explained in details here: [`SpeechToTex ## Loading the model -Create an instance of [`SpeechToTextModule`](../../06-api-reference/classes/SpeechToTextModule.md) and use the [`load`](../../06-api-reference/classes/SpeechToTextModule.md#load) method. It accepts an object with the following fields: - -- [`model`](../../06-api-reference/classes/SpeechToTextModule.md#model) - Object containing: - - [`isMultilingual`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual) - Flag indicating if model is multilingual. +Use the static [`fromModelName`](../../06-api-reference/classes/SpeechToTextModule.md#frommodelname) factory method. It accepts an object with the following fields: - - [`modelSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#modelsource) - The location of the used model (bundled encoder + decoder functionality). +- [`isMultilingual`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#ismultilingual) - Flag indicating if model is multilingual. +- [`modelSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#modelsource) - The location of the used model (bundled encoder + decoder functionality). +- [`tokenizerSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource) - The location of the used tokenizer. - - [`tokenizerSource`](../../06-api-reference/interfaces/SpeechToTextModelConfig.md#tokenizersource) - The location of the used tokenizer. +And an optional second argument: -- [`onDownloadProgressCallback`](../../06-api-reference/classes/SpeechToTextModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `onDownloadProgress` - Callback to track download progress. -This method returns a promise, which can resolve to an error or void. +This method returns a promise resolving to a `SpeechToTextModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. @@ -66,10 +67,12 @@ If you aim to obtain a transcription in other languages than English, use the mu ```typescript import { SpeechToTextModule, WHISPER_TINY } from 'react-native-executorch'; -const model = new SpeechToTextModule(); -await model.load(WHISPER_TINY, (progress) => { - console.log(progress); -}); +const model = await SpeechToTextModule.fromModelName( + WHISPER_TINY, + (progress) => { + console.log(progress); + } +); const transcription = await model.transcribe(spanishAudio, { language: 'es' }); ``` @@ -121,10 +124,12 @@ import * as FileSystem from 'expo-file-system'; const transcribeAudio = async () => { // Initialize with the model config - const model = new SpeechToTextModule(); - await model.load(WHISPER_TINY_EN, (progress) => { - console.log(progress); - }); + const model = await SpeechToTextModule.fromModelName( + WHISPER_TINY_EN, + (progress) => { + console.log(progress); + } + ); // Download the audio file const { uri } = await FileSystem.downloadAsync( @@ -163,10 +168,12 @@ import { SpeechToTextModule, WHISPER_TINY_EN } from 'react-native-executorch'; import { AudioManager, AudioRecorder } from 'react-native-audio-api'; // Load the model -const model = new SpeechToTextModule(); -await model.load(WHISPER_TINY_EN, (progress) => { - console.log(progress); -}); +const model = await SpeechToTextModule.fromModelName( + WHISPER_TINY_EN, + (progress) => { + console.log(progress); + } +); // Configure audio session AudioManager.setAudioSessionOptions({ diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md index 456a62c305..054299290b 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/TextEmbeddingsModule.md @@ -17,11 +17,9 @@ import { ALL_MINILM_L6_V2, } from 'react-native-executorch'; -// Creating an instance -const textEmbeddingsModule = new TextEmbeddingsModule(); - -// Loading the model -await textEmbeddingsModule.load(ALL_MINILM_L6_V2); +// Creating an instance and loading the model +const textEmbeddingsModule = + await TextEmbeddingsModule.fromModelName(ALL_MINILM_L6_V2); // Running the model const embedding = await textEmbeddingsModule.forward('Hello World!'); @@ -33,15 +31,12 @@ All methods of `TextEmbeddingsModule` are explained in details here: [`TextEmbed ## Loading the model -To load the model, use the [`load`](../../06-api-reference/classes/TextEmbeddingsModule.md#load) method. It accepts an object: - -- [`model`](../../06-api-reference/classes/TextEmbeddingsModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#modelsource) - Location of the used model. - - [`tokenizerSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#tokenizersource) - Location of the used tokenizer. +Use the static [`fromModelName`](../../06-api-reference/classes/TextEmbeddingsModule.md#frommodelname) factory method. It accepts a model config object (e.g. `ALL_MINILM_L6_V2`) containing: -- [`onDownloadProgressCallback`](../../06-api-reference/classes/TextEmbeddingsModule.md#ondownloadprogresscallback) - Callback to track download progress. +- [`modelSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#modelsource) - Location of the used model. +- [`tokenizerSource`](../../06-api-reference/classes/TextEmbeddingsModule.md#tokenizersource) - Location of the used tokenizer. -This method returns a promise, which can resolve to an error or void. +And an optional `onDownloadProgress` callback. It returns a promise resolving to a `TextEmbeddingsModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/TextToSpeechModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/TextToSpeechModule.md index 53bde1685e..00fd04f53b 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/TextToSpeechModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/TextToSpeechModule.md @@ -19,15 +19,9 @@ import { KOKORO_VOICE_AF_HEART, } from 'react-native-executorch'; -const model = new TextToSpeechModule(); -await model.load( - { - model: KOKORO_MEDIUM, - voice: KOKORO_VOICE_AF_HEART, - }, - (progress) => { - console.log(progress); - } +const model = await TextToSpeechModule.fromModelName( + { model: KOKORO_MEDIUM, voice: KOKORO_VOICE_AF_HEART }, + (progress) => console.log(progress) ); await model.forward(text, 1.0); @@ -39,15 +33,15 @@ All methods of `TextToSpeechModule` are explained in details here: [`TextToSpeec ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/TextToSpeechModule.md#load) method with the following parameters: +Use the static [`fromModelName`](../../06-api-reference/classes/TextToSpeechModule.md#frommodelname) factory method with the following parameters: -- [`config`](../../06-api-reference/classes/TextToSpeechModule.md#config) - Object containing: - - [`model`](../../06-api-reference/interfaces/TextToSpeechConfig.md#model) - Model configuration. - - [`voice`](../../06-api-reference/interfaces/TextToSpeechConfig.md#voice) - Voice configuration. +- [`config`](../../06-api-reference/interfaces/TextToSpeechConfig.md) - Object containing: + - [`model`](../../06-api-reference/interfaces/TextToSpeechConfig.md#model) - Model configuration (e.g. `KOKORO_MEDIUM`). + - [`voice`](../../06-api-reference/interfaces/TextToSpeechConfig.md#voice) - Voice configuration (e.g. `KOKORO_VOICE_AF_HEART`). -- [`onDownloadProgressCallback`](../../06-api-reference/classes/TextToSpeechModule.md#ondownloadprogresscallback) - Callback to track download progress. +- [`onDownloadProgress`](../../06-api-reference/classes/TextToSpeechModule.md#frommodelname) - Optional callback to track download progress (value between 0 and 1). -This method returns a promise that resolves once the assets are downloaded and loaded into memory. +This method returns a promise that resolves to a `TextToSpeechModule` instance once the assets are downloaded and loaded into memory. For more information on resource sources, see [loading models](../../01-fundamentals/02-loading-models.md). @@ -83,15 +77,13 @@ import { } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; -const tts = new TextToSpeechModule(); +const tts = await TextToSpeechModule.fromModelName({ + model: KOKORO_MEDIUM, + voice: KOKORO_VOICE_AF_HEART, +}); const audioContext = new AudioContext({ sampleRate: 24000 }); try { - await tts.load({ - model: KOKORO_MEDIUM, - voice: KOKORO_VOICE_AF_HEART, - }); - const waveform = await tts.forward('Hello from ExecuTorch!', 1.0); // Create audio buffer and play @@ -117,11 +109,12 @@ import { } from 'react-native-executorch'; import { AudioContext } from 'react-native-audio-api'; -const tts = new TextToSpeechModule(); +const tts = await TextToSpeechModule.fromModelName({ + model: KOKORO_MEDIUM, + voice: KOKORO_VOICE_AF_HEART, +}); const audioContext = new AudioContext({ sampleRate: 24000 }); -await tts.load({ model: KOKORO_MEDIUM, voice: KOKORO_VOICE_AF_HEART }); - try { for await (const chunk of tts.stream({ text: 'This is a streaming test, with a sample input.', @@ -155,9 +148,7 @@ import { KOKORO_VOICE_AF_HEART, } from 'react-native-executorch'; -const tts = new TextToSpeechModule(); - -await tts.load({ +const tts = await TextToSpeechModule.fromModelName({ model: KOKORO_MEDIUM, voice: KOKORO_VOICE_AF_HEART, }); diff --git a/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md b/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md index aa3d14455d..d32a7c3920 100644 --- a/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md +++ b/docs/docs/04-typescript-api/01-natural-language-processing/VADModule.md @@ -14,10 +14,9 @@ TypeScript API implementation of the [useVAD](../../03-hooks/01-natural-language ```typescript import { VADModule, FSMN_VAD } from 'react-native-executorch'; -const model = new VADModule(); -await model.load(FSMN_VAD, (progress) => { - console.log(progress); -}); +const model = await VADModule.fromModelName(FSMN_VAD, (progress) => + console.log(progress) +); await model.forward(waveform); ``` @@ -28,14 +27,15 @@ All methods of `VADModule` are explained in details here: [`VADModule` API Refer ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/VADModule.md#load) method with the following parameters: +To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/VADModule.md#frommodelname) factory with the following parameters: -- [`model`](../../06-api-reference/classes/VADModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/VADModule.md#modelsource) - Location of the used model. +- `namedSources` - Object containing: + - `modelName` - Model name identifier. + - `modelSource` - Location of the model binary. -- [`onDownloadProgressCallback`](../../06-api-reference/classes/VADModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1). -This method returns a promise, which can resolve to an error or void. +The factory returns a promise that resolves to a loaded `VADModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md index df94656e78..4234fa865d 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md @@ -19,11 +19,9 @@ import { const imageUri = 'path/to/image.png'; -// Creating an instance -const classificationModule = new ClassificationModule(); - -// Loading the model -await classificationModule.load(EFFICIENTNET_V2_S); +// Creating and loading the module +const classificationModule = + await ClassificationModule.fromModelName(EFFICIENTNET_V2_S); // Running the model const classesWithProbabilities = await classificationModule.forward(imageUri); @@ -35,14 +33,15 @@ All methods of `ClassificationModule` are explained in details here: [`Classific ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/ClassificationModule.md#load) method with the following parameters: +To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory with the following parameters: -- [`model`](../../06-api-reference/classes/ClassificationModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/ClassificationModule.md#modelsource) - Location of the used model. +- `namedSources` - Object containing: + - `modelName` - Model name identifier. + - `modelSource` - Location of the model binary. -- [`onDownloadProgressCallback`](../../06-api-reference/classes/ClassificationModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1). -This method returns a promise, which can resolve to an error or void. +The factory returns a promise that resolves to a loaded `ClassificationModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md index 8c66917544..7388416334 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ImageEmbeddingsModule.md @@ -17,11 +17,10 @@ import { CLIP_VIT_BASE_PATCH32_IMAGE, } from 'react-native-executorch'; -// Creating an instance -const imageEmbeddingsModule = new ImageEmbeddingsModule(); - -// Loading the model -await imageEmbeddingsModule.load(CLIP_VIT_BASE_PATCH32_IMAGE); +// Creating and loading the module +const imageEmbeddingsModule = await ImageEmbeddingsModule.fromModelName( + CLIP_VIT_BASE_PATCH32_IMAGE +); // Running the model const embedding = await imageEmbeddingsModule.forward( @@ -35,14 +34,15 @@ All methods of `ImageEmbeddingsModule` are explained in details here: [`ImageEmb ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/ImageEmbeddingsModule.md#load) method with the following parameters: +To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/ImageEmbeddingsModule.md#frommodelname) factory with the following parameters: -- [`model`](../../06-api-reference/classes/ImageEmbeddingsModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/ImageEmbeddingsModule.md#modelsource) - Location of the used model. +- `namedSources` - Object containing: + - `modelName` - Model name identifier. + - `modelSource` - Location of the model binary. -- [`onDownloadProgressCallback`](../../06-api-reference/classes/ImageEmbeddingsModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1). -This method returns a promise, which can resolve to an error or void. +The factory returns a promise that resolves to a loaded `ImageEmbeddingsModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md index cfcc14a054..1524859ed2 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/OCRModule.md @@ -15,11 +15,8 @@ TypeScript API implementation of the [useOCR](../../03-hooks/02-computer-vision/ import { OCRModule, OCR_ENGLISH } from 'react-native-executorch'; const imageUri = 'path/to/image.png'; -// Creating an instance -const ocrModule = new OCRModule(); - -// Loading the model -await ocrModule.load(OCR_ENGLISH); +// Creating an instance and loading the model +const ocrModule = await OCRModule.fromModelName(OCR_ENGLISH); // Running the model const detections = await ocrModule.forward(imageUri); @@ -31,16 +28,14 @@ All methods of `OCRModule` are explained in details here: [`OCRModule` API Refer ## Loading the model -To load the model, use the [`load`](../../06-api-reference/classes/OCRModule.md#load) method. It accepts an object: - -- [`model`](../../06-api-reference/classes/OCRModule.md#model) - Object containing: - - [`detectorSource`](../../06-api-reference/classes/OCRModule.md#detectorsource) - Location of the used detector. - - [`recognizerSource`](../../06-api-reference/classes/OCRModule.md#recognizersource) - Location of the used recognizer. - - [`language`](../../06-api-reference/classes/OCRModule.md#recognizersource) - Language used in OCR. +Use the static [`fromModelName`](../../06-api-reference/classes/OCRModule.md#frommodelname) factory method. It accepts a `namedSources` object (e.g. `OCR_ENGLISH`) containing: -- [`onDownloadProgressCallback`](../../06-api-reference/classes/OCRModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `modelName` - Model name identifier. +- [`detectorSource`](../../06-api-reference/classes/OCRModule.md#detectorsource) - Location of the used detector. +- [`recognizerSource`](../../06-api-reference/classes/OCRModule.md#recognizersource) - Location of the used recognizer. +- [`language`](../../06-api-reference/classes/OCRModule.md#recognizersource) - Language used in OCR. -This method returns a promise, which can resolve to an error or void. +And an optional `onDownloadProgress` callback. It returns a promise resolving to an `OCRModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md index 1fa95b1ba6..d942eded65 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/ObjectDetectionModule.md @@ -19,11 +19,10 @@ import { const imageUri = 'path/to/image.png'; -// Creating an instance -const objectDetectionModule = new ObjectDetectionModule(); - -// Loading the model -await objectDetectionModule.load(SSDLITE_320_MOBILENET_V3_LARGE); +// Creating an instance and loading the model +const objectDetectionModule = await ObjectDetectionModule.fromModelName( + SSDLITE_320_MOBILENET_V3_LARGE +); // Running the model const detections = await objectDetectionModule.forward(imageUri); @@ -35,14 +34,7 @@ All methods of `ObjectDetectionModule` are explained in details here: [`ObjectDe ## Loading the model -To initialize the module, create an instance and call the [`load`](../../06-api-reference/classes/ObjectDetectionModule.md#load) method with the following parameters: - -- [`model`](../../06-api-reference/classes/ObjectDetectionModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/ObjectDetectionModule.md#modelsource) - Location of the used model. - -- [`onDownloadProgressCallback`](../../06-api-reference/classes/ObjectDetectionModule.md#ondownloadprogresscallback) - Callback to track download progress. - -This method returns a promise, which can resolve to an error or void. +Use the static [`fromModelName`](../../06-api-reference/classes/ObjectDetectionModule.md#frommodelname) factory method. It accepts a model config object (e.g. `SSDLITE_320_MOBILENET_V3_LARGE`) and an optional `onDownloadProgress` callback. It returns a promise resolving to an `ObjectDetectionModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. @@ -50,6 +42,36 @@ For more information on loading resources, take a look at [loading models](../.. To run the model, you can use the [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) method on the module object. It accepts one argument, which is the image. The image can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64). The method returns a promise, which can resolve either to an error or an array of [`Detection`](../../06-api-reference/interfaces/Detection.md) objects. Each object contains coordinates of the bounding box, the label of the detected object, and the confidence score. +## Using a custom model + +Use [`fromCustomModel`](../../06-api-reference/classes/ObjectDetectionModule.md#fromcustommodel) to load your own exported model binary instead of a built-in preset. + +```typescript +import { ObjectDetectionModule } from 'react-native-executorch'; + +const MyLabels = { BACKGROUND: 0, CAT: 1, DOG: 2 } as const; + +const detector = await ObjectDetectionModule.fromCustomModel( + 'https://example.com/custom_detector.pte', + { labelMap: MyLabels }, + (progress) => console.log(progress) +); +``` + +### Required model contract + +The `.pte` binary must expose a single `forward` method with the following interface: + +**Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. H and W are read from the model's declared input shape at load time. + +**Outputs:** exactly three `float32` tensors, in this order: + +1. **Bounding boxes** — flat `[4·N]` array of `(x1, y1, x2, y2)` coordinates in model-input pixel space. +2. **Confidence scores** — flat `[N]` array of values in `[0, 1]`. +3. **Class indices** — flat `[N]` array of `float32`-encoded integer class indices (0-based, matching the order of entries in your `labelMap`). + +Preprocessing (resize → normalize) and postprocessing (coordinate rescaling, threshold filtering, NMS) are handled by the native runtime. + ## Managing memory The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ObjectDetectionModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ObjectDetectionModule.md#forward) after [`delete`](../../06-api-reference/classes/ObjectDetectionModule.md#delete) unless you load the module again. diff --git a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md index 7ba0182ee8..bf88690bdb 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/SemanticSegmentationModule.md @@ -20,10 +20,8 @@ import { const imageUri = 'path/to/image.png'; // Creating an instance from a built-in model -const segmentation = await SemanticSegmentationModule.fromModelName({ - modelName: 'deeplab-v3', - modelSource: DEEPLAB_V3_RESNET50, -}); +const segmentation = + await SemanticSegmentationModule.fromModelName(DEEPLAB_V3_RESNET50); // Running the model const result = await segmentation.forward(imageUri); @@ -51,14 +49,14 @@ const segmentation = await SemanticSegmentationModule.fromModelName( The `config` parameter is a discriminated union — TypeScript ensures you provide the correct fields for each model name. Available built-in models: `'deeplab-v3-resnet50'`, `'deeplab-v3-resnet50-quantized'`, `'deeplab-v3-resnet101'`, `'deeplab-v3-resnet101-quantized'`, `'deeplab-v3-mobilenet-v3-large'`, `'deeplab-v3-mobilenet-v3-large-quantized'`, `'lraspp-mobilenet-v3-large'`, `'lraspp-mobilenet-v3-large-quantized'`, `'fcn-resnet50'`, `'fcn-resnet50-quantized'`, `'fcn-resnet101'`, `'fcn-resnet101-quantized'`, and `'selfie-segmentation'`. -### Custom models — `fromCustomConfig` +### Custom models — `fromCustomModel` -Use [`fromCustomConfig`](../../06-api-reference/classes/SemanticSegmentationModule.md#fromcustomconfig) for custom-exported segmentation models with your own label map: +Use [`fromCustomModel`](../../06-api-reference/classes/SemanticSegmentationModule.md#fromcustommodel) for custom-exported segmentation models with your own label map: ```typescript const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; -const segmentation = await SemanticSegmentationModule.fromCustomConfig( +const segmentation = await SemanticSegmentationModule.fromCustomModel( 'https://example.com/custom_model.pte', { labelMap: MyLabels, @@ -72,6 +70,16 @@ const segmentation = await SemanticSegmentationModule.fromCustomConfig( The `preprocessorConfig` is optional. If omitted, no input normalization is applied. The module instance will be typed to your custom label map — `forward` will accept and return keys from `MyLabels`. +### Required model contract + +The `.pte` binary must expose a single `forward` method with the following interface: + +**Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. H and W are read from the model's declared input shape at load time. + +**Output:** one `float32` tensor of shape `[1, C, H_out, W_out]` (NCHW) containing raw logits — one channel per class, in the same order as the entries in your `labelMap`. For binary segmentation a single-channel output is also supported: channel 0 is treated as the foreground probability and a synthetic background channel is added automatically. + +Preprocessing (resize → normalize) and postprocessing (softmax, argmax, resize back to original dimensions) are handled by the native runtime. + For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. ## Running the model diff --git a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md index 3f26a44bb7..0e70a24796 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/StyleTransferModule.md @@ -19,11 +19,9 @@ import { const imageUri = 'path/to/image.png'; -// Creating an instance -const styleTransferModule = new StyleTransferModule(); - -// Loading the model -await styleTransferModule.load(STYLE_TRANSFER_CANDY); +// Creating and loading the module +const styleTransferModule = + await StyleTransferModule.fromModelName(STYLE_TRANSFER_CANDY); // Running the model const generatedImageUrl = await styleTransferModule.forward(imageUri); @@ -35,14 +33,15 @@ All methods of `StyleTransferModule` are explained in details here: [`StyleTrans ## Loading the model -To load the model, create a new instance of the module and use the [`load`](../../06-api-reference/classes/StyleTransferModule.md#load) method on it. It accepts an object: +To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/StyleTransferModule.md#frommodelname) factory with the following parameters: -- [`model`](../../06-api-reference/classes/StyleTransferModule.md#model) - Object containing: - - [`modelSource`](../../06-api-reference/classes/StyleTransferModule.md#modelsource) - Location of the used model. +- `namedSources` - Object containing: + - `modelName` - Model name identifier. + - `modelSource` - Location of the model binary. -- [`onDownloadProgressCallback`](../../06-api-reference/classes/StyleTransferModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1). -This method returns a promise, which can resolve to an error or void. +The factory returns a promise that resolves to a loaded `StyleTransferModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/04-typescript-api/02-computer-vision/TextToImageModule.md b/docs/docs/04-typescript-api/02-computer-vision/TextToImageModule.md index 474d48bc83..d6ade747bd 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/TextToImageModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/TextToImageModule.md @@ -19,11 +19,10 @@ import { const input = 'a castle'; -// Creating an instance -const textToImageModule = new TextToImageModule(); - -// Loading the model -await textToImageModule.load(BK_SDM_TINY_VPRED_256); +// Creating an instance and loading the model +const textToImageModule = await TextToImageModule.fromModelName( + BK_SDM_TINY_VPRED_256 +); // Running the model const image = await textToImageModule.forward(input); @@ -35,22 +34,16 @@ All methods of `TextToImageModule` are explained in details here: [`TextToImageM ## Loading the model -To load the model, use the [`load`](../../06-api-reference/classes/TextToImageModule.md#load) method. It accepts an object: - -- [`model`](../../06-api-reference/classes/TextToImageModule.md#model) - Object containing: - - [`schedulerSource`](../../06-api-reference/classes/TextToImageModule.md#schedulersource) - Location of the used scheduler. - - - [`tokenizerSource`](../../06-api-reference/classes/TextToImageModule.md#tokenizersource) - Location of the used tokenizer. - - - [`encoderSource`](../../06-api-reference/classes/TextToImageModule.md#encodersource) - Location of the used encoder. - - - [`unetSource`](../../06-api-reference/classes/TextToImageModule.md#unetsource) - Location of the used unet. - - - [`decoderSource`](../../06-api-reference/classes/TextToImageModule.md#decodersource) - Location of the used decoder. +Use the static [`fromModelName`](../../06-api-reference/classes/TextToImageModule.md#frommodelname) factory method. It accepts a model config object (e.g. `BK_SDM_TINY_VPRED_256`) containing: -- [`onDownloadProgressCallback`](../../06-api-reference/classes/TextToImageModule.md#ondownloadprogresscallback) - Callback to track download progress. +- [`schedulerSource`](../../06-api-reference/classes/TextToImageModule.md#schedulersource) - Location of the used scheduler. +- [`tokenizerSource`](../../06-api-reference/classes/TextToImageModule.md#tokenizersource) - Location of the used tokenizer. +- [`encoderSource`](../../06-api-reference/classes/TextToImageModule.md#encodersource) - Location of the used encoder. +- [`unetSource`](../../06-api-reference/classes/TextToImageModule.md#unetsource) - Location of the used unet. +- [`decoderSource`](../../06-api-reference/classes/TextToImageModule.md#decodersource) - Location of the used decoder. +- [`inferenceCallback`](../../06-api-reference/classes/TextToImageModule.md#inferencecallback) - Optional callback invoked at each denoising step. -This method returns a promise, which can resolve to an error or void. +And an optional `onDownloadProgress` callback. It returns a promise resolving to a `TextToImageModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. @@ -64,7 +57,7 @@ The seed value should be a positive integer. ## Listening for inference steps -To monitor the progress of image generation, you can pass an [`inferenceCallback`](../../06-api-reference/classes/TextToImageModule.md#inferencecallback) function to the [constructor](../../06-api-reference/classes/TextToImageModule.md#constructor). The callback is invoked at each denoising step (for a total of `numSteps + 1` times), yielding the current step index that can be used, for example, to display a progress bar. +To monitor the progress of image generation, you can pass an [`inferenceCallback`](../../06-api-reference/classes/TextToImageModule.md#inferencecallback) in the model config object passed to `fromModelName`. The callback is invoked at each denoising step (for a total of `numSteps + 1` times), yielding the current step index that can be used, for example, to display a progress bar. ## Deleting the model from memory diff --git a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md index 94c2225187..eb47efa857 100644 --- a/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md +++ b/docs/docs/04-typescript-api/02-computer-vision/VerticalOCRModule.md @@ -16,11 +16,8 @@ import { VerticalOCRModule, OCR_ENGLISH } from 'react-native-executorch'; const imageUri = 'path/to/image.png'; -// Creating an instance -const verticalOCRModule = new VerticalOCRModule(); - -// Loading the model -await verticalOCRModule.load(OCR_ENGLISH); +// Creating an instance and loading the model +const verticalOCRModule = await VerticalOCRModule.fromModelName(OCR_ENGLISH); // Running the model const detections = await verticalOCRModule.forward(imageUri); @@ -32,18 +29,15 @@ All methods of `VerticalOCRModule` are explained in details here: [`VerticalOCRM ## Loading the model -To load the model, use the [`load`](../../06-api-reference/classes/VerticalOCRModule.md#load) method. It accepts an object: - -- [`model`](../../06-api-reference/classes/VerticalOCRModule.md#model) - Object containing: - - [`detectorSource`](../../06-api-reference/classes/VerticalOCRModule.md#detectorsource) - Location of the used detector. - - [`recognizerSource`](../../06-api-reference/classes/VerticalOCRModule.md#recognizersource) - Location of the used recognizer. - - [`language`](../../06-api-reference/classes/VerticalOCRModule.md#recognizersource) - Language used in OCR. - -- [`independentCharacters`](../../06-api-reference/classes/VerticalOCRModule.md#independentcharacters) - Flag indicating to either treat characters as independent or not. +Use the static [`fromModelName`](../../06-api-reference/classes/VerticalOCRModule.md#frommodelname) factory method. It accepts a `namedSources` object (e.g. `{ ...OCR_ENGLISH, independentCharacters: true }`) containing: -- [`onDownloadProgressCallback`](../../06-api-reference/classes/VerticalOCRModule.md#ondownloadprogresscallback) - Callback to track download progress. +- `modelName` - Model name identifier. +- [`detectorSource`](../../06-api-reference/classes/VerticalOCRModule.md#detectorsource) - Location of the used detector. +- [`recognizerSource`](../../06-api-reference/classes/VerticalOCRModule.md#recognizersource) - Location of the used recognizer. +- [`language`](../../06-api-reference/classes/VerticalOCRModule.md#recognizersource) - Language used in OCR. +- [`independentCharacters`](../../06-api-reference/classes/VerticalOCRModule.md#independentcharacters) - Flag indicating whether to treat characters as independent. -This method returns a promise, which can resolve to an error or void. +And an optional `onDownloadProgress` callback. It returns a promise resolving to a `VerticalOCRModule` instance. For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page. diff --git a/docs/docs/05-utilities/04-error-handling.md b/docs/docs/05-utilities/04-error-handling.md index 4b3f0674a0..331e57d473 100644 --- a/docs/docs/05-utilities/04-error-handling.md +++ b/docs/docs/05-utilities/04-error-handling.md @@ -16,12 +16,12 @@ import { RnExecutorchErrorCode, } from 'react-native-executorch'; -const llm = new LLMModule({ - tokenCallback: (token) => console.log(token), - messageHistoryCallback: (messages) => console.log(messages), -}); - -await llm.load(LLAMA3_2_1B_QLORA, (progress) => console.log(progress)); +const llm = await LLMModule.fromModelName( + LLAMA3_2_1B_QLORA, + (progress) => console.log(progress), + (token) => console.log(token), + (messages) => console.log(messages) +); // Try to set an invalid configuration try { diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index cf3700a7dd..472b48ef60 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -17,55 +17,61 @@ const LLAMA3_2_TOKENIZER_CONFIG = `${URL_PREFIX}-llama-3.2/${VERSION_TAG}/tokeni * @category Models - LMM */ export const LLAMA3_2_3B = { + modelName: 'llama-3.2-3b', modelSource: LLAMA3_2_3B_MODEL, tokenizerSource: LLAMA3_2_TOKENIZER, tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const LLAMA3_2_3B_QLORA = { + modelName: 'llama-3.2-3b-qlora', modelSource: LLAMA3_2_3B_QLORA_MODEL, tokenizerSource: LLAMA3_2_TOKENIZER, tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const LLAMA3_2_3B_SPINQUANT = { + modelName: 'llama-3.2-3b-spinquant', modelSource: LLAMA3_2_3B_SPINQUANT_MODEL, tokenizerSource: LLAMA3_2_TOKENIZER, tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const LLAMA3_2_1B = { + modelName: 'llama-3.2-1b', modelSource: LLAMA3_2_1B_MODEL, tokenizerSource: LLAMA3_2_TOKENIZER, tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const LLAMA3_2_1B_QLORA = { + modelName: 'llama-3.2-1b-qlora', modelSource: LLAMA3_2_1B_QLORA_MODEL, tokenizerSource: LLAMA3_2_TOKENIZER, tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const LLAMA3_2_1B_SPINQUANT = { + modelName: 'llama-3.2-1b-spinquant', modelSource: LLAMA3_2_1B_SPINQUANT_MODEL, tokenizerSource: LLAMA3_2_TOKENIZER, tokenizerConfigSource: LLAMA3_2_TOKENIZER_CONFIG, -}; +} as const; // QWEN 3 const QWEN3_0_6B_MODEL = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/qwen-3-0.6B/original/qwen3_0_6b_bf16.pte`; @@ -81,55 +87,61 @@ const QWEN3_TOKENIZER_CONFIG = `${URL_PREFIX}-qwen-3/${VERSION_TAG}/tokenizer_co * @category Models - LMM */ export const QWEN3_0_6B = { + modelName: 'qwen3-0.6b', modelSource: QWEN3_0_6B_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN3_0_6B_QUANTIZED = { + modelName: 'qwen3-0.6b-quantized', modelSource: QWEN3_0_6B_QUANTIZED_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN3_1_7B = { + modelName: 'qwen3-1.7b', modelSource: QWEN3_1_7B_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN3_1_7B_QUANTIZED = { + modelName: 'qwen3-1.7b-quantized', modelSource: QWEN3_1_7B_QUANTIZED_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN3_4B = { + modelName: 'qwen3-4b', modelSource: QWEN3_4B_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN3_4B_QUANTIZED = { + modelName: 'qwen3-4b-quantized', modelSource: QWEN3_4B_QUANTIZED_MODEL, tokenizerSource: QWEN3_TOKENIZER, tokenizerConfigSource: QWEN3_TOKENIZER_CONFIG, -}; +} as const; // HAMMER 2.1 const HAMMER2_1_0_5B_MODEL = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/hammer-2.1-0.5B/original/hammer2_1_0_5B_bf16.pte`; @@ -145,55 +157,61 @@ const HAMMER2_1_TOKENIZER_CONFIG = `${URL_PREFIX}-hammer-2.1/${VERSION_TAG}/toke * @category Models - LMM */ export const HAMMER2_1_0_5B = { + modelName: 'hammer2.1-0.5b', modelSource: HAMMER2_1_0_5B_MODEL, tokenizerSource: HAMMER2_1_TOKENIZER, tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const HAMMER2_1_0_5B_QUANTIZED = { + modelName: 'hammer2.1-0.5b-quantized', modelSource: HAMMER2_1_0_5B_QUANTIZED_MODEL, tokenizerSource: HAMMER2_1_TOKENIZER, tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const HAMMER2_1_1_5B = { + modelName: 'hammer2.1-1.5b', modelSource: HAMMER2_1_1_5B_MODEL, tokenizerSource: HAMMER2_1_TOKENIZER, tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const HAMMER2_1_1_5B_QUANTIZED = { + modelName: 'hammer2.1-1.5b-quantized', modelSource: HAMMER2_1_1_5B_QUANTIZED_MODEL, tokenizerSource: HAMMER2_1_TOKENIZER, tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const HAMMER2_1_3B = { + modelName: 'hammer2.1-3b', modelSource: HAMMER2_1_3B_MODEL, tokenizerSource: HAMMER2_1_TOKENIZER, tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const HAMMER2_1_3B_QUANTIZED = { + modelName: 'hammer2.1-3b-quantized', modelSource: HAMMER2_1_3B_QUANTIZED_MODEL, tokenizerSource: HAMMER2_1_TOKENIZER, tokenizerConfigSource: HAMMER2_1_TOKENIZER_CONFIG, -}; +} as const; // SMOLLM2 const SMOLLM2_1_135M_MODEL = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/smolLm-2-135M/original/smolLm2_135M_bf16.pte`; @@ -209,55 +227,61 @@ const SMOLLM2_1_TOKENIZER_CONFIG = `${URL_PREFIX}-smolLm-2/${VERSION_TAG}/tokeni * @category Models - LMM */ export const SMOLLM2_1_135M = { + modelName: 'smollm2.1-135m', modelSource: SMOLLM2_1_135M_MODEL, tokenizerSource: SMOLLM2_1_TOKENIZER, tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const SMOLLM2_1_135M_QUANTIZED = { + modelName: 'smollm2.1-135m-quantized', modelSource: SMOLLM2_1_135M_QUANTIZED_MODEL, tokenizerSource: SMOLLM2_1_TOKENIZER, tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const SMOLLM2_1_360M = { + modelName: 'smollm2.1-360m', modelSource: SMOLLM2_1_360M_MODEL, tokenizerSource: SMOLLM2_1_TOKENIZER, tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const SMOLLM2_1_360M_QUANTIZED = { + modelName: 'smollm2.1-360m-quantized', modelSource: SMOLLM2_1_360M_QUANTIZED_MODEL, tokenizerSource: SMOLLM2_1_TOKENIZER, tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const SMOLLM2_1_1_7B = { + modelName: 'smollm2.1-1.7b', modelSource: SMOLLM2_1_1_7B_MODEL, tokenizerSource: SMOLLM2_1_TOKENIZER, tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const SMOLLM2_1_1_7B_QUANTIZED = { + modelName: 'smollm2.1-1.7b-quantized', modelSource: SMOLLM2_1_1_7B_QUANTIZED_MODEL, tokenizerSource: SMOLLM2_1_TOKENIZER, tokenizerConfigSource: SMOLLM2_1_TOKENIZER_CONFIG, -}; +} as const; // QWEN 2.5 const QWEN2_5_0_5B_MODEL = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/qwen-2.5-0.5B/original/qwen2_5_0_5b_bf16.pte`; @@ -273,55 +297,61 @@ const QWEN2_5_TOKENIZER_CONFIG = `${URL_PREFIX}-qwen-2.5/${VERSION_TAG}/tokenize * @category Models - LMM */ export const QWEN2_5_0_5B = { + modelName: 'qwen2.5-0.5b', modelSource: QWEN2_5_0_5B_MODEL, tokenizerSource: QWEN2_5_TOKENIZER, tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN2_5_0_5B_QUANTIZED = { + modelName: 'qwen2.5-0.5b-quantized', modelSource: QWEN2_5_0_5B_QUANTIZED_MODEL, tokenizerSource: QWEN2_5_TOKENIZER, tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN2_5_1_5B = { + modelName: 'qwen2.5-1.5b', modelSource: QWEN2_5_1_5B_MODEL, tokenizerSource: QWEN2_5_TOKENIZER, tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN2_5_1_5B_QUANTIZED = { + modelName: 'qwen2.5-1.5b-quantized', modelSource: QWEN2_5_1_5B_QUANTIZED_MODEL, tokenizerSource: QWEN2_5_TOKENIZER, tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN2_5_3B = { + modelName: 'qwen2.5-3b', modelSource: QWEN2_5_3B_MODEL, tokenizerSource: QWEN2_5_TOKENIZER, tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const QWEN2_5_3B_QUANTIZED = { + modelName: 'qwen2.5-3b-quantized', modelSource: QWEN2_5_3B_QUANTIZED_MODEL, tokenizerSource: QWEN2_5_TOKENIZER, tokenizerConfigSource: QWEN2_5_TOKENIZER_CONFIG, -}; +} as const; // PHI 4 const PHI_4_MINI_4B_MODEL = `${URL_PREFIX}-phi-4-mini/${VERSION_TAG}/original/phi-4-mini_bf16.pte`; @@ -333,19 +363,21 @@ const PHI_4_MINI_TOKENIZER_CONFIG = `${URL_PREFIX}-phi-4-mini/${VERSION_TAG}/tok * @category Models - LMM */ export const PHI_4_MINI_4B = { + modelName: 'phi-4-mini-4b', modelSource: PHI_4_MINI_4B_MODEL, tokenizerSource: PHI_4_MINI_TOKENIZER, tokenizerConfigSource: PHI_4_MINI_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const PHI_4_MINI_4B_QUANTIZED = { + modelName: 'phi-4-mini-4b-quantized', modelSource: PHI_4_MINI_4B_QUANTIZED_MODEL, tokenizerSource: PHI_4_MINI_TOKENIZER, tokenizerConfigSource: PHI_4_MINI_TOKENIZER_CONFIG, -}; +} as const; // LFM2.5-1.2B-Instruct const LFM2_5_1_2B_INSTRUCT_MODEL = `${URL_PREFIX}-lfm2.5-1.2B-instruct/${NEXT_VERSION_TAG}/original/lfm2_5_1_2b_fp16.pte`; @@ -357,19 +389,21 @@ const LFM2_5_1_2B_TOKENIZER_CONFIG = `${URL_PREFIX}-lfm2.5-1.2B-instruct/${NEXT_ * @category Models - LMM */ export const LFM2_5_1_2B_INSTRUCT = { + modelName: 'lfm2.5-1.2b-instruct', modelSource: LFM2_5_1_2B_INSTRUCT_MODEL, tokenizerSource: LFM2_5_1_2B_TOKENIZER, tokenizerConfigSource: LFM2_5_1_2B_TOKENIZER_CONFIG, -}; +} as const; /** * @category Models - LMM */ export const LFM2_5_1_2B_INSTRUCT_QUANTIZED = { + modelName: 'lfm2.5-1.2b-instruct-quantized', modelSource: LFM2_5_1_2B_INSTRUCT_QUANTIZED_MODEL, tokenizerSource: LFM2_5_1_2B_TOKENIZER, tokenizerConfigSource: LFM2_5_1_2B_TOKENIZER_CONFIG, -}; +} as const; // LFM2.5-VL-1.6B (Vision-Language) const LFM2_VL_1_6B_QUANTIZED_MODEL = `https://huggingface.co/software-mansion/react-native-executorch-lfm2.5-VL-1.6B/resolve/main/quantized/lfm2_5_vl_1_6b_8da4w_xnnpack.pte`; @@ -403,7 +437,7 @@ const EFFICIENTNET_V2_S_QUANTIZED_MODEL = export const EFFICIENTNET_V2_S = { modelName: 'efficientnet-v2-s', modelSource: EFFICIENTNET_V2_S_MODEL, -}; +} as const; /** * @category Models - Classification @@ -411,7 +445,7 @@ export const EFFICIENTNET_V2_S = { export const EFFICIENTNET_V2_S_QUANTIZED = { modelName: 'efficientnet-v2-s-quantized', modelSource: EFFICIENTNET_V2_S_QUANTIZED_MODEL, -}; +} as const; // Object detection const SSDLITE_320_MOBILENET_V3_LARGE_MODEL = @@ -476,7 +510,7 @@ const STYLE_TRANSFER_UDNIE_QUANTIZED_MODEL = export const STYLE_TRANSFER_CANDY = { modelName: 'style-transfer-candy', modelSource: STYLE_TRANSFER_CANDY_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -484,7 +518,7 @@ export const STYLE_TRANSFER_CANDY = { export const STYLE_TRANSFER_CANDY_QUANTIZED = { modelName: 'style-transfer-candy-quantized', modelSource: STYLE_TRANSFER_CANDY_QUANTIZED_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -492,7 +526,7 @@ export const STYLE_TRANSFER_CANDY_QUANTIZED = { export const STYLE_TRANSFER_MOSAIC = { modelName: 'style-transfer-mosaic', modelSource: STYLE_TRANSFER_MOSAIC_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -500,7 +534,7 @@ export const STYLE_TRANSFER_MOSAIC = { export const STYLE_TRANSFER_MOSAIC_QUANTIZED = { modelName: 'style-transfer-mosaic-quantized', modelSource: STYLE_TRANSFER_MOSAIC_QUANTIZED_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -508,7 +542,7 @@ export const STYLE_TRANSFER_MOSAIC_QUANTIZED = { export const STYLE_TRANSFER_RAIN_PRINCESS = { modelName: 'style-transfer-rain-princess', modelSource: STYLE_TRANSFER_RAIN_PRINCESS_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -516,7 +550,7 @@ export const STYLE_TRANSFER_RAIN_PRINCESS = { export const STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED = { modelName: 'style-transfer-rain-princess-quantized', modelSource: STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -524,7 +558,7 @@ export const STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED = { export const STYLE_TRANSFER_UDNIE = { modelName: 'style-transfer-udnie', modelSource: STYLE_TRANSFER_UDNIE_MODEL, -}; +} as const; /** * @category Models - Style Transfer @@ -532,7 +566,7 @@ export const STYLE_TRANSFER_UDNIE = { export const STYLE_TRANSFER_UDNIE_QUANTIZED = { modelName: 'style-transfer-udnie-quantized', modelSource: STYLE_TRANSFER_UDNIE_QUANTIZED_MODEL, -}; +} as const; // S2T const WHISPER_TINY_EN_TOKENIZER = `${URL_PREFIX}-whisper-tiny.en/${VERSION_TAG}/tokenizer.json`; @@ -566,91 +600,91 @@ const WHISPER_SMALL_MODEL = `${URL_PREFIX}-whisper-small/${NEXT_VERSION_TAG}/xnn * @category Models - Speech To Text */ export const WHISPER_TINY_EN = { - type: 'whisper' as const, + modelName: 'whisper-tiny-en', isMultilingual: false, modelSource: WHISPER_TINY_EN_MODEL, tokenizerSource: WHISPER_TINY_EN_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_TINY_EN_QUANTIZED = { - type: 'whisper' as const, + modelName: 'whisper-tiny-en-quantized', isMultilingual: false, modelSource: WHISPER_TINY_EN_QUANTIZED_MODEL, tokenizerSource: WHISPER_TINY_EN_QUANTIZED_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_BASE_EN = { - type: 'whisper' as const, + modelName: 'whisper-base-en', isMultilingual: false, modelSource: WHISPER_BASE_EN_MODEL, tokenizerSource: WHISPER_BASE_EN_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_BASE_EN_QUANTIZED = { - type: 'whisper' as const, + modelName: 'whisper-base-en-quantized', isMultilingual: false, modelSource: WHISPER_BASE_EN_QUANTIZED_MODEL, tokenizerSource: WHISPER_BASE_EN_QUANTIZED_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_SMALL_EN = { - type: 'whisper' as const, + modelName: 'whisper-small-en', isMultilingual: false, modelSource: WHISPER_SMALL_EN_MODEL, tokenizerSource: WHISPER_SMALL_EN_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_SMALL_EN_QUANTIZED = { - type: 'whisper' as const, + modelName: 'whisper-small-en-quantized', isMultilingual: false, modelSource: WHISPER_SMALL_EN_QUANTIZED_MODEL, tokenizerSource: WHISPER_SMALL_EN_QUANTIZED_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_TINY = { - type: 'whisper' as const, + modelName: 'whisper-tiny', isMultilingual: true, modelSource: WHISPER_TINY_MODEL, tokenizerSource: WHISPER_TINY_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_BASE = { - type: 'whisper' as const, + modelName: 'whisper-base', isMultilingual: true, modelSource: WHISPER_BASE_MODEL, tokenizerSource: WHISPER_BASE_TOKENIZER, -}; +} as const; /** * @category Models - Speech To Text */ export const WHISPER_SMALL = { - type: 'whisper' as const, + modelName: 'whisper-small', isMultilingual: true, modelSource: WHISPER_SMALL_MODEL, tokenizerSource: WHISPER_SMALL_TOKENIZER, -}; +} as const; // Semantic Segmentation const DEEPLAB_V3_RESNET50_MODEL = `${URL_PREFIX}-deeplab-v3/${NEXT_VERSION_TAG}/deeplab-v3-resnet50/xnnpack/deeplabv3_resnet50_xnnpack_fp32.pte`; @@ -763,6 +797,7 @@ export const FCN_RESNET101_QUANTIZED = { } as const; const SELFIE_SEGMENTATION_MODEL = `${URL_PREFIX}-selfie-segmentation/${NEXT_VERSION_TAG}/xnnpack/selfie-segmentation.pte`; + /** * @category Models - Semantic Segmentation */ @@ -781,7 +816,7 @@ const CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED_MODEL = `${URL_PREFIX}-clip-vit-base export const CLIP_VIT_BASE_PATCH32_IMAGE = { modelName: 'clip-vit-base-patch32-image', modelSource: CLIP_VIT_BASE_PATCH32_IMAGE_MODEL, -}; +} as const; /** * @category Models - Image Embeddings @@ -789,7 +824,7 @@ export const CLIP_VIT_BASE_PATCH32_IMAGE = { export const CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED = { modelName: 'clip-vit-base-patch32-image-quantized', modelSource: CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED_MODEL, -}; +} as const; // Text Embeddings const ALL_MINILM_L6_V2_MODEL = `${URL_PREFIX}-all-MiniLM-L6-v2/${VERSION_TAG}/all-MiniLM-L6-v2_xnnpack.pte`; @@ -807,33 +842,37 @@ const CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER = `${URL_PREFIX}-clip-vit-base-patch3 * @category Models - Text Embeddings */ export const ALL_MINILM_L6_V2 = { + modelName: 'all-minilm-l6-v2', modelSource: ALL_MINILM_L6_V2_MODEL, tokenizerSource: ALL_MINILM_L6_V2_TOKENIZER, -}; +} as const; /** * @category Models - Text Embeddings */ export const ALL_MPNET_BASE_V2 = { + modelName: 'all-mpnet-base-v2', modelSource: ALL_MPNET_BASE_V2_MODEL, tokenizerSource: ALL_MPNET_BASE_V2_TOKENIZER, -}; +} as const; /** * @category Models - Text Embeddings */ export const MULTI_QA_MINILM_L6_COS_V1 = { + modelName: 'multi-qa-minilm-l6-cos-v1', modelSource: MULTI_QA_MINILM_L6_COS_V1_MODEL, tokenizerSource: MULTI_QA_MINILM_L6_COS_V1_TOKENIZER, -}; +} as const; /** * @category Models - Text Embeddings */ export const MULTI_QA_MPNET_BASE_DOT_V1 = { + modelName: 'multi-qa-mpnet-base-dot-v1', modelSource: MULTI_QA_MPNET_BASE_DOT_V1_MODEL, tokenizerSource: MULTI_QA_MPNET_BASE_DOT_V1_TOKENIZER, -}; +} as const; /** * @category Models - Text Embeddings @@ -842,7 +881,7 @@ export const CLIP_VIT_BASE_PATCH32_TEXT = { modelName: 'clip-vit-base-patch32-text', modelSource: CLIP_VIT_BASE_PATCH32_TEXT_MODEL, tokenizerSource: CLIP_VIT_BASE_PATCH32_TEXT_TOKENIZER, -}; +} as const; // Image generation @@ -850,23 +889,25 @@ export const CLIP_VIT_BASE_PATCH32_TEXT = { * @category Models - Image Generation */ export const BK_SDM_TINY_VPRED_512 = { + modelName: 'bk-sdm-tiny-vpred-512', schedulerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/scheduler/scheduler_config.json`, tokenizerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/tokenizer/tokenizer.json`, encoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/text_encoder/model.pte`, unetSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/unet/model.pte`, decoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/vae/model.pte`, -}; +} as const; /** * @category Models - Image Generation */ export const BK_SDM_TINY_VPRED_256 = { + modelName: 'bk-sdm-tiny-vpred-256', schedulerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/scheduler/scheduler_config.json`, tokenizerSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/tokenizer/tokenizer.json`, encoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/text_encoder/model.pte`, unetSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/unet/model.256.pte`, decoderSource: `${URL_PREFIX}-bk-sdm-tiny/${VERSION_TAG}/vae/model.256.pte`, -}; +} as const; // Voice Activity Detection const FSMN_VAD_MODEL = `${URL_PREFIX}-fsmn-vad/${VERSION_TAG}/xnnpack/fsmn-vad_xnnpack.pte`; @@ -875,5 +916,6 @@ const FSMN_VAD_MODEL = `${URL_PREFIX}-fsmn-vad/${VERSION_TAG}/xnnpack/fsmn-vad_x * @category Models - Voice Activity Detection */ export const FSMN_VAD = { + modelName: 'fsmn-vad', modelSource: FSMN_VAD_MODEL, -}; +} as const; diff --git a/packages/react-native-executorch/src/constants/ocr/models.ts b/packages/react-native-executorch/src/constants/ocr/models.ts index 864a49a85d..2a9e325c3d 100644 --- a/packages/react-native-executorch/src/constants/ocr/models.ts +++ b/packages/react-native-executorch/src/constants/ocr/models.ts @@ -21,6 +21,7 @@ const createOCRObject = ( language: keyof typeof symbols ) => { return { + modelName: `ocr-${language}` as const, detectorSource: DETECTOR_CRAFT_MODEL, recognizerSource, language, diff --git a/packages/react-native-executorch/src/constants/tts/models.ts b/packages/react-native-executorch/src/constants/tts/models.ts index 96bca3842f..a7ac8703a1 100644 --- a/packages/react-native-executorch/src/constants/tts/models.ts +++ b/packages/react-native-executorch/src/constants/tts/models.ts @@ -13,7 +13,7 @@ const KOKORO_EN_MEDIUM_MODELS_ROOT = `${KOKORO_EN_MODELS_ROOT}/medium`; * @category Models - Text to Speech */ export const KOKORO_SMALL = { - type: 'kokoro' as const, + modelName: 'kokoro-small' as const, durationPredictorSource: `${KOKORO_EN_SMALL_MODELS_ROOT}/duration_predictor.pte`, synthesizerSource: `${KOKORO_EN_SMALL_MODELS_ROOT}/synthesizer.pte`, }; @@ -24,7 +24,7 @@ export const KOKORO_SMALL = { * @category Models - Text to Speech */ export const KOKORO_MEDIUM = { - type: 'kokoro' as const, + modelName: 'kokoro-medium' as const, durationPredictorSource: `${KOKORO_EN_MEDIUM_MODELS_ROOT}/duration_predictor.pte`, synthesizerSource: `${KOKORO_EN_MEDIUM_MODELS_ROOT}/synthesizer.pte`, }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts index bee943d675..c014d6b0ed 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useClassification.ts @@ -1,9 +1,9 @@ -import { useModule } from '../useModule'; import { ClassificationModule } from '../../modules/computer_vision/ClassificationModule'; import { ClassificationProps, ClassificationType, } from '../../types/classification'; +import { useModuleFactory } from '../useModuleFactory'; /** * React hook for managing a Classification model instance. @@ -15,9 +15,18 @@ import { export const useClassification = ({ model, preventLoad = false, -}: ClassificationProps): ClassificationType => - useModule({ - module: ClassificationModule, - model, - preventLoad: preventLoad, - }); +}: ClassificationProps): ClassificationType => { + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (config, onProgress) => + ClassificationModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); + + const forward = (imageSource: string) => + runForward((inst) => inst.forward(imageSource)); + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts index d5d82f68f1..b4e79c9263 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useImageEmbeddings.ts @@ -3,7 +3,7 @@ import { ImageEmbeddingsProps, ImageEmbeddingsType, } from '../../types/imageEmbeddings'; -import { useModule } from '../useModule'; +import { useModuleFactory } from '../useModuleFactory'; /** * React hook for managing an Image Embeddings model instance. @@ -15,9 +15,18 @@ import { useModule } from '../useModule'; export const useImageEmbeddings = ({ model, preventLoad = false, -}: ImageEmbeddingsProps): ImageEmbeddingsType => - useModule({ - module: ImageEmbeddingsModule, - model, - preventLoad, - }); +}: ImageEmbeddingsProps): ImageEmbeddingsType => { + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (config, onProgress) => + ImageEmbeddingsModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); + + const forward = (imageSource: string) => + runForward((inst) => inst.forward(imageSource)); + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts index 6b28688340..473d3631ba 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useOCR.ts @@ -1,7 +1,7 @@ -import { useEffect, useState } from 'react'; -import { OCRProps, OCRType } from '../../types/ocr'; +import { useCallback, useEffect, useState } from 'react'; import { OCRController } from '../../controllers/OCRController'; import { RnExecutorchError } from '../../errors/errorUtils'; +import { OCRDetection, OCRProps, OCRType } from '../../types/ocr'; /** * React hook for managing an OCR instance. @@ -11,12 +11,12 @@ import { RnExecutorchError } from '../../errors/errorUtils'; * @returns Ready to use OCR model. */ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { - const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); + const [error, setError] = useState(null); - const [controllerInstance] = useState( + const [controller] = useState( () => new OCRController({ isReadyCallback: setIsReady, @@ -26,33 +26,37 @@ export const useOCR = ({ model, preventLoad = false }: OCRProps): OCRType => { ); useEffect(() => { + setDownloadProgress(0); + setError(null); + if (preventLoad) return; - (async () => { - await controllerInstance.load( - model.detectorSource, - model.recognizerSource, - model.language, - setDownloadProgress - ); - })(); + controller.load( + model.detectorSource, + model.recognizerSource, + model.language, + setDownloadProgress + ); return () => { - controllerInstance.delete(); + if (controller.isReady) { + controller.delete(); + } }; }, [ - controllerInstance, + controller, + model.modelName, model.detectorSource, model.recognizerSource, model.language, preventLoad, ]); - return { - error, - isReady, - isGenerating, - forward: controllerInstance.forward, - downloadProgress, - }; + const forward = useCallback( + (imageSource: string): Promise => + controller.forward(imageSource), + [controller] + ); + + return { error, isReady, isGenerating, downloadProgress, forward }; }; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts index 5333b8a322..81c81ce22f 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useObjectDetection.ts @@ -36,6 +36,7 @@ export const useObjectDetection = ({ factory: (config, onProgress) => ObjectDetectionModule.fromModelName(config, onProgress), config: model, + deps: [model.modelName, model.modelSource], preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts index cad249110f..dd43aaf8b3 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useSemanticSegmentation.ts @@ -39,6 +39,7 @@ export const useSemanticSegmentation = < factory: (config, onProgress) => SemanticSegmentationModule.fromModelName(config, onProgress), config: model, + deps: [model.modelName, model.modelSource], preventLoad, }); diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts index d51ff7a9ea..bfa42eee71 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useStyleTransfer.ts @@ -1,9 +1,9 @@ -import { useModule } from '../useModule'; import { StyleTransferModule } from '../../modules/computer_vision/StyleTransferModule'; import { StyleTransferProps, StyleTransferType, } from '../../types/styleTransfer'; +import { useModuleFactory } from '../useModuleFactory'; /** * React hook for managing a Style Transfer model instance. @@ -15,9 +15,18 @@ import { export const useStyleTransfer = ({ model, preventLoad = false, -}: StyleTransferProps): StyleTransferType => - useModule({ - module: StyleTransferModule, - model, - preventLoad: preventLoad, - }); +}: StyleTransferProps): StyleTransferType => { + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (config, onProgress) => + StyleTransferModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); + + const forward = (imageSource: string) => + runForward((inst) => inst.forward(imageSource)); + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts b/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts index 0487e9e864..6a393ebd5b 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useTextToImage.ts @@ -20,28 +20,63 @@ export const useTextToImage = ({ const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); const [error, setError] = useState(null); - - const [module] = useState(() => new TextToImageModule(inferenceCallback)); + const [moduleInstance, setModuleInstance] = + useState(null); useEffect(() => { if (preventLoad) return; - (async () => { - setDownloadProgress(0); - setError(null); - try { - setIsReady(false); - await module.load(model, setDownloadProgress); - setIsReady(true); - } catch (err) { - setError(parseUnknownError(err)); + let active = true; + setDownloadProgress(0); + setError(null); + setIsReady(false); + + TextToImageModule.fromModelName( + { + modelName: model.modelName, + tokenizerSource: model.tokenizerSource, + schedulerSource: model.schedulerSource, + encoderSource: model.encoderSource, + unetSource: model.unetSource, + decoderSource: model.decoderSource, + inferenceCallback, + }, + (p) => { + if (active) setDownloadProgress(p); } - })(); + ) + .then((mod) => { + if (!active) { + mod.delete(); + return; + } + setModuleInstance((prev) => { + prev?.delete(); + return mod; + }); + setIsReady(true); + }) + .catch((err) => { + if (active) setError(parseUnknownError(err)); + }); return () => { - module.delete(); + active = false; + setModuleInstance((prev) => { + prev?.delete(); + return null; + }); }; - }, [module, model, preventLoad]); + // eslint-disable-next-line react-hooks/exhaustive-deps + }, [ + model.modelName, + model.tokenizerSource, + model.schedulerSource, + model.encoderSource, + model.unetSource, + model.decoderSource, + preventLoad, + ]); const generate = async ( input: string, @@ -49,7 +84,7 @@ export const useTextToImage = ({ numSteps?: number, seed?: number ): Promise => { - if (!isReady) + if (!isReady || !moduleInstance) throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling forward().' @@ -61,17 +96,17 @@ export const useTextToImage = ({ ); try { setIsGenerating(true); - return await module.forward(input, imageSize, numSteps, seed); + return await moduleInstance.forward(input, imageSize, numSteps, seed); } finally { setIsGenerating(false); } }; const interrupt = useCallback(() => { - if (isGenerating) { - module.interrupt(); + if (isGenerating && moduleInstance) { + moduleInstance.interrupt(); } - }, [module, isGenerating]); + }, [moduleInstance, isGenerating]); return { isReady, diff --git a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts index eb9d289eb8..71774198fc 100644 --- a/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts +++ b/packages/react-native-executorch/src/hooks/computer_vision/useVerticalOCR.ts @@ -1,7 +1,7 @@ -import { useEffect, useState } from 'react'; -import { OCRType, VerticalOCRProps } from '../../types/ocr'; +import { useCallback, useEffect, useState } from 'react'; import { VerticalOCRController } from '../../controllers/VerticalOCRController'; import { RnExecutorchError } from '../../errors/errorUtils'; +import { OCRDetection, OCRType, VerticalOCRProps } from '../../types/ocr'; /** * React hook for managing a Vertical OCR instance. @@ -15,12 +15,12 @@ export const useVerticalOCR = ({ independentCharacters = false, preventLoad = false, }: VerticalOCRProps): OCRType => { - const [error, setError] = useState(null); const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); + const [error, setError] = useState(null); - const [controllerInstance] = useState( + const [controller] = useState( () => new VerticalOCRController({ isReadyCallback: setIsReady, @@ -30,23 +30,27 @@ export const useVerticalOCR = ({ ); useEffect(() => { + setDownloadProgress(0); + setError(null); + if (preventLoad) return; - (async () => { - await controllerInstance.load( - model.detectorSource, - model.recognizerSource, - model.language, - independentCharacters, - setDownloadProgress - ); - })(); + controller.load( + model.detectorSource, + model.recognizerSource, + model.language, + independentCharacters, + setDownloadProgress + ); return () => { - controllerInstance.delete(); + if (controller.isReady) { + controller.delete(); + } }; }, [ - controllerInstance, + controller, + model.modelName, model.detectorSource, model.recognizerSource, model.language, @@ -54,11 +58,11 @@ export const useVerticalOCR = ({ preventLoad, ]); - return { - error, - isReady, - isGenerating, - forward: controllerInstance.forward, - downloadProgress, - }; + const forward = useCallback( + (imageSource: string): Promise => + controller.forward(imageSource), + [controller] + ); + + return { error, isReady, isGenerating, downloadProgress, forward }; }; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts index 0f2d818748..13c17e00a2 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useLLM.ts @@ -78,6 +78,7 @@ export function useLLM({ // eslint-disable-next-line react-hooks/exhaustive-deps }, [ controllerInstance, + model.modelName, model.modelSource, model.tokenizerSource, model.tokenizerConfigSource, diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts index e26b6e3c7e..5a3d77fcac 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useSpeechToText.ts @@ -24,42 +24,52 @@ export const useSpeechToText = ({ const [isReady, setIsReady] = useState(false); const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - - const [moduleInstance, _] = useState(() => new SpeechToTextModule()); + const [moduleInstance, setModuleInstance] = + useState(null); useEffect(() => { if (preventLoad) return; - let isMounted = true; - (async () => { - setDownloadProgress(0); - setError(null); - try { - setIsReady(false); - await moduleInstance.load( - { - type: model.type, - isMultilingual: model.isMultilingual, - modelSource: model.modelSource, - tokenizerSource: model.tokenizerSource, - }, - (progress) => { - if (isMounted) setDownloadProgress(progress); - } - ); - if (isMounted) setIsReady(true); - } catch (err) { - if (isMounted) setError(parseUnknownError(err)); + let active = true; + setDownloadProgress(0); + setError(null); + setIsReady(false); + + SpeechToTextModule.fromModelName( + { + modelName: model.modelName, + isMultilingual: model.isMultilingual, + modelSource: model.modelSource, + tokenizerSource: model.tokenizerSource, + }, + (p) => { + if (active) setDownloadProgress(p); } - })(); + ) + .then((mod) => { + if (!active) { + mod.delete(); + return; + } + setModuleInstance((prev) => { + prev?.delete(); + return mod; + }); + setIsReady(true); + }) + .catch((err) => { + if (active) setError(parseUnknownError(err)); + }); return () => { - isMounted = false; - moduleInstance.delete(); + active = false; + setModuleInstance((prev) => { + prev?.delete(); + return null; + }); }; }, [ - moduleInstance, - model.type, + model.modelName, model.isMultilingual, model.modelSource, model.tokenizerSource, @@ -71,7 +81,7 @@ export const useSpeechToText = ({ waveform: Float32Array, options: DecodingOptions = {} ): Promise => { - if (!isReady) { + if (!isReady || !moduleInstance) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling this function.' @@ -103,7 +113,7 @@ export const useSpeechToText = ({ void, unknown > { - if (!isReady) { + if (!isReady || !moduleInstance) { throw new RnExecutorchError( RnExecutorchErrorCode.ModuleNotLoaded, 'The model is currently not loaded. Please load the model before calling this function.' @@ -131,17 +141,44 @@ export const useSpeechToText = ({ const streamInsert = useCallback( (waveform: Float32Array) => { - if (!isReady) return; + if (!isReady || !moduleInstance) return; moduleInstance.streamInsert(waveform); }, [isReady, moduleInstance] ); const streamStop = useCallback(() => { - if (!isReady) return; + if (!isReady || !moduleInstance) return; moduleInstance.streamStop(); }, [isReady, moduleInstance]); + const encode = useCallback( + (waveform: Float32Array): Promise => { + if (!moduleInstance) + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling this function.' + ); + return moduleInstance.encode(waveform); + }, + [moduleInstance] + ); + + const decode = useCallback( + ( + tokens: Int32Array, + encoderOutput: Float32Array + ): Promise => { + if (!moduleInstance) + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling this function.' + ); + return moduleInstance.decode(tokens, encoderOutput); + }, + [moduleInstance] + ); + return { error, isReady, @@ -151,7 +188,7 @@ export const useSpeechToText = ({ stream, streamInsert, streamStop, - encode: moduleInstance.encode.bind(moduleInstance), - decode: moduleInstance.decode.bind(moduleInstance), + encode, + decode, }; }; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts index 4ffa119006..664f12caf9 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextEmbeddings.ts @@ -1,5 +1,5 @@ import { TextEmbeddingsModule } from '../../modules/natural_language_processing/TextEmbeddingsModule'; -import { useModule } from '../useModule'; +import { useModuleFactory } from '../useModuleFactory'; import { TextEmbeddingsType, TextEmbeddingsProps, @@ -15,9 +15,17 @@ import { export const useTextEmbeddings = ({ model, preventLoad = false, -}: TextEmbeddingsProps): TextEmbeddingsType => - useModule({ - module: TextEmbeddingsModule, - model, - preventLoad, - }); +}: TextEmbeddingsProps): TextEmbeddingsType => { + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (config, onProgress) => + TextEmbeddingsModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource, model.tokenizerSource], + preventLoad, + }); + + const forward = (input: string) => runForward((inst) => inst.forward(input)); + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts index 1a751f42dc..1d48aef34e 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useTextToSpeech.ts @@ -29,35 +29,43 @@ export const useTextToSpeech = ({ const [isGenerating, setIsGenerating] = useState(false); const [downloadProgress, setDownloadProgress] = useState(0); - const [moduleInstance] = useState(() => new TextToSpeechModule()); + const [moduleInstance, setModuleInstance] = + useState(null); useEffect(() => { if (preventLoad) return; - (async () => { - setDownloadProgress(0); - setError(null); - try { - setIsReady(false); - await moduleInstance.load( - { - model, - voice, - }, - setDownloadProgress - ); + let active = true; + setDownloadProgress(0); + setError(null); + setIsReady(false); + + TextToSpeechModule.fromModelName({ model, voice }, setDownloadProgress) + .then((mod) => { + if (!active) { + mod.delete(); + return; + } + setModuleInstance((prev) => { + prev?.delete(); + return mod; + }); setIsReady(true); - } catch (err) { - setError(parseUnknownError(err)); - } - })(); + }) + .catch((err) => { + if (active) setError(parseUnknownError(err)); + }); return () => { - moduleInstance.delete(); + active = false; + setModuleInstance((prev) => { + prev?.delete(); + return null; + }); }; // eslint-disable-next-line react-hooks/exhaustive-deps }, [ - moduleInstance, + model.modelName, model.durationPredictorSource, model.synthesizerSource, voice?.voiceSource, @@ -66,18 +74,22 @@ export const useTextToSpeech = ({ ]); // Shared guard for all generation methods - const guardReady = (methodName: string) => { - if (!isReady) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModuleNotLoaded, - `The model is currently not loaded. Please load the model before calling ${methodName}().` - ); - if (isGenerating) - throw new RnExecutorchError( - RnExecutorchErrorCode.ModelGenerating, - 'The model is currently generating. Please wait until previous model run is complete.' - ); - }; + const guardReady = useCallback( + (methodName: string): TextToSpeechModule => { + if (!isReady || !moduleInstance) + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + `The model is currently not loaded. Please load the model before calling ${methodName}().` + ); + if (isGenerating) + throw new RnExecutorchError( + RnExecutorchErrorCode.ModelGenerating, + 'The model is currently generating. Please wait until previous model run is complete.' + ); + return moduleInstance; + }, + [isReady, isGenerating, moduleInstance] + ); // Shared streaming orchestration (guards + onBegin/onNext/onEnd lifecycle) const runStream = useCallback( @@ -105,20 +117,20 @@ export const useTextToSpeech = ({ ); const forward = async (input: TextToSpeechInput) => { - guardReady('forward'); + const instance = guardReady('forward'); try { setIsGenerating(true); - return await moduleInstance.forward(input.text, input.speed ?? 1.0); + return await instance.forward(input.text, input.speed ?? 1.0); } finally { setIsGenerating(false); } }; const forwardFromPhonemes = async (input: TextToSpeechPhonemeInput) => { - guardReady('forwardFromPhonemes'); + const instance = guardReady('forwardFromPhonemes'); try { setIsGenerating(true); - return await moduleInstance.forwardFromPhonemes( + return await instance.forwardFromPhonemes( input.phonemes, input.speed ?? 1.0 ); @@ -129,27 +141,29 @@ export const useTextToSpeech = ({ const stream = useCallback( async (input: TextToSpeechStreamingInput) => { + const instance = guardReady('stream'); await runStream( 'stream', - moduleInstance.stream({ text: input.text, speed: input.speed ?? 1.0 }), + instance.stream({ text: input.text, speed: input.speed ?? 1.0 }), input ); }, - [runStream, moduleInstance] + [guardReady, runStream] ); const streamFromPhonemes = useCallback( async (input: TextToSpeechStreamingPhonemeInput) => { + const instance = guardReady('streamFromPhonemes'); await runStream( 'streamFromPhonemes', - moduleInstance.streamFromPhonemes({ + instance.streamFromPhonemes({ phonemes: input.phonemes, speed: input.speed ?? 1.0, }), input ); }, - [runStream, moduleInstance] + [guardReady, runStream] ); return { @@ -160,7 +174,7 @@ export const useTextToSpeech = ({ forwardFromPhonemes, stream, streamFromPhonemes, - streamStop: moduleInstance.streamStop, + streamStop: () => moduleInstance?.streamStop(), downloadProgress, }; }; diff --git a/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts b/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts index abb3dca6a2..3e5be0214a 100644 --- a/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts +++ b/packages/react-native-executorch/src/hooks/natural_language_processing/useVAD.ts @@ -1,6 +1,6 @@ -import { useModule } from '../useModule'; import { VADModule } from '../../modules/natural_language_processing/VADModule'; import { VADType, VADProps } from '../../types/vad'; +import { useModuleFactory } from '../useModuleFactory'; /** * React hook for managing a VAD model instance. @@ -9,9 +9,18 @@ import { VADType, VADProps } from '../../types/vad'; * @param VADProps - Configuration object containing `model` source and optional `preventLoad` flag. * @returns Ready to use VAD model. */ -export const useVAD = ({ model, preventLoad = false }: VADProps): VADType => - useModule({ - module: VADModule, - model, - preventLoad: preventLoad, - }); +export const useVAD = ({ model, preventLoad = false }: VADProps): VADType => { + const { error, isReady, isGenerating, downloadProgress, runForward } = + useModuleFactory({ + factory: (config, onProgress) => + VADModule.fromModelName(config, onProgress), + config: model, + deps: [model.modelName, model.modelSource], + preventLoad, + }); + + const forward = (waveform: Float32Array) => + runForward((inst) => inst.forward(waveform)); + + return { error, isReady, isGenerating, downloadProgress, forward }; +}; diff --git a/packages/react-native-executorch/src/hooks/useModuleFactory.ts b/packages/react-native-executorch/src/hooks/useModuleFactory.ts index 2be8821268..3d7f474052 100644 --- a/packages/react-native-executorch/src/hooks/useModuleFactory.ts +++ b/packages/react-native-executorch/src/hooks/useModuleFactory.ts @@ -14,12 +14,10 @@ type Deletable = { delete: () => void }; * * @internal */ -export function useModuleFactory< - M extends Deletable, - Config extends { modelName: string; modelSource: unknown }, ->({ +export function useModuleFactory({ factory, config, + deps, preventLoad = false, }: { factory: ( @@ -27,6 +25,7 @@ export function useModuleFactory< onProgress: (progress: number) => void ) => Promise; config: Config; + deps: ReadonlyArray; preventLoad?: boolean; }) { const [error, setError] = useState(null); @@ -38,27 +37,39 @@ export function useModuleFactory< useEffect(() => { if (preventLoad) return; - let currentInstance: M | null = null; + let active = true; + setDownloadProgress(0); + setError(null); + setIsReady(false); - (async () => { - setDownloadProgress(0); - setError(null); - setIsReady(false); - try { - currentInstance = await factory(config, setDownloadProgress); - setInstance(currentInstance); + factory(config, (p) => { + if (active) setDownloadProgress(p); + }) + .then((mod) => { + if (!active) { + mod.delete(); + return; + } + setInstance((prev) => { + prev?.delete(); + return mod; + }); setIsReady(true); - } catch (err) { - setError(parseUnknownError(err)); - } - })(); + }) + .catch((err) => { + if (active) setError(parseUnknownError(err)); + }); return () => { - currentInstance?.delete(); + active = false; + setInstance((prev) => { + prev?.delete(); + return null; + }); }; // eslint-disable-next-line react-hooks/exhaustive-deps - }, [config.modelName, config.modelSource, preventLoad]); + }, [...deps, preventLoad]); const runForward = async (fn: (instance: M) => Promise): Promise => { if (!isReady || !instance) { diff --git a/packages/react-native-executorch/src/modules/BaseLabeledModule.ts b/packages/react-native-executorch/src/modules/BaseLabeledModule.ts index 6d8719b65a..01678f83d5 100644 --- a/packages/react-native-executorch/src/modules/BaseLabeledModule.ts +++ b/packages/react-native-executorch/src/modules/BaseLabeledModule.ts @@ -56,7 +56,4 @@ export abstract class BaseLabeledModule< this.labelMap = labelMap; this.nativeModule = nativeModule; } - - // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) - override async load() {} } diff --git a/packages/react-native-executorch/src/modules/BaseModule.ts b/packages/react-native-executorch/src/modules/BaseModule.ts index 41a2da6cfd..c844cf358b 100644 --- a/packages/react-native-executorch/src/modules/BaseModule.ts +++ b/packages/react-native-executorch/src/modules/BaseModule.ts @@ -1,4 +1,4 @@ -import { Frame, ResourceSource } from '../types/common'; +import { Frame } from '../types/common'; import { TensorPtr } from '../types/common'; /** @@ -55,20 +55,6 @@ export abstract class BaseModule { */ public generateFromFrame!: (frameData: Frame, ...args: any[]) => any; - /** - * Load the model and prepare it for inference. - * - * @param modelSource - Resource location of the model binary - * @param onDownloadProgressCallback - Optional callback to monitor download progress (0-1) - * @param args - Additional model-specific loading arguments - */ - - abstract load( - modelSource: ResourceSource, - onDownloadProgressCallback: (_: number) => void, - ...args: any[] - ): Promise; - /** * Runs the model's forward method with the given input tensors. * It returns the output tensors that mimic the structure of output from ExecuTorch. diff --git a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts index 6d9b5dce56..43691c2047 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts @@ -1,5 +1,6 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; +import { ClassificationModelName } from '../../types/classification'; import { BaseModule } from '../BaseModule'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; @@ -11,21 +12,29 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class ClassificationModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } + /** - * Loads the model, where `modelSource` is a string that specifies the location of the model binary. - * To track the download progress, supply a callback function `onDownloadProgressCallback`. + * Creates a classification instance for a built-in model. * - * @param model - Object containing `modelSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying which built-in model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `ClassificationModule` instance. */ - async load( - model: { modelSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + static async fromModelName( + namedSources: { + modelName: ClassificationModelName; + modelSource: ResourceSource; + }, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource + onDownloadProgress, + namedSources.modelSource ); if (!paths?.[0]) { @@ -35,7 +44,9 @@ export class ClassificationModule extends BaseModule { ); } - this.nativeModule = await global.loadClassification(paths[0]); + return new ClassificationModule( + await global.loadClassification(paths[0]) + ); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); @@ -43,10 +54,31 @@ export class ClassificationModule extends BaseModule { } /** - * Executes the model's forward pass, where `imageSource` can be a fetchable resource or a Base64-encoded string. + * Creates a classification instance with a user-provided model binary. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `ClassificationModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return ClassificationModule.fromModelName( + { modelName: 'custom' as ClassificationModelName, modelSource }, + onDownloadProgress + ); + } + + /** + * Executes the model's forward pass to classify the provided image. * - * @param imageSource - The image source to be classified. - * @returns The classification result. + * @param imageSource - A string image source (file path, URI, or Base64). + * @returns A Promise resolving to an object mapping category labels to confidence scores. */ async forward(imageSource: string): Promise<{ [category: string]: number }> { if (this.nativeModule == null) diff --git a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts index 172750e72c..dd81408e36 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts @@ -1,5 +1,6 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; +import { ImageEmbeddingsModelName } from '../../types/imageEmbeddings'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; @@ -11,20 +12,29 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class ImageEmbeddingsModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } + /** - * Loads the model, where `modelSource` is a string that specifies the location of the model binary. + * Creates an image embeddings instance for a built-in model. * - * @param model - Object containing `modelSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying which built-in model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ImageEmbeddingsModule` instance. */ - async load( - model: { modelSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + static async fromModelName( + namedSources: { + modelName: ImageEmbeddingsModelName; + modelSource: ResourceSource; + }, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource + onDownloadProgress, + namedSources.modelSource ); if (!paths?.[0]) { @@ -34,7 +44,9 @@ export class ImageEmbeddingsModule extends BaseModule { ); } - this.nativeModule = await global.loadImageEmbeddings(paths[0]); + return new ImageEmbeddingsModule( + await global.loadImageEmbeddings(paths[0]) + ); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); @@ -42,10 +54,31 @@ export class ImageEmbeddingsModule extends BaseModule { } /** - * Executes the model's forward pass. Returns an embedding array for a given sentence. + * Creates an image embeddings instance with a user-provided model binary. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ImageEmbeddingsModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return ImageEmbeddingsModule.fromModelName( + { modelName: 'custom' as ImageEmbeddingsModelName, modelSource }, + onDownloadProgress + ); + } + + /** + * Executes the model's forward pass to generate an embedding for the provided image. * - * @param imageSource - The image source (URI/URL) to image that will be embedded. - * @returns A Float32Array containing the image embeddings. + * @param imageSource - A string image source (file path, URI, or Base64). + * @returns A Promise resolving to a `Float32Array` containing the image embedding vector. */ async forward(imageSource: string): Promise { if (this.nativeModule == null) diff --git a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts index 41a931a390..dbdce026ef 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/OCRModule.ts @@ -1,6 +1,6 @@ import { OCRController } from '../../controllers/OCRController'; import { ResourceSource } from '../../types/common'; -import { OCRDetection, OCRLanguage } from '../../types/ocr'; +import { OCRDetection, OCRLanguage, OCRModelName } from '../../types/ocr'; import { Logger } from '../../common/Logger'; import { parseUnknownError } from '../../errors/errorUtils'; @@ -12,39 +12,78 @@ import { parseUnknownError } from '../../errors/errorUtils'; export class OCRModule { private controller: OCRController; - constructor() { - this.controller = new OCRController(); + private constructor(controller: OCRController) { + this.controller = controller; } /** - * Loads the model, where `detectorSource` is a string that specifies the location of the detector binary, - * `recognizerSource` is a string that specifies the location of the recognizer binary, - * and `language` is a parameter that specifies the language of the text to be recognized by the OCR. + * Creates an OCR instance for a built-in model. * - * @param model - Object containing `detectorSource`, `recognizerSource`, and `language`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying the model name, detector source, recognizer source, and language. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `OCRModule` instance. + * + * @example + * ```ts + * import { OCRModule, OCR_ENGLISH } from 'react-native-executorch'; + * const ocr = await OCRModule.fromModelName(OCR_ENGLISH); + * ``` */ - async load( - model: { + static async fromModelName( + namedSources: { + modelName: OCRModelName; detectorSource: ResourceSource; recognizerSource: ResourceSource; language: OCRLanguage; }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ) { + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { - await this.controller.load( - model.detectorSource, - model.recognizerSource, - model.language, - onDownloadProgressCallback + const controller = new OCRController(); + await controller.load( + namedSources.detectorSource, + namedSources.recognizerSource, + namedSources.language, + onDownloadProgress ); + return new OCRModule(controller); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); } } + /** + * Creates an OCR instance with a user-provided model binary. + * Use this when working with a custom-exported OCR model. + * Internally uses `'custom'` as the model name for telemetry. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param detectorSource - A fetchable resource pointing to the text detector model binary. + * @param recognizerSource - A fetchable resource pointing to the text recognizer model binary. + * @param language - The language for the OCR model. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `OCRModule` instance. + */ + static fromCustomModel( + detectorSource: ResourceSource, + recognizerSource: ResourceSource, + language: OCRLanguage, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return OCRModule.fromModelName( + { + modelName: `ocr-${language}` as OCRModelName, + detectorSource, + recognizerSource, + language, + }, + onDownloadProgress + ); + } + /** * Executes the model's forward pass, where `imageSource` can be a fetchable resource or a Base64-encoded string. * diff --git a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts index aa640a0476..c24bbd1369 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts @@ -66,17 +66,17 @@ export class ObjectDetectionModule< /** * Creates an object detection instance for a built-in model. * - * @param config - A {@link ObjectDetectionModelSources} object specifying which model to load and where to fetch it from. + * @param namedSources - A {@link ObjectDetectionModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to an `ObjectDetectionModule` instance typed to the chosen model's label map. */ static async fromModelName( - config: C, + namedSources: C, onDownloadProgress: (progress: number) => void = () => {} ): Promise>> { - const { modelSource } = config; + const { modelSource } = namedSources; const { labelMap, preprocessorConfig } = ModelConfigs[ - config.modelName + namedSources.modelName ] as ObjectDetectionConfig; const normMean = preprocessorConfig?.normMean ?? []; const normStd = preprocessorConfig?.normStd ?? []; @@ -100,14 +100,6 @@ export class ObjectDetectionModule< ); } - /** - * Creates an object detection instance with a user-provided label map and custom config. - * - * @param modelSource - A fetchable resource pointing to the model binary. - * @param config - A {@link ObjectDetectionConfig} object with the label map. - * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. - * @returns A Promise resolving to an `ObjectDetectionModule` instance typed to the provided label map. - */ /** * Executes the model's forward pass to detect objects within the provided image. * @@ -122,7 +114,36 @@ export class ObjectDetectionModule< return super.forward(input, detectionThreshold); } - static async fromCustomConfig( + /** + * Creates an object detection instance with a user-provided model binary and label map. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * Internally uses `'custom'` as the model name for telemetry unless overridden. + * + * ## Required model contract + * + * The `.pte` model binary must expose a single `forward` method with the following interface: + * + * **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in + * `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. + * H and W are read from the model's declared input shape at load time. + * + * **Outputs:** exactly three `float32` tensors, in this order: + * 1. Bounding boxes — flat `[4·N]` array of `(x1, y1, x2, y2)` coordinates in model-input + * pixel space, repeated for N detections. + * 2. Confidence scores — flat `[N]` array of values in `[0, 1]`. + * 3. Class indices — flat `[N]` array of `float32`-encoded integer class indices + * (0-based, matching the order of entries in your `labelMap`). + * + * Preprocessing (resize → normalize) and postprocessing (coordinate rescaling, threshold + * filtering, NMS) are handled by the native runtime — your model only needs to produce + * the raw detections above. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param config - A {@link ObjectDetectionConfig} object with the label map and optional preprocessing parameters. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to an `ObjectDetectionModule` instance typed to the provided label map. + */ + static async fromCustomModel( modelSource: ResourceSource, config: ObjectDetectionConfig, onDownloadProgress: (progress: number) => void = () => {} diff --git a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts index 408212eb48..d24e988930 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts @@ -88,24 +88,21 @@ export class SemanticSegmentationModule< * Creates a segmentation instance for a built-in model. * The config object is discriminated by `modelName` — each model can require different fields. * - * @param config - A {@link SemanticSegmentationModelSources} object specifying which model to load and where to fetch it from. + * @param namedSources - A {@link SemanticSegmentationModelSources} object specifying which model to load and where to fetch it from. * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. * @returns A Promise resolving to a `SemanticSegmentationModule` instance typed to the chosen model's label map. * * @example * ```ts - * const segmentation = await SemanticSegmentationModule.fromModelName({ - * modelName: 'deeplab-v3', - * modelSource: 'https://example.com/deeplab.pte', - * }); + * const segmentation = await SemanticSegmentationModule.fromModelName(DEEPLAB_V3_RESNET50); * ``` */ static async fromModelName( - config: C, + namedSources: C, onDownloadProgress: (progress: number) => void = () => {} ): Promise>> { - const { modelName, modelSource } = config; + const { modelName, modelSource } = namedSources; const { labelMap } = ModelConfigs[modelName]; const { preprocessorConfig } = ModelConfigs[ modelName @@ -127,8 +124,25 @@ export class SemanticSegmentationModule< } /** - * Creates a segmentation instance with a user-provided label map and custom config. + * Creates a segmentation instance with a user-provided model binary and label map. * Use this when working with a custom-exported segmentation model that is not one of the built-in models. + * Internally uses `'custom'` as the model name for telemetry unless overridden. + * + * ## Required model contract + * + * The `.pte` model binary must expose a single `forward` method with the following interface: + * + * **Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in + * `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. + * H and W are read from the model's declared input shape at load time. + * + * **Output:** one `float32` tensor of shape `[1, C, H_out, W_out]` (NCHW) containing raw + * logits — one channel per class, in the same order as the entries in your `labelMap`. + * For binary segmentation a single-channel output is also supported: channel 0 is treated + * as the foreground probability and a synthetic background channel is added automatically. + * + * Preprocessing (resize → normalize) and postprocessing (softmax, argmax, resize back to + * original dimensions) are handled by the native runtime. * * @param modelSource - A fetchable resource pointing to the model binary. * @param config - A {@link SemanticSegmentationConfig} object with the label map and optional preprocessing parameters. @@ -138,13 +152,13 @@ export class SemanticSegmentationModule< * @example * ```ts * const MyLabels = { BACKGROUND: 0, FOREGROUND: 1 } as const; - * const segmentation = await SemanticSegmentationModule.fromCustomConfig( + * const segmentation = await SemanticSegmentationModule.fromCustomModel( * 'https://example.com/custom_model.pte', * { labelMap: MyLabels }, * ); * ``` */ - static async fromCustomConfig( + static async fromCustomModel( modelSource: ResourceSource, config: SemanticSegmentationConfig, onDownloadProgress: (progress: number) => void = () => {} diff --git a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts index beed7f3ab3..6027d1fd2d 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts @@ -1,5 +1,6 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; +import { StyleTransferModelName } from '../../types/styleTransfer'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; import { BaseModule } from '../BaseModule'; @@ -11,21 +12,29 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class StyleTransferModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } + /** - * Loads the model, where `modelSource` is a string that specifies the location of the model binary. - * To track the download progress, supply a callback function `onDownloadProgressCallback`. + * Creates a style transfer instance for a built-in model. * - * @param model - Object containing `modelSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying which built-in model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `StyleTransferModule` instance. */ - async load( - model: { modelSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + static async fromModelName( + namedSources: { + modelName: StyleTransferModelName; + modelSource: ResourceSource; + }, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource + onDownloadProgress, + namedSources.modelSource ); if (!paths?.[0]) { @@ -35,7 +44,7 @@ export class StyleTransferModule extends BaseModule { ); } - this.nativeModule = await global.loadStyleTransfer(paths[0]); + return new StyleTransferModule(await global.loadStyleTransfer(paths[0])); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); @@ -43,10 +52,31 @@ export class StyleTransferModule extends BaseModule { } /** - * Executes the model's forward pass, where `imageSource` can be a fetchable resource or a Base64-encoded string. + * Creates a style transfer instance with a user-provided model binary. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `StyleTransferModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return StyleTransferModule.fromModelName( + { modelName: 'custom' as StyleTransferModelName, modelSource }, + onDownloadProgress + ); + } + + /** + * Executes the model's forward pass to apply the selected style to the provided image. * - * @param imageSource - The image source to be processed. - * @returns The stylized image as a Base64-encoded string. + * @param imageSource - A string image source (file path, URI, or Base64). + * @returns A Promise resolving to the stylized image as a Base64-encoded string. */ async forward(imageSource: string): Promise { if (this.nativeModule == null) diff --git a/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts b/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts index 6b05e2b79b..7f290b0a89 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/TextToImageModule.ts @@ -1,5 +1,6 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; +import { TextToImageModelName } from '../../types/tti'; import { BaseModule } from '../BaseModule'; import { PNG } from 'pngjs/browser'; @@ -15,77 +16,50 @@ import { Logger } from '../../common/Logger'; export class TextToImageModule extends BaseModule { private inferenceCallback: (stepIdx: number) => void; - /** - * Creates a new instance of `TextToImageModule` with optional callback on inference step. - * - * @param inferenceCallback - Optional callback function that receives the current step index during inference. - */ - constructor(inferenceCallback?: (stepIdx: number) => void) { + private constructor( + nativeModule: unknown, + inferenceCallback?: (stepIdx: number) => void + ) { super(); + this.nativeModule = nativeModule; this.inferenceCallback = (stepIdx: number) => { inferenceCallback?.(stepIdx); }; } /** - * Loads the model from specified resources. + * Creates a Text to Image instance for a built-in model. * - * @param model - Object containing sources for tokenizer, scheduler, encoder, unet, and decoder. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying the model name, pipeline sources, and optional inference callback. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `TextToImageModule` instance. + * + * @example + * ```ts + * import { TextToImageModule, BK_SDM_TINY_VPRED_512 } from 'react-native-executorch'; + * const tti = await TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512); + * ``` */ - async load( - model: { + static async fromModelName( + namedSources: { + modelName: TextToImageModelName; tokenizerSource: ResourceSource; schedulerSource: ResourceSource; encoderSource: ResourceSource; unetSource: ResourceSource; decoderSource: ResourceSource; + inferenceCallback?: (stepIdx: number) => void; }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { - const results = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.tokenizerSource, - model.schedulerSource, - model.encoderSource, - model.unetSource, - model.decoderSource + const nativeModule = await TextToImageModule.load( + namedSources, + onDownloadProgress ); - if (!results) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' - ); - } - const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] = - results; - - if ( - !tokenizerPath || - !schedulerPath || - !encoderPath || - !unetPath || - !decoderPath - ) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' - ); - } - - const response = await fetch('file://' + schedulerPath); - const schedulerConfig = await response.json(); - - this.nativeModule = await global.loadTextToImage( - tokenizerPath, - encoderPath, - unetPath, - decoderPath, - schedulerConfig.beta_start, - schedulerConfig.beta_end, - schedulerConfig.num_train_timesteps, - schedulerConfig.steps_offset + return new TextToImageModule( + nativeModule, + namedSources.inferenceCallback ); } catch (error) { Logger.error('Load failed:', error); @@ -93,6 +67,95 @@ export class TextToImageModule extends BaseModule { } } + /** + * Creates a Text to Image instance with user-provided model binaries. + * Use this when working with a custom-exported diffusion pipeline. + * Internally uses `'custom'` as the model name for telemetry. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param sources - An object containing the pipeline source paths. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @param inferenceCallback - Optional callback triggered after each diffusion step. + * @returns A Promise resolving to a `TextToImageModule` instance. + */ + static fromCustomModel( + sources: { + tokenizerSource: ResourceSource; + schedulerSource: ResourceSource; + encoderSource: ResourceSource; + unetSource: ResourceSource; + decoderSource: ResourceSource; + }, + onDownloadProgress: (progress: number) => void = () => {}, + inferenceCallback?: (stepIdx: number) => void + ): Promise { + return TextToImageModule.fromModelName( + { + modelName: 'custom' as TextToImageModelName, + ...sources, + inferenceCallback, + }, + onDownloadProgress + ); + } + + private static async load( + model: { + tokenizerSource: ResourceSource; + schedulerSource: ResourceSource; + encoderSource: ResourceSource; + unetSource: ResourceSource; + decoderSource: ResourceSource; + }, + onDownloadProgressCallback: (progress: number) => void + ): Promise { + const results = await ResourceFetcher.fetch( + onDownloadProgressCallback, + model.tokenizerSource, + model.schedulerSource, + model.encoderSource, + model.unetSource, + model.decoderSource + ); + if (!results || results.length !== 5) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' + ); + } + const [tokenizerPath, schedulerPath, encoderPath, unetPath, decoderPath] = + results; + + if ( + !tokenizerPath || + !schedulerPath || + !encoderPath || + !unetPath || + !decoderPath + ) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' + ); + } + + const response = await fetch('file://' + schedulerPath); + const schedulerConfig = await response.json(); + + return global.loadTextToImage( + tokenizerPath, + encoderPath, + unetPath, + decoderPath, + schedulerConfig.beta_start, + schedulerConfig.beta_end, + schedulerConfig.num_train_timesteps, + schedulerConfig.steps_offset + ); + } + /** * Runs the model to generate an image described by `input`, and conditioned by `seed`, performing `numSteps` inference steps. * The resulting image, with dimensions `imageSize`×`imageSize` pixels, is returned as a base64-encoded string. diff --git a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts index 824a15021b..069c705ebc 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VerticalOCRModule.ts @@ -2,7 +2,7 @@ import { Logger } from '../../common/Logger'; import { VerticalOCRController } from '../../controllers/VerticalOCRController'; import { parseUnknownError } from '../../errors/errorUtils'; import { ResourceSource } from '../../types/common'; -import { OCRDetection, OCRLanguage } from '../../types/ocr'; +import { OCRDetection, OCRLanguage, OCRModelName } from '../../types/ocr'; /** * Module for Vertical Optical Character Recognition (Vertical OCR) tasks. @@ -12,42 +12,83 @@ import { OCRDetection, OCRLanguage } from '../../types/ocr'; export class VerticalOCRModule { private controller: VerticalOCRController; - constructor() { - this.controller = new VerticalOCRController(); + private constructor(controller: VerticalOCRController) { + this.controller = controller; } /** - * Loads the model, where `detectorSource` is a string that specifies the location of the detector binary, - * `recognizerSource` is a string that specifies the location of the recognizer binary, - * and `language` is a parameter that specifies the language of the text to be recognized by the OCR. + * Creates a Vertical OCR instance for a built-in model. * - * @param model - Object containing `detectorSource`, `recognizerSource`, and `language`. - * @param independentCharacters - Whether to treat characters independently during recognition. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying the model name, detector source, recognizer source, language, and optional independent characters flag. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `VerticalOCRModule` instance. + * + * @example + * ```ts + * import { VerticalOCRModule, OCR_JAPANESE } from 'react-native-executorch'; + * const ocr = await VerticalOCRModule.fromModelName({ ...OCR_JAPANESE, independentCharacters: true }); + * ``` */ - async load( - model: { + static async fromModelName( + namedSources: { + modelName: OCRModelName; detectorSource: ResourceSource; recognizerSource: ResourceSource; language: OCRLanguage; + independentCharacters?: boolean; }, - independentCharacters: boolean, - onDownloadProgressCallback: (progress: number) => void = () => {} - ) { + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { - await this.controller.load( - model.detectorSource, - model.recognizerSource, - model.language, - independentCharacters, - onDownloadProgressCallback + const controller = new VerticalOCRController(); + await controller.load( + namedSources.detectorSource, + namedSources.recognizerSource, + namedSources.language, + namedSources.independentCharacters ?? false, + onDownloadProgress ); + return new VerticalOCRModule(controller); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); } } + /** + * Creates a Vertical OCR instance with a user-provided model binary. + * Use this when working with a custom-exported Vertical OCR model. + * Internally uses `'custom'` as the model name for telemetry. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param detectorSource - A fetchable resource pointing to the text detector model binary. + * @param recognizerSource - A fetchable resource pointing to the text recognizer model binary. + * @param language - The language for the OCR model. + * @param independentCharacters - Whether to treat characters independently during recognition. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `VerticalOCRModule` instance. + */ + static fromCustomModel( + detectorSource: ResourceSource, + recognizerSource: ResourceSource, + language: OCRLanguage, + independentCharacters: boolean = false, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return VerticalOCRModule.fromModelName( + { + modelName: `ocr-${language}` as OCRModelName, + detectorSource, + recognizerSource, + language, + independentCharacters, + }, + onDownloadProgress + ); + } + /** * Executes the model's forward pass, where `imageSource` can be a fetchable resource or a Base64-encoded string. * diff --git a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts index 914ea61950..61a0bab091 100644 --- a/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts +++ b/packages/react-native-executorch/src/modules/computer_vision/VisionLabeledModule.ts @@ -20,7 +20,4 @@ export abstract class VisionLabeledModule< this.labelMap = labelMap; this.nativeModule = nativeModule; } - - // TODO: figure it out so we can delete this (we need this because of basemodule inheritance) - override async load() {} } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts index b8acf7ac9e..707b08813e 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/LLMModule.ts @@ -1,6 +1,14 @@ import { LLMController } from '../../controllers/LLMController'; +import { Logger } from '../../common/Logger'; +import { parseUnknownError } from '../../errors/errorUtils'; import { ResourceSource } from '../../types/common'; -import { LLMCapability, LLMConfig, LLMTool, Message } from '../../types/llm'; +import { + LLMCapability, + LLMConfig, + LLMModelName, + LLMTool, + Message, +} from '../../types/llm'; /** * Module for managing a Large Language Model (LLM) instance. @@ -9,26 +17,12 @@ import { LLMCapability, LLMConfig, LLMTool, Message } from '../../types/llm'; */ export class LLMModule { private controller: LLMController; - private pendingConfig?: LLMConfig; - /** - * Creates a new instance of `LLMModule` with optional callbacks. - * @param optionalCallbacks - Object containing optional callbacks. - * - * @returns A new LLMModule instance. - */ - constructor({ + private constructor({ tokenCallback, messageHistoryCallback, }: { - /** - * An optional function that will be called on every generated token (`string`) with that token as its only argument. - */ tokenCallback?: (token: string) => void; - /** - * An optional function called on every finished message (`Message[]`). - * Returns the entire message history. - */ messageHistoryCallback?: (messageHistory: Message[]) => void; } = {}) { this.controller = new LLMController({ @@ -38,34 +32,89 @@ export class LLMModule { } /** - * Loads the LLM model and tokenizer. + * Creates an LLM instance for a built-in model. + * + * @param namedSources - An object specifying the model name, model source, tokenizer source, + * tokenizer config source, and optional capabilities. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @param tokenCallback - Optional callback invoked on every generated token. + * @param messageHistoryCallback - Optional callback invoked when the model finishes a response, with the full message history. + * @returns A Promise resolving to an `LLMModule` instance. * - * @param model - Object containing model, tokenizer, and tokenizer config sources. - * @param model.modelSource - `ResourceSource` that specifies the location of the model binary. - * @param model.tokenizerSource - `ResourceSource` pointing to the JSON file which contains the tokenizer. - * @param model.tokenizerConfigSource - `ResourceSource` pointing to the JSON file which contains the tokenizer config. - * @param onDownloadProgressCallback - Optional callback to track download progress (value between 0 and 1). + * @example + * ```ts + * import { LLMModule, LLAMA3_2_3B } from 'react-native-executorch'; + * const llm = await LLMModule.fromModelName(LLAMA3_2_3B); + * ``` */ - async load( - model: { + static async fromModelName( + namedSources: { + modelName: LLMModelName; modelSource: ResourceSource; tokenizerSource: ResourceSource; tokenizerConfigSource: ResourceSource; capabilities?: readonly LLMCapability[]; }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ) { - await this.controller.load({ - ...model, - onDownloadProgressCallback, - }); - - if (this.pendingConfig) { - this.controller.configure(this.pendingConfig); - this.pendingConfig = undefined; + onDownloadProgress: (progress: number) => void = () => {}, + tokenCallback?: (token: string) => void, + messageHistoryCallback?: (messageHistory: Message[]) => void + ): Promise { + const instance = new LLMModule({ tokenCallback, messageHistoryCallback }); + try { + await instance.controller.load({ + modelSource: namedSources.modelSource, + tokenizerSource: namedSources.tokenizerSource, + tokenizerConfigSource: namedSources.tokenizerConfigSource, + onDownloadProgressCallback: onDownloadProgress, + }); + return instance; + } catch (error) { + Logger.error('Load failed:', error); + throw parseUnknownError(error); } } + /** + * Creates an LLM instance with a user-provided model binary. + * Use this when working with a custom-exported LLM. + * Internally uses `'custom'` as the model name for telemetry. + * + * ## Required model contract + * + * The `.pte` model binary must be exported following the + * [ExecuTorch LLM export process](https://docs.pytorch.org/executorch/1.1/llm/export-llm.html). + * The native runner expects the standard ExecuTorch text-generation interface — KV-cache + * management, prefill/decode phases, and logit sampling are all handled by the runtime. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param tokenizerSource - A fetchable resource pointing to the tokenizer JSON file. + * @param tokenizerConfigSource - A fetchable resource pointing to the tokenizer config JSON file. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @param tokenCallback - Optional callback invoked on every generated token. + * @param messageHistoryCallback - Optional callback invoked when the model finishes a response, with the full message history. + * @returns A Promise resolving to an `LLMModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + tokenizerSource: ResourceSource, + tokenizerConfigSource: ResourceSource, + onDownloadProgress: (progress: number) => void = () => {}, + tokenCallback?: (token: string) => void, + messageHistoryCallback?: (messageHistory: Message[]) => void + ): Promise { + return LLMModule.fromModelName( + { + modelName: 'custom' as LLMModelName, + modelSource, + tokenizerSource, + tokenizerConfigSource, + }, + onDownloadProgress, + tokenCallback, + messageHistoryCallback + ); + } + /** * Sets new token callback invoked on every token batch. * @@ -86,11 +135,7 @@ export class LLMModule { * @param config - Configuration object containing `chatConfig`, `toolsConfig`, and `generationConfig`. */ configure(config: LLMConfig) { - if (this.controller.isReady) { - this.controller.configure(config); - } else { - this.pendingConfig = config; - } + this.controller.configure(config); } /** diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts index 187dff334d..d981e7e1ec 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/SpeechToTextModule.ts @@ -1,11 +1,14 @@ import { DecodingOptions, SpeechToTextModelConfig, + SpeechToTextModelName, TranscriptionResult, } from '../../types/stt'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; +import { ResourceSource } from '../../types/common'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; +import { Logger } from '../../common/Logger'; /** * Module for Speech to Text (STT) functionalities. @@ -14,21 +17,81 @@ import { RnExecutorchError, parseUnknownError } from '../../errors/errorUtils'; */ export class SpeechToTextModule { private nativeModule: any; - private modelConfig!: SpeechToTextModelConfig; + private modelConfig: SpeechToTextModelConfig; + + private constructor( + nativeModule: unknown, + modelConfig: SpeechToTextModelConfig + ) { + this.nativeModule = nativeModule; + this.modelConfig = modelConfig; + } /** - * Loads the model specified by the config object. - * `onDownloadProgressCallback` allows you to monitor the current progress of the model download. + * Creates a Speech to Text instance for a built-in model. * - * @param model - Configuration object containing model sources. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - Configuration object containing model name, sources, and multilingual flag. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `SpeechToTextModule` instance. + * + * @example + * ```ts + * import { SpeechToTextModule, WHISPER_TINY_EN } from 'react-native-executorch'; + * const stt = await SpeechToTextModule.fromModelName(WHISPER_TINY_EN); + * ``` */ - public async load( - model: SpeechToTextModelConfig, - onDownloadProgressCallback: (progress: number) => void = () => {} - ) { - this.modelConfig = model; + static async fromModelName( + namedSources: SpeechToTextModelConfig, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + try { + const nativeModule = await SpeechToTextModule.loadWhisper( + namedSources, + onDownloadProgress + ); + return new SpeechToTextModule(nativeModule, namedSources); + } catch (error) { + Logger.error('Load failed:', error); + throw parseUnknownError(error); + } + } + /** + * Creates a Speech to Text instance with user-provided model binaries. + * Use this when working with a custom-exported STT model. + * Internally uses `'custom'` as the model name for telemetry. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Currently only the Whisper architecture is supported by the native runner. + * Refer to the native source code for the current expected interface. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param tokenizerSource - A fetchable resource pointing to the tokenizer file. + * @param isMultilingual - Whether the model supports multiple languages. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `SpeechToTextModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + tokenizerSource: ResourceSource, + isMultilingual: boolean, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return SpeechToTextModule.fromModelName( + { + modelName: 'custom' as SpeechToTextModelName, + modelSource, + tokenizerSource, + isMultilingual, + }, + onDownloadProgress + ); + } + + private static async loadWhisper( + model: SpeechToTextModelConfig, + onDownloadProgressCallback: (progress: number) => void + ): Promise { const tokenizerLoadPromise = ResourceFetcher.fetch( undefined, model.tokenizerSource @@ -47,8 +110,9 @@ export class SpeechToTextModule { 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } - this.nativeModule = await global.loadSpeechToText( - model.type, + // Currently only Whisper architecture is supported + return await global.loadSpeechToText( + 'whisper', modelSources[0], tokenizerSources[0] ); diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts index 20372ca85b..5720c6399c 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts @@ -1,4 +1,5 @@ import { ResourceSource } from '../../types/common'; +import { TextEmbeddingsModelName } from '../../types/textEmbeddings'; import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { BaseModule } from '../BaseModule'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; @@ -11,30 +12,30 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class TextEmbeddingsModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } + /** - * Loads the model and tokenizer specified by the config object. + * Creates a text embeddings instance for a built-in model. * - * @param model - Object containing model and tokenizer sources. - * @param model.modelSource - `ResourceSource` that specifies the location of the text embeddings model binary. - * @param model.tokenizerSource - `ResourceSource` that specifies the location of the tokenizer JSON file. - * @param onDownloadProgressCallback - Optional callback to track download progress (value between 0 and 1). + * @param namedSources - An object specifying which built-in model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `TextEmbeddingsModule` instance. */ - async load( - model: { modelSource: ResourceSource; tokenizerSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + static async fromModelName( + namedSources: { + modelName: TextEmbeddingsModelName; + modelSource: ResourceSource; + tokenizerSource: ResourceSource; + }, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { - const modelPromise = ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource - ); - const tokenizerPromise = ResourceFetcher.fetch( - undefined, - model.tokenizerSource - ); const [modelResult, tokenizerResult] = await Promise.all([ - modelPromise, - tokenizerPromise, + ResourceFetcher.fetch(onDownloadProgress, namedSources.modelSource), + ResourceFetcher.fetch(undefined, namedSources.tokenizerSource), ]); const modelPath = modelResult?.[0]; const tokenizerPath = tokenizerResult?.[0]; @@ -44,9 +45,8 @@ export class TextEmbeddingsModule extends BaseModule { 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } - this.nativeModule = await global.loadTextEmbeddings( - modelPath, - tokenizerPath + return new TextEmbeddingsModule( + await global.loadTextEmbeddings(modelPath, tokenizerPath) ); } catch (error) { Logger.error('Load failed:', error); @@ -55,12 +55,44 @@ export class TextEmbeddingsModule extends BaseModule { } /** - * Executes the model's forward pass, where `input` is a text that will be embedded. + * Creates a text embeddings instance with a user-provided model binary and tokenizer. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param tokenizerSource - A fetchable resource pointing to the tokenizer file. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `TextEmbeddingsModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + tokenizerSource: ResourceSource, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return TextEmbeddingsModule.fromModelName( + { + modelName: 'custom' as TextEmbeddingsModelName, + modelSource, + tokenizerSource, + }, + onDownloadProgress + ); + } + + /** + * Executes the model's forward pass to generate an embedding for the provided text. * * @param input - The text string to embed. - * @returns A Float32Array containing the vector embeddings. + * @returns A Promise resolving to a `Float32Array` containing the embedding vector. */ async forward(input: string): Promise { + if (this.nativeModule == null) + throw new RnExecutorchError( + RnExecutorchErrorCode.ModuleNotLoaded, + 'The model is currently not loaded. Please load the model before calling forward().' + ); return new Float32Array(await this.nativeModule.generate(input)); } } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts index 126747403d..a12285057b 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextToSpeechModule.ts @@ -16,87 +16,89 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class TextToSpeechModule { - /** - * Native module instance - */ - nativeModule: any = null; + private nativeModule: any; + + private constructor(nativeModule: unknown) { + this.nativeModule = nativeModule; + } /** - * Loads the model and voice assets specified by the config object. - * `onDownloadProgressCallback` allows you to monitor the current progress. + * Creates a Text to Speech instance. + * + * @param config - Configuration object containing `model` and `voice`. + * Pass one of the built-in constants (e.g. `{ model: KOKORO_MEDIUM, voice: KOKORO_VOICE_AF_HEART }`), or use require() to pass them. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `TextToSpeechModule` instance. * - * @param config - Configuration object containing `model` source and `voice`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @example + * ```ts + * import { TextToSpeechModule, KOKORO_MEDIUM, KOKORO_VOICE_AF_HEART } from 'react-native-executorch'; + * const tts = await TextToSpeechModule.fromModelName( + * { model: KOKORO_MEDIUM, voice: KOKORO_VOICE_AF_HEART }, + * ); + * ``` */ - public async load( + static async fromModelName( config: TextToSpeechConfig, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { - // Select the text to speech model based on it's fixed identifier - if (config.model.type === 'kokoro') { - await this.loadKokoro( + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + try { + const nativeModule = await TextToSpeechModule.loadKokoro( config.model, config.voice, - onDownloadProgressCallback + onDownloadProgress ); + return new TextToSpeechModule(nativeModule); + } catch (error) { + Logger.error('Load failed:', error); + throw parseUnknownError(error); } - // ... more models? ... } - // Specialized loader - Kokoro model - private async loadKokoro( + private static async loadKokoro( model: KokoroConfig, voice: VoiceConfig, onDownloadProgressCallback: (progress: number) => void - ): Promise { - try { - if ( - !voice.extra || - !voice.extra.taggerSource || - !voice.extra.lexiconSource - ) { - throw new RnExecutorchError( - RnExecutorchErrorCode.InvalidConfig, - 'Kokoro: voice config is missing required extra fields: taggerSource and/or lexiconSource.' - ); - } - - const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.durationPredictorSource, - model.synthesizerSource, - voice.voiceSource, - voice.extra.taggerSource, - voice.extra.lexiconSource + ): Promise { + if ( + !voice.extra || + !voice.extra.taggerSource || + !voice.extra.lexiconSource + ) { + throw new RnExecutorchError( + RnExecutorchErrorCode.InvalidConfig, + 'Kokoro: voice config is missing required extra fields: taggerSource and/or lexiconSource.' ); + } - if ( - paths === null || - paths.length !== 5 || - paths.some((p) => p == null) - ) { - throw new RnExecutorchError( - RnExecutorchErrorCode.DownloadInterrupted, - 'Download interrupted or missing resource.' - ); - } + const paths = await ResourceFetcher.fetch( + onDownloadProgressCallback, + model.durationPredictorSource, + model.synthesizerSource, + voice.voiceSource, + voice.extra.taggerSource, + voice.extra.lexiconSource + ); - const modelPaths = paths.slice(0, 2) as [string, string, string, string]; - const voiceDataPath = paths[2] as string; - const phonemizerPaths = paths.slice(3, 5) as [string, string]; - - this.nativeModule = await global.loadTextToSpeechKokoro( - voice.lang, - phonemizerPaths[0], - phonemizerPaths[1], - modelPaths[0], - modelPaths[1], - voiceDataPath + if (paths === null || paths.length !== 5) { + throw new RnExecutorchError( + RnExecutorchErrorCode.DownloadInterrupted, + 'Download interrupted or missing resource.' ); - } catch (error) { - Logger.error('Load failed:', error); - throw parseUnknownError(error); } + + const modelPaths = paths.slice(0, 2) as [string, string]; + const voiceDataPath = paths[2] as string; + const phonemizerPaths = paths.slice(3, 5) as [string, string]; + + return await global.loadTextToSpeechKokoro( + voice.lang, + phonemizerPaths[0], + phonemizerPaths[1], + modelPaths[0], + modelPaths[1], + voiceDataPath + ); } private ensureLoaded(methodName: string): void { diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts index 02e5c7b3ae..c59ec5d0c9 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/VADModule.ts @@ -1,6 +1,6 @@ import { ResourceFetcher } from '../../utils/ResourceFetcher'; import { ResourceSource } from '../../types/common'; -import { Segment } from '../../types/vad'; +import { Segment, VADModelName } from '../../types/vad'; import { BaseModule } from '../BaseModule'; import { RnExecutorchErrorCode } from '../../errors/ErrorCodes'; import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils'; @@ -12,21 +12,26 @@ import { Logger } from '../../common/Logger'; * @category Typescript API */ export class VADModule extends BaseModule { + private constructor(nativeModule: unknown) { + super(); + this.nativeModule = nativeModule; + } + /** - * Loads the model, where `modelSource` is a string that specifies the location of the model binary. - * To track the download progress, supply a callback function `onDownloadProgressCallback`. + * Creates a VAD instance for a built-in model. * - * @param model - Object containing `modelSource`. - * @param onDownloadProgressCallback - Optional callback to monitor download progress. + * @param namedSources - An object specifying which built-in model to load and where to fetch it from. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `VADModule` instance. */ - async load( - model: { modelSource: ResourceSource }, - onDownloadProgressCallback: (progress: number) => void = () => {} - ): Promise { + static async fromModelName( + namedSources: { modelName: VADModelName; modelSource: ResourceSource }, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { try { const paths = await ResourceFetcher.fetch( - onDownloadProgressCallback, - model.modelSource + onDownloadProgress, + namedSources.modelSource ); if (!paths?.[0]) { throw new RnExecutorchError( @@ -34,7 +39,7 @@ export class VADModule extends BaseModule { 'The download has been interrupted. As a result, not every file was downloaded. Please retry the download.' ); } - this.nativeModule = await global.loadVAD(paths[0]); + return new VADModule(await global.loadVAD(paths[0])); } catch (error) { Logger.error('Load failed:', error); throw parseUnknownError(error); @@ -42,10 +47,31 @@ export class VADModule extends BaseModule { } /** - * Executes the model's forward pass, where `waveform` is a Float32Array representing the audio signal (16kHz). + * Creates a VAD instance with a user-provided model binary. + * Use this when working with a custom-exported model that is not one of the built-in presets. + * + * @remarks The native model contract for this method is not formally defined and may change + * between releases. Refer to the native source code for the current expected tensor interface. + * + * @param modelSource - A fetchable resource pointing to the model binary. + * @param onDownloadProgress - Optional callback to monitor download progress, receiving a value between 0 and 1. + * @returns A Promise resolving to a `VADModule` instance. + */ + static fromCustomModel( + modelSource: ResourceSource, + onDownloadProgress: (progress: number) => void = () => {} + ): Promise { + return VADModule.fromModelName( + { modelName: 'custom' as VADModelName, modelSource }, + onDownloadProgress + ); + } + + /** + * Executes the model's forward pass to detect speech segments within the provided audio. * - * @param waveform - The input audio waveform as a Float32Array. It must represent a mono audio signal sampled at 16kHz. - * @returns A promise resolving to an array of detected speech segments. + * @param waveform - A `Float32Array` representing a mono audio signal sampled at 16kHz. + * @returns A Promise resolving to an array of {@link Segment} objects. */ async forward(waveform: Float32Array): Promise { if (this.nativeModule == null) diff --git a/packages/react-native-executorch/src/types/classification.ts b/packages/react-native-executorch/src/types/classification.ts index 51152ec080..144f2af5ae 100644 --- a/packages/react-native-executorch/src/types/classification.ts +++ b/packages/react-native-executorch/src/types/classification.ts @@ -1,16 +1,26 @@ import { RnExecutorchError } from '../errors/errorUtils'; import { ResourceSource } from './common'; +/** + * Union of all built-in classification model names. + * + * @category Types + */ +export type ClassificationModelName = + | 'efficientnet-v2-s' + | 'efficientnet-v2-s-quantized'; + /** * Props for the `useClassification` hook. * * @category Types - * @property {Object} model - An object containing the model source. + * @property {Object} model - An object containing the model configuration. + * @property {ClassificationModelName} model.modelName - Unique name identifying the model. * @property {ResourceSource} model.modelSource - The source of the classification model binary. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ export interface ClassificationProps { - model: { modelSource: ResourceSource }; + model: { modelName: ClassificationModelName; modelSource: ResourceSource }; preventLoad?: boolean; } diff --git a/packages/react-native-executorch/src/types/imageEmbeddings.ts b/packages/react-native-executorch/src/types/imageEmbeddings.ts index 5dc23d66f2..88308ddd6f 100644 --- a/packages/react-native-executorch/src/types/imageEmbeddings.ts +++ b/packages/react-native-executorch/src/types/imageEmbeddings.ts @@ -1,16 +1,26 @@ import { RnExecutorchError } from '../errors/errorUtils'; import { ResourceSource } from './common'; +/** + * Union of all built-in image embeddings model names. + * + * @category Types + */ +export type ImageEmbeddingsModelName = + | 'clip-vit-base-patch32-image' + | 'clip-vit-base-patch32-image-quantized'; + /** * Props for the `useImageEmbeddings` hook. * * @category Types - * @property {Object} model - An object containing the model source. + * @property {Object} model - An object containing the model configuration. + * @property {ImageEmbeddingsModelName} model.modelName - Unique name identifying the model. * @property {ResourceSource} model.modelSource - The source of the image embeddings model binary. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ export interface ImageEmbeddingsProps { - model: { modelSource: ResourceSource }; + model: { modelName: ImageEmbeddingsModelName; modelSource: ResourceSource }; preventLoad?: boolean; } diff --git a/packages/react-native-executorch/src/types/llm.ts b/packages/react-native-executorch/src/types/llm.ts index 94020883f1..a1e610256a 100644 --- a/packages/react-native-executorch/src/types/llm.ts +++ b/packages/react-native-executorch/src/types/llm.ts @@ -14,6 +14,48 @@ export type LLMCapability = 'vision'; export type MediaArg = 'vision' extends C[number] ? { imagePath?: string } : object; +/** + * Union of all built-in LLM model names. + * + * @category Types + */ +export type LLMModelName = + | 'llama-3.2-3b' + | 'llama-3.2-3b-qlora' + | 'llama-3.2-3b-spinquant' + | 'llama-3.2-1b' + | 'llama-3.2-1b-qlora' + | 'llama-3.2-1b-spinquant' + | 'qwen3-0.6b' + | 'qwen3-0.6b-quantized' + | 'qwen3-1.7b' + | 'qwen3-1.7b-quantized' + | 'qwen3-4b' + | 'qwen3-4b-quantized' + | 'hammer2.1-0.5b' + | 'hammer2.1-0.5b-quantized' + | 'hammer2.1-1.5b' + | 'hammer2.1-1.5b-quantized' + | 'hammer2.1-3b' + | 'hammer2.1-3b-quantized' + | 'smollm2.1-135m' + | 'smollm2.1-135m-quantized' + | 'smollm2.1-360m' + | 'smollm2.1-360m-quantized' + | 'smollm2.1-1.7b' + | 'smollm2.1-1.7b-quantized' + | 'qwen2.5-0.5b' + | 'qwen2.5-0.5b-quantized' + | 'qwen2.5-1.5b' + | 'qwen2.5-1.5b-quantized' + | 'qwen2.5-3b' + | 'qwen2.5-3b-quantized' + | 'phi-4-mini-4b' + | 'phi-4-mini-4b-quantized' + | 'lfm2.5-1.2b-instruct' + | 'lfm2.5-1.2b-instruct-quantized' + | 'lfm2.5-vl-1.6b-quantized'; + /** * Properties for initializing and configuring a Large Language Model (LLM) instance. * @@ -21,6 +63,11 @@ export type MediaArg = */ export interface LLMProps { model: { + /** + * The built-in model name (e.g. `'llama-3.2-3b'`). Used for telemetry and hook reload triggers. + * Pass one of the pre-built LLM constants (e.g. `LLAMA3_2_3B`) to populate all required fields. + */ + modelName: LLMModelName; /** * `ResourceSource` that specifies the location of the model binary. */ diff --git a/packages/react-native-executorch/src/types/ocr.ts b/packages/react-native-executorch/src/types/ocr.ts index 6ca2f43249..e31d618478 100644 --- a/packages/react-native-executorch/src/types/ocr.ts +++ b/packages/react-native-executorch/src/types/ocr.ts @@ -39,6 +39,12 @@ export interface OCRProps { * Object containing the necessary model sources and configuration for the OCR pipeline. */ model: { + /** + * The built-in model name, e.g. `'ocr-en'`. Used for telemetry and hook reload triggers. + * Pass one of the pre-built OCR constants (e.g. `OCR_ENGLISH`) to populate all required fields. + */ + modelName: OCRModelName; + /** * `ResourceSource` that specifies the location of the text detector model binary. */ @@ -117,3 +123,11 @@ export interface OCRType { * @category Types */ export type OCRLanguage = keyof typeof symbols; + +/** + * Union of all built-in OCR model names. + * Each name is derived from the language code, e.g. `'ocr-en'`, `'ocr-ja'`. + * + * @category Types + */ +export type OCRModelName = `ocr-${OCRLanguage}`; diff --git a/packages/react-native-executorch/src/types/stt.ts b/packages/react-native-executorch/src/types/stt.ts index df0ab063f2..e3217c8da6 100644 --- a/packages/react-native-executorch/src/types/stt.ts +++ b/packages/react-native-executorch/src/types/stt.ts @@ -1,6 +1,22 @@ import { ResourceSource } from './common'; import { RnExecutorchError } from '../errors/errorUtils'; +/** + * Named Speech to Text model variants. + * + * @category Types + */ +export type SpeechToTextModelName = + | 'whisper-tiny-en' + | 'whisper-tiny-en-quantized' + | 'whisper-base-en' + | 'whisper-base-en-quantized' + | 'whisper-small-en' + | 'whisper-small-en-quantized' + | 'whisper-tiny' + | 'whisper-base' + | 'whisper-small'; + /** * Configuration for Speech to Text model. * @@ -261,7 +277,11 @@ export interface TranscriptionResult { * @category Types */ export interface SpeechToTextModelConfig { - type: 'whisper'; // | ... (add more in the future) + /** + * The built-in model name (e.g. `'whisper-tiny-en'`). Used for telemetry and hook reload triggers. + * Pass one of the pre-built STT constants (e.g. `WHISPER_TINY_EN`) to populate all required fields. + */ + modelName: SpeechToTextModelName; /** * A boolean flag indicating whether the model supports multiple languages. diff --git a/packages/react-native-executorch/src/types/styleTransfer.ts b/packages/react-native-executorch/src/types/styleTransfer.ts index 1620867226..2571203ee1 100644 --- a/packages/react-native-executorch/src/types/styleTransfer.ts +++ b/packages/react-native-executorch/src/types/styleTransfer.ts @@ -1,16 +1,32 @@ import { RnExecutorchError } from '../errors/errorUtils'; import { ResourceSource } from './common'; +/** + * Union of all built-in style transfer model names. + * + * @category Types + */ +export type StyleTransferModelName = + | 'style-transfer-candy' + | 'style-transfer-candy-quantized' + | 'style-transfer-mosaic' + | 'style-transfer-mosaic-quantized' + | 'style-transfer-rain-princess' + | 'style-transfer-rain-princess-quantized' + | 'style-transfer-udnie' + | 'style-transfer-udnie-quantized'; + /** * Configuration properties for the `useStyleTransfer` hook. * * @category Types - * @property {Object} model - Object containing the `modelSource` for the style transfer model. + * @property {Object} model - Object containing the model configuration. + * @property {StyleTransferModelName} model.modelName - Unique name identifying the model. * @property {ResourceSource} model.modelSource - `ResourceSource` that specifies the location of the style transfer model binary. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if loaded for the first time) after running the hook. */ export interface StyleTransferProps { - model: { modelSource: ResourceSource }; + model: { modelName: StyleTransferModelName; modelSource: ResourceSource }; preventLoad?: boolean; } diff --git a/packages/react-native-executorch/src/types/textEmbeddings.ts b/packages/react-native-executorch/src/types/textEmbeddings.ts index 43bf6606d3..2614bbbc40 100644 --- a/packages/react-native-executorch/src/types/textEmbeddings.ts +++ b/packages/react-native-executorch/src/types/textEmbeddings.ts @@ -1,15 +1,34 @@ import { RnExecutorchError } from '../errors/errorUtils'; import { ResourceSource } from '../types/common'; +/** + * Union of all built-in text embeddings model names. + * + * @category Types + */ +export type TextEmbeddingsModelName = + | 'all-minilm-l6-v2' + | 'all-mpnet-base-v2' + | 'multi-qa-minilm-l6-cos-v1' + | 'multi-qa-mpnet-base-dot-v1' + | 'clip-vit-base-patch32-text'; + /** * Props for the useTextEmbeddings hook. * * @category Types - * @property {Object} model - An object containing the model and tokenizer sources. + * @property {Object} model - An object containing the model configuration. + * @property {TextEmbeddingsModelName} model.modelName - Unique name identifying the model. + * @property {ResourceSource} model.modelSource - The source of the text embeddings model binary. + * @property {ResourceSource} model.tokenizerSource - The source of the tokenizer JSON file. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ export interface TextEmbeddingsProps { model: { + /** + * The unique name of the text embeddings model. + */ + modelName: TextEmbeddingsModelName; /** * The source of the text embeddings model binary. */ diff --git a/packages/react-native-executorch/src/types/tti.ts b/packages/react-native-executorch/src/types/tti.ts index 7cdce9fd86..543c99ce19 100644 --- a/packages/react-native-executorch/src/types/tti.ts +++ b/packages/react-native-executorch/src/types/tti.ts @@ -1,6 +1,15 @@ import { RnExecutorchError } from '../errors/errorUtils'; import { ResourceSource } from '../types/common'; +/** + * Union of all built-in Text-to-Image model names. + * + * @category Types + */ +export type TextToImageModelName = + | 'bk-sdm-tiny-vpred-512' + | 'bk-sdm-tiny-vpred-256'; + /** * Configuration properties for the `useTextToImage` hook. * @@ -11,6 +20,11 @@ export interface TextToImageProps { * Object containing the required model sources for the diffusion pipeline. */ model: { + /** + * The built-in model name (e.g. `'bk-sdm-tiny-vpred-512'`). Used for telemetry and hook reload triggers. + * Pass one of the pre-built TTI constants (e.g. `BK_SDM_TINY_VPRED_512`) to populate all required fields. + */ + modelName: TextToImageModelName; /** Source for the text tokenizer binary/config. */ tokenizerSource: ResourceSource; /** Source for the diffusion scheduler binary/config. */ diff --git a/packages/react-native-executorch/src/types/tts.ts b/packages/react-native-executorch/src/types/tts.ts index ebc4b065a5..097f35976a 100644 --- a/packages/react-native-executorch/src/types/tts.ts +++ b/packages/react-native-executorch/src/types/tts.ts @@ -1,6 +1,13 @@ import { ResourceSource } from './common'; import { RnExecutorchError } from '../errors/errorUtils'; +/** + * Union of all built-in Text to Speech model names. + * + * @category Types + */ +export type TextToSpeechModelName = 'kokoro-small' | 'kokoro-medium'; + /** * List all the languages available in TTS models (as lang shorthands) * @@ -43,12 +50,12 @@ export interface KokoroVoiceExtras { * Only the core Kokoro model sources, as phonemizer sources are included in voice configuration. * * @category Types - * @property {'kokoro'} type - model type identifier + * @property {TextToSpeechModelName} modelName - model name identifier * @property {ResourceSource} durationPredictorSource - source to Kokoro's duration predictor model binary * @property {ResourceSource} synthesizerSource - source to Kokoro's synthesizer model binary */ export interface KokoroConfig { - type: 'kokoro'; + modelName: TextToSpeechModelName; durationPredictorSource: ResourceSource; synthesizerSource: ResourceSource; } diff --git a/packages/react-native-executorch/src/types/vad.ts b/packages/react-native-executorch/src/types/vad.ts index cf379124d6..57bd112628 100644 --- a/packages/react-native-executorch/src/types/vad.ts +++ b/packages/react-native-executorch/src/types/vad.ts @@ -1,16 +1,24 @@ import { ResourceSource } from '../types/common'; import { RnExecutorchError } from '../errors/errorUtils'; +/** + * Union of all built-in VAD model names. + * + * @category Types + */ +export type VADModelName = 'fsmn-vad'; + /** * Props for the useVAD hook. * * @category Types - * @property {Object} model - An object containing the model source. + * @property {Object} model - An object containing the model configuration. + * @property {VADModelName} model.modelName - Unique name identifying the model. * @property {ResourceSource} model.modelSource - The source of the VAD model binary. * @property {boolean} [preventLoad] - Boolean that can prevent automatic model loading (and downloading the data if you load it for the first time) after running the hook. */ export interface VADProps { - model: { modelSource: ResourceSource }; + model: { modelName: VADModelName; modelSource: ResourceSource }; preventLoad?: boolean; }