Skip to content

Commit 9fc03a6

Browse files
authored
Enable passing extra args to load function in classes derived from BaseModule (#381)
### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Testing instructions Tested in CV object detection - it works ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 4bbb2c4 commit 9fc03a6

3 files changed

Lines changed: 7 additions & 4 deletions

File tree

packages/react-native-executorch/src/modules/BaseModule.ts

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ export class BaseModule {
77
static onDownloadProgressCallback: (downloadProgress: number) => void =
88
() => {};
99

10-
static async load(...sources: ResourceSource[]): Promise<void> {
10+
static async load(
11+
sources: ResourceSource[],
12+
...loadArgs: any[] // this can be used in derived classes to pass extra args to load method
13+
): Promise<void> {
1114
try {
1215
const paths = await ResourceFetcher.fetchMultipleResources(
1316
this.onDownloadProgressCallback,
1417
...sources
1518
);
16-
await this.nativeModule.loadModule(...paths);
19+
await this.nativeModule.loadModule(...paths, ...loadArgs);
1720
} catch (error) {
1821
throw new Error(getError(error));
1922
}

packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export class TextEmbeddingsModule extends BaseModule {
99
modelSource: ResourceSource,
1010
tokenizerSource: ResourceSource
1111
) {
12-
await super.load(modelSource, tokenizerSource);
12+
await super.load([modelSource, tokenizerSource]);
1313
}
1414

1515
static override async forward(input: string): Promise<number[]> {

packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ export class TokenizerModule extends BaseModule {
66
protected static override nativeModule = TokenizerNativeModule;
77

88
static override async load(tokenizerSource: ResourceSource) {
9-
await super.load(tokenizerSource);
9+
await super.load([tokenizerSource]);
1010
}
1111

1212
static async decode(

0 commit comments

Comments
 (0)