Skip to content

Commit f293cb2

Browse files
committed
fix: move model construction off the JS thread
The `loadModel` template in `RnExecutorchInstaller.h` constructs model objects synchronously on the JS thread. For models like Kokoro TTS, the constructor loads .pte files, initializes the phonemizer, and reads voice data — blocking the JS thread for several seconds. This prevents React from rendering loading states (spinners, progress indicators) until construction completes. This change makes `loadModel` return a Promise and dispatches the model construction to `GlobalThreadPool::detach`, matching the pattern already used by `promiseHostFunction` for inference calls like `generate()`. On the JS side, all `global.load*()` call sites are updated to `await` the now-async result, and the global type declarations are updated to return `Promise<any>`. This is a breaking change for any consumers calling `global.load*()` directly and expecting a synchronous return value. The public module APIs (`load()` methods) are already async, so no user-facing API changes are needed.
1 parent fdbd3e9 commit f293cb2

File tree

16 files changed

+102
-72
lines changed

16 files changed

+102
-72
lines changed

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 71 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,11 @@
77
#include <rnexecutorch/Error.h>
88
#include <rnexecutorch/host_objects/JsiConversions.h>
99
#include <rnexecutorch/host_objects/ModelHostObject.h>
10+
#include <rnexecutorch/jsi/Promise.h>
1011
#include <rnexecutorch/metaprogramming/ConstructorHelpers.h>
1112
#include <rnexecutorch/metaprogramming/FunctionHelpers.h>
1213
#include <rnexecutorch/metaprogramming/TypeConcepts.h>
14+
#include <rnexecutorch/threads/GlobalThreadPool.h>
1315

1416
namespace rnexecutorch {
1517

@@ -49,50 +51,78 @@ class RnExecutorchInstaller {
4951
expectedCount, count);
5052
throw jsi::JSError(runtime, errorMessage);
5153
}
52-
try {
53-
auto constructorArgs =
54-
meta::createConstructorArgsWithCallInvoker<ModelT>(
55-
args, runtime, jsCallInvoker);
5654

57-
auto modelImplementationPtr = std::apply(
58-
[](auto &&...unpackedArgs) {
59-
return std::make_shared<ModelT>(
60-
std::forward<decltype(unpackedArgs)>(unpackedArgs)...);
61-
},
62-
std::move(constructorArgs));
55+
// Parse JSI arguments on the JS thread (required for jsi::Value
56+
// access), then dispatch the heavy model construction to a background
57+
// thread and return a Promise.
58+
auto constructorArgs =
59+
meta::createConstructorArgsWithCallInvoker<ModelT>(
60+
args, runtime, jsCallInvoker);
6361

64-
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
65-
modelImplementationPtr, jsCallInvoker);
62+
return Promise::createPromise(
63+
runtime, jsCallInvoker,
64+
[jsCallInvoker,
65+
constructorArgs =
66+
std::move(constructorArgs)](std::shared_ptr<Promise> promise) {
67+
threads::GlobalThreadPool::detach(
68+
[jsCallInvoker, promise,
69+
constructorArgs = std::move(constructorArgs)]() {
70+
try {
71+
auto modelImplementationPtr = std::apply(
72+
[](auto &&...unpackedArgs) {
73+
return std::make_shared<ModelT>(
74+
std::forward<decltype(unpackedArgs)>(
75+
unpackedArgs)...);
76+
},
77+
std::move(constructorArgs));
6678

67-
auto jsiObject =
68-
jsi::Object::createFromHostObject(runtime, modelHostObject);
69-
jsiObject.setExternalMemoryPressure(
70-
runtime, modelImplementationPtr->getMemoryLowerBound());
71-
return jsiObject;
72-
} catch (const rnexecutorch::RnExecutorchError &e) {
73-
jsi::Object errorData(runtime);
74-
errorData.setProperty(runtime, "code", e.getNumericCode());
75-
errorData.setProperty(
76-
runtime, "message",
77-
jsi::String::createFromUtf8(runtime, e.what()));
78-
throw jsi::JSError(runtime,
79-
jsi::Value(runtime, std::move(errorData)));
80-
} catch (const std::runtime_error &e) {
81-
// This catch should be merged with the next one
82-
// (std::runtime_error inherits from std::exception) HOWEVER react
83-
// native has broken RTTI which breaks proper exception type
84-
// checking. Remove when the following change is present in our
85-
// version:
86-
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
87-
throw jsi::JSError(runtime, e.what());
88-
return jsi::Value();
89-
} catch (const std::exception &e) {
90-
throw jsi::JSError(runtime, e.what());
91-
return jsi::Value();
92-
} catch (...) {
93-
throw jsi::JSError(runtime, "Unknown error");
94-
return jsi::Value();
95-
}
79+
auto modelHostObject =
80+
std::make_shared<ModelHostObject<ModelT>>(
81+
modelImplementationPtr, jsCallInvoker);
82+
83+
auto memoryLowerBound =
84+
modelImplementationPtr->getMemoryLowerBound();
85+
86+
jsCallInvoker->invokeAsync(
87+
[promise, modelHostObject,
88+
memoryLowerBound](jsi::Runtime &rt) {
89+
auto jsiObject =
90+
jsi::Object::createFromHostObject(
91+
rt, modelHostObject);
92+
jsiObject.setExternalMemoryPressure(
93+
rt, memoryLowerBound);
94+
promise->resolve(std::move(jsiObject));
95+
});
96+
} catch (const rnexecutorch::RnExecutorchError &e) {
97+
auto code = e.getNumericCode();
98+
auto msg = std::string(e.what());
99+
jsCallInvoker->invokeAsync(
100+
[promise, code, msg](jsi::Runtime &rt) {
101+
jsi::Object errorData(rt);
102+
errorData.setProperty(rt, "code", code);
103+
errorData.setProperty(
104+
rt, "message",
105+
jsi::String::createFromUtf8(rt, msg));
106+
promise->reject(
107+
jsi::Value(rt, std::move(errorData)));
108+
});
109+
} catch (const std::runtime_error &e) {
110+
jsCallInvoker->invokeAsync(
111+
[promise, msg = std::string(e.what())]() {
112+
promise->reject(msg);
113+
});
114+
} catch (const std::exception &e) {
115+
jsCallInvoker->invokeAsync(
116+
[promise, msg = std::string(e.what())]() {
117+
promise->reject(msg);
118+
});
119+
} catch (...) {
120+
jsCallInvoker->invokeAsync([promise]() {
121+
promise->reject(std::string("Unknown error"));
122+
});
123+
}
124+
});
125+
});
96126
});
97127
}
98128
};

