Skip to content

Commit c527caa

Browse files
tests: create typed tests for vision models concurrent generates
1 parent 2562214 commit c527caa

File tree

5 files changed

+73
-55
lines changed

5 files changed

+73
-55
lines changed

packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "BaseModelTests.h"
2+
#include "VisionModelTests.h"
23
#include <executorch/runtime/core/exec_aten/exec_aten.h>
34
#include <gtest/gtest.h>
45
#include <rnexecutorch/Error.h>
@@ -38,6 +39,8 @@ template <> struct ModelTraits<Classification> {
3839
using ClassificationTypes = ::testing::Types<Classification>;
3940
INSTANTIATE_TYPED_TEST_SUITE_P(Classification, CommonModelTest,
4041
ClassificationTypes);
42+
INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest,
43+
ClassificationTypes);
4144

4245
// ============================================================================
4346
// Model-specific tests

packages/react-native-executorch/common/rnexecutorch/tests/integration/ImageEmbeddingsTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "BaseModelTests.h"
2+
#include "VisionModelTests.h"
23
#include <cmath>
34
#include <executorch/runtime/core/exec_aten/exec_aten.h>
45
#include <gtest/gtest.h>
@@ -39,6 +40,8 @@ template <> struct ModelTraits<ImageEmbeddings> {
3940
using ImageEmbeddingsTypes = ::testing::Types<ImageEmbeddings>;
4041
INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, CommonModelTest,
4142
ImageEmbeddingsTypes);
43+
INSTANTIATE_TYPED_TEST_SUITE_P(ImageEmbeddings, VisionModelTest,
44+
ImageEmbeddingsTypes);
4245

4346
// ============================================================================
4447
// Model-specific tests

packages/react-native-executorch/common/rnexecutorch/tests/integration/ObjectDetectionTest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include "BaseModelTests.h"
2+
#include "VisionModelTests.h"
23
#include <executorch/extension/tensor/tensor.h>
34
#include <gtest/gtest.h>
45
#include <rnexecutorch/Error.h>
@@ -57,6 +58,8 @@ template <> struct ModelTraits<ObjectDetection> {
5758
using ObjectDetectionTypes = ::testing::Types<ObjectDetection>;
5859
INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, CommonModelTest,
5960
ObjectDetectionTypes);
61+
INSTANTIATE_TYPED_TEST_SUITE_P(ObjectDetection, VisionModelTest,
62+
ObjectDetectionTypes);
6063

6164
// ============================================================================
6265
// Model-specific tests

packages/react-native-executorch/common/rnexecutorch/tests/integration/StyleTransferTest.cpp

Lines changed: 3 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,10 @@
11
#include "BaseModelTests.h"
2-
#include <atomic>
2+
#include "VisionModelTests.h"
33
#include <executorch/runtime/core/exec_aten/exec_aten.h>
44
#include <gtest/gtest.h>
55
#include <rnexecutorch/Error.h>
66
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
77
#include <rnexecutorch/models/style_transfer/StyleTransfer.h>
8-
#include <thread>
98
#include <variant>
109

1110
using namespace rnexecutorch;
@@ -48,6 +47,8 @@ template <> struct ModelTraits<StyleTransfer> {
4847
using StyleTransferTypes = ::testing::Types<StyleTransfer>;
4948
INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, CommonModelTest,
5049
StyleTransferTypes);
50+
INSTANTIATE_TYPED_TEST_SUITE_P(StyleTransfer, VisionModelTest,
51+
StyleTransferTypes);
5152

5253
// ============================================================================
5354
// generateFromString tests
@@ -79,17 +80,6 @@ TEST(StyleTransferGenerateTests, ValidImageReturnsFilePath) {
7980
EXPECT_GT(pr.height, 0);
8081
}
8182

