Skip to content

Commit bf29217

Browse files
benITo47msluszniak
andauthored
test: Semantic Segmentation tests (#989)
## Description During a rename that recently took place - we forgot to change tests. ### Introduces a breaking change? - [ ] Yes - [x] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [x] Other (chores, tests, code style improvements etc.) ### Tested on - [ ] iOS - [ ] Android ### Testing instructions Run the tests ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues Closes #818 ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. --> --------- Co-authored-by: Mateusz Słuszniak <mateusz.sluszniak@swmansion.com> Co-authored-by: Mateusz Sluszniak <56299341+msluszniak@users.noreply.github.com>
1 parent 992d04b commit bf29217

File tree

4 files changed

+95
-46
lines changed

4 files changed

+95
-46
lines changed

packages/react-native-executorch/common/rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,12 @@ BaseSemanticSegmentation::computeResult(
218218
}
219219
}
220220

221-
// Filter classes of interest
222221
auto buffersToReturn = std::make_shared<
223222
std::unordered_map<std::string, std::shared_ptr<OwningArrayBuffer>>>();
223+
bool returnAllClasses = classesOfInterest.empty();
224224
for (std::size_t cl = 0; cl < resultClasses.size(); ++cl) {
225-
if (cl < allClasses.size() && classesOfInterest.contains(allClasses[cl])) {
225+
if (cl < allClasses.size() &&
226+
(returnAllClasses || classesOfInterest.contains(allClasses[cl]))) {
226227
(*buffersToReturn)[allClasses[cl]] = resultClasses[cl];
227228
}
228229
}

packages/react-native-executorch/common/rnexecutorch/tests/CMakeLists.txt

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,12 +300,24 @@ add_rn_test(TextToImageTests integration/TextToImageTest.cpp
300300
LIBS tokenizers_deps
301301
)
302302

303+
add_rn_test(SemanticSegmentationTests integration/SemanticSegmentationTest.cpp
304+
SOURCES
305+
${RNEXECUTORCH_DIR}/models/semantic_segmentation/BaseSemanticSegmentation.cpp
306+
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
307+
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
308+
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
309+
${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp
310+
${IMAGE_UTILS_SOURCES}
311+
LIBS opencv_deps android
312+
)
313+
303314
add_rn_test(InstanceSegmentationTests integration/InstanceSegmentationTest.cpp
304315
SOURCES
305316
${RNEXECUTORCH_DIR}/models/instance_segmentation/BaseInstanceSegmentation.cpp
306317
${RNEXECUTORCH_DIR}/models/VisionModel.cpp
307318
${RNEXECUTORCH_DIR}/utils/FrameProcessor.cpp
308319
${RNEXECUTORCH_DIR}/utils/FrameExtractor.cpp
320+
${RNEXECUTORCH_DIR}/utils/FrameTransform.cpp
309321
${RNEXECUTORCH_DIR}/utils/computer_vision/Processing.cpp
310322
${IMAGE_UTILS_SOURCES}
311323
LIBS opencv_deps android

packages/react-native-executorch/common/rnexecutorch/tests/integration/SemanticSegmentationTest.cpp

Lines changed: 79 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1+
#include <algorithm>
12
#include <executorch/extension/tensor/tensor.h>
23
#include <executorch/runtime/core/exec_aten/exec_aten.h>
34
#include <gtest/gtest.h>
45
#include <rnexecutorch/Error.h>
56
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
6-
#include <rnexecutorch/models/semantic_segmentation/Constants.h>
7-
#include <rnexecutorch/models/semantic_segmentation/SemanticSegmentation.h>
7+
#include <rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h>
88
#include <string>
99
#include <vector>
1010

@@ -19,6 +19,18 @@ constexpr auto kValidSemanticSegmentationModelPath =
1919
constexpr auto kValidTestImagePath =
2020
"file:///data/local/tmp/rnexecutorch_tests/test_image.jpg";
2121

22+
// DeepLab V3 class labels (Pascal VOC)
23+
static const std::vector<std::string> kDeeplabV3Labels = {
24+
"BACKGROUND", "AEROPLANE", "BICYCLE", "BIRD", "BOAT",
25+
"BOTTLE", "BUS", "CAR", "CAT", "CHAIR",
26+
"COW", "DININGTABLE", "DOG", "HORSE", "MOTORBIKE",
27+
"PERSON", "POTTEDPLANT", "SHEEP", "SOFA", "TRAIN",
28+
"TVMONITOR"};
29+
30+
// ImageNet normalization constants
31+
static const std::vector<float> kImageNetMean = {0.485f, 0.456f, 0.406f};
32+
static const std::vector<float> kImageNetStd = {0.229f, 0.224f, 0.225f};
33+
2234
static JSTensorViewIn makeRgbView(std::vector<uint8_t> &buf, int32_t h,
2335
int32_t w) {
2436
buf.assign(static_cast<size_t>(h * w * 3), 128);
@@ -30,8 +42,9 @@ static JSTensorViewIn makeRgbView(std::vector<uint8_t> &buf, int32_t h,
3042
class SemanticSegmentationForwardTest : public ::testing::Test {
3143
protected:
3244
void SetUp() override {
33-
model = std::make_unique<SemanticSegmentation>(
34-
kValidSemanticSegmentationModelPath, nullptr);
45+
model = std::make_unique<BaseSemanticSegmentation>(
46+
kValidSemanticSegmentationModelPath, kImageNetMean, kImageNetStd,
47+
kDeeplabV3Labels, nullptr);
3548
auto shapes = model->getAllInputShapes("forward");
3649
ASSERT_FALSE(shapes.empty());
3750
shape = shapes[0];
@@ -47,21 +60,24 @@ class SemanticSegmentationForwardTest : public ::testing::Test {
4760
make_tensor_ptr(sizes, dummyData.data(), exec_aten::ScalarType::Float);
4861
}
4962

50-
std::unique_ptr<SemanticSegmentation> model;
63+
std::unique_ptr<BaseSemanticSegmentation> model;
5164
std::vector<int32_t> shape;
5265
std::vector<float> dummyData;
5366
std::vector<int32_t> sizes;
5467
TensorPtr inputTensor;
5568
};
5669

5770
TEST(SemanticSegmentationCtorTests, InvalidPathThrows) {
58-
EXPECT_THROW(SemanticSegmentation("this_file_does_not_exist.pte", nullptr),
71+
EXPECT_THROW(BaseSemanticSegmentation("this_file_does_not_exist.pte",
72+
kImageNetMean, kImageNetStd,
73+
kDeeplabV3Labels, nullptr),
5974
RnExecutorchError);
6075
}
6176

6277
TEST(SemanticSegmentationCtorTests, ValidPathDoesntThrow) {
63-
EXPECT_NO_THROW(
64-
SemanticSegmentation(kValidSemanticSegmentationModelPath, nullptr));
78+
EXPECT_NO_THROW(BaseSemanticSegmentation(kValidSemanticSegmentationModelPath,
79+
kImageNetMean, kImageNetStd,
80+
kDeeplabV3Labels, nullptr));
6581
}
6682

6783
TEST_F(SemanticSegmentationForwardTest, ForwardWithValidTensorSucceeds) {
@@ -108,40 +124,52 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) {
108124
// generateFromString tests
109125
// ============================================================================
110126
TEST(SemanticSegmentationGenerateTests, InvalidImagePathThrows) {
111-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
127+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
128+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
129+
nullptr);
112130
EXPECT_THROW(
113131
(void)model.generateFromString("nonexistent_image.jpg", {}, true),
114132
RnExecutorchError);
115133
}
116134

117135
TEST(SemanticSegmentationGenerateTests, EmptyImagePathThrows) {
118-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
136+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
137+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
138+
nullptr);
119139
EXPECT_THROW((void)model.generateFromString("", {}, true), RnExecutorchError);
120140
}
121141

122142
TEST(SemanticSegmentationGenerateTests, MalformedURIThrows) {
123-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
143+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
144+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
145+
nullptr);
124146
EXPECT_THROW(
125147
(void)model.generateFromString("not_a_valid_uri://bad", {}, true),
126148
RnExecutorchError);
127149
}
128150

129151
TEST(SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) {
130-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
152+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
153+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
154+
nullptr);
131155
auto result = model.generateFromString(kValidTestImagePath, {}, true);
132156
EXPECT_NE(result.argmax, nullptr);
133157
EXPECT_NE(result.classBuffers, nullptr);
134158
}
135159

136160
TEST(SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) {
137-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
161+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
162+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
163+
nullptr);
138164
auto result = model.generateFromString(kValidTestImagePath, {}, true);
139165
ASSERT_NE(result.classBuffers, nullptr);
140166
EXPECT_EQ(result.classBuffers->size(), 21u);
141167
}
142168

143169
TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) {
144-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
170+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
171+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
172+
nullptr);
145173
std::set<std::string, std::less<>> filter = {"PERSON", "CAT"};
146174
auto result = model.generateFromString(kValidTestImagePath, filter, true);
147175
ASSERT_NE(result.classBuffers, nullptr);
@@ -152,7 +180,9 @@ TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) {
152180
}
153181

154182
TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) {
155-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
183+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
184+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
185+
nullptr);
156186
auto result = model.generateFromString(kValidTestImagePath, {}, false);
157187
EXPECT_NE(result.argmax, nullptr);
158188
}
@@ -161,7 +191,9 @@ TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) {
161191
// generateFromPixels tests
162192
// ============================================================================
163193
TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) {
164-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
194+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
195+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
196+
nullptr);
165197
std::vector<uint8_t> buf;
166198
auto view = makeRgbView(buf, 64, 64);
167199
auto result = model.generateFromPixels(view, {}, true);
@@ -170,7 +202,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) {
170202
}
171203

