diff --git a/docs/docs/executorch-bindings/useExecutorchModule.md b/docs/docs/executorch-bindings/useExecutorchModule.md index d90c9d8b57..7678434b06 100644 --- a/docs/docs/executorch-bindings/useExecutorchModule.md +++ b/docs/docs/executorch-bindings/useExecutorchModule.md @@ -30,19 +30,19 @@ The `modelSource` parameter expects a location string pointing to the model bina ### Returns -| Field | Type | Description | -| :----------------: | :--------------------------------------------------------: | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | -| `error` | string | null | Contains the error message if the model failed to load. | -| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | -| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | -| `loadMethod` | `(methodName: string) => Promise` | Loads resources specific to `methodName` into memory before execution. | -| `loadForward` | `() => Promise` | Loads resources specific to `forward` method into memory before execution. Uses `loadMethod` under the hood. | -| `forward` | `(input: ETInput, shape: number[]) => Promise` | Executes the model's forward pass, where `input` is a Javascript typed array and `shape` is an array of integers representing input Tensor shape. The output is a Tensor - raw result of inference. | -| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1. | +| Field | Type | Description | +| :----------------: | :----------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------: | +| `error` | string | null | Contains the error message if the model failed to load. | +| `isGenerating` | `boolean` | Indicates whether the model is currently processing an inference. | +| `isReady` | `boolean` | Indicates whether the model has successfully loaded and is ready for inference. | +| `loadMethod` | `(methodName: string) => Promise` | Loads resources specific to `methodName` into memory before execution. | +| `loadForward` | `() => Promise` | Loads resources specific to `forward` method into memory before execution. Uses `loadMethod` under the hood. | +| `forward` | `(input: Tensor[] \| Tensor): Promise` | Executes the model's forward pass, where `input` is a `Tensor` or array of tensors `Tensor[]`. Tensor is a compound type consisting of two elements: data and shape. Data is a JavaScript typed array, and shape is an array of integers representing the input tensor shape. | +| `downloadProgress` | `number` | Represents the download progress as a value between 0 and 1. | ## ETInput -The `ETInput` type defines the typed arrays that can be used as inputs in the `forward` method: +The `ETInput` type defines the typed arrays that can be used as data in `Tensor`: - Int8Array - Int32Array @@ -50,6 +50,10 @@ The `ETInput` type defines the typed arrays that can be used as inputs in the `f - Float32Array - Float64Array +## Tensor + +The `Tensor` is a complex type that aggregates both data and shape of the tensor passed to the `forward` method. + ## Errors All functions provided by the `useExecutorchModule` hook are asynchronous and may throw an error. The `ETError` enum includes errors [defined by the ExecuTorch team](https://github.com/pytorch/executorch/blob/main/runtime/core/error.h) and additional errors specified by our library. diff --git a/docs/docs/typescript-api/ExecutorchModule.md b/docs/docs/typescript-api/ExecutorchModule.md index 484515e6d9..0875f31638 100644 --- a/docs/docs/typescript-api/ExecutorchModule.md +++ b/docs/docs/typescript-api/ExecutorchModule.md @@ -13,38 +13,46 @@ import { } from 'react-native-executorch'; // Creating the input array -const shape = [1, 3, 640, 640]; -const input = new Float32Array(1 * 3 * 640 * 640); +const inputShape = [1, 3, 640, 640]; +const inputData = new Float32Array(1 * 3 * 640 * 640); // Loading the model await ExecutorchModule.load(STYLE_TRANSFER_CANDY); // Running the model -const output = await ExecutorchModule.forward(input, shape); +const output = await ExecutorchModule.forward({ + data: inputData, + shape: inputShape, +}); ``` ### Methods -| Method | Type | Description | -| -------------------- | ------------------------------------------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `load` | `(modelSource: ResourceSource): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. | -| `forward` | `(input: ETInput, shape: number[]): Promise` | Executes the model's forward pass, where `input` is a JavaScript typed array and `shape` is an array of integers representing input Tensor shape. The output is a Tensor - raw result of inference. | -| `loadMethod` | `(methodName: string): Promise` | Loads resources specific to `methodName` into memory before execution. | -| `loadForward` | `(): Promise` | Loads resources specific to `forward` method into memory before execution. Uses `loadMethod` under the hood. | -| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. | +| Method | Type | Description | +| -------------------- | ----------------------------------------------------- | ----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `load` | `(modelSource: ResourceSource): Promise` | Loads the model, where `modelSource` is a string that specifies the location of the model binary. | +| `forward` | `(input: Tensor[] \| Tensor): Promise` | Executes the model's forward pass, where `input` is a `Tensor` or array of tensors `Tensor[]`. Tensor is a compound type consisting of two elements: data and shape. Data is a JavaScript typed array, and shape is an array of integers representing the input tensor shape. | +| `loadMethod` | `(methodName: string): Promise` | Loads resources specific to `methodName` into memory before execution. | +| `loadForward` | `(): Promise` | Loads resources specific to `forward` method into memory before execution. Uses `loadMethod` under the hood. | +| `onDownloadProgress` | `(callback: (downloadProgress: number) => void): any` | Subscribe to the download progress event. |
Type definitions ```typescript -type ResourceSource = string | number | object; +export type ResourceSource = string | number | object; -export type ETInput = +type ETInput = | Int8Array | Int32Array | BigInt64Array | Float32Array | Float64Array; + +export interface Tensor { + data: ETInput; + shape: number[]; +} ```
@@ -55,7 +63,7 @@ To load the model, use the `load` method. It accepts the `modelSource` which is ## Running the model -To run the model use the `forward` method. It accepts two arguments: `input` and `shape`. The `input` is a JavaScript typed array, and `shape` is an array of integers representing the input tensor shape. There's no need to explicitly define the input type, as it will automatically be inferred from the typed array you pass to forward method. Outputs from the model, such as classification probabilities, are returned in raw format. +To run the model use the `forward` method. It accepts one argument: `input`. The `input` is a `Tensor` or array of tensors `Tensor[]`. Tensor is a compound type consisting of two elements: data and shape. Data is a JavaScript typed array, and `shape` is an array of integers representing the input tensor shape. There's no need to explicitly define the input type, as it will automatically be inferred from the typed array you pass to forward method. Outputs from the model, such as classification probabilities, are returned in raw format. ## Loading methods diff --git a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts index c356c17be5..9b838896c8 100644 --- a/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts +++ b/packages/react-native-executorch/src/modules/general/ExecutorchModule.ts @@ -1,7 +1,7 @@ import { ETError, getError } from '../../Error'; import { ETModuleNativeModule } from '../../native/RnExecutorchModules'; import { ResourceSource } from '../../types/common'; -import { ETInput } from '../../types/common'; +import { Tensor } from '../../types/common'; import { getTypeIdentifier } from '../../types/common'; import { BaseModule } from '../BaseModule'; @@ -12,21 +12,28 @@ export class ExecutorchModule extends BaseModule { return await super.load(modelSource); } - static override async forward(input: ETInput[] | ETInput, shape: number[][]) { + static override async forward(input: Tensor[] | Tensor) { if (!Array.isArray(input)) { input = [input]; } let inputTypeIdentifiers = []; + let shape = []; let modelInputs = []; for (let idx = 0; idx < input.length; idx++) { - let currentInputTypeIdentifier = getTypeIdentifier(input[idx] as ETInput); + const currentInput = input[idx]; + if (!currentInput || !currentInput.data) { + throw new Error('Input tensor is undefined.'); + } + + let currentInputTypeIdentifier = getTypeIdentifier(currentInput.data); if (currentInputTypeIdentifier === -1) { throw new Error(getError(ETError.InvalidArgument)); } + shape.push(currentInput.shape); inputTypeIdentifiers.push(currentInputTypeIdentifier); - modelInputs.push([...(input[idx] as unknown as number[])]); + modelInputs.push([...(currentInput as unknown as number[])]); } try { diff --git a/packages/react-native-executorch/src/types/common.ts b/packages/react-native-executorch/src/types/common.ts index 688ac3869e..54e7fd95b7 100644 --- a/packages/react-native-executorch/src/types/common.ts +++ b/packages/react-native-executorch/src/types/common.ts @@ -15,3 +15,8 @@ export type ETInput = | BigInt64Array | Float32Array | Float64Array; + +export interface Tensor { + data: ETInput; + shape: number[]; +}