Skip to content

Commit 5560f47

Browse files
committed
refactor: unify members of in/out tensors, adjust conversions & usage
1 parent 6f9ce02 commit 5560f47

5 files changed

Lines changed: 28 additions & 28 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorView.h renamed to packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewIn.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@ namespace rnexecutorch {
44

55
using executorch::aten::ScalarType;
66

7-
struct JSTensorView {
7+
struct JSTensorViewIn {
88
void *dataPtr;
9+
std::vector<int32_t> sizes;
910
ScalarType scalarType;
10-
std::vector<int32_t> shape;
1111
};
1212
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JSTensorViewOut.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,13 @@ namespace rnexecutorch {
1010
using executorch::runtime::etensor::ScalarType;
1111

1212
struct JSTensorViewOut {
13+
std::shared_ptr<OwningArrayBuffer> dataPtr;
1314
std::vector<int32_t> sizes;
1415
ScalarType scalarType;
15-
std::shared_ptr<OwningArrayBuffer> data;
1616

1717
JSTensorViewOut(std::vector<int32_t> sizes, ScalarType scalarType,
18-
std::shared_ptr<OwningArrayBuffer> data)
19-
: sizes(std::move(sizes)), scalarType(scalarType), data(std::move(data)) {
20-
}
18+
std::shared_ptr<OwningArrayBuffer> dataPtr)
19+
: sizes(std::move(sizes)), scalarType(scalarType),
20+
dataPtr(std::move(dataPtr)) {}
2121
};
2222
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
88
#include <jsi/jsi.h>
9-
#include <rnexecutorch/host_objects/JSTensorView.h>
9+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
1010
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1111
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
1212

@@ -40,27 +40,27 @@ inline std::string getValue<std::string>(const jsi::Value &val,
4040
}
4141

4242
template <>
43-
inline JSTensorView getValue<JSTensorView>(const jsi::Value &val,
44-
jsi::Runtime &runtime) {
43+
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
44+
jsi::Runtime &runtime) {
4545
jsi::Object obj = val.asObject(runtime);
46-
JSTensorView tensorView;
46+
JSTensorViewIn tensorView;
4747

4848
int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber();
4949
tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt);
5050

51-
jsi::Value shapeValue = obj.getProperty(runtime, "shape");
51+
jsi::Value shapeValue = obj.getProperty(runtime, "sizes");
5252
jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime);
53-
size_t shapeDims = shapeArray.size(runtime);
54-
tensorView.shape.reserve(shapeDims);
53+
size_t numShapeDims = shapeArray.size(runtime);
54+
tensorView.sizes.reserve(numShapeDims);
5555

56-
for (size_t i = 0; i < shapeDims; ++i) {
56+
for (size_t i = 0; i < numShapeDims; ++i) {
5757
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
58-
tensorView.shape.push_back(static_cast<int32_t>(dim));
58+
tensorView.sizes.push_back(static_cast<int32_t>(dim));
5959
}
6060

6161
// On JS side, TensorPtr objects hold a 'data' property which should be either
6262
// an ArrayBuffer or TypedArray
63-
jsi::Value dataValue = obj.getProperty(runtime, "data");
63+
jsi::Value dataValue = obj.getProperty(runtime, "dataPtr");
6464
jsi::Object dataObj = dataValue.asObject(runtime);
6565

6666
// Check if it's an ArrayBuffer or TypedArray
@@ -96,17 +96,17 @@ inline JSTensorView getValue<JSTensorView>(const jsi::Value &val,
9696
}
9797

9898
template <>
99-
inline std::vector<JSTensorView>
100-
getValue<std::vector<JSTensorView>>(const jsi::Value &val,
101-
jsi::Runtime &runtime) {
99+
inline std::vector<JSTensorViewIn>
100+
getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
101+
jsi::Runtime &runtime) {
102102
jsi::Array array = val.asObject(runtime).asArray(runtime);
103103
size_t length = array.size(runtime);
104-
std::vector<JSTensorView> result;
104+
std::vector<JSTensorViewIn> result;
105105
result.reserve(length);
106106

107107
for (size_t i = 0; i < length; ++i) {
108108
jsi::Value element = array.getValueAtIndex(runtime, i);
109-
result.push_back(getValue<JSTensorView>(element, runtime));
109+
result.push_back(getValue<JSTensorViewIn>(element, runtime));
110110
}
111111
return result;
112112
}
@@ -187,14 +187,14 @@ getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
187187
for (size_t i = 0; i < vec.size(); i++) {
188188
jsi::Object tensorObj(runtime);
189189

190-
tensorObj.setProperty(runtime, "shape",
190+
tensorObj.setProperty(runtime, "sizes",
191191
getJsiValue(vec[i]->sizes, runtime));
192192

193193
tensorObj.setProperty(runtime, "scalarType",
194194
jsi::Value(static_cast<int>(vec[i]->scalarType)));
195195

196-
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->data);
197-
tensorObj.setProperty(runtime, "data", arrayBuffer);
196+
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr);
197+
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);
198198

199199
array.setValueAtIndex(runtime, i, tensorObj);
200200
}

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ BaseModel::getAllInputShapes(std::string methodName) {
8686
}
8787

8888
std::vector<std::shared_ptr<JSTensorViewOut>>
89-
BaseModel::forwardJS(const std::vector<JSTensorView> tensorViewVec) {
89+
BaseModel::forwardJS(const std::vector<JSTensorViewIn> tensorViewVec) {
9090
if (!module) {
9191
throw std::runtime_error("Model not loaded: Cannot perform forward pass");
9292
}
@@ -102,7 +102,7 @@ BaseModel::forwardJS(const std::vector<JSTensorView> tensorViewVec) {
102102
for (size_t i = 0; i < tensorViewVec.size(); i++) {
103103
const auto &currTensorView = tensorViewVec[i];
104104
auto tensorPtr =
105-
make_tensor_ptr(currTensorView.shape, currTensorView.dataPtr,
105+
make_tensor_ptr(currTensorView.sizes, currTensorView.dataPtr,
106106
currTensorView.scalarType);
107107
tensorPtrs.emplace_back(tensorPtr);
108108
evalues.emplace_back(*tensorPtr); // Dereference TensorPtr to get Tensor,

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
#include <ReactCommon/CallInvoker.h>
77
#include <executorch/extension/module/module.h>
88
#include <jsi/jsi.h>
9-
#include <rnexecutorch/host_objects/JSTensorView.h>
9+
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
1010
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
1111
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
1212

@@ -25,7 +25,7 @@ class BaseModel {
2525
getAllInputShapes(std::string methodName = "forward");
2626

2727
std::vector<std::shared_ptr<JSTensorViewOut>>
28-
forwardJS(std::vector<JSTensorView> tensorViewVec);
28+
forwardJS(std::vector<JSTensorViewIn> tensorViewVec);
2929

3030
protected:
3131
Result<std::vector<EValue>> forward(const EValue &input_value);

0 commit comments

Comments
 (0)