Skip to content

Commit cd31498

Browse files
committed
feat: add inference mutex for thread safety in VAD, Text Embeddings and Text-to-Image
1 parent 92981bf commit cd31498

File tree

5 files changed

+14
-0
lines changed

5 files changed

+14
-0
lines changed

packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,14 @@ TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
3535
return {.inputIds = inputIds64, .attentionMask = attentionMask};
3636
}
3737

38+
void TextEmbeddings::unload() noexcept {
39+
std::scoped_lock lock(generate_mutex_);
40+
BaseModel::unload();
41+
}
42+
3843
std::shared_ptr<OwningArrayBuffer>
3944
TextEmbeddings::generate(const std::string input) {
45+
std::scoped_lock lock(generate_mutex_);
4046
auto preprocessed = preprocess(input);
4147

4248
std::vector<int32_t> tokenIdsShape = {

packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
4+
#include <mutex>
45
#include <rnexecutorch/TokenizerModule.h>
56
#include <rnexecutorch/models/embeddings/BaseEmbeddings.h>
67

@@ -20,8 +21,10 @@ class TextEmbeddings final : public BaseEmbeddings {
2021
[[nodiscard(
2122
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
2223
generate(const std::string input);
24+
void unload() noexcept;
2325

2426
private:
27+
mutable std::mutex generate_mutex_;
2528
std::vector<std::vector<int32_t>> inputShapes;
2629
TokenIdsWithAttentionMask preprocess(const std::string &input);
2730
std::unique_ptr<TokenizerModule> tokenizer;

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ std::shared_ptr<OwningArrayBuffer>
5858
TextToImage::generate(std::string input, int32_t imageSize,
5959
size_t numInferenceSteps, int32_t seed,
6060
std::shared_ptr<jsi::Function> callback) {
61+
std::scoped_lock lock(generate_mutex_);
6162
setImageSize(imageSize);
6263
setSeed(seed);
6364

@@ -137,6 +138,7 @@ size_t TextToImage::getMemoryLowerBound() const noexcept {
137138
}
138139

139140
void TextToImage::unload() noexcept {
141+
std::scoped_lock lock(generate_mutex_);
140142
encoder->unload();
141143
unet->unload();
142144
decoder->unload();

packages/react-native-executorch/common/rnexecutorch/models/text_to_image/TextToImage.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#pragma once
22

33
#include <memory>
4+
#include <mutex>
45
#include <string>
56
#include <vector>
67

@@ -49,6 +50,7 @@ class TextToImage final {
4950
static constexpr float guidanceScale = 7.5f;
5051
static constexpr float latentsScale = 0.18215f;
5152
bool interrupted = false;
53+
mutable std::mutex generate_mutex_;
5254

5355
std::shared_ptr<react::CallInvoker> callInvoker;
5456
std::unique_ptr<Scheduler> scheduler;

packages/react-native-executorch/common/rnexecutorch/models/voice_activity_detection/VoiceActivityDetection.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ class VoiceActivityDetection : public BaseModel {
2323
std::shared_ptr<react::CallInvoker> callInvoker);
2424
[[nodiscard("Registered non-void function")]] std::vector<types::Segment>
2525
generate(std::span<float> waveform) const;
26+
void unload() noexcept;
2627

2728
void unload() noexcept;
2829

0 commit comments

Comments
 (0)