Skip to content

Commit 058381a

Browse files
committed
Improve tests
1 parent b71c041 commit 058381a

File tree

4 files changed

+261
-99
lines changed

4 files changed

+261
-99
lines changed

packages/react-native-executorch/common/rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ cv::Size BaseInstanceSegmentation::modelInputSize() const {
4040
return VisionModel::modelInputSize();
4141
}
4242
const auto &shape = inputShapes[0];
43-
return cv::Size(shape[shape.size() - 1], shape[shape.size() - 2]);
43+
return cv::Size(shape[shape.size() - 2], shape[shape.size() - 1]);
4444
}
4545

4646
std::vector<types::Instance> BaseInstanceSegmentation::runInference(
@@ -61,7 +61,7 @@ std::vector<types::Instance> BaseInstanceSegmentation::runInference(
6161

6262
modelInputShape_ = inputShapes[0];
6363
const auto &shape = modelInputShape_;
64-
cv::Size modelInputSize(shape[shape.size() - 1], shape[shape.size() - 2]);
64+
cv::Size modelInputSize(shape[shape.size() - 2], shape[shape.size() - 1]);
6565
cv::Size originalSize(image.cols, image.rows);
6666

6767
cv::Mat preprocessed = preprocess(image);

packages/react-native-executorch/common/rnexecutorch/tests/integration/InstanceSegmentationTest.cpp

Lines changed: 110 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
#include "BaseModelTests.h"
2+
#include "VisionModelTests.h"
3+
#include <executorch/extension/tensor/tensor.h>
24
#include <gtest/gtest.h>
35
#include <rnexecutorch/Error.h>
6+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
47
#include <rnexecutorch/models/instance_segmentation/BaseInstanceSegmentation.h>
58
#include <rnexecutorch/models/instance_segmentation/Types.h>
69

@@ -29,97 +32,100 @@ template <> struct ModelTraits<BaseInstanceSegmentation> {
2932
}
3033

3134
static void callGenerate(ModelType &model) {
32-
(void)model.generate(kValidTestImagePath, 0.5, 0.5, 100, {}, true,
33-
kMethodName);
35+
(void)model.generateFromString(kValidTestImagePath, 0.5, 0.5, 100, {}, true,
36+
kMethodName);
3437
}
3538
};
3639
} // namespace model_tests
3740

3841
using InstanceSegmentationTypes = ::testing::Types<BaseInstanceSegmentation>;
3942
INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, CommonModelTest,
4043
InstanceSegmentationTypes);
44+
INSTANTIATE_TYPED_TEST_SUITE_P(InstanceSegmentation, VisionModelTest,
45+
InstanceSegmentationTypes);
4146

4247
// ============================================================================
43-
// Generate tests (from string)
48+
// Model-specific tests
4449
// ============================================================================
4550
TEST(InstanceSegGenerateTests, InvalidImagePathThrows) {
4651
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
4752
nullptr);
48-
EXPECT_THROW((void)model.generate("nonexistent_image.jpg", 0.5, 0.5, 100, {},
49-
true, kMethodName),
53+
EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg", 0.5, 0.5,
54+
100, {}, true, kMethodName),
5055
RnExecutorchError);
5156
}
5257

5358
TEST(InstanceSegGenerateTests, EmptyImagePathThrows) {
5459
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
5560
nullptr);
56-
EXPECT_THROW((void)model.generate("", 0.5, 0.5, 100, {}, true, kMethodName),
57-
RnExecutorchError);
61+
EXPECT_THROW(
62+
(void)model.generateFromString("", 0.5, 0.5, 100, {}, true, kMethodName),
63+
RnExecutorchError);
5864
}
5965

6066
TEST(InstanceSegGenerateTests, EmptyMethodNameThrows) {
6167
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
6268
nullptr);
63-
EXPECT_THROW(
64-
(void)model.generate(kValidTestImagePath, 0.5, 0.5, 100, {}, true, ""),
65-
RnExecutorchError);
69+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5,
70+
100, {}, true, ""),
71+
RnExecutorchError);
6672
}
6773

6874
TEST(InstanceSegGenerateTests, NegativeConfidenceThrows) {
6975
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
7076
nullptr);
71-
EXPECT_THROW((void)model.generate(kValidTestImagePath, -0.1, 0.5, 100, {},
72-
true, kMethodName),
77+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, -0.1, 0.5,
78+
100, {}, true, kMethodName),
7379
RnExecutorchError);
7480
}
7581

7682
TEST(InstanceSegGenerateTests, ConfidenceAboveOneThrows) {
7783
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
7884
nullptr);
79-
EXPECT_THROW((void)model.generate(kValidTestImagePath, 1.1, 0.5, 100, {},
80-
true, kMethodName),
85+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 1.1, 0.5,
86+
100, {}, true, kMethodName),
8187
RnExecutorchError);
8288
}
8389

