Skip to content

Commit 51549bb

Browse files
committed
wip: working ish brute force
1 parent 9fa5c9f commit 51549bb

5 files changed

Lines changed: 154 additions & 19 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/bindings/ExecutorchModule.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
#include "ExecutorchModule.h"
22

3+
#include <executorch/extension/module/module.h>
34
#include <rnexecutorch/Log.h>
45
#include <sstream>
56

67
namespace rnexecutorch {
78

89
using ::executorch::extension::Module;
10+
using namespace executorch::aten;
11+
using namespace executorch::extension;
912
using ::executorch::runtime::Error;
1013
using namespace facebook;
1114

@@ -22,6 +25,19 @@ ExecutorchModule::ExecutorchModule(
2225
}
2326
}
2427

28+
int ExecutorchModule::forward(std::vector<JsiTensorView> tensorViewVec) {
29+
auto currTensor = tensorViewVec[0];
30+
auto myTensor =
31+
make_tensor_ptr(currTensor.shape, currTensor.dataPtr, ScalarType::Float);
32+
auto result = module->forward(myTensor);
33+
if (!result.ok()) {
34+
std::string errorStr = std::to_string(static_cast<int>(result.error()));
35+
log(LOG_LEVEL::Debug, errorStr.c_str());
36+
throw std::runtime_error("Failed to run forward! Error: " + errorStr);
37+
}
38+
return 1;
39+
}
40+
2541
std::vector<int32_t> ExecutorchModule::getInputShape(std::string method_name,
2642
int index) {
2743
auto method_meta = module->method_meta(method_name);

packages/react-native-executorch/common/rnexecutorch/bindings/ExecutorchModule.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
#include <ReactCommon/CallInvoker.h>
66
#include <executorch/extension/module/module.h>
77
#include <jsi/jsi.h>
8+
#include <rnexecutorch/utils/JsiTensorView.h>
89

910
namespace rnexecutorch {
1011

@@ -13,6 +14,7 @@ class ExecutorchModule {
1314
ExecutorchModule(const std::string &modelSource,
1415
std::shared_ptr<facebook::react::CallInvoker> callInvoker);
1516
std::vector<int32_t> getInputShape(std::string method_name, int index);
17+
int forward(std::vector<JsiTensorView> tensorViewVec);
1618

1719
protected:
1820
std::unique_ptr<executorch::extension::Module> module;

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,9 @@
33
#include <set>
44
#include <type_traits>
55

6+
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
67
#include <jsi/jsi.h>
8+
#include <rnexecutorch/utils/JsiTensorView.h>
79

810
namespace rnexecutorch::jsiconversion {
911

@@ -34,6 +36,93 @@ inline std::string getValue<std::string>(const jsi::Value &val,
3436
return val.getString(runtime).utf8(runtime);
3537
}
3638

39+
template <>
40+
inline JsiTensorView getValue<JsiTensorView>(const jsi::Value &val,
41+
jsi::Runtime &runtime) {
42+
jsi::Object obj = val.asObject(runtime);
43+
JsiTensorView tensorView;
44+
45+
int scalarTypeInt =
46+
getValue<int>(obj.getProperty(runtime, "scalarType"), runtime);
47+
tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt);
48+
49+
jsi::Value shapeValue = obj.getProperty(runtime, "shape");
50+
jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime);
51+
size_t shapeDims = shapeArray.size(runtime);
52+
tensorView.shape.reserve(shapeDims);
53+
54+
for (size_t i = 0; i < shapeDims; ++i) {
55+
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
56+
tensorView.shape.push_back(static_cast<int32_t>(dim));
57+
}
58+
59+
// On JS side, TensorPtr objects hold a 'data' property which should be either
60+
// an ArrayBuffer or TypedArray
61+
jsi::Value dataValue = obj.getProperty(runtime, "data");
62+
if (!dataValue.isObject()) {
63+
throw jsi::JSError(runtime, "Data must be a typed array or ArrayBuffer");
64+
}
65+
66+
jsi::Object dataObj = dataValue.asObject(runtime);
67+
68+
// Check if it's an ArrayBuffer or TypedArray
69+
if (dataObj.isArrayBuffer(runtime)) {
70+
jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime);
71+
tensorView.dataPtr = arrayBuffer.data(runtime);
72+
73+
// Get the array size in bytes
74+
size_t arrayBytes = arrayBuffer.size(runtime);
75+
size_t elementBytes =
76+
executorch::runtime::elementSize(tensorView.scalarType);
77+
tensorView.numel = arrayBytes / elementBytes;
78+
79+
} else {
80+
// Handle typed arrays (Float32Array, Int32Array, etc.)
81+
if (dataObj.hasProperty(runtime, "buffer") &&
82+
dataObj.hasProperty(runtime, "byteOffset") &&
83+
dataObj.hasProperty(runtime, "byteLength") &&
84+
dataObj.hasProperty(runtime, "length")) {
85+
86+
tensorView.numel =
87+
getValue<int>(dataObj.getProperty(runtime, "length"), runtime);
88+
89+
jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer");
90+
if (!bufferValue.isObject() ||
91+
!bufferValue.asObject(runtime).isArrayBuffer(runtime)) {
92+
throw jsi::JSError(runtime,
93+
"TypedArray buffer property must be an ArrayBuffer");
94+
}
95+
96+
jsi::ArrayBuffer arrayBuffer =
97+
bufferValue.asObject(runtime).getArrayBuffer(runtime);
98+
size_t byteOffset =
99+
getValue<int>(dataObj.getProperty(runtime, "byteOffset"), runtime);
100+
101+
tensorView.dataPtr =
102+
static_cast<uint8_t *>(arrayBuffer.data(runtime)) + byteOffset;
103+
} else {
104+
throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray");
105+
}
106+
}
107+
return std::move(tensorView);
108+
}
109+
110+
template <>
111+
inline std::vector<JsiTensorView>
112+
getValue<std::vector<JsiTensorView>>(const jsi::Value &val,
113+
jsi::Runtime &runtime) {
114+
jsi::Array array = val.asObject(runtime).asArray(runtime);
115+
size_t length = array.size(runtime);
116+
std::vector<JsiTensorView> result;
117+
result.reserve(length);
118+
119+
for (size_t i = 0; i < length; ++i) {
120+
jsi::Value element = array.getValueAtIndex(runtime, i);
121+
result.push_back(getValue<JsiTensorView>(element, runtime));
122+
}
123+
return result;
124+
}
125+
37126
template <>
38127
inline std::vector<std::string>
39128
getValue<std::vector<std::string>>(const jsi::Value &val,
@@ -88,6 +177,10 @@ inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
88177
return jsi::Value(runtime, array);
89178
}
90179

180+
inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) {
181+
return jsi::Value(runtime, val);
182+
}
183+
91184
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
92185
return jsi::String::createFromAscii(runtime, str);
93186
}

