Skip to content

Commit 3527d01

Browse files
feat: initial version of vision model API
1 parent e4e0f95 commit 3527d01

18 files changed

Lines changed: 577 additions & 71 deletions

File tree

.cspell-wordlist.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,4 +111,5 @@ logprob
111111
RNFS
112112
pogodin
113113
kesha
114-
antonov
114+
antonov
115+
worklet

apps/computer-vision/app/object_detection/index.tsx

Lines changed: 159 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,66 @@
11
import Spinner from '../../components/Spinner';
2-
import { BottomBar } from '../../components/BottomBar';
32
import { getImage } from '../../utils';
43
import {
54
Detection,
65
useObjectDetection,
76
SSDLITE_320_MOBILENET_V3_LARGE,
87
} from 'react-native-executorch';
9-
import { View, StyleSheet, Image } from 'react-native';
8+
import { View, StyleSheet, Image, TouchableOpacity, Text } from 'react-native';
109
import ImageWithBboxes from '../../components/ImageWithBboxes';
1110
import React, { useContext, useEffect, useState } from 'react';
1211
import { GeneratingContext } from '../../context';
1312
import ScreenWrapper from '../../ScreenWrapper';
13+
import ColorPalette from '../../colors';
14+
import { Images } from 'react-native-nitro-image';
15+
16+
// Helper function to convert image URI to raw pixel data using NitroImage
17+
async function imageUriToPixelData(
18+
uri: string,
19+
targetWidth: number,
20+
targetHeight: number
21+
): Promise<{
22+
data: ArrayBuffer;
23+
width: number;
24+
height: number;
25+
channels: number;
26+
}> {
27+
try {
28+
// Load image and resize to target dimensions
29+
const image = await Images.loadFromFileAsync(uri);
30+
const resized = image.resize(targetWidth, targetHeight);
31+
32+
// Get pixel data as ArrayBuffer (RGBA format)
33+
const pixelData = resized.toRawPixelData();
34+
const buffer =
35+
pixelData instanceof ArrayBuffer ? pixelData : pixelData.buffer;
36+
37+
// Calculate actual buffer dimensions (accounts for device pixel ratio)
38+
const bufferSize = buffer?.byteLength || 0;
39+
const totalPixels = bufferSize / 4; // RGBA = 4 bytes per pixel
40+
const aspectRatio = targetWidth / targetHeight;
41+
const actualHeight = Math.sqrt(totalPixels / aspectRatio);
42+
const actualWidth = totalPixels / actualHeight;
43+
44+
console.log('Requested:', targetWidth, 'x', targetHeight);
45+
console.log('Buffer size:', bufferSize);
46+
console.log(
47+
'Actual dimensions:',
48+
Math.round(actualWidth),
49+
'x',
50+
Math.round(actualHeight)
51+
);
52+
53+
return {
54+
data: buffer,
55+
width: Math.round(actualWidth),
56+
height: Math.round(actualHeight),
57+
channels: 4, // RGBA
58+
};
59+
} catch (error) {
60+
console.error('Error loading image with NitroImage:', error);
61+
throw error;
62+
}
63+
}
1464

