Skip to content

Commit 131eb70

Browse files
feat: make classification module a vision labeled module
1 parent 3863425 commit 131eb70

File tree

11 files changed

+1358
-112
lines changed

11 files changed

+1358
-112
lines changed

.cspell-wordlist.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,45 @@ metaprogramming
129129
ktlint
130130
lefthook
131131
espeak
132+
KOMODO
133+
DUNGENESS
134+
SHIH
135+
RIDGEBACK
136+
BLUETICK
137+
REDBONE
138+
IBIZAN
139+
OTTERHOUND
140+
BULLTERRIER
141+
BEDLINGTON
142+
SEALYHAM
143+
DANDIE
144+
DINMONT
145+
VIZSLA
146+
CLUMBER
147+
MALINOIS
148+
KOMONDOR
149+
BOUVIER
150+
FLANDRES
151+
APPENZELLER
152+
ENTLEBUCHER
153+
LEONBERG
154+
BRABANCON
155+
LYCAENID
156+
PATAS
157+
INDRI
158+
BARRACOUTA
159+
ABAYA
160+
BOTTLECAP
161+
CHAINLINK
162+
GASMASK
163+
GOLFCART
164+
HOOPSKIRT
165+
LUMBERMILL
166+
PADDLEWHEEL
167+
PICKELHAUBE
168+
CARBONARA
169+
GYROMITRA
170+
BOLETE
132171
NCHW
133172
həlˈO
134173
wˈɜɹld