packages/react-native-executorch/src/controllers/LLMController.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ export class LLMController {
118118
this.tokenizerConfig = JSON.parse(
119119
await ResourceFetcher.fs.readAsString(tokenizerConfigPath!)
120120
);
121-
this.nativeModule = global.loadLLM(modelPath, tokenizerPath);
121+
this.nativeModule = await global.loadLLM(modelPath, tokenizerPath);
122122
this.isReadyCallback(true);
123123
this.onToken = (data: string) => {
124124
if (!data) {

packages/react-native-executorch/src/controllers/OCRController.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ export class OCRController extends BaseOCRController {
99
recognizerPath: string,
1010
language: OCRLanguage
1111
): any {
12-
return global.loadOCR(detectorPath, recognizerPath, symbols[language]);
12+
return await global.loadOCR(detectorPath, recognizerPath, symbols[language]);
1313
}
1414

1515
public load = async (

packages/react-native-executorch/src/controllers/VerticalOCRController.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ export class VerticalOCRController extends BaseOCRController {
1010
language: OCRLanguage,
1111
independentCharacters?: boolean
1212
): any {
13-
return global.loadVerticalOCR(
13+
return await global.loadVerticalOCR(
1414
detectorPath,
1515
recognizerPath,
1616
symbols[language],

packages/react-native-executorch/src/index.ts

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,26 @@ export function cleanupExecutorch() {
3535

3636
// eslint-disable no-var
3737
declare global {
38-
var loadStyleTransfer: (source: string) => any;
38+
var loadStyleTransfer: (source: string) => Promise<any>;
3939
var loadSemanticSegmentation: (
4040
source: string,
4141
normMean: Triple<number> | [],
4242
normStd: Triple<number> | [],
4343
allClasses: string[]
44-
) => any;
45-
var loadClassification: (source: string) => any;
44+
) => Promise<any>;
45+
var loadClassification: (source: string) => Promise<any>;
4646
var loadObjectDetection: (
4747
source: string,
4848
normMean: Triple<number> | [],
4949
normStd: Triple<number> | [],
5050
labelNames: string[]
51-
) => any;
52-
var loadExecutorchModule: (source: string) => any;
53-
var loadTokenizerModule: (source: string) => any;
54-
var loadImageEmbeddings: (source: string) => any;
55-
var loadVAD: (source: string) => any;
56-
var loadTextEmbeddings: (modelSource: string, tokenizerSource: string) => any;
57-
var loadLLM: (modelSource: string, tokenizerSource: string) => any;
51+
) => Promise<any>;
52+
var loadExecutorchModule: (source: string) => Promise<any>;
53+
var loadTokenizerModule: (source: string) => Promise<any>;
54+
var loadImageEmbeddings: (source: string) => Promise<any>;
55+
var loadVAD: (source: string) => Promise<any>;
56+
var loadTextEmbeddings: (modelSource: string, tokenizerSource: string) => Promise<any>;
57+
var loadLLM: (modelSource: string, tokenizerSource: string) => Promise<any>;
5858
var loadTextToImage: (
5959
tokenizerSource: string,
6060
encoderSource: string,
@@ -64,31 +64,31 @@ declare global {
6464
schedulerBetaEnd: number,
6565
schedulerNumTrainTimesteps: number,
6666
schedulerStepsOffset: number
67-
) => any;
67+
) => Promise<any>;
6868
var loadSpeechToText: (
6969
encoderSource: string,
7070
decoderSource: string,
7171
modelName: string
72-
) => any;
72+
) => Promise<any>;
7373
var loadTextToSpeechKokoro: (
7474
lang: string,
7575
taggerData: string,
7676
phonemizerData: string,
7777
durationPredictorSource: string,
7878
synthesizerSource: string,
7979
voice: string
80-
) => any;
80+
) => Promise<any>;
8181
var loadOCR: (
8282
detectorSource: string,
8383
recognizer: string,
8484
symbols: string
85-
) => any;
85+
) => Promise<any>;
8686
var loadVerticalOCR: (
8787
detectorSource: string,
8888
recognizer: string,
8989
symbols: string,
9090
independentCharacters?: boolean
91-
) => any;
91+
) => Promise<any>;
9292
}
9393
// eslint-disable no-var
9494
if (

packages/react-native-executorch/src/modules/computer_vision/ClassificationModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export class ClassificationModule extends BaseModule {
3535
);
3636
}
3737

38-
this.nativeModule = global.loadClassification(paths[0]);
38+
this.nativeModule = await global.loadClassification(paths[0]);
3939
} catch (error) {
4040
Logger.error('Load failed:', error);
4141
throw parseUnknownError(error);

packages/react-native-executorch/src/modules/computer_vision/ImageEmbeddingsModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ export class ImageEmbeddingsModule extends BaseModule {
3434
);
3535
}
3636

37-
this.nativeModule = global.loadImageEmbeddings(paths[0]);
37+
this.nativeModule = await global.loadImageEmbeddings(paths[0]);
3838
} catch (error) {
3939
Logger.error('Load failed:', error);
4040
throw parseUnknownError(error);

packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ export class ObjectDetectionModule<
8888
if (allLabelNames[i] == null) allLabelNames[i] = '';
8989
}
9090
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
91-
const nativeModule = global.loadObjectDetection(
91+
const nativeModule = await global.loadObjectDetection(
9292
modelPath,
9393
normMean,
9494
normStd,
@@ -137,7 +137,7 @@ export class ObjectDetectionModule<
137137
if (allLabelNames[i] == null) allLabelNames[i] = '';
138138
}
139139
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
140-
const nativeModule = global.loadObjectDetection(
140+
const nativeModule = await global.loadObjectDetection(
141141
modelPath,
142142
normMean,
143143
normStd,

packages/react-native-executorch/src/modules/computer_vision/SemanticSegmentationModule.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ export class SemanticSegmentationModule<
114114
const normStd = preprocessorConfig?.normStd ?? [];
115115
const allClassNames = Object.keys(labelMap).filter((k) => isNaN(Number(k)));
116116
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
117-
const nativeModule = global.loadSemanticSegmentation(
117+
const nativeModule = await global.loadSemanticSegmentation(
118118
modelPath,
119119
normMean,
120120
normStd,
@@ -155,7 +155,7 @@ export class SemanticSegmentationModule<
155155
isNaN(Number(k))
156156
);
157157
const modelPath = await fetchModelPath(modelSource, onDownloadProgress);
158-
const nativeModule = global.loadSemanticSegmentation(
158+
const nativeModule = await global.loadSemanticSegmentation(
159159
modelPath,
160160
normMean,
161161
normStd,

packages/react-native-executorch/src/modules/computer_vision/StyleTransferModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export class StyleTransferModule extends BaseModule {
3535
);
3636
}
3737

38-
this.nativeModule = global.loadStyleTransfer(paths[0]);
38+
this.nativeModule = await global.loadStyleTransfer(paths[0]);
3939
} catch (error) {
4040
Logger.error('Load failed:', error);
4141
throw parseUnknownError(error);

0 commit comments

Comments
 (0)