172204
TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) {
173-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
205+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
206+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
207+
nullptr);
174208
std::vector<uint8_t> buf;
175209
auto view = makeRgbView(buf, 64, 64);
176210
auto result = model.generateFromPixels(view, {}, true);
@@ -179,7 +213,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) {
179213
}
180214

181215
TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) {
182-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
216+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
217+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
218+
nullptr);
183219
std::vector<uint8_t> buf;
184220
auto view = makeRgbView(buf, 64, 64);
185221
std::set<std::string, std::less<>> filter = {"PERSON"};
@@ -194,32 +230,42 @@ TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) {
194230
// Inherited BaseModel tests
195231
// ============================================================================
196232
TEST(SemanticSegmentationInheritedTests, GetInputShapeWorks) {
197-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
233+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
234+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
235+
nullptr);
198236
auto shape = model.getInputShape("forward", 0);
199237
EXPECT_EQ(shape.size(), 4);
200238
EXPECT_EQ(shape[0], 1); // Batch size
201239
EXPECT_EQ(shape[1], 3); // RGB channels
202240
}
203241

204242
TEST(SemanticSegmentationInheritedTests, GetAllInputShapesWorks) {
205-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
243+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
244+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
245+
nullptr);
206246
auto shapes = model.getAllInputShapes("forward");
207247
EXPECT_FALSE(shapes.empty());
208248
}
209249

