Skip to content

Commit ada8e44

Browse files
authored
fix: prevent negative values from appearing in downloadProgress (#465)
## Description Changes: - updated the downloading strategy to prevent negative values from appearing in downloadProgress ### Type of change - [x] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [ ] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [ ] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent de57e5d commit ada8e44

7 files changed

Lines changed: 75 additions & 41 deletions

File tree

packages/react-native-executorch/src/controllers/LLMController.ts

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -102,22 +102,34 @@ export class LLMController {
102102
this.isReadyCallback(false);
103103

104104
try {
105-
const paths = await ResourceFetcher.fetch(
106-
onDownloadProgressCallback,
105+
const tokenizersPromise = ResourceFetcher.fetch(
106+
undefined,
107107
tokenizerSource,
108-
tokenizerConfigSource,
108+
tokenizerConfigSource
109+
);
110+
111+
const modelPromise = ResourceFetcher.fetch(
112+
onDownloadProgressCallback,
109113
modelSource
110114
);
111-
if (paths === null || paths?.length < 3) {
115+
116+
const [tokenizersResults, modelResult] = await Promise.all([
117+
tokenizersPromise,
118+
modelPromise,
119+
]);
120+
121+
const tokenizerPath = tokenizersResults?.[0];
122+
const tokenizerConfigPath = tokenizersResults?.[1];
123+
const modelPath = modelResult?.[0];
124+
125+
if (!tokenizerPath || !tokenizerConfigPath || !modelPath) {
112126
throw new Error('Download interrupted!');
113127
}
114-
const tokenizerFileUri = paths[0]!;
115-
const tokenizerConfigFileUri = paths[1]!;
116-
const modelFileUri = paths[2]!;
128+
117129
this.tokenizerConfig = JSON.parse(
118-
await readAsStringAsync('file://' + tokenizerConfigFileUri!)
130+
await readAsStringAsync('file://' + tokenizerConfigPath!)
119131
);
120-
this.nativeModule = global.loadLLM(modelFileUri, tokenizerFileUri);
132+
this.nativeModule = global.loadLLM(modelPath, tokenizerPath);
121133
this.isReadyCallback(true);
122134
this.onToken = (data: string) => {
123135
if (

packages/react-native-executorch/src/controllers/SpeechToTextController.ts

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,18 +100,23 @@ export class SpeechToTextController {
100100
this.config = MODEL_CONFIGS[modelName];
101101

102102
try {
103-
await this.tokenizerModule.load(
103+
const tokenizerLoadPromise = this.tokenizerModule.load(
104104
tokenizerSource || this.config.tokenizer.source
105105
);
106-
const paths = await ResourceFetcher.fetch(
106+
const pathsPromise = ResourceFetcher.fetch(
107107
onDownloadProgressCallback,
108108
encoderSource || this.config.sources.encoder,
109109
decoderSource || this.config.sources.decoder
110110
);
111-
if (paths === null || paths.length < 2) {
111+
const [_, encoderDecoderResults] = await Promise.all([
112+
tokenizerLoadPromise,
113+
pathsPromise,
114+
]);
115+
encoderSource = encoderDecoderResults?.[0];
116+
decoderSource = encoderDecoderResults?.[1];
117+
if (!encoderSource || !decoderSource) {
112118
throw new Error('Download interrupted.');
113119
}
114-
[encoderSource, decoderSource] = paths;
115120
} catch (e) {
116121
this.onErrorCallback(e);
117122
return;
@@ -127,8 +132,8 @@ export class SpeechToTextController {
127132

128133
try {
129134
const nativeSpeechToText = await global.loadSpeechToText(
130-
encoderSource!,
131-
decoderSource!,
135+
encoderSource,
136+
decoderSource,
132137
modelName
133138
);
134139
this.speechToTextNativeModule = nativeSpeechToText;

packages/react-native-executorch/src/hooks/natural_language_processing/useTokenizer.ts

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { useEffect, useRef, useState } from 'react';
1+
import { useEffect, useMemo, useState } from 'react';
22
import { TokenizerModule } from '../../modules/natural_language_processing/TokenizerModule';
33
import { ResourceSource } from '../../types/common';
44
import { ETError, getError } from '../../Error';
@@ -14,33 +14,30 @@ export const useTokenizer = ({
1414
const [isReady, setIsReady] = useState(false);
1515
const [isGenerating, setIsGenerating] = useState(false);
1616
const [downloadProgress, setDownloadProgress] = useState(0);
17-
const tokenizerModuleRef = useRef<TokenizerModule | null>(null);
17+
const model = useMemo(() => new TokenizerModule(), []);
1818

1919
useEffect(() => {
20-
const loadModule = async () => {
20+
if (preventLoad) return;
21+
(async () => {
22+
setDownloadProgress(0);
23+
setError(null);
2124
try {
2225
setIsReady(false);
23-
tokenizerModuleRef.current = new TokenizerModule();
24-
tokenizerModuleRef.current.load(tokenizerSource, setDownloadProgress);
26+
await model.load(tokenizerSource, setDownloadProgress);
2527
setIsReady(true);
2628
} catch (err) {
2729
setError((err as Error).message);
2830
}
29-
};
30-
if (!preventLoad) {
31-
loadModule();
32-
}
33-
}, [tokenizerSource, preventLoad]);
31+
})();
32+
}, [model, tokenizerSource, preventLoad]);
3433

3534
const stateWrapper = <T extends (...args: any[]) => Promise<any>>(fn: T) => {
36-
return async (...args: Parameters<T>): Promise<ReturnType<T>> => {
37-
if (!isReady || !tokenizerModuleRef.current)
38-
throw new Error(getError(ETError.ModuleNotLoaded));
35+
return (...args: Parameters<T>): Promise<ReturnType<T>> => {
36+
if (!isReady) throw new Error(getError(ETError.ModuleNotLoaded));
3937
if (isGenerating) throw new Error(getError(ETError.ModelGenerating));
40-
41-
setIsGenerating(true);
4238
try {
43-
return await fn.apply(tokenizerModuleRef.current, args);
39+
setIsGenerating(true);
40+
return fn.apply(model, args);
4441
} finally {
4542
setIsGenerating(false);
4643
}

packages/react-native-executorch/src/modules/natural_language_processing/TextEmbeddingsModule.ts

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,21 @@ export class TextEmbeddingsModule extends BaseNonStaticModule {
88
tokenizerSource: ResourceSource,
99
onDownloadProgressCallback: (_: number) => void = () => {}
1010
): Promise<void> {
11-
const paths = await ResourceFetcher.fetch(
11+
const modelPromise = ResourceFetcher.fetch(
1212
onDownloadProgressCallback,
13-
modelSource,
14-
tokenizerSource
13+
modelSource
1514
);
16-
if (paths === null || paths.length < 2) {
15+
const tokenizerPromise = ResourceFetcher.fetch(undefined, tokenizerSource);
16+
const [modelResult, tokenizerResult] = await Promise.all([
17+
modelPromise,
18+
tokenizerPromise,
19+
]);
20+
const modelPath = modelResult?.[0];
21+
const tokenizerPath = tokenizerResult?.[0];
22+
if (!modelPath || !tokenizerPath) {
1723
throw new Error('Download interrupted.');
1824
}
19-
this.nativeModule = global.loadTextEmbeddings(
20-
paths[0] || '',
21-
paths[1] || ''
22-
);
25+
this.nativeModule = global.loadTextEmbeddings(modelPath, tokenizerPath);
2326
}
2427

2528
async forward(input: string): Promise<Float32Array> {

packages/react-native-executorch/src/modules/natural_language_processing/TokenizerModule.ts

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@ export class TokenizerModule {
1212
onDownloadProgressCallback,
1313
modelSource
1414
);
15-
if (paths === null || paths.length < 1) {
15+
const path = paths?.[0];
16+
if (!path) {
1617
throw new Error('Download interrupted.');
1718
}
18-
this.nativeModule = global.loadTokenizerModule(paths[0] || '');
19+
this.nativeModule = global.loadTokenizerModule(path);
1920
}
2021

2122
async encode(s: string) {

packages/react-native-executorch/src/utils/ResourceFetcher.ts

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,11 @@ export class ResourceFetcher {
376376
sourceExtended.cacheFileUri,
377377
{ sessionType: FileSystemSessionType.BACKGROUND },
378378
({ totalBytesWritten, totalBytesExpectedToWrite }) => {
379+
if (totalBytesExpectedToWrite === -1) {
380+
// If totalBytesExpectedToWrite is -1, it means the server does not provide content length.
381+
sourceExtended.callback!(0);
382+
return;
383+
}
379384
sourceExtended.callback!(totalBytesWritten / totalBytesExpectedToWrite);
380385
}
381386
);

packages/react-native-executorch/src/utils/ResourceFetcherUtils.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ export namespace ResourceFetcherUtils {
9090
}
9191

9292
const contentLength = response.headers.get('content-length');
93+
if (!contentLength) {
94+
Logger.warn(`No content-length header for ${source}`);
95+
}
96+
9397
length = contentLength ? parseInt(contentLength, 10) : 0;
9498
previousFilesTotalLength = totalLength;
9599
totalLength += length;
@@ -134,6 +138,13 @@ export namespace ResourceFetcherUtils {
134138
setProgress(1);
135139
return;
136140
}
141+
142+
// Avoid division by zero
143+
if (totalLength === 0) {
144+
setProgress(0);
145+
return;
146+
}
147+
137148
const baseProgress = previousFilesTotalLength / totalLength;
138149
const scaledProgress = progress * (currentFileLength / totalLength);
139150
const updatedProgress = baseProgress + scaledProgress;

0 commit comments

Comments
 (0)