Skip to content

Commit 9987421

Browse files
committed
C++ and library part of iamge embeddings. Docs and examples will follow
1 parent 189fa97 commit 9987421

8 files changed

Lines changed: 152 additions & 1 deletion

File tree

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
#include <rnexecutorch/host_objects/JsiConversions.h>
44
#include <rnexecutorch/models/classification/Classification.h>
5+
#include <rnexecutorch/models/image_embeddings/ImageEmbeddings.h>
56
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
67
#include <rnexecutorch/models/object_detection/ObjectDetection.h>
78
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
@@ -42,6 +43,10 @@ void RnExecutorchInstaller::injectJSIBindings(
4243
*jsiRuntime, "loadExecutorchModule",
4344
RnExecutorchInstaller::loadModel<BaseModel>(jsiRuntime, jsCallInvoker,
4445
"loadExecutorchModule"));
46+
jsiRuntime->global().setProperty(
47+
*jsiRuntime, "loadImageEmbeddings",
48+
RnExecutorchInstaller::loadModel<ImageEmbeddings>(
49+
jsiRuntime, jsCallInvoker, "loadImageEmbeddings"));
4550
}
4651

4752
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,23 @@ getJsiValue(const std::vector<std::shared_ptr<OwningArrayBuffer>> &vec,
182182
return jsi::Value(runtime, array);
183183
}
184184

185+
inline jsi::Value getJsiValue(const std::shared_ptr<JSTensorViewOut> &jsTensor,
186+
jsi::Runtime &runtime) {
187+
188+
jsi::Object tensorObj(runtime);
189+
190+
tensorObj.setProperty(runtime, "sizes",
191+
getJsiValue(jsTensor->sizes, runtime));
192+
193+
tensorObj.setProperty(runtime, "scalarType",
194+
jsi::Value(static_cast<int>(jsTensor->scalarType)));
195+
196+
jsi::ArrayBuffer arrayBuffer(runtime, jsTensor->dataPtr);
197+
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);
198+
199+
return jsi::Value(runtime, tensorObj);
200+
}
201+
185202
inline jsi::Value
186203
getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
187204
jsi::Runtime &runtime) {
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "ImageEmbeddings.h"
2+
3+
#include <cstdint>
4+
#include <executorch/extension/tensor/tensor.h>
5+
#include <iostream>
6+
#include <rnexecutorch/Log.h>
7+
#include <rnexecutorch/data_processing/ImageProcessing.h>
8+
9+
namespace rnexecutorch {
10+
11+
ImageEmbeddings::ImageEmbeddings(
12+
const std::string &modelSource,
13+
std::shared_ptr<react::CallInvoker> callInvoker)
14+
: BaseModel(modelSource, callInvoker) {
15+
auto inputTensors = getAllInputShapes();
16+
if (inputTensors.size() == 0) {
17+
throw std::runtime_error("Model seems to not take any input tensors.");
18+
}
19+
std::vector<int32_t> modelInputShape = inputTensors[0];
20+
if (modelInputShape.size() < 2) {
21+
char errorMessage[100];
22+
std::snprintf(errorMessage, sizeof(errorMessage),
23+
"Unexpected model input size, expected at least 2 dimentions "
24+
"but got: %zu.",
25+
modelInputShape.size());
26+
throw std::runtime_error(errorMessage);
27+
}
28+
modelImageSize = cv::Size(modelInputShape[modelInputShape.size() - 1],
29+
modelInputShape[modelInputShape.size() - 2]);
30+
}
31+
32+
std::shared_ptr<JSTensorViewOut>
33+
ImageEmbeddings::generate(std::string imageSource) {
34+
auto [inputTensor, originalSize] =
35+
imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]);
36+
37+
auto result = BaseModel::forward(inputTensor);
38+
if (!result.ok()) {
39+
throw std::runtime_error("Forward pass failed: Error " +
40+
std::to_string(static_cast<int>(result.error())));
41+
}
42+
43+
auto &outputs = result.get();
44+
45+
if (outputs.size() > 1) {
46+
throw std::runtime_error("It returned multiple outputs!");
47+
}
48+
49+
auto &outputTensor = outputs.at(0).toTensor();
50+
auto sizesRaw = outputTensor.sizes();
51+
std::vector<int32_t> sizes =
52+
std::vector<int32_t>(sizesRaw.begin(), sizesRaw.end());
53+
size_t bufferSize = outputTensor.numel() * outputTensor.element_size();
54+
auto buffer = std::make_shared<OwningArrayBuffer>(bufferSize);
55+
std::memcpy(buffer->data(), outputTensor.const_data_ptr(), bufferSize);
56+
auto jsTensor = std::make_shared<JSTensorViewOut>(
57+
sizes, outputTensor.scalar_type(), buffer);
58+
59+
return jsTensor;
60+
}
61+
} // namespace rnexecutorch
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
#pragma once
2+
3+
#include <unordered_map>
4+
5+
#include <executorch/extension/tensor/tensor_ptr.h>
6+
#include <executorch/runtime/core/evalue.h>
7+
#include <opencv2/opencv.hpp>
8+
9+
#include <rnexecutorch/models/BaseModel.h>
10+
11+
namespace rnexecutorch {
12+
using executorch::extension::TensorPtr;
13+
using executorch::runtime::EValue;
14+
15+
class ImageEmbeddings : public BaseModel {
16+
public:
17+
ImageEmbeddings(const std::string &modelSource,
18+
std::shared_ptr<react::CallInvoker> callInvoker);
19+
std::shared_ptr<JSTensorViewOut> generate(std::string imageSource);
20+
21+
private:
22+
cv::Size modelImageSize{0, 0};
23+
};
24+
25+
} // namespace rnexecutorch
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
import { ImageEmbeddingsModule } from '../../modules/computer_vision/ImageEmbeddingsModule';
2+
import { ResourceSource } from '../../types/common';
3+
import { useNonStaticModule } from '../useNonStaticModule';
4+
5+
export const useImageEmbeddings = ({
6+
modelSource,
7+
preventLoad = false,
8+
}: {
9+
modelSource: ResourceSource;
10+
preventLoad?: boolean;
11+
}) =>
12+
useNonStaticModule({
13+
module: ImageEmbeddingsModule,
14+
loadArgs: [modelSource],
15+
preventLoad,
16+
});