210250
TEST(SemanticSegmentationInheritedTests, GetMethodMetaWorks) {
211-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
251+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
252+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
253+
nullptr);
212254
auto result = model.getMethodMeta("forward");
213255
EXPECT_TRUE(result.ok());
214256
}
215257

216258
TEST(SemanticSegmentationInheritedTests, GetMemoryLowerBoundReturnsPositive) {
217-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
259+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
260+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
261+
nullptr);
218262
EXPECT_GT(model.getMemoryLowerBound(), 0u);
219263
}
220264

221265
TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) {
222-
SemanticSegmentation model(kValidSemanticSegmentationModelPath, nullptr);
266+
BaseSemanticSegmentation model(kValidSemanticSegmentationModelPath,
267+
kImageNetMean, kImageNetStd, kDeeplabV3Labels,
268+
nullptr);
223269
auto shape = model.getInputShape("forward", 0);
224270
EXPECT_EQ(shape[2], shape[3]); // Height == Width for DeepLabV3
225271
}
@@ -228,29 +274,18 @@ TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) {
228274
// Constants tests
229275
// ============================================================================
230276
TEST(SemanticSegmentationConstantsTests, ClassLabelsHas21Entries) {
231-
EXPECT_EQ(constants::kDeeplabV3Resnet50Labels.size(), 21u);
277+
EXPECT_EQ(kDeeplabV3Labels.size(), 21u);
232278
}
233279

234280
TEST(SemanticSegmentationConstantsTests, ClassLabelsContainExpectedClasses) {
235-
auto &labels = constants::kDeeplabV3Resnet50Labels;
236-
bool hasBackground = false;
237-
bool hasPerson = false;
238-
bool hasCat = false;
239-
bool hasDog = false;
240-
241-
for (const auto &label : labels) {
242-
if (label == "BACKGROUND")
243-
hasBackground = true;
244-
if (label == "PERSON")
245-
hasPerson = true;
246-
if (label == "CAT")
247-
hasCat = true;
248-
if (label == "DOG")
249-
hasDog = true;
250-
}
281+
const auto &labels = kDeeplabV3Labels;
282+
283+
auto contains = [&labels](const std::string &target) {
284+
return std::ranges::find(labels, target) != labels.end();
285+
};
251286

252-
EXPECT_TRUE(hasBackground);
253-
EXPECT_TRUE(hasPerson);
254-
EXPECT_TRUE(hasCat);
255-
EXPECT_TRUE(hasDog);
287+
EXPECT_TRUE(contains("BACKGROUND"));
288+
EXPECT_TRUE(contains("PERSON"));
289+
EXPECT_TRUE(contains("CAT"));
290+
EXPECT_TRUE(contains("DOG"));
256291
}

packages/react-native-executorch/common/rnexecutorch/tests/run_tests.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ TEST_EXECUTABLES=(
3434
"LLMTests"
3535
"TextToImageTests"
3636
"InstanceSegmentationTests"
37+
"SemanticSegmentationTests"
3738
"OCRTests"
3839
"VerticalOCRTests"
3940
)

0 commit comments

Comments
 (0)