1565
export default function ObjectDetectionScreen() {
1666
const [imageUri, setImageUri] = useState('');
@@ -42,10 +92,41 @@ export default function ObjectDetectionScreen() {
4292
const runForward = async () => {
4393
if (imageUri) {
4494
try {
45-
const output = await ssdLite.forward(imageUri);
95+
console.log('Running forward with string URI...');
96+
const output = await ssdLite.forward(imageUri, 0.5);
97+
console.log('String URI result:', output.length, 'detections');
4698
setResults(output);
4799
} catch (e) {
48-
console.error(e);
100+
console.error('Error in runForward:', e);
101+
}
102+
}
103+
};
104+
105+
const runForwardPixels = async () => {
106+
if (imageUri && imageDimensions) {
107+
try {
108+
console.log('Converting image to pixel data...');
109+
// Resize to 640x640 to avoid memory issues
110+
const intermediateSize = 640;
111+
const pixelData = await imageUriToPixelData(
112+
imageUri,
113+
intermediateSize,
114+
intermediateSize
115+
);
116+
117+
console.log('Running forward with pixel data...', {
118+
width: pixelData.width,
119+
height: pixelData.height,
120+
channels: pixelData.channels,
121+
dataSize: pixelData.data.byteLength,
122+
});
123+
124+
// Run inference using unified forward() API
125+
const output = await ssdLite.forward(pixelData, 0.5);
126+
console.log('Pixel data result:', output.length, 'detections');
127+
setResults(output);
128+
} catch (e) {
129+
console.error('Error in runForwardPixels:', e);
49130
}
50131
}
51132
};
@@ -81,10 +162,41 @@ export default function ObjectDetectionScreen() {
81162
)}
82163
</View>
83164
</View>
84-
<BottomBar
85-
handleCameraPress={handleCameraPress}
86-
runForward={runForward}
87-
/>
165+
166+
{/* Custom bottom bar with two buttons */}
167+
<View style={styles.bottomContainer}>
168+
<View style={styles.bottomIconsContainer}>
169+
<TouchableOpacity onPress={() => handleCameraPress(false)}>
170+
<Text style={styles.iconText}>📷 Gallery</Text>
171+
</TouchableOpacity>
172+
</View>
173+
174+
<View style={styles.buttonsRow}>
175+
<TouchableOpacity
176+
style={[
177+
styles.button,
178+
styles.halfButton,
179+
!imageUri && styles.buttonDisabled,
180+
]}
181+
onPress={runForward}
182+
disabled={!imageUri}
183+
>
184+
<Text style={styles.buttonText}>Run (String)</Text>
185+
</TouchableOpacity>
186+
187+
<TouchableOpacity
188+
style={[
189+
styles.button,
190+
styles.halfButton,
191+
!imageUri && styles.buttonDisabled,
192+
]}
193+
onPress={runForwardPixels}
194+
disabled={!imageUri}
195+
>
196+
<Text style={styles.buttonText}>Run (Pixels)</Text>
197+
</TouchableOpacity>
198+
</View>
199+
</View>
88200
</ScreenWrapper>
89201
);
90202
}
@@ -129,4 +241,43 @@ const styles = StyleSheet.create({
129241
width: '100%',
130242
height: '100%',
131243
},
244+
bottomContainer: {
245+
width: '100%',
246+
gap: 15,
247+
alignItems: 'center',
248+
padding: 16,
249+
flex: 1,
250+
},
251+
bottomIconsContainer: {
252+
flexDirection: 'row',
253+
justifyContent: 'center',
254+
width: '100%',
255+
},
256+
iconText: {
257+
fontSize: 16,
258+
color: ColorPalette.primary,
259+
},
260+
buttonsRow: {
261+
flexDirection: 'row',
262+
width: '100%',
263+
gap: 10,
264+
},
265+
button: {
266+
height: 50,
267+
justifyContent: 'center',
268+
alignItems: 'center',
269+
backgroundColor: ColorPalette.primary,
270+
color: '#fff',
271+
borderRadius: 8,
272+
},
273+
halfButton: {
274+
flex: 1,
275+
},
276+
buttonDisabled: {
277+
opacity: 0.5,
278+
},
279+
buttonText: {
280+
color: '#fff',
281+
fontSize: 16,
282+
},
132283
});

packages/react-native-executorch/common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,16 @@ class RnExecutorchInstaller {
5454
meta::createConstructorArgsWithCallInvoker<ModelT>(
5555
args, runtime, jsCallInvoker);
5656

57-
auto modelImplementationPtr = std::make_shared<ModelT>(
58-
std::make_from_tuple<ModelT>(constructorArgs));
57+
// This unpacks the tuple and calls the constructor directly inside
58+
// make_shared. It avoids creating a temporary object, so no
59+
// move/copy is required.
60+
auto modelImplementationPtr = std::apply(
61+
[](auto &&...unpackedArgs) {
62+
return std::make_shared<ModelT>(
63+
std::forward<decltype(unpackedArgs)>(unpackedArgs)...);
64+
},
65+
std::move(constructorArgs));
66+
5967
auto modelHostObject = std::make_shared<ModelHostObject<ModelT>>(
6068
modelImplementationPtr, jsCallInvoker);
6169

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,11 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
4545
"getInputShape"));
4646
}
4747

48-
if constexpr (meta::HasGenerate<Model>) {
49-
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
50-
promiseHostFunction<&Model::generate>,
51-
"generate"));
48+
if constexpr (meta::HasGenerateFromString<Model>) {
49+
addFunctions(
50+
JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
51+
promiseHostFunction<&Model::generateFromString>,
52+
"generateFromString"));
5253
}
5354

5455
if constexpr (meta::HasEncode<Model>) {
@@ -155,10 +156,22 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
155156
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
156157
promiseHostFunction<&Model::stream>,
157158
"stream"));
159+
}
160+
161+
// Register generateFromFrame for all VisionModel subclasses
162+
if constexpr (meta::DerivedFromOrSameAs<Model, models::VisionModel>) {
158163
addFunctions(JSI_EXPORT_FUNCTION(
159164
ModelHostObject<Model>, synchronousHostFunction<&Model::streamStop>,
160165
"streamStop"));
161166
}
167+
168+
// Register generateFromPixels for models that support it
169+
if constexpr (meta::HasGenerateFromPixels<Model>) {
170+
addFunctions(
171+
JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
172+
visionHostFunction<&Model::generateFromPixels>,
173+
"generateFromPixels"));
174+
}
162175
}
163176

164177
// A generic host function that runs synchronously, works analogously to the

packages/react-native-executorch/common/rnexecutorch/metaprogramming/TypeConcepts.h

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,13 @@ template <typename T, typename Base>
1212
concept SameAs = std::is_same_v<Base, T>;
1313

