Skip to content

Commit 87e8371

Browse files
authored
refactor: ImageEmbedding and TextEmbedding, remove normalization, update URLs (#445)
## Description Changes: - Refactored CLIP URLs - Updated textEmbeddings models URLs to v0.5.0 - Refactored textEmbeddings demo - Removed normalization from ImageEmbeddings - Refactored ImageEmbeddings c++ code - Removed attentionMask from TextEmbeddings postProcessing ### Type of change - [ ] Bug fix (non-breaking change which fixes an issue) - [ ] New feature (non-breaking change which adds functionality) - [x] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [ ] Documentation update (improves or adds clarity to existing documentation) ### Tested on - [x] iOS - [x] Android ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings
1 parent 63cdb16 commit 87e8371

11 files changed

Lines changed: 127 additions & 117 deletions

File tree

apps/text-embeddings/app/clip-embeddings/index.tsx

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,12 @@ import { Ionicons } from '@expo/vector-icons';
1414
import {
1515
useTextEmbeddings,
1616
useImageEmbeddings,
17-
CLIP_VIT_BASE_PATCH_32_TEXT_ENCODER,
18-
CLIP_VIT_BASE_PATCH_32_IMAGE_ENCODER_MODEL,
17+
CLIP_VIT_BASE_PATCH32_TEXT,
18+
CLIP_VIT_BASE_PATCH32_IMAGE,
1919
} from 'react-native-executorch';
2020
import { launchImageLibrary } from 'react-native-image-picker';
2121
import { useIsFocused } from '@react-navigation/native';
22+
import { dotProduct } from '../../utils/math';
2223

2324
export default function ClipEmbeddingsScreenWrapper() {
2425
const isFocused = useIsFocused();
@@ -27,11 +28,8 @@ export default function ClipEmbeddingsScreenWrapper() {
2728
}
2829

2930
function ClipEmbeddingsScreen() {
30-
const model = useTextEmbeddings({ ...CLIP_VIT_BASE_PATCH_32_TEXT_ENCODER });
31-
32-
const imageModel = useImageEmbeddings({
33-
modelSource: CLIP_VIT_BASE_PATCH_32_IMAGE_ENCODER_MODEL,
34-
});
31+
const textModel = useTextEmbeddings(CLIP_VIT_BASE_PATCH32_TEXT);
32+
const imageModel = useImageEmbeddings(CLIP_VIT_BASE_PATCH32_IMAGE);
3533

3634
const [inputSentence, setInputSentence] = useState('');
3735
const [sentencesWithEmbeddings, setSentencesWithEmbeddings] = useState<
@@ -41,18 +39,10 @@ function ClipEmbeddingsScreen() {
4139
{ sentence: string; similarity: number }[]
4240
>([]);
4341

44-
const dotProduct = (a: Float32Array, b: Float32Array) => {
45-
let sum = 0;
46-
for (let i = 0; i < a.length; i++) {
47-
sum += a[i] * b[i];
48-
}
49-
return sum;
50-
};
51-
5242
useEffect(
5343
() => {
5444
const computeEmbeddings = async () => {
55-
if (!model.isReady) return;
45+
if (!textModel.isReady) return;
5646

5747
const sentences = [
5848
'The weather is lovely today.',
@@ -64,7 +54,7 @@ function ClipEmbeddingsScreen() {
6454
try {
6555
const embeddings = [];
6656
for (const sentence of sentences) {
67-
const embedding = await model.forward(sentence);
57+
const embedding = await textModel.forward(sentence);
6858
embeddings.push({ sentence, embedding });
6959
}
7060
setSentencesWithEmbeddings(embeddings);
@@ -76,14 +66,14 @@ function ClipEmbeddingsScreen() {
7666
computeEmbeddings();
7767
},
7868
// eslint-disable-next-line react-hooks/exhaustive-deps
79-
[model.isReady]
69+
[textModel.isReady]
8070
);
8171

8272
const checkSimilarities = async () => {
83-
if (!model.isReady || !inputSentence.trim()) return;
73+
if (!textModel.isReady || !inputSentence.trim()) return;
8474

8575
try {
86-
const inputEmbedding = await model.forward(inputSentence);
76+
const inputEmbedding = await textModel.forward(inputSentence);
8777
const matches = sentencesWithEmbeddings.map(
8878
({ sentence, embedding }) => ({
8979
sentence,
@@ -98,10 +88,10 @@ function ClipEmbeddingsScreen() {
9888
};
9989

10090
const addToSentences = async () => {
101-
if (!model.isReady || !inputSentence.trim()) return;
91+
if (!textModel.isReady || !inputSentence.trim()) return;
10292

10393
try {
104-
const embedding = await model.forward(inputSentence);
94+
const embedding = await textModel.forward(inputSentence);
10595
setSentencesWithEmbeddings((prev) => [
10696
...prev,
10797
{ sentence: inputSentence, embedding },
@@ -115,7 +105,7 @@ function ClipEmbeddingsScreen() {
115105
};
116106

117107
const clearList = async () => {
118-
if (!model.isReady) return;
108+
if (!textModel.isReady) return;
119109
try {
120110
setSentencesWithEmbeddings([]);
121111
} catch (error) {
@@ -149,16 +139,14 @@ function ClipEmbeddingsScreen() {
149139
}
150140
};
151141

152-
const getModelStatusText = () => {
153-
if (model.error || imageModel.error) {
154-
return `Oops! Error: ${model.error || imageModel.error}`;
142+
const getModelStatusText = (model: typeof textModel | typeof imageModel) => {
143+
if (model.error) {
144+
return `Oops! Error: ${model.error}`;
155145
}
156-
if (!model.isReady || !imageModel.isReady) {
157-
return `Loading model ${(((model.downloadProgress + imageModel.downloadProgress) / 2) * 100).toFixed(2)}%`;
146+
if (!model.isReady) {
147+
return `Loading model ${(model.downloadProgress * 100).toFixed(2)}%`;
158148
}
159-
return model.isGenerating || imageModel.isGenerating
160-
? 'Generating...'
161-
: 'Model is ready';
149+
return model.isGenerating ? 'Generating...' : 'Model is ready';
162150
};
163151

164152
return (
@@ -169,8 +157,12 @@ function ClipEmbeddingsScreen() {
169157
>
170158
<ScrollView contentContainerStyle={styles.scrollContainer}>
171159
<Text style={styles.heading}>Text Embeddings Playground</Text>
172-
<Text style={styles.sectionTitle}>{getModelStatusText()}</Text>
173-
160+
<Text style={styles.sectionTitle}>
161+
Text Model: {getModelStatusText(textModel)}
162+
</Text>
163+
<Text style={styles.sectionTitle}>
164+
Image Model: {getModelStatusText(imageModel)}
165+
</Text>
174166
<View style={styles.card}>
175167
<Text style={styles.sectionTitle}>List of Existing Sentences</Text>
176168
{sentencesWithEmbeddings.map((item, index) => (

apps/text-embeddings/app/text-embeddings/index.tsx

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import {
1717
ALL_MINILM_L6_V2_TOKENIZER,
1818
} from 'react-native-executorch';
1919
import { useIsFocused } from '@react-navigation/native';
20+
import { dotProduct } from '../../utils/math';
2021

2122
export default function TextEmbeddingsScreenWrapper() {
2223
const isFocused = useIsFocused();
@@ -38,14 +39,6 @@ function TextEmbeddingsScreen() {
3839
{ sentence: string; similarity: number }[]
3940
>([]);
4041

41-
const dotProduct = (a: Float32Array, b: Float32Array) => {
42-
let sum = 0;
43-
for (let i = 0; i < a.length; i++) {
44-
sum += a[i] * b[i];
45-
}
46-
return sum;
47-
};
48-
4942
useEffect(
5043
() => {
5144
const computeEmbeddings = async () => {

apps/text-embeddings/utils/math.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
export const dotProduct = (a: Float32Array, b: Float32Array) => {
2+
if (a.length !== b.length) {
3+
throw new Error('Vectors must be of the same length');
4+
}
5+
6+
let sum = 0;
7+
for (let i = 0; i < a.length; i++) {
8+
sum += a[i] * b[i];
9+
}
10+
return sum;
11+
};

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,13 @@
33
#include <rnexecutorch/TokenizerModule.h>
44
#include <rnexecutorch/host_objects/JsiConversions.h>
55
#include <rnexecutorch/models/classification/Classification.h>
6-
#include <rnexecutorch/models/image_embeddings/ImageEmbeddings.h>
6+
#include <rnexecutorch/models/embeddings/image/ImageEmbeddings.h>
7+
#include <rnexecutorch/models/embeddings/text/TextEmbeddings.h>
78
#include <rnexecutorch/models/image_segmentation/ImageSegmentation.h>
89
#include <rnexecutorch/models/llm/LLM.h>
910
#include <rnexecutorch/models/object_detection/ObjectDetection.h>
1011
#include <rnexecutorch/models/speech_to_text/SpeechToText.h>
1112
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
12-
#include <rnexecutorch/models/text_embeddings/TextEmbeddings.h>
1313

1414
namespace rnexecutorch {
1515

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
#include "BaseEmbeddings.h"
2+
3+
#include <span>
4+
5+
namespace rnexecutorch {
6+
7+
BaseEmbeddings::BaseEmbeddings(const std::string &modelSource,
8+
std::shared_ptr<react::CallInvoker> callInvoker)
9+
: BaseModel(modelSource, callInvoker) {}
10+
11+
std::shared_ptr<OwningArrayBuffer>
12+
BaseEmbeddings::postprocess(const Result<std::vector<EValue>> &forwardResult) {
13+
auto forwardResultTensor = forwardResult->at(0).toTensor();
14+
auto dataPtr = forwardResultTensor.mutable_data_ptr();
15+
auto outputNumel = forwardResultTensor.numel();
16+
17+
std::span<float> modelOutput(static_cast<float *>(dataPtr), outputNumel);
18+
19+
auto createBuffer = [](const auto &data, size_t size) {
20+
auto buffer = std::make_shared<OwningArrayBuffer>(size);
21+
std::memcpy(buffer->data(), data, size);
22+
return buffer;
23+
};
24+
return createBuffer(modelOutput.data(), modelOutput.size_bytes());
25+
}
26+
27+
} // namespace rnexecutorch
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#pragma once
2+
3+
#include <rnexecutorch/models/BaseModel.h>
4+
5+
namespace rnexecutorch {
6+
7+
class BaseEmbeddings : public BaseModel {
8+
public:
9+
BaseEmbeddings(const std::string &modelSource,
10+
std::shared_ptr<react::CallInvoker> callInvoker);
11+
12+
protected:
13+
std::shared_ptr<OwningArrayBuffer>
14+
postprocess(const Result<std::vector<EValue>> &forwardResult);
15+
};
16+
17+
}; // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/image_embeddings/ImageEmbeddings.cpp renamed to packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.cpp

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
#include <executorch/extension/tensor/tensor.h>
44
#include <rnexecutorch/data_processing/ImageProcessing.h>
55
#include <rnexecutorch/data_processing/Numerical.h>
6+
67
namespace rnexecutorch {
78

89
ImageEmbeddings::ImageEmbeddings(
910
const std::string &modelSource,
1011
std::shared_ptr<react::CallInvoker> callInvoker)
11-
: BaseModel(modelSource, callInvoker) {
12+
: BaseEmbeddings(modelSource, callInvoker) {
1213
auto inputTensors = getAllInputShapes();
1314
if (inputTensors.size() == 0) {
1415
throw std::runtime_error("Model seems to not take any input tensors.");
@@ -31,30 +32,14 @@ ImageEmbeddings::generate(std::string imageSource) {
3132
auto [inputTensor, originalSize] =
3233
imageprocessing::readImageToTensor(imageSource, getAllInputShapes()[0]);
3334

34-
auto result = BaseModel::forward(inputTensor);
35-
if (!result.ok()) {
36-
throw std::runtime_error("Forward pass failed: Error " +
37-
std::to_string(static_cast<int>(result.error())));
35+
auto forwardResult = BaseModel::forward(inputTensor);
36+
if (!forwardResult.ok()) {
37+
throw std::runtime_error(
38+
"Function forward in ImageEmbeddings failed with error code: " +
39+
std::to_string(static_cast<uint32_t>(forwardResult.error())));
3840
}
3941

40-
auto &outputs = result.get();
41-
42-
if (outputs.size() > 1) {
43-
throw std::runtime_error("It returned multiple outputs!");
44-
}
45-
46-
auto &outputTensor = outputs.at(0).toTensor();
47-
std::span<float> outputTensorSpan(
48-
static_cast<float *>(outputTensor.mutable_data_ptr()),
49-
outputTensor.numel());
50-
51-
numerical::normalize(outputTensorSpan);
52-
53-
size_t bufferSize = outputTensorSpan.size_bytes();
54-
auto buffer = std::make_shared<OwningArrayBuffer>(bufferSize);
55-
56-
std::memcpy(buffer->data(), outputTensorSpan.data(), bufferSize);
57-
58-
return buffer;
42+
return BaseEmbeddings::postprocess(forwardResult);
5943
}
44+
6045
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/image_embeddings/ImageEmbeddings.h renamed to packages/react-native-executorch/common/rnexecutorch/models/embeddings/image/ImageEmbeddings.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
#include <executorch/runtime/core/evalue.h>
55
#include <opencv2/opencv.hpp>
66

7-
#include <rnexecutorch/models/BaseModel.h>
7+
#include <rnexecutorch/models/embeddings/BaseEmbeddings.h>
88

99
namespace rnexecutorch {
1010
using executorch::extension::TensorPtr;
1111
using executorch::runtime::EValue;
1212

13-
class ImageEmbeddings : public BaseModel {
13+
class ImageEmbeddings final : public BaseEmbeddings {
1414
public:
1515
ImageEmbeddings(const std::string &modelSource,
1616
std::shared_ptr<react::CallInvoker> callInvoker);
@@ -20,4 +20,4 @@ class ImageEmbeddings : public BaseModel {
2020
cv::Size modelImageSize{0, 0};
2121
};
2222

23-
} // namespace rnexecutorch
23+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/text_embeddings/TextEmbeddings.cpp renamed to packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.cpp

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ using namespace executorch::extension;
99
TextEmbeddings::TextEmbeddings(const std::string &modelSource,
1010
const std::string &tokenizerSource,
1111
std::shared_ptr<react::CallInvoker> callInvoker)
12-
: BaseModel(modelSource, callInvoker),
12+
: BaseEmbeddings(modelSource, callInvoker),
1313
tokenizer(
1414
std::make_unique<TokenizerModule>(tokenizerSource, callInvoker)) {}
1515

@@ -36,42 +36,26 @@ TokenIdsWithAttentionMask TextEmbeddings::preprocess(const std::string &input) {
3636
std::shared_ptr<OwningArrayBuffer>
3737
TextEmbeddings::generate(const std::string input) {
3838
auto preprocessed = preprocess(input);
39+
3940
std::vector<int32_t> tokenIdsShape = {
4041
1, static_cast<int32_t>(preprocessed.inputIds.size())};
4142
std::vector<int32_t> attnMaskShape = {
4243
1, static_cast<int32_t>(preprocessed.attentionMask.size())};
44+
4345
auto tokenIds = make_tensor_ptr(tokenIdsShape, preprocessed.inputIds.data(),
4446
ScalarType::Long);
4547
auto attnMask = make_tensor_ptr(
4648
attnMaskShape, preprocessed.attentionMask.data(), ScalarType::Long);
49+
4750
auto forwardResult = BaseModel::forward({tokenIds, attnMask});
51+
4852
if (!forwardResult.ok()) {
4953
throw std::runtime_error(
50-
"Failed to forward, error: " +
54+
"Function forward in TextEmbeddings failed with error code: " +
5155
std::to_string(static_cast<uint32_t>(forwardResult.error())));
5256
}
5357

54-
auto forwardResultTensor = forwardResult->at(0).toTensor();
55-
auto dataPtr = forwardResultTensor.mutable_data_ptr();
56-
auto outputNumel = forwardResultTensor.numel();
57-
58-
std::span<float> modelOutputSpan(static_cast<float *>(dataPtr), outputNumel);
59-
std::span<const int64_t> attnMaskSpan(preprocessed.attentionMask.data(),
60-
preprocessed.attentionMask.size());
61-
62-
return postprocess(modelOutputSpan, attnMaskSpan);
63-
}
64-
65-
std::shared_ptr<OwningArrayBuffer>
66-
TextEmbeddings::postprocess(std::span<float> modelOutput,
67-
std::span<const int64_t> attnMask) {
68-
auto createBuffer = [](const auto &data, size_t size) {
69-
auto buffer = std::make_shared<OwningArrayBuffer>(size);
70-
std::memcpy(buffer->data(), data, size);
71-
return buffer;
72-
};
73-
74-
return createBuffer(modelOutput.data(), modelOutput.size_bytes());
58+
return BaseEmbeddings::postprocess(forwardResult);
7559
}
7660

77-
} // namespace rnexecutorch
61+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/text_embeddings/TextEmbeddings.h renamed to packages/react-native-executorch/common/rnexecutorch/models/embeddings/text/TextEmbeddings.h

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
#pragma once
22

33
#include <rnexecutorch/TokenizerModule.h>
4-
#include <rnexecutorch/models/BaseModel.h>
5-
#include <span>
4+
#include <rnexecutorch/models/embeddings/BaseEmbeddings.h>
65

76
namespace rnexecutorch {
87

@@ -11,7 +10,7 @@ struct TokenIdsWithAttentionMask {
1110
std::vector<int64_t> attentionMask;
1211
};
1312

14-
class TextEmbeddings : public BaseModel {
13+
class TextEmbeddings final : public BaseEmbeddings {
1514
public:
1615
TextEmbeddings(const std::string &modelSource,
1716
const std::string &tokenizerSource,
@@ -21,9 +20,7 @@ class TextEmbeddings : public BaseModel {
2120
private:
2221
std::vector<std::vector<int32_t>> inputShapes;
2322
TokenIdsWithAttentionMask preprocess(const std::string &input);
24-
std::shared_ptr<OwningArrayBuffer>
25-
postprocess(std::span<float> modelOutput,
26-
std::span<const int64_t> attentionMask);
2723
std::unique_ptr<TokenizerModule> tokenizer;
2824
};
25+
2926
}; // namespace rnexecutorch

0 commit comments

Comments
 (0)