8490
TEST(InstanceSegGenerateTests, NegativeIouThresholdThrows) {
8591
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
8692
nullptr);
87-
EXPECT_THROW((void)model.generate(kValidTestImagePath, 0.5, -0.1, 100, {},
88-
true, kMethodName),
93+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, -0.1,
94+
100, {}, true, kMethodName),
8995
RnExecutorchError);
9096
}
9197

9298
TEST(InstanceSegGenerateTests, IouThresholdAboveOneThrows) {
9399
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
94100
nullptr);
95-
EXPECT_THROW((void)model.generate(kValidTestImagePath, 0.5, 1.1, 100, {},
96-
true, kMethodName),
101+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 1.1,
102+
100, {}, true, kMethodName),
97103
RnExecutorchError);
98104
}
99105

100106
TEST(InstanceSegGenerateTests, ValidImageReturnsResults) {
101107
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
102108
nullptr);
103-
auto results =
104-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
109+
auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
110+
{}, true, kMethodName);
105111
EXPECT_FALSE(results.empty());
106112
}
107113

108114
TEST(InstanceSegGenerateTests, HighThresholdReturnsFewerResults) {
109115
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
110116
nullptr);
111-
auto lowResults =
112-
model.generate(kValidTestImagePath, 0.1, 0.5, 100, {}, true, kMethodName);
113-
auto highResults =
114-
model.generate(kValidTestImagePath, 0.9, 0.5, 100, {}, true, kMethodName);
117+
auto lowResults = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 100,
118+
{}, true, kMethodName);
119+
auto highResults = model.generateFromString(kValidTestImagePath, 0.9, 0.5,
120+
100, {}, true, kMethodName);
115121
EXPECT_GE(lowResults.size(), highResults.size());
116122
}
117123

118124
TEST(InstanceSegGenerateTests, MaxInstancesLimitsResults) {
119125
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
120126
nullptr);
121-
auto results =
122-
model.generate(kValidTestImagePath, 0.1, 0.5, 2, {}, true, kMethodName);
127+
auto results = model.generateFromString(kValidTestImagePath, 0.1, 0.5, 2, {},
128+
true, kMethodName);
123129
EXPECT_LE(results.size(), 2u);
124130
}
125131

@@ -129,8 +135,8 @@ TEST(InstanceSegGenerateTests, MaxInstancesLimitsResults) {
129135
TEST(InstanceSegResultTests, InstancesHaveValidBoundingBoxes) {
130136
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
131137
nullptr);
132-
auto results =
133-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
138+
auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
139+
{}, true, kMethodName);
134140

