Skip to content

Commit 19ea798

Browse files
committed
refactor: unify members of in/out tensors, adjust conversions & usage
1 parent 1d7535b commit 19ea798

5 files changed

Lines changed: 28 additions & 28 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorView.h renamed to packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ namespace rnexecutorch {
44

55
using executorch::aten::ScalarType;
66

7-
struct JSTensorView {
7+
struct JSTensorViewIn {
88
void *dataPtr;
9+
std::vector<int32_t> sizes;
910
ScalarType scalarType;
10-
std::vector<int32_t> shape;
1111
};
1212
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ namespace rnexecutorch {
1010
using executorch::runtime::etensor::ScalarType;
1111

1212
struct JSTensorViewOut {
13+
std::shared_ptr<OwningArrayBuffer> dataPtr;
1314
std::vector<int32_t> sizes;
1415
ScalarType scalarType;
15-
std::shared_ptr<OwningArrayBuffer> data;
1616

1717
JSTensorViewOut(std::vector<int32_t> sizes, ScalarType scalarType,
18-
std::shared_ptr<OwningArrayBuffer> data)
19-
: sizes(std::move(sizes)), scalarType(scalarType), data(std::move(data)) {
20-
}
18+
std::shared_ptr<OwningArrayBuffer> dataPtr)
19+
: sizes(std::move(sizes)), scalarType(scalarType),
20+
dataPtr(std::move(dataPtr)) {}
2121
};
2222
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
88
#include <jsi/jsi.h>
9-
#include <rnexecutorch/host_objects/JSTensorView.h>
9+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
1010
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1111
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
1212

@@ -43,27 +43,27 @@ inline std::string getValue<std::string>(const jsi::Value &val,
4343
}
4444

4545
template <>
46-
inline JSTensorView getValue<JSTensorView>(const jsi::Value &val,
47-
jsi::Runtime &runtime) {
46+
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
47+
jsi::Runtime &runtime) {
4848
jsi::Object obj = val.asObject(runtime);
49-
JSTensorView tensorView;
49+
JSTensorViewIn tensorView;
5050

5151
int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber();
5252
tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt);
5353

54-
jsi::Value shapeValue = obj.getProperty(runtime, "shape");
54+
jsi::Value shapeValue = obj.getProperty(runtime, "sizes");
5555
jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime);
56-
size_t shapeDims = shapeArray.size(runtime);
57-
tensorView.shape.reserve(shapeDims);
56+
size_t numShapeDims = shapeArray.size(runtime);
57+
tensorView.sizes.reserve(numShapeDims);
5858

59-
for (size_t i = 0; i < shapeDims; ++i) {
59+
for (size_t i = 0; i < numShapeDims; ++i) {
6060
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
61-
tensorView.shape.push_back(static_cast<int32_t>(dim));
61+
tensorView.sizes.push_back(static_cast<int32_t>(dim));
6262
}
6363

6464
// On JS side, TensorPtr objects hold a 'data' property which should be either
6565
// an ArrayBuffer or TypedArray
66-
jsi::Value dataValue = obj.getProperty(runtime, "data");
66+
jsi::Value dataValue = obj.getProperty(runtime, "dataPtr");
6767
jsi::Object dataObj = dataValue.asObject(runtime);
6868

6969
// Check if it's an ArrayBuffer or TypedArray
@@ -99,17 +99,17 @@ inline JSTensorView getValue<JSTensorView>(const jsi::Value &val,
9999
}
100100

101101
template <>
102-
inline std::vector<JSTensorView>
103-
getValue<std::vector<JSTensorView>>(const jsi::Value &val,
104-
jsi::Runtime &runtime) {
102+
inline std::vector<JSTensorViewIn>
103+
getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
104+
jsi::Runtime &runtime) {
105105
jsi::Array array = val.asObject(runtime).asArray(runtime);
106106
size_t length = array.size(runtime);
107-
std::vector<JSTensorView> result;
107+
std::vector<JSTensorViewIn> result;
108108
result.reserve(length);
109109

110110
for (size_t i = 0; i < length; ++i) {
111111
jsi::Value element = array.getValueAtIndex(runtime, i);
112-
result.push_back(getValue<JSTensorView>(element, runtime));
112+
result.push_back(getValue<JSTensorViewIn>(element, runtime));
113113
}
114114
return result;
115115
}
@@ -190,14 +190,14 @@ getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
190190
for (size_t i = 0; i < vec.size(); i++) {
191191
jsi::Object tensorObj(runtime);
192192

193-
tensorObj.setProperty(runtime, "shape",
193+
tensorObj.setProperty(runtime, "sizes",
194194
getJsiValue(vec[i]->sizes, runtime));
195195

196196
tensorObj.setProperty(runtime, "scalarType",
197197
jsi::Value(static_cast<int>(vec[i]->scalarType)));
198198

199-
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->data);
200-
tensorObj.setProperty(runtime, "data", arrayBuffer);
199+
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr);
200+
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);
201201

202202
array.setValueAtIndex(runtime, i, tensorObj);
203203
}

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
8686
}
8787

8888
std::vector<std::shared_ptr<JSTensorViewOut>>
89-
BaseModel::forwardJS(const std::vector<JSTensorView> tensorViewVec) {
89+
BaseModel::forwardJS(const std::vector<JSTensorViewIn> tensorViewVec) {
9090
if (!module) {
9191
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
9292
}
@@ -102,7 +102,7 @@ BaseModel::forwardJS(const std::vector<JSTensorView> tensorViewVec) {
102102
for (size_t i = 0; i < tensorViewVec.size(); i++) {
103103
const auto &currTensorView = tensorViewVec[i];
104104
auto tensorPtr =
105-
make_tensor_ptr(currTensorView.shape, currTensorView.dataPtr,
105+
make_tensor_ptr(currTensorView.sizes, currTensorView.dataPtr,
106106
currTensorView.scalarType);
107107
tensorPtrs.emplace_back(tensorPtr);
108108
evalues.emplace_back(*tensorPtr); // Dereference TensorPtr to get Tensor,

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <ReactCommon/CallInvoker.h>
77
#include <executorch/extension/module/module.h>
88
#include <jsi/jsi.h>
9-
#include <rnexecutorch/host_objects/JSTensorView.h>
9+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
1010
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1111
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
1212

@@ -25,7 +25,7 @@ class BaseModel {
2525
getAllInputShapes(std::string methodName = "forward");
2626

2727
std::vector<std::shared_ptr<JSTensorViewOut>>
28-
forwardJS(std::vector<JSTensorView> tensorViewVec);
28+
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
2929

3030
protected:
3131
Result<std::vector<EValue>> forward(const EValue &input_value);

0 commit comments

Comments
 (0)