diff --git a/packages/react-native-executorch/src/modules/BaseModule.ts b/packages/react-native-executorch/src/modules/BaseModule.ts index 0061c57e3d..0c0d8c7881 100644 --- a/packages/react-native-executorch/src/modules/BaseModule.ts +++ b/packages/react-native-executorch/src/modules/BaseModule.ts @@ -7,13 +7,16 @@ export class BaseModule { static onDownloadProgressCallback: (downloadProgress: number) => void = () => {}; - static async load(...sources: ResourceSource[]): Promise { + static async load( + sources: ResourceSource[], + ...loadArgs: any[] // this can be used in derived classes to pass extra args to load method + ): Promise { try { const paths = await ResourceFetcher.fetchMultipleResources( this.onDownloadProgressCallback, ...sources ); - await this.nativeModule.loadModule(...paths); + await this.nativeModule.loadModule(...paths, ...loadArgs); } catch (error) { throw new Error(getError(error)); } diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts index 45f41cd8ea..d60222f1f6 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts @@ -9,7 +9,7 @@ export class TextEmbeddingsModule extends BaseModule { modelSource: ResourceSource, tokenizerSource: ResourceSource ) { - await super.load(modelSource, tokenizerSource); + await super.load([modelSource, tokenizerSource]); } static override async forward(input: string): Promise { diff --git a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts index f5a69c734e..589a147d78 100644 --- a/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts +++ b/packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts @@ -6,7 +6,7 @@ export class TokenizerModule extends BaseModule { protected static override nativeModule = TokenizerNativeModule; static override async load(tokenizerSource: ResourceSource) { - await super.load(tokenizerSource); + await super.load([tokenizerSource]); } static async decode(