Skip to content

Commit 6c1dea0

Browse files
feat!: make classification module a vision labeled module (#986)
## Description - Refactors `ClassificationModule` into a type-safe, labeled vision module — `forward()` now returns `Record<'GOLDEN_RETRIEVER' | 'TABBY' | ..., number>` instead of `Record<string, number>` - Adds `Imagenet1kLabel` enum (1000 ImageNet class names) as a TS constant, replacing the hardcoded C++ `Constants.h` label array - Makes normalization (`normMean`/`normStd`) and label names configurable from JS, passed through to the native `Classification` constructor - Introduces `fromCustomModel` factory accepting a `ClassificationConfig` with a user-provided `labelMap` and optional preprocessing params - Adds `VisionLabeledModule` base class for shared label-mapping logic ### Introduces a breaking change? - [x] Yes - [ ] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [x] Other (chores, tests, code style improvements etc.) ### Tested on - [x] iOS - [x] Android ### Testing instructions Run classification in computer vision example app ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [x] I have performed a self-review of my code - [x] I have commented my code, particularly in hard-to-understand areas - [x] I have updated the documentation accordingly - [x] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent 2d535a4 commit 6c1dea0

File tree

12 files changed

+1369
-1141
lines changed

12 files changed

+1369
-1141
lines changed

.cspell-wordlist.txt

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,45 @@ metaprogramming
129129
ktlint
130130
lefthook
131131
espeak
132+
KOMODO
133+
DUNGENESS
134+
SHIH
135+
RIDGEBACK
136+
BLUETICK
137+
REDBONE
138+
IBIZAN
139+
OTTERHOUND
140+
BULLTERRIER
141+
BEDLINGTON
142+
SEALYHAM
143+
DANDIE
144+
DINMONT
145+
VIZSLA
146+
CLUMBER
147+
MALINOIS
148+
KOMONDOR
149+
BOUVIER
150+
FLANDRES
151+
APPENZELLER
152+
ENTLEBUCHER
153+
LEONBERG
154+
BRABANCON
155+
LYCAENID
156+
PATAS
157+
INDRI
158+
BARRACOUTA
159+
ABAYA
160+
BOTTLECAP
161+
CHAINLINK
162+
GASMASK
163+
GOLFCART
164+
HOOPSKIRT
165+
LUMBERMILL
166+
PADDLEWHEEL
167+
PICKELHAUBE
168+
CARBONARA
169+
GYROMITRA
170+
BOLETE
132171
NCHW
133172
həlˈO
134173
wˈɜɹld

docs/docs/03-hooks/02-computer-vision/useClassification.md

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,9 +37,13 @@ try {
3737

3838
`useClassification` takes [`ClassificationProps`](../../06-api-reference/interfaces/ClassificationProps.md) that consists of:
3939

40-
- `model` containing [`modelSource`](../../06-api-reference/interfaces/ClassificationProps.md#modelsource).
40+
- `model` - An object containing:
41+
- `modelName` - The name of a built-in model. See [`ClassificationModelSources`](../../06-api-reference/interfaces/ClassificationProps.md) for the list of supported models.
42+
- `modelSource` - The location of the model binary (a URL or a bundled resource).
4143
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/ClassificationProps.md#preventload) which prevents auto-loading of the model.
4244

45+
The hook is generic over the model config — TypeScript automatically infers the correct label type based on the `modelName` you provide. No explicit generic parameter is needed.
46+
4347
You need more details? Check the following resources:
4448

4549
- For detailed information about `useClassification` arguments check this section: [`useClassification` arguments](../../06-api-reference/functions/useClassification.md#parameters).
@@ -48,11 +52,17 @@ You need more details? Check the following resources:
4852

4953
### Returns
5054

51-
`useClassification` returns an object called `ClassificationType` containing bunch of functions to interact with Classification models. To get more details please read: [`ClassificationType` API Reference](../../06-api-reference/interfaces/ClassificationType.md).
55+
`useClassification` returns a [`ClassificationType`](../../06-api-reference/interfaces/ClassificationType.md) object containing:
56+
57+
- `isReady` - Whether the model is loaded and ready to process images.
58+
- `isGenerating` - Whether the model is currently processing an image.
59+
- `error` - An error object if the model failed to load or encountered a runtime error.
60+
- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary.
61+
- `forward` - A function to run inference on an image.
5262

5363
## Running the model
5464

55-
To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object containing categories with their probabilities.
65+
To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object mapping label keys to their probabilities.
5666

5767
:::info
5868
Images from external sources are stored in your application's temporary directory.
@@ -90,6 +100,6 @@ function App() {
90100

91101
## Supported models
92102

93-
| Model | Number of classes | Class list | Quantized |
94-
| ------------------------------------------------------------------------------------------------------ | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------: |
95-
| [efficientnet_v2_s](https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s) | 1000 | [ImageNet1k_v1](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/common/rnexecutorch/models/classification/Constants.h) | Yes |
103+
| Model | Number of classes | Class list | Quantized |
104+
| ------------------------------------------------------------------------------------------------------ | ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------: |
105+
| [efficientnet_v2_s](https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s) | 1000 | [ImageNet1k_v1](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/classification.ts) | Yes |

docs/docs/04-typescript-api/02-computer-vision/ClassificationModule.md

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,24 +33,45 @@ All methods of `ClassificationModule` are explained in details here: [`Classific
3333

3434
## Loading the model
3535

36-
To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory with the following parameters:
37-
38-
- `namedSources` - Object containing:
39-
- `modelName` - Model name identifier.
40-
- `modelSource` - Location of the model binary.
41-
42-
- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1).
43-
44-
The factory returns a promise that resolves to a loaded `ClassificationModule` instance.
36+
Use the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory method. It accepts a model config object (e.g. `EFFICIENTNET_V2_S`) and an optional `onDownloadProgress` callback. It returns a promise resolving to a `ClassificationModule` instance.
4537

4638
For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.
4739

4840
## Running the model
4941

50-
To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object containing categories with their probabilities.
42+
To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object mapping label keys to their probabilities.
5143

5244
For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.
5345

46+
## Using a custom model
47+
48+
Use [`fromCustomModel`](../../06-api-reference/classes/ClassificationModule.md#fromcustommodel) to load your own exported model binary instead of a built-in preset.
49+
50+
```typescript
51+
import { ClassificationModule } from 'react-native-executorch';
52+
53+
const MyLabels = { CAT: 0, DOG: 1, BIRD: 2 } as const;
54+
55+
const classifier = await ClassificationModule.fromCustomModel(
56+
'https://example.com/custom_classifier.pte',
57+
{ labelMap: MyLabels },
58+
(progress) => console.log(progress)
59+
);
60+
61+
const result = await classifier.forward(imageUri);
62+
// result is typed as Record<'CAT' | 'DOG' | 'BIRD', number>
63+
```
64+
65+
### Required model contract
66+
67+
The `.pte` binary must expose a single `forward` method with the following interface:
68+
69+
**Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. H and W are read from the model's declared input shape at load time.
70+
71+
**Output:** one `float32` tensor of shape `[1, C]` containing raw logits — one value per class, in the same order as the entries in your `labelMap`. Softmax is applied by the native runtime.
72+
73+
Preprocessing (resize → normalize) is handled by the native runtime — your model only needs to produce the raw logits.
74+
5475
## Managing memory
5576

56-
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ClassificationModule.md#forward) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) after [`delete`](../../06-api-reference/classes/ClassificationModule.md#forward) unless you load the module again.
77+
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ClassificationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) after [`delete`](../../06-api-reference/classes/ClassificationModule.md#delete) unless you load the module again.

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.cpp

Lines changed: 28 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,34 @@
11
#include "Classification.h"
22

3-
#include <future>
4-
53
#include <rnexecutorch/Error.h>
64
#include <rnexecutorch/ErrorCodes.h>
5+
#include <rnexecutorch/Log.h>
6+
77
#include <rnexecutorch/data_processing/ImageProcessing.h>
88
#include <rnexecutorch/data_processing/Numerical.h>
9-
#include <rnexecutorch/models/classification/Constants.h>
109

1110
namespace rnexecutorch::models::classification {
1211

1312
Classification::Classification(const std::string &modelSource,
13+
std::vector<float> normMean,
14+
std::vector<float> normStd,
15+
std::vector<std::string> labelNames,
1416
std::shared_ptr<react::CallInvoker> callInvoker)
15-
: VisionModel(modelSource, callInvoker) {
17+
: VisionModel(modelSource, callInvoker),
18+
labelNames_(std::move(labelNames)) {
19+
if (normMean.size() == 3) {
20+
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
21+
} else if (!normMean.empty()) {
22+
log(LOG_LEVEL::Warn,
23+
"normMean must have 3 elements — ignoring provided value.");
24+
}
25+
if (normStd.size() == 3) {
26+
normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]);
27+
} else if (!normStd.empty()) {
28+
log(LOG_LEVEL::Warn,
29+
"normStd must have 3 elements — ignoring provided value.");
30+
}
31+
1632
auto inputShapes = getAllInputShapes();
1733
if (inputShapes.size() == 0) {
1834
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
@@ -37,7 +53,11 @@ Classification::runInference(cv::Mat image) {
3753
cv::Mat preprocessed = preprocess(image);
3854

3955
auto inputTensor =
40-
image_processing::getTensorFromMatrix(modelInputShape_, preprocessed);
56+
(normMean_ && normStd_)
57+
? image_processing::getTensorFromMatrix(
58+
modelInputShape_, preprocessed, *normMean_, *normStd_)
59+
: image_processing::getTensorFromMatrix(modelInputShape_,
60+
preprocessed);
4161

4262
auto forwardResult = BaseModel::forward(inputTensor);
4363
if (!forwardResult.ok()) {
@@ -78,13 +98,13 @@ Classification::postprocess(const Tensor &tensor) {
7898
static_cast<const float *>(tensor.const_data_ptr()), tensor.numel());
7999
std::vector<float> resultVec(resultData.begin(), resultData.end());
80100

81-
if (resultVec.size() != constants::kImagenet1kV1Labels.size()) {
101+
if (resultVec.size() != labelNames_.size()) {
82102
char errorMessage[100];
83103
std::snprintf(
84104
errorMessage, sizeof(errorMessage),
85105
"Unexpected classification output size, was expecting: %zu classes "
86106
"but got: %zu classes",
87-
constants::kImagenet1kV1Labels.size(), resultVec.size());
107+
labelNames_.size(), resultVec.size());
88108
throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput,
89109
errorMessage);
90110
}
@@ -93,7 +113,7 @@ Classification::postprocess(const Tensor &tensor) {
93113

94114
std::unordered_map<std::string_view, float> probs;
95115
for (std::size_t cl = 0; cl < resultVec.size(); ++cl) {
96-
probs[constants::kImagenet1kV1Labels[cl]] = resultVec[cl];
116+
probs[labelNames_[cl]] = resultVec[cl];
97117
}
98118

99119
return probs;

packages/react-native-executorch/common/rnexecutorch/models/classification/Classification.h

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <optional>
34
#include <unordered_map>
45

56
#include <executorch/extension/tensor/tensor_ptr.h>
@@ -16,7 +17,9 @@ using executorch::extension::TensorPtr;
1617

1718
class Classification : public VisionModel {
1819
public:
19-
Classification(const std::string &modelSource,
20+
Classification(const std::string &modelSource, std::vector<float> normMean,
21+
std::vector<float> normStd,
22+
std::vector<std::string> labelNames,
2023
std::shared_ptr<react::CallInvoker> callInvoker);
2124

2225
[[nodiscard("Registered non-void function")]] std::unordered_map<
@@ -35,9 +38,15 @@ class Classification : public VisionModel {
3538
std::unordered_map<std::string_view, float> runInference(cv::Mat image);
3639

3740
std::unordered_map<std::string_view, float> postprocess(const Tensor &tensor);
41+
42+
std::vector<std::string> labelNames_;
43+
std::optional<cv::Scalar> normMean_;
44+
std::optional<cv::Scalar> normStd_;
3845
};
3946
} // namespace models::classification
4047

4148
REGISTER_CONSTRUCTOR(models::classification::Classification, std::string,
49+
std::vector<float>, std::vector<float>,
50+
std::vector<std::string>,
4251
std::shared_ptr<react::CallInvoker>);
43-
} // namespace rnexecutorch
52+
} // namespace rnexecutorch

0 commit comments

Comments
 (0)