Skip to content

Commit e9ed5b0

Browse files
committed
fix: complete scalartype enum & move types to common.ts
1 parent 495639f commit e9ed5b0

2 files changed

Lines changed: 46 additions & 23 deletions

File tree

packages/react-native-executorch/src/modules/general/NewExecutorchModule.ts

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,6 @@
11
import { ResourceFetcher } from '../../utils/ResourceFetcher';
22
import { ResourceSource } from '../../types/common';
3-
4-
type TensorBuffer =
5-
| Float32Array
6-
| Float64Array
7-
| Int8Array
8-
| Int16Array
9-
| Int32Array
10-
| Uint8Array
11-
| Uint16Array
12-
| Uint32Array
13-
| BigInt64Array
14-
| BigUint64Array;
15-
16-
enum ScalarType {
17-
FLOAT16 = 1,
18-
}
19-
20-
interface TensorPtr {
21-
data: TensorBuffer;
22-
shape: number[];
23-
scalarType: ScalarType;
24-
}
3+
import { TensorPtr } from '../../types/common';
254

265
export class NewExecutorchModule {
276
nativeModule: any = null;
@@ -37,7 +16,7 @@ export class NewExecutorchModule {
3716
this.nativeModule = global.loadExecutorchModule(paths[0] || '');
3817
}
3918

40-
async forward(inputTensor: TensorPtr[]): Promise<void> {
19+
async forward(inputTensor: TensorPtr[]): Promise<ArrayBuffer[]> {
4120
return await this.nativeModule.forward(inputTensor);
4221
}
4322

packages/react-native-executorch/src/types/common.ts

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,47 @@ export type ETInput =
1515
| BigInt64Array
1616
| Float32Array
1717
| Float64Array;
18+
19+
export enum ScalarType {
20+
BYTE = 0,
21+
CHAR = 1,
22+
SHORT = 2,
23+
INT = 3,
24+
LONG = 4,
25+
HALF = 5,
26+
FLOAT = 6,
27+
DOUBLE = 7,
28+
BOOL = 11,
29+
QINT8 = 12,
30+
QUINT8 = 13,
31+
QINT32 = 14,
32+
QUINT4X2 = 16,
33+
QUINT2X4 = 17,
34+
BITS16 = 22,
35+
FLOAT8E5M2 = 23,
36+
FLOAT8E4M3FN = 24,
37+
FLOAT8E5M2FNUZ = 25,
38+
FLOAT8E4M3FNUZ = 26,
39+
UINT16 = 27,
40+
UINT32 = 28,
41+
UINT64 = 29,
42+
}
43+
44+
export type TensorBuffer =
45+
| ArrayBuffer
46+
| Float32Array
47+
| Float64Array
48+
| Int8Array
49+
| Int16Array
50+
| Int32Array
51+
| Uint8Array
52+
| Uint16Array
53+
| Uint32Array
54+
| BigInt64Array
55+
| BigUint64Array;
56+
57+
export interface TensorPtr {
58+
data: TensorBuffer;
59+
shape: number[];
60+
scalarType: ScalarType;
61+
}

0 commit comments

Comments
 (0)