Skip to content

Commit ffcf72f

Browse files
feat: add tests for generateFromPixels method
1 parent 98395af commit ffcf72f

7 files changed

Lines changed: 103 additions & 129 deletions

File tree

Lines changed: 8 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,16 @@
11
import Spinner from '../../components/Spinner';
2+
import { BottomBar } from '../../components/BottomBar';
23
import { getImage } from '../../utils';
34
import {
45
Detection,
56
useObjectDetection,
67
SSDLITE_320_MOBILENET_V3_LARGE,
7-
ScalarType,
8-
PixelData,
98
} from 'react-native-executorch';
10-
import { View, StyleSheet, Image, TouchableOpacity, Text } from 'react-native';
9+
import { View, StyleSheet, Image } from 'react-native';
1110
import ImageWithBboxes from '../../components/ImageWithBboxes';
1211
import React, { useContext, useEffect, useState } from 'react';
1312
import { GeneratingContext } from '../../context';
1413
import ScreenWrapper from '../../ScreenWrapper';
15-
import ColorPalette from '../../colors';
1614

1715
export default function ObjectDetectionScreen() {
1816
const [imageUri, setImageUri] = useState('');
@@ -44,59 +42,14 @@ export default function ObjectDetectionScreen() {
4442
const runForward = async () => {
4543
if (imageUri) {
4644
try {
47-
console.log('Running forward with string URI...');
48-
const output = await ssdLite.forward(imageUri, 0.5);
49-
console.log('String URI result:', output.length, 'detections');
45+
const output = await ssdLite.forward(imageUri);
5046
setResults(output);
5147
} catch (e) {
52-
console.error('Error in runForward:', e);
48+
console.error(e);
5349
}
5450
}
5551
};
5652

57-
const runForwardPixels = async () => {
58-
try {
59-
console.log('Testing with hardcoded pixel data...');
60-
61-
// Create a simple 320x320 test image (all zeros - black image)
62-
// In a real scenario, you would load actual image pixel data here
63-
const width = 320;
64-
const height = 320;
65-
const channels = 3; // RGB
66-
67-
// Create a black image (you can replace this with actual pixel data)
68-
const rgbData = new Uint8Array(width * height * channels);
69-
70-
// Optionally, add some test pattern (e.g., white square in center)
71-
for (let y = 100; y < 220; y++) {
72-
for (let x = 100; x < 220; x++) {
73-
const idx = (y * width + x) * 3;
74-
rgbData[idx + 0] = 255; // R
75-
rgbData[idx + 1] = 255; // G
76-
rgbData[idx + 2] = 255; // B
77-
}
78-
}
79-
80-
const pixelData: PixelData = {
81-
dataPtr: rgbData,
82-
sizes: [height, width, channels],
83-
scalarType: ScalarType.BYTE,
84-
};
85-
86-
console.log('Running forward with hardcoded pixel data...', {
87-
sizes: pixelData.sizes,
88-
dataSize: pixelData.dataPtr.byteLength,
89-
});
90-
91-
// Run inference using unified forward() API
92-
const output = await ssdLite.forward(pixelData, 0.3);
93-
console.log('Pixel data result:', output.length, 'detections');
94-
setResults(output);
95-
} catch (e) {
96-
console.error('Error in runForwardPixels:', e);
97-
}
98-
};
99-
10053
if (!ssdLite.isReady) {
10154
return (
10255
<Spinner
@@ -128,41 +81,10 @@ export default function ObjectDetectionScreen() {
12881
)}
12982
</View>
13083
</View>
131-
132-
{/* Custom bottom bar with two buttons */}
133-
<View style={styles.bottomContainer}>
134-
<View style={styles.bottomIconsContainer}>
135-
<TouchableOpacity onPress={() => handleCameraPress(false)}>
136-
<Text style={styles.iconText}>📷 Gallery</Text>
137-
</TouchableOpacity>
138-
</View>
139-
140-
<View style={styles.buttonsRow}>
141-
<TouchableOpacity
142-
style={[
143-
styles.button,
144-
styles.halfButton,
145-
!imageUri && styles.buttonDisabled,
146-
]}
147-
onPress={runForward}
148-
disabled={!imageUri}
149-
>
150-
<Text style={styles.buttonText}>Run (String)</Text>
151-
</TouchableOpacity>
152-
153-
<TouchableOpacity
154-
style={[
155-
styles.button,
156-
styles.halfButton,
157-
!imageUri && styles.buttonDisabled,
158-
]}
159-
onPress={runForwardPixels}
160-
disabled={!imageUri}
161-
>
162-
<Text style={styles.buttonText}>Run (Pixels)</Text>
163-
</TouchableOpacity>
164-
</View>
165-
</View>
84+
<BottomBar
85+
handleCameraPress={handleCameraPress}
86+
runForward={runForward}
87+
/>
16688
</ScreenWrapper>
16789
);
16890
}
@@ -207,43 +129,4 @@ const styles = StyleSheet.create({
207129
width: '100%',
208130
height: '100%',
209131
},
210-
bottomContainer: {
211-
width: '100%',
212-
gap: 15,
213-
alignItems: 'center',
214-
padding: 16,
215-
flex: 1,
216-
},
217-
bottomIconsContainer: {
218-
flexDirection: 'row',
219-
justifyContent: 'center',
220-
width: '100%',
221-
},
222-
iconText: {
223-
fontSize: 16,
224-
color: ColorPalette.primary,
225-
},
226-
buttonsRow: {
227-
flexDirection: 'row',
228-
width: '100%',
229-
gap: 10,
230-
},
231-
button: {
232-
height: 50,
233-
justifyContent: 'center',
234-
alignItems: 'center',
235-
backgroundColor: ColorPalette.primary,
236-
color: '#fff',
237-
borderRadius: 8,
238-
},
239-
halfButton: {
240-
flex: 1,
241-
},
242-
buttonDisabled: {
243-
opacity: 0.5,
244-
},
245-
buttonText: {
246-
color: '#fff',
247-
fontSize: 16,
248-
},
249132
});

