Skip to content

Commit aeb603f

Browse files
committed
Enable passing extra args to load function in classes dervied from BaseModule
1 parent 6e945b0 commit aeb603f

4 files changed

Lines changed: 8 additions & 5 deletions

File tree

packages/react-native-executorch/src/modules/BaseModule.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ export class BaseModule {
77
static onDownloadProgressCallback: (downloadProgress: number) => void =
88
() => {};
99

10-
static async load(...sources: ResourceSource[]): Promise<void> {
10+
static async load(
11+
sources: ResourceSource[],
12+
...loadArgs: any[] // this can be used in derived classes to pass extra args to load method
13+
): Promise<void> {
1114
try {
1215
const paths = await ResourceFetcher.fetchMultipleResources(
1316
this.onDownloadProgressCallback,
1417
...sources
1518
);
16-
await this.nativeModule.loadModule(...paths);
19+
await this.nativeModule.loadModule(...paths, ...loadArgs);
1720
} catch (error) {
1821
throw new Error(getError(error));
1922
}

packages/react-native-executorch/src/modules/general/ExecutorchModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export class ExecutorchModule extends BaseModule {
99
protected static override nativeModule = ETModuleNativeModule;
1010

1111
static override async load(modelSource: ResourceSource) {
12-
return await super.load(modelSource);
12+
return await super.load([modelSource]);
1313
}
1414

1515
static override async forward(input: ETInput[] | ETInput, shape: number[][]) {

packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export class TextEmbeddingsModule extends BaseModule {
99
modelSource: ResourceSource,
1010
tokenizerSource: ResourceSource
1111
) {
12-
await super.load(modelSource, tokenizerSource);
12+
await super.load([modelSource, tokenizerSource]);
1313
}
1414

1515
static override async forward(input: string): Promise<number[]> {

packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export class TokenizerModule extends BaseModule {
66
protected static override nativeModule = TokenizerNativeModule;
77

88
static override async load(tokenizerSource: ResourceSource) {
9-
await super.load(tokenizerSource);
9+
await super.load([tokenizerSource]);
1010
}
1111

1212
static async decode(

0 commit comments

Comments
 (0)