packages/react-native-executorch/src/hooks/useNonStaticModule.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ interface ModuleConstructor<M extends Module> {
1414
export const useNonStaticModule = <
1515
M extends Module,
1616
LoadArgs extends Parameters<M['load']>,
17-
ForwardArgs extends any[],
17+
ForwardArgs extends Parameters<M['forward']>,
1818
ForwardReturn extends Awaited<ReturnType<M['forward']>>,
1919
>({
2020
module,

packages/react-native-executorch/src/index.tsx

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ declare global {
99
var loadClassification: (source: string) => any;
1010
var loadObjectDetection: (source: string) => any;
1111
var loadExecutorchModule: (source: string) => any;
12+
var loadImageEmbeddings: (source: string) => any;
1213
}
1314
// eslint-disable no-var
1415
if (
@@ -32,6 +33,7 @@ export * from './hooks/computer_vision/useStyleTransfer';
3233
export * from './hooks/computer_vision/useImageSegmentation';
3334
export * from './hooks/computer_vision/useOCR';
3435
export * from './hooks/computer_vision/useVerticalOCR';
36+
export * from './hooks/computer_vision/useImageEmbeddings';
3537

3638
export * from './hooks/natural_language_processing/useLLM';
3739
export * from './hooks/natural_language_processing/useSpeechToText';
@@ -48,6 +50,7 @@ export * from './modules/computer_vision/ImageSegmentationModule';
4850
export * from './modules/computer_vision/OCRModule';
4951
export * from './modules/computer_vision/VerticalOCRModule';
5052
export * from './modules/general/ExecutorchModule';
53+
export * from './modules/computer_vision/ImageEmbeddingsModule';
5154

5255
export * from './modules/natural_language_processing/LLMModule';
5356
export * from './modules/natural_language_processing/SpeechToTextModule';
Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import { ResourceFetcher } from '../../utils/ResourceFetcher';
2+
import { ResourceSource } from '../../types/common';
3+
import { TensorPtr } from '../../types/common';
4+
import { ETError, getError } from '../../Error';
5+
import { BaseNonStaticModule } from '../BaseNonStaticModule';
6+
7+
export class ImageEmbeddingsModule extends BaseNonStaticModule {
8+
async load(
9+
modelSource: ResourceSource,
10+
onDownloadProgressCallback: (_: number) => void = () => {}
11+
): Promise<void> {
12+
const paths = await ResourceFetcher.fetchMultipleResources(
13+
onDownloadProgressCallback,
14+
modelSource
15+
);
16+
this.nativeModule = global.loadImageEmbeddings(paths[0] || '');
17+
}
18+
19+
async forward(imageSource: string): Promise<TensorPtr> {
20+
if (this.nativeModule == null)
21+
throw new Error(getError(ETError.ModuleNotLoaded));
22+
return await this.nativeModule.generate(imageSource);
23+
}
24+
}

0 commit comments

Comments
 (0)