1+ #include < algorithm>
12#include < executorch/extension/tensor/tensor.h>
23#include < executorch/runtime/core/exec_aten/exec_aten.h>
34#include < gtest/gtest.h>
45#include < rnexecutorch/Error.h>
56#include < rnexecutorch/host_objects/JSTensorViewIn.h>
6- #include < rnexecutorch/models/semantic_segmentation/Constants.h>
7- #include < rnexecutorch/models/semantic_segmentation/SemanticSegmentation.h>
7+ #include < rnexecutorch/models/semantic_segmentation/BaseSemanticSegmentation.h>
88#include < string>
99#include < vector>
1010
@@ -19,6 +19,18 @@ constexpr auto kValidSemanticSegmentationModelPath =
1919constexpr auto kValidTestImagePath =
2020 " file:///data/local/tmp/rnexecutorch_tests/test_image.jpg" ;
2121
22+ // DeepLab V3 class labels (Pascal VOC)
23+ static const std::vector<std::string> kDeeplabV3Labels = {
24+ " BACKGROUND" , " AEROPLANE" , " BICYCLE" , " BIRD" , " BOAT" ,
25+ " BOTTLE" , " BUS" , " CAR" , " CAT" , " CHAIR" ,
26+ " COW" , " DININGTABLE" , " DOG" , " HORSE" , " MOTORBIKE" ,
27+ " PERSON" , " POTTEDPLANT" , " SHEEP" , " SOFA" , " TRAIN" ,
28+ " TVMONITOR" };
29+
30+ // ImageNet normalization constants
31+ static const std::vector<float > kImageNetMean = {0 .485f , 0 .456f , 0 .406f };
32+ static const std::vector<float > kImageNetStd = {0 .229f , 0 .224f , 0 .225f };
33+
2234static JSTensorViewIn makeRgbView (std::vector<uint8_t > &buf, int32_t h,
2335 int32_t w) {
2436 buf.assign (static_cast <size_t >(h * w * 3 ), 128 );
@@ -30,8 +42,9 @@ static JSTensorViewIn makeRgbView(std::vector<uint8_t> &buf, int32_t h,
3042class SemanticSegmentationForwardTest : public ::testing::Test {
3143protected:
3244 void SetUp () override {
33- model = std::make_unique<SemanticSegmentation>(
34- kValidSemanticSegmentationModelPath , nullptr );
45+ model = std::make_unique<BaseSemanticSegmentation>(
46+ kValidSemanticSegmentationModelPath , kImageNetMean , kImageNetStd ,
47+ kDeeplabV3Labels , nullptr );
3548 auto shapes = model->getAllInputShapes (" forward" );
3649 ASSERT_FALSE (shapes.empty ());
3750 shape = shapes[0 ];
@@ -47,21 +60,24 @@ class SemanticSegmentationForwardTest : public ::testing::Test {
4760 make_tensor_ptr (sizes, dummyData.data (), exec_aten::ScalarType::Float);
4861 }
4962
50- std::unique_ptr<SemanticSegmentation > model;
63+ std::unique_ptr<BaseSemanticSegmentation > model;
5164 std::vector<int32_t > shape;
5265 std::vector<float > dummyData;
5366 std::vector<int32_t > sizes;
5467 TensorPtr inputTensor;
5568};
5669
5770TEST (SemanticSegmentationCtorTests, InvalidPathThrows) {
58- EXPECT_THROW (SemanticSegmentation (" this_file_does_not_exist.pte" , nullptr ),
71+ EXPECT_THROW (BaseSemanticSegmentation (" this_file_does_not_exist.pte" ,
72+ kImageNetMean , kImageNetStd ,
73+ kDeeplabV3Labels , nullptr ),
5974 RnExecutorchError);
6075}
6176
6277TEST (SemanticSegmentationCtorTests, ValidPathDoesntThrow) {
63- EXPECT_NO_THROW (
64- SemanticSegmentation (kValidSemanticSegmentationModelPath , nullptr ));
78+ EXPECT_NO_THROW (BaseSemanticSegmentation (kValidSemanticSegmentationModelPath ,
79+ kImageNetMean , kImageNetStd ,
80+ kDeeplabV3Labels , nullptr ));
6581}
6682
6783TEST_F (SemanticSegmentationForwardTest, ForwardWithValidTensorSucceeds) {
@@ -108,40 +124,52 @@ TEST_F(SemanticSegmentationForwardTest, ForwardAfterUnloadThrows) {
108124// generateFromString tests
109125// ============================================================================
110126TEST (SemanticSegmentationGenerateTests, InvalidImagePathThrows) {
111- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
127+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
128+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
129+ nullptr );
112130 EXPECT_THROW (
113131 (void )model.generateFromString (" nonexistent_image.jpg" , {}, true ),
114132 RnExecutorchError);
115133}
116134
117135TEST (SemanticSegmentationGenerateTests, EmptyImagePathThrows) {
118- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
136+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
137+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
138+ nullptr );
119139 EXPECT_THROW ((void )model.generateFromString (" " , {}, true ), RnExecutorchError);
120140}
121141
122142TEST (SemanticSegmentationGenerateTests, MalformedURIThrows) {
123- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
143+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
144+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
145+ nullptr );
124146 EXPECT_THROW (
125147 (void )model.generateFromString (" not_a_valid_uri://bad" , {}, true ),
126148 RnExecutorchError);
127149}
128150
129151TEST (SemanticSegmentationGenerateTests, ValidImageNoFilterReturnsResult) {
130- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
152+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
153+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
154+ nullptr );
131155 auto result = model.generateFromString (kValidTestImagePath , {}, true );
132156 EXPECT_NE (result.argmax , nullptr );
133157 EXPECT_NE (result.classBuffers , nullptr );
134158}
135159
136160TEST (SemanticSegmentationGenerateTests, ValidImageReturnsAllClasses) {
137- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
161+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
162+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
163+ nullptr );
138164 auto result = model.generateFromString (kValidTestImagePath , {}, true );
139165 ASSERT_NE (result.classBuffers , nullptr );
140166 EXPECT_EQ (result.classBuffers ->size (), 21u );
141167}
142168
143169TEST (SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) {
144- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
170+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
171+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
172+ nullptr );
145173 std::set<std::string, std::less<>> filter = {" PERSON" , " CAT" };
146174 auto result = model.generateFromString (kValidTestImagePath , filter, true );
147175 ASSERT_NE (result.classBuffers , nullptr );
@@ -152,7 +180,9 @@ TEST(SemanticSegmentationGenerateTests, ClassFilterLimitsClassBuffers) {
152180}
153181
154182TEST (SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) {
155- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
183+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
184+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
185+ nullptr );
156186 auto result = model.generateFromString (kValidTestImagePath , {}, false );
157187 EXPECT_NE (result.argmax , nullptr );
158188}
@@ -161,7 +191,9 @@ TEST(SemanticSegmentationGenerateTests, ResizeFalseReturnsResult) {
161191// generateFromPixels tests
162192// ============================================================================
163193TEST (SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) {
164- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
194+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
195+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
196+ nullptr );
165197 std::vector<uint8_t > buf;
166198 auto view = makeRgbView (buf, 64 , 64 );
167199 auto result = model.generateFromPixels (view, {}, true );
@@ -170,7 +202,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsNoFilterReturnsResult) {
170202}
171203
172204TEST (SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) {
173- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
205+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
206+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
207+ nullptr );
174208 std::vector<uint8_t > buf;
175209 auto view = makeRgbView (buf, 64 , 64 );
176210 auto result = model.generateFromPixels (view, {}, true );
@@ -179,7 +213,9 @@ TEST(SemanticSegmentationPixelTests, ValidPixelsReturnsAllClasses) {
179213}
180214
181215TEST (SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) {
182- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
216+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
217+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
218+ nullptr );
183219 std::vector<uint8_t > buf;
184220 auto view = makeRgbView (buf, 64 , 64 );
185221 std::set<std::string, std::less<>> filter = {" PERSON" };
@@ -194,32 +230,42 @@ TEST(SemanticSegmentationPixelTests, ClassFilterLimitsClassBuffers) {
194230// Inherited BaseModel tests
195231// ============================================================================
196232TEST (SemanticSegmentationInheritedTests, GetInputShapeWorks) {
197- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
233+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
234+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
235+ nullptr );
198236 auto shape = model.getInputShape (" forward" , 0 );
199237 EXPECT_EQ (shape.size (), 4 );
200238 EXPECT_EQ (shape[0 ], 1 ); // Batch size
201239 EXPECT_EQ (shape[1 ], 3 ); // RGB channels
202240}
203241
204242TEST (SemanticSegmentationInheritedTests, GetAllInputShapesWorks) {
205- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
243+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
244+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
245+ nullptr );
206246 auto shapes = model.getAllInputShapes (" forward" );
207247 EXPECT_FALSE (shapes.empty ());
208248}
209249
210250TEST (SemanticSegmentationInheritedTests, GetMethodMetaWorks) {
211- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
251+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
252+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
253+ nullptr );
212254 auto result = model.getMethodMeta (" forward" );
213255 EXPECT_TRUE (result.ok ());
214256}
215257
216258TEST (SemanticSegmentationInheritedTests, GetMemoryLowerBoundReturnsPositive) {
217- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
259+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
260+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
261+ nullptr );
218262 EXPECT_GT (model.getMemoryLowerBound (), 0u );
219263}
220264
221265TEST (SemanticSegmentationInheritedTests, InputShapeIsSquare) {
222- SemanticSegmentation model (kValidSemanticSegmentationModelPath , nullptr );
266+ BaseSemanticSegmentation model (kValidSemanticSegmentationModelPath ,
267+ kImageNetMean , kImageNetStd , kDeeplabV3Labels ,
268+ nullptr );
223269 auto shape = model.getInputShape (" forward" , 0 );
224270 EXPECT_EQ (shape[2 ], shape[3 ]); // Height == Width for DeepLabV3
225271}
@@ -228,29 +274,18 @@ TEST(SemanticSegmentationInheritedTests, InputShapeIsSquare) {
228274// Constants tests
229275// ============================================================================
230276TEST (SemanticSegmentationConstantsTests, ClassLabelsHas21Entries) {
231- EXPECT_EQ (constants:: kDeeplabV3Resnet50Labels .size (), 21u );
277+ EXPECT_EQ (kDeeplabV3Labels .size (), 21u );
232278}
233279
234280TEST (SemanticSegmentationConstantsTests, ClassLabelsContainExpectedClasses) {
235- auto &labels = constants::kDeeplabV3Resnet50Labels ;
236- bool hasBackground = false ;
237- bool hasPerson = false ;
238- bool hasCat = false ;
239- bool hasDog = false ;
240-
241- for (const auto &label : labels) {
242- if (label == " BACKGROUND" )
243- hasBackground = true ;
244- if (label == " PERSON" )
245- hasPerson = true ;
246- if (label == " CAT" )
247- hasCat = true ;
248- if (label == " DOG" )
249- hasDog = true ;
250- }
281+ const auto &labels = kDeeplabV3Labels ;
282+
283+ auto contains = [&labels](const std::string &target) {
284+ return std::ranges::find (labels, target) != labels.end ();
285+ };
251286
252- EXPECT_TRUE (hasBackground );
253- EXPECT_TRUE (hasPerson );
254- EXPECT_TRUE (hasCat );
255- EXPECT_TRUE (hasDog );
287+ EXPECT_TRUE (contains ( " BACKGROUND " ) );
288+ EXPECT_TRUE (contains ( " PERSON " ) );
289+ EXPECT_TRUE (contains ( " CAT " ) );
290+ EXPECT_TRUE (contains ( " DOG " ) );
256291}
0 commit comments