1414
template <typename T>
15-
concept HasGenerate = requires(T t) {
16-
{ &T::generate };
15+
concept HasGenerateFromString = requires(T t) {
16+
{ &T::generateFromString };
17+
};
18+
19+
template <typename T>
20+
concept HasGenerateFromPixels = requires(T t) {
21+
{ &T::generateFromPixels };
1722
};
1823

1924
template <typename T>

packages/react-native-executorch/common/rnexecutorch/models/VisionModel.cpp

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ namespace models {
66

77
using namespace facebook;
88

9-
cv::Mat VisionModel::extractAndPreprocess(jsi::Runtime &runtime,
10-
const jsi::Value &frameData) const {
9+
cv::Mat VisionModel::extractFromFrame(jsi::Runtime &runtime,
10+
const jsi::Value &frameData) const {
1111
// Extract frame using FrameProcessor utility
1212
auto frameObj = frameData.asObject(runtime);
1313
cv::Mat frame = utils::FrameProcessor::extractFrame(runtime, frameObj);
@@ -16,5 +16,48 @@ cv::Mat VisionModel::extractAndPreprocess(jsi::Runtime &runtime,
1616
return preprocessFrame(frame);
1717
}
1818

19+
cv::Mat VisionModel::extractFromPixels(jsi::Runtime &runtime,
20+
const jsi::Object &pixelData) const {
21+
// Extract width, height, and channels
22+
if (!pixelData.hasProperty(runtime, "width") ||
23+
!pixelData.hasProperty(runtime, "height") ||
24+
!pixelData.hasProperty(runtime, "channels") ||
25+
!pixelData.hasProperty(runtime, "data")) {
26+
throw std::runtime_error(
27+
"Invalid pixel data: must contain width, height, channels, and data");
28+
}
29+
30+
int width = pixelData.getProperty(runtime, "width").asNumber();
31+
int height = pixelData.getProperty(runtime, "height").asNumber();
32+
int channels = pixelData.getProperty(runtime, "channels").asNumber();
33+
34+
// Get the ArrayBuffer
35+
auto dataValue = pixelData.getProperty(runtime, "data");
36+
if (!dataValue.isObject() ||
37+
!dataValue.asObject(runtime).isArrayBuffer(runtime)) {
38+
throw std::runtime_error(
39+
"pixel data 'data' property must be an ArrayBuffer");
40+
}
41+
42+
auto arrayBuffer = dataValue.asObject(runtime).getArrayBuffer(runtime);
43+
size_t expectedSize = width * height * channels;
44+
45+
if (arrayBuffer.size(runtime) != expectedSize) {
46+
throw std::runtime_error(
47+
"ArrayBuffer size does not match width * height * channels");
48+
}
49+
50+
// Create cv::Mat and copy the data
51+
// OpenCV uses BGR/BGRA format internally, but we'll create as-is and let
52+
// preprocessFrame handle conversion
53+
int cvType = (channels == 3) ? CV_8UC3 : CV_8UC4;
54+
cv::Mat image(height, width, cvType);
55+
56+
// Copy data from ArrayBuffer to cv::Mat
57+
std::memcpy(image.data, arrayBuffer.data(runtime), expectedSize);
58+
59+
return image;
60+
}
61+
1962
} // namespace models
2063
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/VisionModel.h

Lines changed: 39 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -121,12 +121,48 @@ class VisionModel : public BaseModel {
121121
* responsible
122122
* @note Typical usage:
123123
* @code
124-
* cv::Mat preprocessed = extractAndPreprocess(runtime, frameData);
124+
* cv::Mat preprocessed = extractFromFrame(runtime, frameData);
125125
* auto tensor = image_processing::getTensorFromMatrix(dims, preprocessed);
126126
* @endcode
127127
*/
128-
cv::Mat extractAndPreprocess(jsi::Runtime &runtime,
129-
const jsi::Value &frameData) const;
128+
cv::Mat extractFromFrame(jsi::Runtime &runtime,
129+
const jsi::Value &frameData) const;
130+
131+
/**
132+
* @brief Extract cv::Mat from raw pixel data (ArrayBuffer) sent from
133+
* JavaScript
134+
*
135+
* This method enables users to run inference on raw pixel data without file
136+
* I/O. Useful for processing images already in memory (e.g., from canvas,
137+
* image library).
138+
*
139+
* @param runtime JSI runtime
140+
* @param pixelData JSI object containing:
141+
* - data: ArrayBuffer with raw pixel values
142+
* - width: number - image width
143+
* - height: number - image height
144+
* - channels: number - number of channels (3 for RGB, 4 for
145+
* RGBA)
146+
*
147+
* @return cv::Mat containing the pixel data
148+
*
149+
* @throws std::runtime_error if pixelData format is invalid
150+
*
151+
* @note The returned cv::Mat owns a copy of the data
152+
* @note Expected pixel format: RGB or RGBA, row-major order
153+
* @note Typical usage from JS:
154+
* @code
155+
* const pixels = new Uint8Array([...]); // Raw pixel data
156+
* const result = model.generateFromPixels({
157+
* data: pixels.buffer,
158+
* width: 640,
159+
* height: 480,
160+
* channels: 3
161+
* }, 0.5);
162+
* @endcode
163+
*/
164+
cv::Mat extractFromPixels(jsi::Runtime &runtime,
165+
const jsi::Object &pixelData) const;
130166
};
131167

132168
} // namespace models

0 commit comments

Comments
 (0)