diff --git a/docs/docs/05-utilities/model-registry.md b/docs/docs/05-utilities/model-registry.md new file mode 100644 index 0000000000..3b4241c986 --- /dev/null +++ b/docs/docs/05-utilities/model-registry.md @@ -0,0 +1,37 @@ +--- +title: Model Registry +--- + +The [Model Registry](/react-native-executorch/docs/next/api-reference/variables/MODEL_REGISTRY) is a collection of all pre-configured model definitions shipped with React Native ExecuTorch. Each entry contains the model's name and all source URLs needed to download and run it, so you don't have to manage URLs manually. + +## Usage + +```typescript +import { MODEL_REGISTRY, LLAMA3_2_1B } from 'react-native-executorch'; +``` + +### Accessing a model directly + +Every model config is exported as a standalone constant: + +```typescript +import { LLAMA3_2_1B } from 'react-native-executorch'; + +const llm = useLLM({ model: LLAMA3_2_1B }); +``` + +### Listing all models + +Use `MODEL_REGISTRY` to discover and enumerate all available models: + +```typescript +import { MODEL_REGISTRY } from 'react-native-executorch'; + +// Get all model names +const names = Object.values(MODEL_REGISTRY.ALL_MODELS).map((m) => m.modelName); + +// Find models by name +const whisperModels = Object.values(MODEL_REGISTRY.ALL_MODELS).filter((m) => + m.modelName.includes('whisper') +); +``` diff --git a/packages/bare-resource-fetcher/src/ResourceFetcher.ts b/packages/bare-resource-fetcher/src/ResourceFetcher.ts index fc148526f0..4a7c639ff9 100644 --- a/packages/bare-resource-fetcher/src/ResourceFetcher.ts +++ b/packages/bare-resource-fetcher/src/ResourceFetcher.ts @@ -300,7 +300,6 @@ export const BareResourceFetcher: BareResourceFetcherInterface = { await RNFS.moveFile(extendedInfo.cacheFileUri!, extendedInfo.fileUri!); this.downloads.delete(source); - ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(extendedInfo.uri!); const filename = extendedInfo.fileUri!.split('/').pop(); if (filename) { diff --git a/packages/bare-resource-fetcher/src/ResourceFetcherUtils.ts b/packages/bare-resource-fetcher/src/ResourceFetcherUtils.ts index 859e6f892b..74657c0cdc 100644 --- a/packages/bare-resource-fetcher/src/ResourceFetcherUtils.ts +++ b/packages/bare-resource-fetcher/src/ResourceFetcherUtils.ts @@ -26,6 +26,7 @@ export namespace ResourceFetcherUtils { export const calculateDownloadProgress = CoreUtils.calculateDownloadProgress; export const triggerHuggingFaceDownloadCounter = CoreUtils.triggerHuggingFaceDownloadCounter; + export const triggerDownloadEvent = CoreUtils.triggerDownloadEvent; export const getFilenameFromUri = CoreUtils.getFilenameFromUri; export function getType(source: ResourceSource): SourceType { diff --git a/packages/expo-resource-fetcher/src/ResourceFetcher.ts b/packages/expo-resource-fetcher/src/ResourceFetcher.ts index e1c0373936..efbc260e20 100644 --- a/packages/expo-resource-fetcher/src/ResourceFetcher.ts +++ b/packages/expo-resource-fetcher/src/ResourceFetcher.ts @@ -264,9 +264,6 @@ export const ExpoResourceFetcher: ExpoResourceFetcherInterface = { to: resource.extendedInfo.fileUri, }); this.downloads.delete(source); - ResourceFetcherUtils.triggerHuggingFaceDownloadCounter( - resource.extendedInfo.uri - ); return this.returnOrStartNext( resource.extendedInfo, @@ -526,7 +523,6 @@ export const ExpoResourceFetcher: ExpoResourceFetcherInterface = { to: sourceExtended.fileUri, }); this.downloads.delete(source); - ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(uri); return ResourceFetcherUtils.removeFilePrefix(sourceExtended.fileUri); }, diff --git a/packages/expo-resource-fetcher/src/ResourceFetcherUtils.ts b/packages/expo-resource-fetcher/src/ResourceFetcherUtils.ts index 98347e0daf..0194442f77 100644 --- a/packages/expo-resource-fetcher/src/ResourceFetcherUtils.ts +++ b/packages/expo-resource-fetcher/src/ResourceFetcherUtils.ts @@ -39,6 +39,7 @@ export namespace ResourceFetcherUtils { export const calculateDownloadProgress = CoreUtils.calculateDownloadProgress; export const triggerHuggingFaceDownloadCounter = CoreUtils.triggerHuggingFaceDownloadCounter; + export const triggerDownloadEvent = CoreUtils.triggerDownloadEvent; export const getFilenameFromUri = CoreUtils.getFilenameFromUri; export function getType(source: ResourceSource): SourceType { diff --git a/packages/react-native-executorch/src/constants/modelUrls.ts b/packages/react-native-executorch/src/constants/modelUrls.ts index 0e4bcdf080..2eb69dfb0a 100644 --- a/packages/react-native-executorch/src/constants/modelUrls.ts +++ b/packages/react-native-executorch/src/constants/modelUrls.ts @@ -976,3 +976,129 @@ export const FSMN_VAD = { modelName: 'fsmn-vad', modelSource: FSMN_VAD_MODEL, } as const; + +/** + * Registry of all available model configurations. + * + * Use this to discover and enumerate all models shipped with the library. + * @example + * ```ts + * import { MODEL_REGISTRY } from 'react-native-executorch'; + * + * // List all model names + * const names = Object.values(MODEL_REGISTRY).map(m => m.modelName); + * + * // Find models by name substring + * const whisperModels = Object.values(MODEL_REGISTRY) + * .filter(m => m.modelName.includes('whisper')); + * ``` + * @category Utils + */ +export const MODEL_REGISTRY = { + ALL_MODELS: { + LLAMA3_2_3B, + LLAMA3_2_3B_QLORA, + LLAMA3_2_3B_SPINQUANT, + LLAMA3_2_1B, + LLAMA3_2_1B_QLORA, + LLAMA3_2_1B_SPINQUANT, + QWEN3_0_6B, + QWEN3_0_6B_QUANTIZED, + QWEN3_1_7B, + QWEN3_1_7B_QUANTIZED, + QWEN3_4B, + QWEN3_4B_QUANTIZED, + HAMMER2_1_0_5B, + HAMMER2_1_0_5B_QUANTIZED, + HAMMER2_1_1_5B, + HAMMER2_1_1_5B_QUANTIZED, + HAMMER2_1_3B, + HAMMER2_1_3B_QUANTIZED, + SMOLLM2_1_135M, + SMOLLM2_1_135M_QUANTIZED, + SMOLLM2_1_360M, + SMOLLM2_1_360M_QUANTIZED, + SMOLLM2_1_1_7B, + SMOLLM2_1_1_7B_QUANTIZED, + QWEN2_5_0_5B, + QWEN2_5_0_5B_QUANTIZED, + QWEN2_5_1_5B, + QWEN2_5_1_5B_QUANTIZED, + QWEN2_5_3B, + QWEN2_5_3B_QUANTIZED, + PHI_4_MINI_4B, + PHI_4_MINI_4B_QUANTIZED, + LFM2_5_1_2B_INSTRUCT, + LFM2_5_1_2B_INSTRUCT_QUANTIZED, + LFM2_VL_1_6B_QUANTIZED, + EFFICIENTNET_V2_S, + EFFICIENTNET_V2_S_QUANTIZED, + SSDLITE_320_MOBILENET_V3_LARGE, + RF_DETR_NANO, + STYLE_TRANSFER_CANDY, + STYLE_TRANSFER_CANDY_QUANTIZED, + STYLE_TRANSFER_MOSAIC, + STYLE_TRANSFER_MOSAIC_QUANTIZED, + STYLE_TRANSFER_RAIN_PRINCESS, + STYLE_TRANSFER_RAIN_PRINCESS_QUANTIZED, + STYLE_TRANSFER_UDNIE, + STYLE_TRANSFER_UDNIE_QUANTIZED, + WHISPER_TINY_EN, + WHISPER_TINY_EN_QUANTIZED, + WHISPER_BASE_EN, + WHISPER_BASE_EN_QUANTIZED, + WHISPER_SMALL_EN, + WHISPER_SMALL_EN_QUANTIZED, + WHISPER_TINY, + WHISPER_BASE, + WHISPER_SMALL, + DEEPLAB_V3_RESNET50, + DEEPLAB_V3_RESNET101, + DEEPLAB_V3_MOBILENET_V3_LARGE, + LRASPP_MOBILENET_V3_LARGE, + FCN_RESNET50, + FCN_RESNET101, + DEEPLAB_V3_RESNET50_QUANTIZED, + DEEPLAB_V3_RESNET101_QUANTIZED, + DEEPLAB_V3_MOBILENET_V3_LARGE_QUANTIZED, + LRASPP_MOBILENET_V3_LARGE_QUANTIZED, + FCN_RESNET50_QUANTIZED, + FCN_RESNET101_QUANTIZED, + SELFIE_SEGMENTATION, + YOLO26N_SEG, + YOLO26S_SEG, + YOLO26M_SEG, + YOLO26L_SEG, + YOLO26X_SEG, + RF_DETR_NANO_SEG, + CLIP_VIT_BASE_PATCH32_IMAGE, + CLIP_VIT_BASE_PATCH32_IMAGE_QUANTIZED, + ALL_MINILM_L6_V2, + ALL_MPNET_BASE_V2, + MULTI_QA_MINILM_L6_COS_V1, + MULTI_QA_MPNET_BASE_DOT_V1, + CLIP_VIT_BASE_PATCH32_TEXT, + BK_SDM_TINY_VPRED_512, + BK_SDM_TINY_VPRED_256, + FSMN_VAD, + }, +} as const; + +const urlToModelName = new Map(); +for (const config of Object.values(MODEL_REGISTRY.ALL_MODELS)) { + const modelName = config.modelName; + for (const [key, value] of Object.entries(config)) { + if (key !== 'modelName' && typeof value === 'string') { + urlToModelName.set(value, modelName); + } + } +} + +/** + * Looks up the model name for a given source URL. + * @param url - The source URL to look up. + * @returns The model name if found, otherwise undefined. + */ +export function getModelNameForUrl(url: string): string | undefined { + return urlToModelName.get(url); +} diff --git a/packages/react-native-executorch/src/constants/resourceFetcher.ts b/packages/react-native-executorch/src/constants/resourceFetcher.ts new file mode 100644 index 0000000000..58a57fdd93 --- /dev/null +++ b/packages/react-native-executorch/src/constants/resourceFetcher.ts @@ -0,0 +1,2 @@ +export const DOWNLOAD_EVENT_ENDPOINT = + 'https://ai.swmansion.com/telemetry/downloads/api/downloads'; diff --git a/packages/react-native-executorch/src/utils/ResourceFetcher.ts b/packages/react-native-executorch/src/utils/ResourceFetcher.ts index bd92c07f2f..51f2abb025 100644 --- a/packages/react-native-executorch/src/utils/ResourceFetcher.ts +++ b/packages/react-native-executorch/src/utils/ResourceFetcher.ts @@ -75,6 +75,7 @@ export interface ResourceFetcherAdapter { */ export class ResourceFetcher { private static adapter: ResourceFetcherAdapter | null = null; + private static reportedUrls = new Set(); /** * Sets a custom resource fetcher adapter for resource operations. @@ -123,16 +124,21 @@ export class ResourceFetcher { callback: (downloadProgress: number) => void = () => {}, ...sources: ResourceSource[] ) { - for (const source of sources) { - if (typeof source === 'string') { - try { - ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(source); - } catch (error) { - throw error; + const result = await this.getAdapter().fetch(callback, ...sources); + if (result) { + for (const source of sources) { + if (typeof source === 'string' && !this.reportedUrls.has(source)) { + this.reportedUrls.add(source); + try { + ResourceFetcherUtils.triggerDownloadEvent(source); + ResourceFetcherUtils.triggerHuggingFaceDownloadCounter(source); + } catch (error) { + throw error; + } } } } - return this.getAdapter().fetch(callback, ...sources); + return result; } /** diff --git a/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts b/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts index b98496af58..689775f93d 100644 --- a/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts +++ b/packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts @@ -1,4 +1,6 @@ import { ResourceSource } from '..'; +import { getModelNameForUrl } from '../constants/modelUrls'; +import { DOWNLOAD_EVENT_ENDPOINT } from '../constants/resourceFetcher'; /** * Http status codes @@ -193,6 +195,44 @@ export namespace ResourceFetcherUtils { } } + function getCountryCode(): string { + try { + const locale = Intl.DateTimeFormat().resolvedOptions().locale; + const regionTag = locale.split('-').pop(); + if (regionTag && regionTag.length === 2) { + return regionTag.toUpperCase(); + } + } catch {} + return 'UNKNOWN'; + } + + function getModelNameFromUri(uri: string): string { + const knownName = getModelNameForUrl(uri); + if (knownName) { + return knownName; + } + const pathname = new URL(uri).pathname; + const filename = pathname.split('/').pop() ?? uri; + return filename.replace(/\.[^.]+$/, ''); + } + + /** + * Sends a download event to the analytics endpoint. + * @param uri - The URI of the downloaded resource. + */ + export function triggerDownloadEvent(uri: string) { + try { + fetch(DOWNLOAD_EVENT_ENDPOINT, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ + modelName: getModelNameFromUri(uri), + countryCode: getCountryCode(), + }), + }); + } catch (e) {} + } + /** * Generates a safe filename from a URI by removing the protocol and replacing special characters. * @param uri - The source URI.