docs/docs/03-hooks/02-computer-vision/useClassification.md

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,13 @@ try {
3737

3838
`useClassification` takes [`ClassificationProps`](../../06-api-reference/interfaces/ClassificationProps.md) that consists of:
3939

40-
- `model` containing [`modelSource`](../../06-api-reference/interfaces/ClassificationProps.md#modelsource).
40+
- `model` - An object containing:
41+
- `modelName` - The name of a built-in model. See [`ClassificationModelSources`](../../06-api-reference/interfaces/ClassificationProps.md) for the list of supported models.
42+
- `modelSource` - The location of the model binary (a URL or a bundled resource).
4143
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/ClassificationProps.md#preventload) which prevents auto-loading of the model.
4244

45+
The hook is generic over the model config — TypeScript automatically infers the correct label type based on the `modelName` you provide. No explicit generic parameter is needed.
46+
4347
You need more details? Check the following resources:
4448

4549
- For detailed information about `useClassification` arguments check this section: [`useClassification` arguments](../../06-api-reference/functions/useClassification.md#parameters).
@@ -48,11 +52,17 @@ You need more details? Check the following resources:
4852

4953
### Returns
5054

51-
`useClassification` returns an object called `ClassificationType` containing bunch of functions to interact with Classification models. To get more details please read: [`ClassificationType` API Reference](../../06-api-reference/interfaces/ClassificationType.md).
55+
`useClassification` returns a [`ClassificationType`](../../06-api-reference/interfaces/ClassificationType.md) object containing:
56+
57+
- `isReady` - Whether the model is loaded and ready to process images.
58+
- `isGenerating` - Whether the model is currently processing an image.
59+
- `error` - An error object if the model failed to load or encountered a runtime error.
60+
- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary.
61+
- `forward` - A function to run inference on an image.
5262

5363
## Running the model
5464

55-
To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object containing categories with their probabilities.
65+
To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object mapping label keys to their probabilities.
5666

5767
:::info
5868
Images from external sources are stored in your application's temporary directory.

docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,45 @@ All methods of `ClassificationModule` are explained in details here: [`Classific
3333

3434
## Loading the model
3535

36-
To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory with the following parameters:
37-
38-
- `namedSources` - Object containing:
39-
- `modelName` - Model name identifier.
40-
- `modelSource` - Location of the model binary.
41-
42-
- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1).
43-
44-
The factory returns a promise that resolves to a loaded `ClassificationModule` instance.
36+
Use the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory method. It accepts a model config object (e.g. `EFFICIENTNET_V2_S`) and an optional `onDownloadProgress` callback. It returns a promise resolving to a `ClassificationModule` instance.
4537

4638
For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
4739

4840
## Running the model
4941

50-
To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object containing categories with their probabilities.
42+
To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object mapping label keys to their probabilities.
5143

5244
For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
5345

46+
## Using a custom model
47+
48+
Use [`fromCustomModel`](../../06-api-reference/classes/ClassificationModule.md#fromcustommodel) to load your own exported model binary instead of a built-in preset.
49+
50+
```typescript
51+
import { ClassificationModule } from 'react-native-executorch';
52+
53+
const MyLabels = { CAT: 0, DOG: 1, BIRD: 2 } as const;
54+
55+
const classifier = await ClassificationModule.fromCustomModel(
56+
'https://example.com/custom_classifier.pte',
57+
{ labelMap: MyLabels },
58+
(progress) => console.log(progress)
59+
);
60+
61+
const result = await classifier.forward(imageUri);
62+
// result is typed as Record<'CAT' | 'DOG' | 'BIRD', number>
63+
```
64+
65+
### Required model contract
66+
67+
The `.pte` binary must expose a single `forward` method with the following interface:
68+
69+
**Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. H and W are read from the model's declared input shape at load time.
70+
71+
**Output:** one `float32` tensor of shape `[1, C]` containing raw logits — one value per class, in the same order as the entries in your `labelMap`. Softmax is applied by the native runtime.
72+
73+
Preprocessing (resize → normalize) is handled by the native runtime — your model only needs to produce the raw logits.
74+
5475
## Managing memory
5576

56-
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ClassificationModule.md#forward) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) after [`delete`](../../06-api-reference/classes/ClassificationModule.md#forward) unless you load the module again.
77+
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ClassificationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) after [`delete`](../../06-api-reference/classes/ClassificationModule.md#delete) unless you load the module again.

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,26 @@
11
#include "Classification.h"
22

3-
#include <future>
4-
53
#include <rnexecutorch/Error.h>
64
#include <rnexecutorch/ErrorCodes.h>
75
#include <rnexecutorch/data_processing/ImageProcessing.h>
86
#include <rnexecutorch/data_processing/Numerical.h>
9-
#include <rnexecutorch/models/classification/Constants.h>
107

118
namespace rnexecutorch::models::classification {
129

1310
Classification::Classification(const std::string &modelSource,
11+
std::vector<float> normMean,
12+
std::vector<float> normStd,
13+
std::vector<std::string> labelNames,
1414
std::shared_ptr<react::CallInvoker> callInvoker)
15-
: VisionModel(modelSource, callInvoker) {
15+
: VisionModel(modelSource, callInvoker),
16+
labelNames_(std::move(labelNames)) {
17+
if (normMean.size() == 3) {
18+
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
19+
}
20+
if (normStd.size() == 3) {
21+
normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]);
22+
}
23+
1624
auto inputShapes = getAllInputShapes();
1725
if (inputShapes.size() == 0) {
1826
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
@@ -37,7 +45,11 @@ Classification::runInference(cv::Mat image) {
3745
cv::Mat preprocessed = preprocess(image);
3846

3947
auto inputTensor =
40-
image_processing::getTensorFromMatrix(modelInputShape_, preprocessed);
48+
(normMean_ && normStd_)
49+
? image_processing::getTensorFromMatrix(
50+
modelInputShape_, preprocessed, *normMean_, *normStd_)
51+
: image_processing::getTensorFromMatrix(modelInputShape_,
52+
preprocessed);
4153

4254
auto forwardResult = BaseModel::forward(inputTensor);
4355
if (!forwardResult.ok()) {
@@ -78,13 +90,13 @@ Classification::postprocess(const Tensor &tensor) {
7890
static_cast<const float *>(tensor.const_data_ptr()), tensor.numel());
7991
std::vector<float> resultVec(resultData.begin(), resultData.end());
8092

81-
if (resultVec.size() != constants::kImagenet1kV1Labels.size()) {
93+
if (resultVec.size() != labelNames_.size()) {
8294
char errorMessage[100];
8395
std::snprintf(
8496
errorMessage, sizeof(errorMessage),
8597
"Unexpected classification output size, was expecting: %zu classes "
8698
"but got: %zu classes",
87-
constants::kImagenet1kV1Labels.size(), resultVec.size());
99+
labelNames_.size(), resultVec.size());
88100
throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput,
89101
errorMessage);
90102
}
@@ -93,7 +105,7 @@ Classification::postprocess(const Tensor &tensor) {
93105

94106
std::unordered_map<std::string_view, float> probs;
95107
for (std::size_t cl = 0; cl < resultVec.size(); ++cl) {
96-
probs[constants::kImagenet1kV1Labels[cl]] = resultVec[cl];
108+
probs[labelNames_[cl]] = resultVec[cl];
97109
}
98110

99111
return probs;

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <optional>
34
#include <unordered_map>
45

56
#include <executorch/extension/tensor/tensor_ptr.h>
@@ -16,7 +17,9 @@ using executorch::extension::TensorPtr;
1617

1718
class Classification : public VisionModel {
1819
public:
19-
Classification(const std::string &modelSource,
20+
Classification(const std::string &modelSource, std::vector<float> normMean,
21+
std::vector<float> normStd,
22+
std::vector<std::string> labelNames,
2023
std::shared_ptr<react::CallInvoker> callInvoker);
2124

2225
[[nodiscard("Registered non-void function")]] std::unordered_map<
@@ -35,9 +38,15 @@ class Classification : public VisionModel {
3538
std::unordered_map<std::string_view, float> runInference(cv::Mat image);
3639

3740
std::unordered_map<std::string_view, float> postprocess(const Tensor &tensor);
41+
42+
std::vector<std::string> labelNames_;
43+
std::optional<cv::Scalar> normMean_;
44+
std::optional<cv::Scalar> normStd_;
3845
};
3946
} // namespace models::classification
4047

4148
REGISTER_CONSTRUCTOR(models::classification::Classification, std::string,
49+
std::vector<float>, std::vector<float>,
50+
std::vector<std::string>,
4251
std::shared_ptr<react::CallInvoker>);
43-
} // namespace rnexecutorch
52+
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/tests/integration/ClassificationTest.cpp

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,18 @@ constexpr auto kValidClassificationModelPath = "efficientnet_v2_s_xnnpack.pte";
1515
constexpr auto kValidTestImagePath =
1616
"file:///data/local/tmp/rnexecutorch_tests/test_image.jpg";
1717

18+
static std::vector<float> kImagenetNormMean = {0.485f, 0.456f, 0.406f};
19+
static std::vector<float> kImagenetNormStd = {0.229f, 0.224f, 0.225f};
20+
21+
static std::vector<std::string> getImagenetLabelNames() {
22+
std::vector<std::string> names;
23+
names.reserve(constants::kImagenet1kV1Labels.size());
24+
for (const auto &label : constants::kImagenet1kV1Labels) {
25+
names.emplace_back(label);
26+
}
27+
return names;
28+
}
29+
1830
// ============================================================================
1931
// Common tests via typed test suite
2032
// ============================================================================
@@ -23,11 +35,12 @@ template <> struct ModelTraits<Classification> {
2335
using ModelType = Classification;
2436

2537
static ModelType createValid() {
26-
return ModelType(kValidClassificationModelPath, nullptr);
38+
return ModelType(kValidClassificationModelPath, kImagenetNormMean,
39+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
2740
}
2841

2942
static ModelType createInvalid() {
30-
return ModelType("nonexistent.pte", nullptr);
43+
return ModelType("nonexistent.pte", {}, {}, {}, nullptr);
3144
}
3245

3346
static void callGenerate(ModelType &model) {
@@ -46,37 +59,43 @@ INSTANTIATE_TYPED_TEST_SUITE_P(Classification, VisionModelTest,
4659
// Model-specific tests
4760
// ============================================================================
4861
TEST(ClassificationGenerateTests, InvalidImagePathThrows) {
49-
Classification model(kValidClassificationModelPath, nullptr);
62+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
63+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
5064
EXPECT_THROW((void)model.generateFromString("nonexistent_image.jpg"),
5165
RnExecutorchError);
5266
}
5367

5468
TEST(ClassificationGenerateTests, EmptyImagePathThrows) {
55-
Classification model(kValidClassificationModelPath, nullptr);
69+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
70+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
5671
EXPECT_THROW((void)model.generateFromString(""), RnExecutorchError);
5772
}
5873

5974
TEST(ClassificationGenerateTests, MalformedURIThrows) {
60-
Classification model(kValidClassificationModelPath, nullptr);
75+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
76+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
6177
EXPECT_THROW((void)model.generateFromString("not_a_valid_uri://bad"),
6278
RnExecutorchError);
6379
}
6480

6581
TEST(ClassificationGenerateTests, ValidImageReturnsResults) {
66-
Classification model(kValidClassificationModelPath, nullptr);
82+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
83+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
6784
auto results = model.generateFromString(kValidTestImagePath);
6885
EXPECT_FALSE(results.empty());
6986
}
7087

7188
TEST(ClassificationGenerateTests, ResultsHaveCorrectSize) {
72-
Classification model(kValidClassificationModelPath, nullptr);
89+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
90+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
7391
auto results = model.generateFromString(kValidTestImagePath);
7492
auto expectedNumClasses = constants::kImagenet1kV1Labels.size();
7593
EXPECT_EQ(results.size(), expectedNumClasses);
7694
}
7795

7896
TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) {
79-
Classification model(kValidClassificationModelPath, nullptr);
97+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
98+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
8099
auto results = model.generateFromString(kValidTestImagePath);
81100

82101
float sum = 0.0f;
@@ -89,7 +108,8 @@ TEST(ClassificationGenerateTests, ResultsContainValidProbabilities) {
89108
}
90109

91110
TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) {
92-
Classification model(kValidClassificationModelPath, nullptr);
111+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
112+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
93113
auto results = model.generateFromString(kValidTestImagePath);
94114

95115
float maxProb = 0.0f;
@@ -101,22 +121,32 @@ TEST(ClassificationGenerateTests, TopPredictionHasReasonableConfidence) {
101121
EXPECT_GT(maxProb, 0.0f);
102122
}
103123

124+
TEST(ClassificationGenerateTests, WrongLabelCountThrows) {
125+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
126+
kImagenetNormStd, {"A", "B", "C"}, nullptr);
127+
EXPECT_THROW((void)model.generateFromString(kValidTestImagePath),
128+
RnExecutorchError);
129+
}
130+
104131
TEST(ClassificationInheritedTests, GetInputShapeWorks) {
105-
Classification model(kValidClassificationModelPath, nullptr);
132+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
133+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
106134
auto shape = model.getInputShape("forward", 0);
107135
EXPECT_EQ(shape.size(), 4);
108136
EXPECT_EQ(shape[0], 1);
109137
EXPECT_EQ(shape[1], 3);
110138
}
111139

112140
TEST(ClassificationInheritedTests, GetAllInputShapesWorks) {
113-
Classification model(kValidClassificationModelPath, nullptr);
141+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
142+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
114143
auto shapes = model.getAllInputShapes("forward");
115144
EXPECT_FALSE(shapes.empty());
116145
}
117146

118147
TEST(ClassificationInheritedTests, GetMethodMetaWorks) {
119-
Classification model(kValidClassificationModelPath, nullptr);
148+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
149+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
120150
auto result = model.getMethodMeta("forward");
121151
EXPECT_TRUE(result.ok());
122152
}
@@ -125,7 +155,8 @@ TEST(ClassificationInheritedTests, GetMethodMetaWorks) {
125155
// generateFromPixels smoke test
126156
// ============================================================================
127157
TEST(ClassificationPixelTests, ValidPixelsReturnsResults) {
128-
Classification model(kValidClassificationModelPath, nullptr);
158+
Classification model(kValidClassificationModelPath, kImagenetNormMean,
159+
kImagenetNormStd, getImagenetLabelNames(), nullptr);
129160
std::vector<uint8_t> buf(64 * 64 * 3, 128);
130161
JSTensorViewIn view{
131162
buf.data(), {64, 64, 3}, executorch::aten::ScalarType::Byte};

0 commit comments

Comments
 (0)