packages/react-native-executorch/src/index.tsx

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,16 @@ declare global {
99
var loadExecutorchModule: (source: string) => any;
1010
}
1111
// eslint-disable no-var
12-
13-
if (global.loadStyleTransfer == null) {
12+
if (
13+
global.loadStyleTransfer == null ||
14+
global.loadImageSegmentation == null ||
15+
global.loadExecutorchModule == null
16+
) {
1417
if (!ETInstallerNativeModule) {
1518
throw new Error(
1619
`Failed to install react-native-executorch: The native module could not be found.`
1720
);
1821
}
19-
2022
ETInstallerNativeModule.install();
2123
}
2224

@@ -34,6 +36,7 @@ export * from './hooks/natural_language_processing/useTextEmbeddings';
3436
export * from './hooks/natural_language_processing/useTokenizer';
3537

3638
export * from './hooks/general/useExecutorchModule';
39+
export * from './modules/general/NewExecutorchModule';
3740

3841
// modules
3942
export * from './modules/computer_vision/ClassificationModule';
Lines changed: 37 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,47 @@
1-
import { ETError, getError } from '../../Error';
2-
import { ETModuleNativeModule } from '../../native/RnExecutorchModules';
1+
import { ResourceFetcher } from '../../utils/ResourceFetcher';
32
import { ResourceSource } from '../../types/common';
4-
import { ETInput } from '../../types/common';
5-
import { getTypeIdentifier } from '../../types/common';
6-
import { BaseModule } from '../BaseModule';
73

8-
export class NewExecutorchModule {
9-
private nativeModule;
4+
type TensorBuffer =
5+
| Float32Array
6+
| Float64Array
7+
| Int8Array
8+
| Int16Array
9+
| Int32Array
10+
| Uint8Array
11+
| Uint16Array
12+
| Uint32Array
13+
| BigInt64Array
14+
| BigUint64Array;
1015

11-
constructor(modelSource: ResourceSource) {
12-
this.nativeModule = global.loadExecutorchModule(modelSource as string);
13-
}
16+
enum ScalarType {
17+
FLOAT16 = 1,
18+
}
19+
20+
interface TensorPtr {
21+
data: TensorBuffer;
22+
shape: number[];
23+
scalarType: ScalarType;
24+
}
25+
26+
export class NewExecutorchModule {
27+
nativeModule: any = null;
1428

15-
static async forward() {
16-
throw Error('Not yet implemented!');
29+
async load(
30+
modelSource: ResourceSource,
31+
onDownloadProgressCallback: (_: number) => void = () => {}
32+
): Promise<void> {
33+
const paths = await ResourceFetcher.fetchMultipleResources(
34+
onDownloadProgressCallback,
35+
modelSource
36+
);
37+
this.nativeModule = global.loadExecutorchModule(paths[0] || '');
1738
}
1839

19-
static async getInputShape(methodName: string, index: number) {
20-
this.nativeModule.getInputShape();
40+
async forward(inputTensor: TensorPtr[]): Promise<void> {
41+
return await this.nativeModule.forward(inputTensor);
2142
}
2243

23-
static async loadForward() {
24-
await this.loadMethod('forward');
44+
async getInputShape(methodName: string, index: number): Promise<number[]> {
45+
return this.nativeModule.getInputShape(methodName, index);
2546
}
2647
}

0 commit comments

Comments
 (0)