Skip to content

Commit 1446d54

Browse files
committed
Add tests
1 parent d917cd5 commit 1446d54

File tree

6 files changed

+343
-26
lines changed

6 files changed

+343
-26
lines changed

docs/docs/03-hooks/02-computer-vision/useInstanceSegmentation.md

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ To run the model, use the [`forward`](../../06-api-reference/interfaces/Instance
6666
- `confidenceThreshold` - Minimum confidence score for including instances. Defaults to the model's configured threshold (typically `0.5`).
6767
- `iouThreshold` - IoU threshold for non-maximum suppression. Defaults to `0.5`.
6868
- `maxInstances` - Maximum number of instances to return. Defaults to `100`.
69-
- `classesOfInterest` - Filter results to include only specific classes (e.g. `['PERSON', 'CAR']`).
69+
- `classesOfInterest` - Filter results to include only specific classes (e.g. `['PERSON', 'CAR']`). Use label names from the model's label enum (e.g. [`CocoLabelYolo`](../../06-api-reference/enumerations/CocoLabelYolo.md) for YOLO models).
7070
- `returnMaskAtOriginalResolution` - Whether to resize masks to the original image resolution. Defaults to `true`.
7171
- `inputSize` - Input size for the model (e.g. `384`, `512`, `640`). Must be one of the model's available input sizes. If the model has only one forward method (i.e. no `availableInputSizes` configured), this option is not needed.
7272

@@ -121,10 +121,14 @@ function App() {
121121

122122
## Supported models
123123

124-
| Model | Number of classes | Class list | Available input sizes |
125-
| ----------- | ----------------- | -------------------------------------------------------- | --------------------- |
126-
| yolo26n-seg | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | 384, 512, 640 |
127-
| yolo26s-seg | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | 384, 512, 640 |
128-
| yolo26m-seg | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | 384, 512, 640 |
129-
| yolo26l-seg | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | 384, 512, 640 |
130-
| yolo26x-seg | 80 | [COCO](../../06-api-reference/enumerations/CocoLabel.md) | 384, 512, 640 |
124+
:::info
125+
YOLO models use the [`CocoLabelYolo`](../../06-api-reference/enumerations/CocoLabelYolo.md) enum (80 classes, 0-indexed), which differs from [`CocoLabel`](../../06-api-reference/enumerations/CocoLabel.md) used by RF-DETR and SSDLite object detection models (91 classes, 1-indexed). When filtering with `classesOfInterest`, use the label names from `CocoLabelYolo`.
126+
:::
127+
128+
| Model | Number of classes | Class list | Available input sizes |
129+
| ----------- | ----------------- | ------------------------------------------------------------------- | --------------------- |
130+
| yolo26n-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 |
131+
| yolo26s-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 |
132+
| yolo26m-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 |
133+
| yolo26l-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 |
134+
| yolo26x-seg | 80 | [COCO (YOLO)](../../06-api-reference/enumerations/CocoLabelYolo.md) | 384, 512, 640 |

docs/docs/04-typescript-api/02-computer-vision/InstanceSegmentationModule.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,11 @@ To run the model, use the [`forward`](../../06-api-reference/classes/InstanceSeg
104104
- `imageSource` (required) - The image to process. Can be a remote URL, a local file URI, or a base64-encoded image (whole URI or only raw base64).
105105
- `options` (optional) - An [`InstanceSegmentationOptions`](../../06-api-reference/interfaces/InstanceSegmentationOptions.md) object for configuring the segmentation (confidence threshold, IoU threshold, input size, classes of interest, etc.).
106106

107-
The method returns a promise resolving to an array of [`SegmentedInstance`](../../06-api-reference/interfaces/SegmentedInstance.md) objects. Each object contains bounding box coordinates, a binary segmentation mask, the label of the detected instance, and the confidence score.
107+
The method returns a promise resolving to an array of [`SegmentedInstance`](../../06-api-reference/interfaces/SegmentedInstance.md) objects. Each object contains bounding box coordinates, a binary segmentation mask, a string `label` (resolved from the model's label enum), and the confidence score.
108+
109+
:::info
110+
Built-in YOLO models use [`CocoLabelYolo`](../../06-api-reference/enumerations/CocoLabelYolo.md) (80 classes, 0-indexed), not [`CocoLabel`](../../06-api-reference/enumerations/CocoLabel.md) (91 classes, 1-indexed, used by RF-DETR / SSDLite). When filtering with `classesOfInterest`, use the key names from `CocoLabelYolo`.
111+
:::
108112

109113
## Managing memory
110114

packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,13 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp
234234
LIBS tokenizers_deps
235235
)
236236

237+
add_rn_test(InstanceSegmentationTests integration/InstanceSegmentationTest.cpp
238+
SOURCES
239+
${RNEXECUTORCH_DIR}/models/instance_segmentation/BaseInstanceSegmentation.cpp
240+
${IMAGE_UTILS_SOURCES}
241+
LIBS opencv_deps android
242+
)
243+
237244
add_rn_test(OCRTests integration/OCRTest.cpp
238245
SOURCES
239246
${RNEXECUTORCH_DIR}/models/ocr/OCR.cpp
Lines changed: 285 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
#include "BaseModelTests.h"
2+
#include <gtest/gtest.h>
3+
#include <rnexecutorch/Error.h>
4+
#include <rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h>
5+
#include <rnexecutorch/models/instance_segmentation/Types.h>
6+
7+
using namespace rnexecutorch;
8+
using namespace rnexecutorch::models::instance_segmentation;
9+
using namespace model_tests;
10+
11+
constexpr auto kValidInstanceSegModelPath = "yolo26n-seg.pte";
12+
constexpr auto kValidTestImagePath =
13+
"file:///data/local/tmp/rnexecutorch_tests/segmentation_image.jpg";
14+
constexpr auto kMethodName = "forward_384";
15+
16+
// ============================================================================
17+
// Common tests via typed test suite
18+
// ============================================================================
19+
namespace model_tests {
20+
template <> struct ModelTraits<BaseInstanceSegmentation> {
21+
using ModelType = BaseInstanceSegmentation;
22+
23+
static ModelType createValid() {
24+
return ModelType(kValidInstanceSegModelPath, {}, {}, true, nullptr);
25+
}
26+
27+
static ModelType createInvalid() {
28+
return ModelType("nonexistent.pte", {}, {}, true, nullptr);
29+
}
30+
31+
static void callGenerate(ModelType &model) {
32+
(void)model.generate(kValidTestImagePath, 0.5, 0.5, 100, {}, true,
33+
kMethodName);
34+
}
35+
};
36+
} // namespace model_tests
37+
38+
using InstanceSegmentationTypes = ::testing::Types<BaseInstanceSegmentation>;
39+
INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, CommonModelTest,
40+
InstanceSegmentationTypes);
41+
42+
// ============================================================================
43+
// Generate tests (from string)
44+
// ============================================================================
45+
TEST(InstanceSegGenerateTests, InvalidImagePathThrows) {
46+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
47+
nullptr);
48+
EXPECT_THROW((void)model.generate("nonexistent_image.jpg", 0.5, 0.5, 100, {},
49+
true, kMethodName),
50+
RnExecutorchError);
51+
}
52+
53+
TEST(InstanceSegGenerateTests, EmptyImagePathThrows) {
54+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
55+
nullptr);
56+
EXPECT_THROW((void)model.generate("", 0.5, 0.5, 100, {}, true, kMethodName),
57+
RnExecutorchError);
58+
}
59+
60+
TEST(InstanceSegGenerateTests, EmptyMethodNameThrows) {
61+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
62+
nullptr);
63+
EXPECT_THROW(
64+
(void)model.generate(kValidTestImagePath, 0.5, 0.5, 100, {}, true, ""),
65+
RnExecutorchError);
66+
}
67+
68+
TEST(InstanceSegGenerateTests, NegativeConfidenceThrows) {
69+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
70+
nullptr);
71+
EXPECT_THROW((void)model.generate(kValidTestImagePath, -0.1, 0.5, 100, {},
72+
true, kMethodName),
73+
RnExecutorchError);
74+
}
75+
76+
TEST(InstanceSegGenerateTests, ConfidenceAboveOneThrows) {
77+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
78+
nullptr);
79+
EXPECT_THROW((void)model.generate(kValidTestImagePath, 1.1, 0.5, 100, {},
80+
true, kMethodName),
81+
RnExecutorchError);
82+
}
83+
84+
TEST(InstanceSegGenerateTests, NegativeIouThresholdThrows) {
85+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
86+
nullptr);
87+
EXPECT_THROW((void)model.generate(kValidTestImagePath, 0.5, -0.1, 100, {},
88+
true, kMethodName),
89+
RnExecutorchError);
90+
}
91+
92+
TEST(InstanceSegGenerateTests, IouThresholdAboveOneThrows) {
93+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
94+
nullptr);
95+
EXPECT_THROW((void)model.generate(kValidTestImagePath, 0.5, 1.1, 100, {},
96+
true, kMethodName),
97+
RnExecutorchError);
98+
}
99+
100+
TEST(InstanceSegGenerateTests, ValidImageReturnsResults) {
101+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
102+
nullptr);
103+
auto results =
104+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
105+
EXPECT_FALSE(results.empty());
106+
}
107+
108+
TEST(InstanceSegGenerateTests, HighThresholdReturnsFewerResults) {
109+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
110+
nullptr);
111+
auto lowResults =
112+
model.generate(kValidTestImagePath, 0.1, 0.5, 100, {}, true, kMethodName);
113+
auto highResults =
114+
model.generate(kValidTestImagePath, 0.9, 0.5, 100, {}, true, kMethodName);
115+
EXPECT_GE(lowResults.size(), highResults.size());
116+
}
117+
118+
TEST(InstanceSegGenerateTests, MaxInstancesLimitsResults) {
119+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
120+
nullptr);
121+
auto results =
122+
model.generate(kValidTestImagePath, 0.1, 0.5, 2, {}, true, kMethodName);
123+
EXPECT_LE(results.size(), 2u);
124+
}
125+
126+
// ============================================================================
127+
// Result validation tests
128+
// ============================================================================
129+
TEST(InstanceSegResultTests, InstancesHaveValidBoundingBoxes) {
130+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
131+
nullptr);
132+
auto results =
133+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
134+
135+
for (const auto &inst : results) {
136+
EXPECT_LE(inst.x1, inst.x2);
137+
EXPECT_LE(inst.y1, inst.y2);
138+
EXPECT_GE(inst.x1, 0.0f);
139+
EXPECT_GE(inst.y1, 0.0f);
140+
}
141+
}
142+
143+
TEST(InstanceSegResultTests, InstancesHaveValidScores) {
144+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
145+
nullptr);
146+
auto results =
147+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
148+
149+
for (const auto &inst : results) {
150+
EXPECT_GE(inst.score, 0.0f);
151+
EXPECT_LE(inst.score, 1.0f);
152+
}
153+
}
154+
155+
TEST(InstanceSegResultTests, InstancesHaveValidMasks) {
156+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
157+
nullptr);
158+
auto results =
159+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
160+
161+
for (const auto &inst : results) {
162+
EXPECT_GT(inst.maskWidth, 0);
163+
EXPECT_GT(inst.maskHeight, 0);
164+
EXPECT_EQ(inst.mask.size(),
165+
static_cast<size_t>(inst.maskWidth) * inst.maskHeight);
166+
167+
for (uint8_t val : inst.mask) {
168+
EXPECT_TRUE(val == 0 || val == 1);
169+
}
170+
}
171+
}
172+
173+
TEST(InstanceSegResultTests, InstancesHaveUniqueIds) {
174+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
175+
nullptr);
176+
auto results =
177+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
178+
179+
std::set<int> ids;
180+
for (const auto &inst : results) {
181+
EXPECT_TRUE(ids.insert(inst.instanceId).second)
182+
<< "Duplicate instanceId: " << inst.instanceId;
183+
}
184+
}
185+
186+
TEST(InstanceSegResultTests, InstancesHaveValidClassIndices) {
187+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
188+
nullptr);
189+
auto results =
190+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
191+
192+
for (const auto &inst : results) {
193+
EXPECT_GE(inst.classIndex, 0);
194+
EXPECT_LT(inst.classIndex, 80); // COCO YOLO has 80 classes
195+
}
196+
}
197+
198+
// ============================================================================
199+
// Class filtering tests
200+
// ============================================================================
201+
TEST(InstanceSegFilterTests, ClassFilterReturnsOnlyMatchingClasses) {
202+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
203+
nullptr);
204+
// Filter to class index 0 (PERSON in CocoLabelYolo)
205+
std::vector<int32_t> classIndices = {0};
206+
auto results = model.generate(kValidTestImagePath, 0.3, 0.5, 100,
207+
classIndices, true, kMethodName);
208+
209+
for (const auto &inst : results) {
210+
EXPECT_EQ(inst.classIndex, 0);
211+
}
212+
}
213+
214+
TEST(InstanceSegFilterTests, EmptyFilterReturnsAllClasses) {
215+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
216+
nullptr);
217+
auto unfilteredResults =
218+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
219+
auto filteredResults =
220+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
221+
222+
EXPECT_EQ(unfilteredResults.size(), filteredResults.size());
223+
}
224+
225+
// ============================================================================
226+
// Mask resolution tests
227+
// ============================================================================
228+
TEST(InstanceSegMaskTests, LowResMaskIsSmallerThanOriginal) {
229+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
230+
nullptr);
231+
auto hiRes =
232+
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
233+
auto loRes = model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, false,
234+
kMethodName);
235+
236+
if (!hiRes.empty() && !loRes.empty()) {
237+
EXPECT_LE(loRes[0].mask.size(), hiRes[0].mask.size());
238+
}
239+
}
240+
241+
// ============================================================================
242+
// NMS tests
243+
// ============================================================================
244+
TEST(InstanceSegNMSTests, NMSEnabledReturnsFewerOrEqualResults) {
245+
BaseInstanceSegmentation modelWithNMS(kValidInstanceSegModelPath, {}, {},
246+
true, nullptr);
247+
BaseInstanceSegmentation modelWithoutNMS(kValidInstanceSegModelPath, {}, {},
248+
false, nullptr);
249+
250+
auto nmsResults = modelWithNMS.generate(kValidTestImagePath, 0.3, 0.5, 100,
251+
{}, true, kMethodName);
252+
auto noNmsResults = modelWithoutNMS.generate(kValidTestImagePath, 0.3, 0.5,
253+
100, {}, true, kMethodName);
254+
255+
EXPECT_LE(nmsResults.size(), noNmsResults.size());
256+
}
257+
258+
// ============================================================================
259+
// Inherited method tests
260+
// ============================================================================
261+
TEST(InstanceSegInheritedTests, GetMethodMetaWorks) {
262+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
263+
nullptr);
264+
auto result = model.getMethodMeta(kMethodName);
265+
EXPECT_TRUE(result.ok());
266+
}
267+
268+
// ============================================================================
269+
// Normalisation tests
270+
// ============================================================================
271+
TEST(InstanceSegNormTests, ValidNormParamsDoesntThrow) {
272+
const std::vector<float> mean = {0.485f, 0.456f, 0.406f};
273+
const std::vector<float> std = {0.229f, 0.224f, 0.225f};
274+
EXPECT_NO_THROW(BaseInstanceSegmentation(kValidInstanceSegModelPath, mean,
275+
std, true, nullptr));
276+
}
277+
278+
TEST(InstanceSegNormTests, ValidNormParamsGenerateSucceeds) {
279+
const std::vector<float> mean = {0.485f, 0.456f, 0.406f};
280+
const std::vector<float> std = {0.229f, 0.224f, 0.225f};
281+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, mean, std, true,
282+
nullptr);
283+
EXPECT_NO_THROW((void)model.generate(kValidTestImagePath, 0.5, 0.5, 100, {},
284+
true, kMethodName));
285+
}

packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ TEST_EXECUTABLES=(
3232
"TextToImageTests"
3333
"OCRTests"
3434
"VerticalOCRTests"
35+
"InstanceSegmentationTests"
3536
)
3637

3738
# ============================================================================
@@ -66,6 +67,8 @@ MODELS=(
6667
"t2i_encoder.pte|https://huggingface.co/software-mansion/react-native-executorch-bk-sdm-tiny/resolve/v0.6.0/text_encoder/model.pte"
6768
"t2i_unet.pte|https://huggingface.co/software-mansion/react-native-executorch-bk-sdm-tiny/resolve/v0.6.0/unet/model.256.pte"
6869
"t2i_decoder.pte|https://huggingface.co/software-mansion/react-native-executorch-bk-sdm-tiny/resolve/v0.6.0/vae/model.256.pte"
70+
"yolo26n-seg.pte|https://huggingface.co/software-mansion/react-native-executorch-yolo26-seg/resolve/v0.8.0/yolo26n-seg/xnnpack/yolo26n-seg.pte"
71+
"segmentation_image.jpg|https://upload.wikimedia.org/wikipedia/commons/thumb/8/85/Collage_audi.jpg/1280px-Collage_audi.jpg"
6972
)
7073

7174
# ============================================================================

0 commit comments

Comments
 (0)