Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,9 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
// We need to dispatch a thread if we want the function to be
// asynchronous. In this thread all accesses to jsi::Runtime need to
// be done via the callInvoker.
threads::GlobalThreadPool::detach([this, promise,
threads::GlobalThreadPool::detach([model = this->model,
callInvoker = this->callInvoker,
promise,
argsConverted =
std::move(argsConverted)]() {
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,14 @@ TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
return {.inputIds = inputIds64, .attentionMask = attentionMask};
}

void TextEmbeddings::unload() noexcept {
std::scoped_lock lock(inference_mutex_);
BaseModel::unload();
}

std::shared_ptr<OwningArrayBuffer>
TextEmbeddings::generate(const std::string input) {
std::scoped_lock lock(inference_mutex_);
auto preprocessed = preprocess(input);

std::vector<int32_t> tokenIdsShape = {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "rnexecutorch/metaprogramming/ConstructorHelpers.h"
#include <mutex>
#include <rnexecutorch/TokenizerModule.h>
#include <rnexecutorch/models/embeddings/BaseEmbeddings.h>

Expand All @@ -20,8 +21,10 @@ class TextEmbeddings final : public BaseEmbeddings {
[[nodiscard(
"Registered non-void function")]] std::shared_ptr<OwningArrayBuffer>
generate(const std::string input);
void unload() noexcept;

private:
mutable std::mutex inference_mutex_;
std::vector<std::vector<int32_t>> inputShapes;
TokenIdsWithAttentionMask preprocess(const std::string &input);
std::unique_ptr<TokenizerModule> tokenizer;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ std::shared_ptr<OwningArrayBuffer>
TextToImage::generate(std::string input, int32_t imageSize,
size_t numInferenceSteps, int32_t seed,
std::shared_ptr<jsi::Function> callback) {
std::scoped_lock lock(inference_mutex_);
setImageSize(imageSize);
setSeed(seed);

Expand Down Expand Up @@ -137,6 +138,7 @@ size_t TextToImage::getMemoryLowerBound() const noexcept {
}

void TextToImage::unload() noexcept {
std::scoped_lock lock(inference_mutex_);
encoder->unload();
unet->unload();
decoder->unload();
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include <memory>
#include <mutex>
#include <string>
#include <vector>

Expand Down Expand Up @@ -49,6 +50,7 @@ class TextToImage final {
static constexpr float guidanceScale = 7.5f;
static constexpr float latentsScale = 0.18215f;
bool interrupted = false;
mutable std::mutex inference_mutex_;

std::shared_ptr<react::CallInvoker> callInvoker;
std::unique_ptr<Scheduler> scheduler;
Expand Down
Loading