Skip to content

Commit 26f88f6

Browse files
authored
fix: add inference mutex to Text Embedding and Text-to-Image (#1060)
## Description Adds thread-safety to Text Embeddings and Text-to-Image models mirroring what was already done for models inheriting from VisionModel and VAD. ### Introduces a breaking change? - [ ] Yes - [x] No ### Type of change - [x] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [ ] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions Use the following app screen and try to trigger the race condition before fix and verify that it doesn't occur after applying the fix. You can use `adb logcat | grep -E "FATAL|SIGSEGV|backtrace"` to observe the error on Android. ```ts import React, { useState } from 'react'; import { Button, ScrollView, Text, View } from 'react-native'; import { BK_SDM_TINY_VPRED_512, TextToImageModule, } from 'react-native-executorch'; import { CLIP_VIT_BASE_PATCH32_TEXT, TextEmbeddingsModule, } from 'react-native-executorch'; import { FSMN_VAD, VADModule } from 'react-native-executorch'; const DELAY_MS = 50; // tune this so that forward() is running when delete() is called const MODEL_VAD = { name: 'VAD', load: (onProgress: (p: number) => void) => VADModule.fromModelName(FSMN_VAD, onProgress), input: () => new Float32Array(16000 * 300), }; const MODEL_TEXT_EMBEDDINGS = { name: 'TextEmbeddings', load: (onProgress: (p: number) => void) => TextEmbeddingsModule.fromModelName(CLIP_VIT_BASE_PATCH32_TEXT, onProgress), input: () => 'hello world', }; const MODEL_TEXT_TO_IMAGE = { name: 'TextToImage', load: (onProgress: (p: number) => void) => TextToImageModule.fromModelName(BK_SDM_TINY_VPRED_512, onProgress), input: () => 'a red apple', }; const MODEL = MODEL_TEXT_EMBEDDINGS; export default function RaceTest() { const [lines, setLines] = useState<string[]>([]); const [downloadProgress, setDownloadProgress] = useState<number | null>(null); const log = (line: string) => setLines((prev) => [line, ...prev]); const run = async () => { setLines([]); setDownloadProgress(null); log(`model: ${MODEL.name}`); log('loading'); const model = await MODEL.load((p) => setDownloadProgress(p)); setDownloadProgress(null); log('running forward()'); const result = model.forward(MODEL.input()); log(`waiting ${DELAY_MS} ms`); await new Promise((r) => setTimeout(r, DELAY_MS)); log('calling delete()'); model.delete(); try { await result; log('forward() completed successfully'); } catch (e: any) { log('error: ' + (e?.message ?? String(e))); } }; return ( <View> <Button title="Run Race Test" onPress={run} /> {downloadProgress !== null && ( <Text>downloading: {Math.round(downloadProgress * 100)}%</Text> )} <ScrollView> {lines.map((l, i) => ( <Text key={i}>{l}</Text> ))} </ScrollView> </View> ); } ``` ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues #1055 ### Checklist - [x] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent cb032c6 commit 26f88f6

File tree

5 files changed

+16
-1
lines changed

5 files changed

+16
-1
lines changed

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,9 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
375375
// We need to dispatch a thread if we want the function to be
376376
// asynchronous. In this thread all accesses to jsi::Runtime need to
377377
// be done via the callInvoker.
378-
threads::GlobalThreadPool::detach([this, promise,
378+
threads::GlobalThreadPool::detach([model = this->model,
379+
callInvoker = this->callInvoker,
380+
promise,
379381
argsConverted =
380382
std::move(argsConverted)]() {
381383
try {

packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
3535
return {.inputIds = inputIds64, .attentionMask = attentionMask};
3636
}
3737

38+
void TextEmbeddings::unload() noexcept {
39+
std::scoped_lock lock(inference_mutex_);
40+
BaseModel::unload();
41+
}
42+
3843
std::shared_ptr<OwningArrayBuffer>
3944
TextEmbeddings::generate(const std::string input) {
45+
std::scoped_lock lock(inference_mutex_);
4046
auto preprocessed = preprocess(input);
4147

4248
std::vector<int32_t> tokenIdsShape = {

packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
4+
#include <mutex>
45
#include <rnexecutorch/TokenizerModule.h>
56
#include <rnexecutorch/models/embeddings/BaseEmbeddings.h>
67

@@ -20,8 +21,10 @@ class TextEmbeddings final : public BaseEmbeddings {
2021
[[nodiscard(
2122
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
2223
generate(const std::string input);
24+
void unload() noexcept;
2325

2426
private:
27+
mutable std::mutex inference_mutex_;
2528
std::vector<std::vector<int32_t>> inputShapes;
2629
TokenIdsWithAttentionMask preprocess(const std::string &input);
2730
std::unique_ptr<TokenizerModule> tokenizer;

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ std::shared_ptr<OwningArrayBuffer>
5858
TextToImage::generate(std::string input, int32_t imageSize,
5959
size_t numInferenceSteps, int32_t seed,
6060
std::shared_ptr<jsi::Function> callback) {
61+
std::scoped_lock lock(inference_mutex_);
6162
setImageSize(imageSize);
6263
setSeed(seed);
6364

@@ -137,6 +138,7 @@ size_t TextToImage::getMemoryLowerBound() const noexcept {
137138
}
138139

139140
void TextToImage::unload() noexcept {
141+
std::scoped_lock lock(inference_mutex_);
140142
encoder->unload();
141143
unet->unload();
142144
decoder->unload();

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <memory>
4+
#include <mutex>
45
#include <string>
56
#include <vector>
67

@@ -49,6 +50,7 @@ class TextToImage final {
4950
static constexpr float guidanceScale = 7.5f;
5051
static constexpr float latentsScale = 0.18215f;
5152
bool interrupted = false;
53+
mutable std::mutex inference_mutex_;
5254

5355
std::shared_ptr<react::CallInvoker> callInvoker;
5456
std::unique_ptr<Scheduler> scheduler;

0 commit comments

Comments
 (0)