packages/react-native-executorch/common/rnexecutorch/models/object_detection/ObjectDetection.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,10 @@ ObjectDetection::postprocess(const std::vector<EValue> &tensors,
106106

107107
std::vector<types::Detection>
108108
ObjectDetection::runInference(cv::Mat image, double detectionThreshold) {
109+
if (detectionThreshold < 0.0 || detectionThreshold > 1.0) {
110+
throw RnExecutorchError(RnExecutorchErrorCode::InvalidUserInput,
111+
"detectionThreshold must be in range [0, 1]");
112+
}
109113
std::scoped_lock lock(inference_mutex_);
110114

111115
cv::Size originalSize = image.size();

packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,11 @@ add_rn_test(ObjectDetectionTests integration/ObjectDetectionTest.cpp
156156
SOURCES
157157
${RNEXECUTORCH_DIR}/models/object_detection/ObjectDetection.cpp
158158
${RNEXECUTORCH_DIR}/models/object_detection/Utils.cpp
159+
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
160+
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
161+
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
159162
${IMAGE_UTILS_SOURCES}
160-
LIBS opencv_deps
163+
LIBS opencv_deps android
161164
)
162165

163166
add_rn_test(ImageEmbeddingsTests integration/ImageEmbeddingsTest.cpp

packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
#include "BaseModelTests.h"
2+
#include <executorch/extension/tensor/tensor.h>
23
#include <gtest/gtest.h>
34
#include <rnexecutorch/Error.h>
5+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
46
#include <rnexecutorch/models/object_detection/Constants.h>
57
#include <rnexecutorch/models/object_detection/ObjectDetection.h>
68

@@ -115,6 +117,73 @@ TEST(ObjectDetectionGenerateTests, DetectionsHaveValidLabels) {
115117
}
116118
}
117119

120+
// ============================================================================
121+
// generateFromPixels tests
122+
// ============================================================================
123+
TEST(ObjectDetectionPixelTests, ValidPixelDataReturnsResults) {
124+
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
125+
constexpr int width = 4, height = 4, channels = 3;
126+
std::vector<uint8_t> pixelData(width * height * channels, 128);
127+
JSTensorViewIn tensorView{pixelData.data(),
128+
{height, width, channels},
129+
executorch::aten::ScalarType::Byte};
130+
auto results = model.generateFromPixels(tensorView, 0.3);
131+
EXPECT_GE(results.size(), 0u);
132+
}
133+
134+
TEST(ObjectDetectionPixelTests, WrongSizesLengthThrows) {
135+
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
136+
std::vector<uint8_t> pixelData(16, 0);
137+
JSTensorViewIn tensorView{
138+
pixelData.data(), {4, 4}, executorch::aten::ScalarType::Byte};
139+
EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5),
140+
RnExecutorchError);
141+
}
142+
143+
TEST(ObjectDetectionPixelTests, WrongChannelCountThrows) {
144+
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
145+
constexpr int width = 4, height = 4, channels = 4;
146+
std::vector<uint8_t> pixelData(width * height * channels, 0);
147+
JSTensorViewIn tensorView{pixelData.data(),
148+
{height, width, channels},
149+
executorch::aten::ScalarType::Byte};
150+
EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5),
151+
RnExecutorchError);
152+
}
153+
154+
TEST(ObjectDetectionPixelTests, WrongScalarTypeThrows) {
155+
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
156+
constexpr int width = 4, height = 4, channels = 3;
157+
std::vector<uint8_t> pixelData(width * height * channels, 0);
158+
JSTensorViewIn tensorView{pixelData.data(),
159+
{height, width, channels},
160+
executorch::aten::ScalarType::Float};
161+
EXPECT_THROW((void)model.generateFromPixels(tensorView, 0.5),
162+
RnExecutorchError);
163+
}
164+
165+
TEST(ObjectDetectionPixelTests, NegativeThresholdThrows) {
166+
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
167+
constexpr int width = 4, height = 4, channels = 3;
168+
std::vector<uint8_t> pixelData(width * height * channels, 128);
169+
JSTensorViewIn tensorView{pixelData.data(),
170+
{height, width, channels},
171+
executorch::aten::ScalarType::Byte};
172+
EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1),
173+
RnExecutorchError);
174+
}
175+
176+
TEST(ObjectDetectionPixelTests, ThresholdAboveOneThrows) {
177+
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
178+
constexpr int width = 4, height = 4, channels = 3;
179+
std::vector<uint8_t> pixelData(width * height * channels, 128);
180+
JSTensorViewIn tensorView{pixelData.data(),
181+
{height, width, channels},
182+
executorch::aten::ScalarType::Byte};
183+
EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1),
184+
RnExecutorchError);
185+
}
186+
118187
TEST(ObjectDetectionInheritedTests, GetInputShapeWorks) {
119188
ObjectDetection model(kValidObjectDetectionModelPath, nullptr);
120189
auto shape = model.getInputShape("forward", 0);

packages/react-native-executorch/common/rnexecutorch/tests/integration/stubs/jsi_stubs.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,14 @@ namespace facebook::jsi {
1414
MutableBuffer::~MutableBuffer() {}
1515
Value::~Value() {}
1616
Value::Value(Value &&other) noexcept {}
17+
18+
// Needed to link ObjectDetectionTests: generateFromFrame and FrameProcessor
19+
// pull in these JSI symbols, but they are never called in tests.
20+
Object Value::asObject(Runtime &) const & { __builtin_unreachable(); }
21+
BigInt Value::asBigInt(Runtime &) const & { __builtin_unreachable(); }
22+
23+
uint64_t BigInt::asUint64(Runtime &) const { return 0; }
24+
1725
} // namespace facebook::jsi
1826

1927
namespace facebook::react {

packages/react-native-executorch/src/modules/computer_vision/ObjectDetectionModule.ts

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import { ResourceFetcher } from '../../utils/ResourceFetcher';
2-
import { ResourceSource } from '../../types/common';
2+
import { ResourceSource, PixelData } from '../../types/common';
33
import { Detection } from '../../types/objectDetection';
44
import { RnExecutorchErrorCode } from '../../errors/ErrorCodes';
55
import { parseUnknownError, RnExecutorchError } from '../../errors/errorUtils';
@@ -41,4 +41,11 @@ export class ObjectDetectionModule extends VisionModule<Detection[]> {
4141
throw parseUnknownError(error);
4242
}
4343
}
44+
45+
async forward(
46+
input: string | PixelData,
47+
detectionThreshold: number = 0.5
48+
): Promise<Detection[]> {
49+
return super.forward(input, detectionThreshold);
50+
}
4451
}

packages/react-native-executorch/src/types/objectDetection.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -179,7 +179,7 @@ export interface ObjectDetectionType {
179179
* **Note**: For VisionCamera frame processing, use `processFrame` instead.
180180
*
181181
* @param input - Image source (string or PixelData object)
182-
* @param detectionThreshold - An optional number between 0 and 1 representing the minimum confidence score. Default is 0.7.
182+
* @param detectionThreshold - An optional number between 0 and 1 representing the minimum confidence score. Default is 0.5.
183183
* @returns A Promise that resolves to an array of `Detection` objects.
184184
* @throws {RnExecutorchError} If the model is not loaded or is currently processing another image.
185185
*
@@ -225,7 +225,7 @@ export interface ObjectDetectionType {
225225
* ```
226226
*
227227
* @param frame - VisionCamera Frame object
228-
* @param detectionThreshold - The threshold for detection sensitivity. Default is 0.7.
228+
* @param detectionThreshold - The threshold for detection sensitivity. Default is 0.5.
229229
* @returns Array of Detection objects representing detected items in the frame.
230230
*/
231231
runOnFrame:

0 commit comments

Comments
 (0)