135141
for (const auto &inst : results) {
136142
EXPECT_LE(inst.x1, inst.x2);
@@ -143,8 +149,8 @@ TEST(InstanceSegResultTests, InstancesHaveValidBoundingBoxes) {
143149
TEST(InstanceSegResultTests, InstancesHaveValidScores) {
144150
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
145151
nullptr);
146-
auto results =
147-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
152+
auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
153+
{}, true, kMethodName);
148154

149155
for (const auto &inst : results) {
150156
EXPECT_GE(inst.score, 0.0f);
@@ -155,8 +161,8 @@ TEST(InstanceSegResultTests, InstancesHaveValidScores) {
155161
TEST(InstanceSegResultTests, InstancesHaveValidMasks) {
156162
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
157163
nullptr);
158-
auto results =
159-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
164+
auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
165+
{}, true, kMethodName);
160166

161167
for (const auto &inst : results) {
162168
EXPECT_GT(inst.maskWidth, 0);
@@ -174,8 +180,8 @@ TEST(InstanceSegResultTests, InstancesHaveValidMasks) {
174180
TEST(InstanceSegResultTests, InstancesHaveValidClassIndices) {
175181
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
176182
nullptr);
177-
auto results =
178-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
183+
auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
184+
{}, true, kMethodName);
179185

180186
for (const auto &inst : results) {
181187
EXPECT_GE(inst.classIndex, 0);
@@ -191,8 +197,8 @@ TEST(InstanceSegFilterTests, ClassFilterReturnsOnlyMatchingClasses) {
191197
nullptr);
192198
// Filter to class index 0 (PERSON in CocoLabelYolo)
193199
std::vector<int32_t> classIndices = {0};
194-
auto results = model.generate(kValidTestImagePath, 0.3, 0.5, 100,
195-
classIndices, true, kMethodName);
200+
auto results = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
201+
classIndices, true, kMethodName);
196202

197203
for (const auto &inst : results) {
198204
EXPECT_EQ(inst.classIndex, 0);
@@ -202,12 +208,12 @@ TEST(InstanceSegFilterTests, ClassFilterReturnsOnlyMatchingClasses) {
202208
TEST(InstanceSegFilterTests, EmptyFilterReturnsAllClasses) {
203209
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
204210
nullptr);
205-
auto allResults =
206-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
211+
auto allResults = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
212+
{}, true, kMethodName);
207213
EXPECT_FALSE(allResults.empty());
208214

209-
auto noResults = model.generate(kValidTestImagePath, 0.3, 0.5, 100, {50},
210-
true, kMethodName);
215+
auto noResults = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100,
216+
{50}, true, kMethodName);
211217
EXPECT_TRUE(noResults.empty());
212218
}
213219

@@ -217,10 +223,10 @@ TEST(InstanceSegFilterTests, EmptyFilterReturnsAllClasses) {
217223
TEST(InstanceSegMaskTests, LowResMaskIsSmallerThanOriginal) {
218224
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
219225
nullptr);
220-
auto hiRes =
221-
model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, true, kMethodName);
222-
auto loRes = model.generate(kValidTestImagePath, 0.3, 0.5, 100, {}, false,
223-
kMethodName);
226+
auto hiRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {},
227+
true, kMethodName);
228+
auto loRes = model.generateFromString(kValidTestImagePath, 0.3, 0.5, 100, {},
229+
false, kMethodName);
224230

225231
if (!hiRes.empty() && !loRes.empty()) {
226232
EXPECT_LE(loRes[0].mask->size(), hiRes[0].mask->size());
@@ -244,9 +250,67 @@ TEST(InstanceSegNMSTests, NMSEnabledReturnsFewerOrEqualResults) {
244250
EXPECT_LE(nmsResults.size(), noNmsResults.size());
245251
}
246252

253+
// ============================================================================
254+
// generateFromPixels tests
255+
// ============================================================================
256+
TEST(InstanceSegPixelTests, ValidPixelDataReturnsResults) {
257+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
258+
nullptr);
259+
constexpr int32_t width = 4, height = 4, channels = 3;
260+
std::vector<uint8_t> pixelData(width * height * channels, 128);
261+
JSTensorViewIn tensorView{pixelData.data(),
262+
{height, width, channels},
263+
executorch::aten::ScalarType::Byte};
264+
auto results = model.generateFromPixels(tensorView, 0.3, 0.5, 100, {}, true,
265+
kMethodName);
266+
EXPECT_GE(results.size(), 0u);
267+
}
268+
269+
TEST(InstanceSegPixelTests, NegativeConfidenceThrows) {
270+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
271+
nullptr);
272+
constexpr int32_t width = 4, height = 4, channels = 3;
273+
std::vector<uint8_t> pixelData(width * height * channels, 128);
274+
JSTensorViewIn tensorView{pixelData.data(),
275+
{height, width, channels},
276+
executorch::aten::ScalarType::Byte};
277+
EXPECT_THROW((void)model.generateFromPixels(tensorView, -0.1, 0.5, 100, {},
278+
true, kMethodName),
279+
RnExecutorchError);
280+
}
281+
282+
TEST(InstanceSegPixelTests, ConfidenceAboveOneThrows) {
283+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
284+
nullptr);
285+
constexpr int32_t width = 4, height = 4, channels = 3;
286+
std::vector<uint8_t> pixelData(width * height * channels, 128);
287+
JSTensorViewIn tensorView{pixelData.data(),
288+
{height, width, channels},
289+
executorch::aten::ScalarType::Byte};
290+
EXPECT_THROW((void)model.generateFromPixels(tensorView, 1.1, 0.5, 100, {},
291+
true, kMethodName),
292+
RnExecutorchError);
293+
}
294+
247295
// ============================================================================
248296
// Inherited method tests
249297
// ============================================================================
298+
TEST(InstanceSegInheritedTests, GetInputShapeWorks) {
299+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
300+
nullptr);
301+
auto shape = model.getInputShape(kMethodName, 0);
302+
EXPECT_EQ(shape.size(), 4);
303+
EXPECT_EQ(shape[0], 1);
304+
EXPECT_EQ(shape[1], 3);
305+
}
306+
307+
TEST(InstanceSegInheritedTests, GetAllInputShapesWorks) {
308+
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
309+
nullptr);
310+
auto shapes = model.getAllInputShapes(kMethodName);
311+
EXPECT_FALSE(shapes.empty());
312+
}
313+
250314
TEST(InstanceSegInheritedTests, GetMethodMetaWorks) {
251315
BaseInstanceSegmentation model(kValidInstanceSegModelPath, {}, {}, true,
252316
nullptr);
@@ -269,6 +333,6 @@ TEST(InstanceSegNormTests, ValidNormParamsGenerateSucceeds) {
269333
const std::vector<float> std = {0.229f, 0.224f, 0.225f};
270334
BaseInstanceSegmentation model(kValidInstanceSegModelPath, mean, std, true,
271335
nullptr);
272-
EXPECT_NO_THROW((void)model.generate(kValidTestImagePath, 0.5, 0.5, 100, {},
273-
true, kMethodName));
336+
EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, 0.5, 0.5,
337+
100, {}, true, kMethodName));
274338
}

0 commit comments

Comments
 (0)