Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions .cspell-wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,45 @@ metaprogramming
ktlint
lefthook
espeak
KOMODO
DUNGENESS
SHIH
RIDGEBACK
BLUETICK
REDBONE
IBIZAN
OTTERHOUND
BULLTERRIER
BEDLINGTON
SEALYHAM
DANDIE
DINMONT
VIZSLA
CLUMBER
MALINOIS
KOMONDOR
BOUVIER
FLANDRES
APPENZELLER
ENTLEBUCHER
LEONBERG
BRABANCON
LYCAENID
PATAS
INDRI
BARRACOUTA
ABAYA
BOTTLECAP
CHAINLINK
GASMASK
GOLFCART
HOOPSKIRT
LUMBERMILL
PADDLEWHEEL
PICKELHAUBE
CARBONARA
GYROMITRA
BOLETE
NCHW
həlˈO
wˈɜɹld
Expand Down
22 changes: 16 additions & 6 deletions docs/docs/03-hooks/02-computer-vision/useClassification.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ try {

`useClassification` takes [`ClassificationProps`](../../06-api-reference/interfaces/ClassificationProps.md) that consists of:

- `model` containing [`modelSource`](../../06-api-reference/interfaces/ClassificationProps.md#modelsource).
- `model` - An object containing:
- `modelName` - The name of a built-in model. See [`ClassificationModelSources`](../../06-api-reference/interfaces/ClassificationProps.md) for the list of supported models.
- `modelSource` - The location of the model binary (a URL or a bundled resource).
- An optional flag [`preventLoad`](../../06-api-reference/interfaces/ClassificationProps.md#preventload) which prevents auto-loading of the model.

The hook is generic over the model config — TypeScript automatically infers the correct label type based on the `modelName` you provide. No explicit generic parameter is needed.

You need more details? Check the following resources:

- For detailed information about `useClassification` arguments check this section: [`useClassification` arguments](../../06-api-reference/functions/useClassification.md#parameters).
Expand All @@ -48,11 +52,17 @@ You need more details? Check the following resources:

### Returns

`useClassification` returns an object called `ClassificationType` containing bunch of functions to interact with Classification models. To get more details please read: [`ClassificationType` API Reference](../../06-api-reference/interfaces/ClassificationType.md).
`useClassification` returns a [`ClassificationType`](../../06-api-reference/interfaces/ClassificationType.md) object containing:

- `isReady` - Whether the model is loaded and ready to process images.
- `isGenerating` - Whether the model is currently processing an image.
- `error` - An error object if the model failed to load or encountered a runtime error.
- `downloadProgress` - A value between 0 and 1 representing the download progress of the model binary.
- `forward` - A function to run inference on an image.

## Running the model

To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object containing categories with their probabilities.
To run the model, use the [`forward`](../../06-api-reference/interfaces/ClassificationType.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The function returns a promise resolving to an object mapping label keys to their probabilities.

:::info
Images from external sources are stored in your application's temporary directory.
Expand Down Expand Up @@ -90,6 +100,6 @@ function App() {

## Supported models

| Model | Number of classes | Class list | Quantized |
| ------------------------------------------------------------------------------------------------------ | ----------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------: |
| [efficientnet_v2_s](https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s) | 1000 | [ImageNet1k_v1](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/common/rnexecutorch/models/classification/Constants.h) | Yes |
| Model | Number of classes | Class list | Quantized |
| ------------------------------------------------------------------------------------------------------ | ----------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------: |
| [efficientnet_v2_s](https://huggingface.co/software-mansion/react-native-executorch-efficientnet-v2-s) | 1000 | [ImageNet1k_v1](https://github.com/software-mansion/react-native-executorch/blob/main/packages/react-native-executorch/src/constants/classification.ts) | Yes |
Original file line number Diff line number Diff line change
Expand Up @@ -33,24 +33,45 @@ All methods of `ClassificationModule` are explained in details here: [`Classific

## Loading the model

To create a ready-to-use instance, call the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory with the following parameters:

- `namedSources` - Object containing:
- `modelName` - Model name identifier.
- `modelSource` - Location of the model binary.

- `onDownloadProgress` - Optional callback to track download progress (value between 0 and 1).

The factory returns a promise that resolves to a loaded `ClassificationModule` instance.
Use the static [`fromModelName`](../../06-api-reference/classes/ClassificationModule.md#frommodelname) factory method. It accepts a model config object (e.g. `EFFICIENTNET_V2_S`) and an optional `onDownloadProgress` callback. It returns a promise resolving to a `ClassificationModule` instance.

For more information on loading resources, take a look at [loading models](../../01-fundamentals/02-loading-models.md) page.

## Running the model

To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object containing categories with their probabilities.
To run the model, use the [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) method. It accepts one argument — the image to classify. The image can be a remote URL, a local file URI, a base64-encoded image (whole URI or only raw base64), or a [`PixelData`](../../06-api-reference/interfaces/PixelData.md) object (raw RGB pixel buffer). The method returns a promise resolving to an object mapping label keys to their probabilities.

For real-time frame processing, use [`runOnFrame`](../../03-hooks/02-computer-vision/visioncamera-integration.md) instead.

## Using a custom model

Use [`fromCustomModel`](../../06-api-reference/classes/ClassificationModule.md#fromcustommodel) to load your own exported model binary instead of a built-in preset.

```typescript
import { ClassificationModule } from 'react-native-executorch';

const MyLabels = { CAT: 0, DOG: 1, BIRD: 2 } as const;

const classifier = await ClassificationModule.fromCustomModel(
'https://example.com/custom_classifier.pte',
{ labelMap: MyLabels },
(progress) => console.log(progress)
);

const result = await classifier.forward(imageUri);
// result is typed as Record<'CAT' | 'DOG' | 'BIRD', number>
```

### Required model contract

The `.pte` binary must expose a single `forward` method with the following interface:

**Input:** one `float32` tensor of shape `[1, 3, H, W]` — a single RGB image, values in `[0, 1]` after optional per-channel normalization `(pixel − mean) / std`. H and W are read from the model's declared input shape at load time.

**Output:** one `float32` tensor of shape `[1, C]` containing raw logits — one value per class, in the same order as the entries in your `labelMap`. Softmax is applied by the native runtime.

Preprocessing (resize → normalize) is handled by the native runtime — your model only needs to produce the raw logits.

## Managing memory

The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ClassificationModule.md#forward) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) after [`delete`](../../06-api-reference/classes/ClassificationModule.md#forward) unless you load the module again.
The module is a regular JavaScript object, and as such its lifespan will be managed by the garbage collector. In most cases this should be enough, and you should not worry about freeing the memory of the module yourself, but in some cases you may want to release the memory occupied by the module before the garbage collector steps in. In this case use the method [`delete`](../../06-api-reference/classes/ClassificationModule.md#delete) on the module object you will no longer use, and want to remove from the memory. Note that you cannot use [`forward`](../../06-api-reference/classes/ClassificationModule.md#forward) after [`delete`](../../06-api-reference/classes/ClassificationModule.md#delete) unless you load the module again.
Original file line number Diff line number Diff line change
@@ -1,18 +1,34 @@
#include "Classification.h"

#include <future>

#include <rnexecutorch/Error.h>
#include <rnexecutorch/ErrorCodes.h>
#include <rnexecutorch/Log.h>

#include <rnexecutorch/data_processing/ImageProcessing.h>
#include <rnexecutorch/data_processing/Numerical.h>
#include <rnexecutorch/models/classification/Constants.h>

namespace rnexecutorch::models::classification {

Classification::Classification(const std::string &modelSource,
std::vector<float> normMean,
std::vector<float> normStd,
std::vector<std::string> labelNames,
std::shared_ptr<react::CallInvoker> callInvoker)
: VisionModel(modelSource, callInvoker) {
: VisionModel(modelSource, callInvoker),
labelNames_(std::move(labelNames)) {
if (normMean.size() == 3) {
normMean_ = cv::Scalar(normMean[0], normMean[1], normMean[2]);
} else if (!normMean.empty()) {
log(LOG_LEVEL::Warn,
"normMean must have 3 elements — ignoring provided value.");
}
if (normStd.size() == 3) {
normStd_ = cv::Scalar(normStd[0], normStd[1], normStd[2]);
} else if (!normStd.empty()) {
log(LOG_LEVEL::Warn,
"normStd must have 3 elements — ignoring provided value.");
}

Comment thread
NorbertKlockiewicz marked this conversation as resolved.
auto inputShapes = getAllInputShapes();
if (inputShapes.size() == 0) {
throw RnExecutorchError(RnExecutorchErrorCode::UnexpectedNumInputs,
Expand All @@ -37,7 +53,11 @@ Classification::runInference(cv::Mat image) {
cv::Mat preprocessed = preprocess(image);

auto inputTensor =
image_processing::getTensorFromMatrix(modelInputShape_, preprocessed);
(normMean_ && normStd_)
? image_processing::getTensorFromMatrix(
modelInputShape_, preprocessed, *normMean_, *normStd_)
: image_processing::getTensorFromMatrix(modelInputShape_,
preprocessed);

auto forwardResult = BaseModel::forward(inputTensor);
if (!forwardResult.ok()) {
Expand Down Expand Up @@ -78,13 +98,13 @@ Classification::postprocess(const Tensor &tensor) {
static_cast<const float *>(tensor.const_data_ptr()), tensor.numel());
std::vector<float> resultVec(resultData.begin(), resultData.end());

if (resultVec.size() != constants::kImagenet1kV1Labels.size()) {
if (resultVec.size() != labelNames_.size()) {
char errorMessage[100];
std::snprintf(
errorMessage, sizeof(errorMessage),
"Unexpected classification output size, was expecting: %zu classes "
"but got: %zu classes",
constants::kImagenet1kV1Labels.size(), resultVec.size());
labelNames_.size(), resultVec.size());
throw RnExecutorchError(RnExecutorchErrorCode::InvalidModelOutput,
errorMessage);
}
Expand All @@ -93,7 +113,7 @@ Classification::postprocess(const Tensor &tensor) {

std::unordered_map<std::string_view, float> probs;
for (std::size_t cl = 0; cl < resultVec.size(); ++cl) {
probs[constants::kImagenet1kV1Labels[cl]] = resultVec[cl];
probs[labelNames_[cl]] = resultVec[cl];
}

return probs;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#pragma once

#include <optional>
#include <unordered_map>

#include <executorch/extension/tensor/tensor_ptr.h>
Expand All @@ -16,7 +17,9 @@ using executorch::extension::TensorPtr;

class Classification : public VisionModel {
public:
Classification(const std::string &modelSource,
Classification(const std::string &modelSource, std::vector<float> normMean,
std::vector<float> normStd,
std::vector<std::string> labelNames,
std::shared_ptr<react::CallInvoker> callInvoker);

[[nodiscard("Registered non-void function")]] std::unordered_map<
Expand All @@ -35,9 +38,15 @@ class Classification : public VisionModel {
std::unordered_map<std::string_view, float> runInference(cv::Mat image);

std::unordered_map<std::string_view, float> postprocess(const Tensor &tensor);

std::vector<std::string> labelNames_;
std::optional<cv::Scalar> normMean_;
std::optional<cv::Scalar> normStd_;
};
} // namespace models::classification

REGISTER_CONSTRUCTOR(models::classification::Classification, std::string,
std::vector<float>, std::vector<float>,
std::vector<std::string>,
std::shared_ptr<react::CallInvoker>);
} // namespace rnexecutorch
} // namespace rnexecutorch
Loading
Loading