82-
TEST(StyleTransferGenerateTests, MultipleGeneratesWork) {
83-
StyleTransfer model(kValidStyleTransferModelPath, nullptr);
84-
EXPECT_NO_THROW((void)model.generateFromString(kValidTestImagePath, false));
85-
auto result1 = model.generateFromString(kValidTestImagePath, false);
86-
auto result2 = model.generateFromString(kValidTestImagePath, false);
87-
ASSERT_TRUE(std::holds_alternative<PixelDataResult>(result1));
88-
ASSERT_TRUE(std::holds_alternative<PixelDataResult>(result2));
89-
EXPECT_NE(std::get<PixelDataResult>(result1).dataPtr, nullptr);
90-
EXPECT_NE(std::get<PixelDataResult>(result2).dataPtr, nullptr);
91-
}
92-
9383
// ============================================================================
9484
// generateFromString saveToFile tests
9585
// ============================================================================
@@ -173,48 +163,6 @@ TEST(StyleTransferPixelTests, OutputDimensionsMatchInputSize) {
173163
EXPECT_EQ(pr.height, 64);
174164
}
175165

176-
// ============================================================================
177-
// Thread safety tests
178-
// ============================================================================
179-
TEST(StyleTransferThreadSafetyTests, TwoConcurrentGeneratesDoNotCrash) {
180-
StyleTransfer model(kValidStyleTransferModelPath, nullptr);
181-
std::atomic<int32_t> successCount{0};
182-
std::atomic<int32_t> exceptionCount{0};
183-
184-
auto task = [&]() {
185-
try {
186-
(void)model.generateFromString(kValidTestImagePath, false);
187-
successCount++;
188-
} catch (const RnExecutorchError &) {
189-
exceptionCount++;
190-
}
191-
};
192-
193-
std::thread a(task);
194-
std::thread b(task);
195-
a.join();
196-
b.join();
197-
198-
EXPECT_EQ(successCount + exceptionCount, 2);
199-
}
200-
201-
TEST(StyleTransferThreadSafetyTests,
202-
GenerateAndUnloadConcurrentlyDoesNotCrash) {
203-
StyleTransfer model(kValidStyleTransferModelPath, nullptr);
204-
205-
std::thread a([&]() {
206-
try {
207-
(void)model.generateFromString(kValidTestImagePath, false);
208-
} catch (const RnExecutorchError &) {
209-
}
210-
});
211-
std::thread b([&]() { model.unload(); });
212-
213-
a.join();
214-
b.join();
215-
// If we reach here without crashing, the mutex serialized correctly.
216-
}
217-
218166
// ============================================================================
219167
// Inherited BaseModel tests
220168
// ============================================================================
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#pragma once
2+
3+
#include "BaseModelTests.h"
4+
#include <atomic>
5+
#include <gtest/gtest.h>
6+
#include <rnexecutorch/Error.h>
7+
#include <thread>
8+
9+
namespace model_tests {
10+
11+
template <typename T> class VisionModelTest : public ::testing::Test {
12+
protected:
13+
using Traits = ModelTraits<T>;
14+
using ModelType = typename Traits::ModelType;
15+
};
16+
17+
TYPED_TEST_SUITE_P(VisionModelTest);
18+
19+
TYPED_TEST_P(VisionModelTest, TwoConcurrentGeneratesDoNotCrash) {
20+
SETUP_TRAITS();
21+
auto model = Traits::createValid();
22+
std::atomic<int32_t> successCount{0};
23+
std::atomic<int32_t> exceptionCount{0};
24+
25+
auto task = [&]() {
26+
try {
27+
Traits::callGenerate(model);
28+
successCount++;
29+
} catch (const rnexecutorch::RnExecutorchError &) {
30+
exceptionCount++;
31+
}
32+
};
33+
34+
std::thread a(task);
35+
std::thread b(task);
36+
a.join();
37+
b.join();
38+
39+
EXPECT_EQ(successCount + exceptionCount, 2);
40+
}
41+
42+
TYPED_TEST_P(VisionModelTest, GenerateAndUnloadConcurrentlyDoesNotCrash) {
43+
SETUP_TRAITS();
44+
auto model = Traits::createValid();
45+
46+
std::thread a([&]() {
47+
try {
48+
Traits::callGenerate(model);
49+
} catch (const rnexecutorch::RnExecutorchError &) {
50+
}
51+
});
52+
std::thread b([&]() { model.unload(); });
53+
54+
a.join();
55+
b.join();
56+
}
57+
58+
REGISTER_TYPED_TEST_SUITE_P(VisionModelTest, TwoConcurrentGeneratesDoNotCrash,
59+
GenerateAndUnloadConcurrentlyDoesNotCrash);
60+
61+
} // namespace model_tests

0 commit comments

Comments
 (0)