|
6 | 6 |
|
7 | 7 | #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
8 | 8 | #include <jsi/jsi.h> |
9 | | -#include <rnexecutorch/host_objects/JSTensorView.h> |
| 9 | +#include <rnexecutorch/host_objects/JSTensorViewIn.h> |
10 | 10 | #include <rnexecutorch/host_objects/JSTensorViewOut.h> |
11 | 11 | #include <rnexecutorch/jsi/OwningArrayBuffer.h> |
12 | 12 |
|
@@ -40,27 +40,27 @@ inline std::string getValue<std::string>(const jsi::Value &val, |
40 | 40 | } |
41 | 41 |
|
42 | 42 | template <> |
43 | | -inline JSTensorView getValue<JSTensorView>(const jsi::Value &val, |
44 | | - jsi::Runtime &runtime) { |
| 43 | +inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val, |
| 44 | + jsi::Runtime &runtime) { |
45 | 45 | jsi::Object obj = val.asObject(runtime); |
46 | | - JSTensorView tensorView; |
| 46 | + JSTensorViewIn tensorView; |
47 | 47 |
|
48 | 48 | int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber(); |
49 | 49 | tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt); |
50 | 50 |
|
51 | | - jsi::Value shapeValue = obj.getProperty(runtime, "shape"); |
| 51 | + jsi::Value shapeValue = obj.getProperty(runtime, "sizes"); |
52 | 52 | jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime); |
53 | | - size_t shapeDims = shapeArray.size(runtime); |
54 | | - tensorView.shape.reserve(shapeDims); |
| 53 | + size_t numShapeDims = shapeArray.size(runtime); |
| 54 | + tensorView.sizes.reserve(numShapeDims); |
55 | 55 |
|
56 | | - for (size_t i = 0; i < shapeDims; ++i) { |
| 56 | + for (size_t i = 0; i < numShapeDims; ++i) { |
57 | 57 | int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime); |
58 | | - tensorView.shape.push_back(static_cast<int32_t>(dim)); |
| 58 | + tensorView.sizes.push_back(static_cast<int32_t>(dim)); |
59 | 59 | } |
60 | 60 |
|
61 | 61 | // On JS side, TensorPtr objects hold a 'data' property which should be either |
62 | 62 | // an ArrayBuffer or TypedArray |
63 | | - jsi::Value dataValue = obj.getProperty(runtime, "data"); |
| 63 | + jsi::Value dataValue = obj.getProperty(runtime, "dataPtr"); |
64 | 64 | jsi::Object dataObj = dataValue.asObject(runtime); |
65 | 65 |
|
66 | 66 | // Check if it's an ArrayBuffer or TypedArray |
@@ -96,17 +96,17 @@ inline JSTensorView getValue<JSTensorView>(const jsi::Value &val, |
96 | 96 | } |
97 | 97 |
|
98 | 98 | template <> |
99 | | -inline std::vector<JSTensorView> |
100 | | -getValue<std::vector<JSTensorView>>(const jsi::Value &val, |
101 | | - jsi::Runtime &runtime) { |
| 99 | +inline std::vector<JSTensorViewIn> |
| 100 | +getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val, |
| 101 | + jsi::Runtime &runtime) { |
102 | 102 | jsi::Array array = val.asObject(runtime).asArray(runtime); |
103 | 103 | size_t length = array.size(runtime); |
104 | | - std::vector<JSTensorView> result; |
| 104 | + std::vector<JSTensorViewIn> result; |
105 | 105 | result.reserve(length); |
106 | 106 |
|
107 | 107 | for (size_t i = 0; i < length; ++i) { |
108 | 108 | jsi::Value element = array.getValueAtIndex(runtime, i); |
109 | | - result.push_back(getValue<JSTensorView>(element, runtime)); |
| 109 | + result.push_back(getValue<JSTensorViewIn>(element, runtime)); |
110 | 110 | } |
111 | 111 | return result; |
112 | 112 | } |
@@ -187,14 +187,14 @@ getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec, |
187 | 187 | for (size_t i = 0; i < vec.size(); i++) { |
188 | 188 | jsi::Object tensorObj(runtime); |
189 | 189 |
|
190 | | - tensorObj.setProperty(runtime, "shape", |
| 190 | + tensorObj.setProperty(runtime, "sizes", |
191 | 191 | getJsiValue(vec[i]->sizes, runtime)); |
192 | 192 |
|
193 | 193 | tensorObj.setProperty(runtime, "scalarType", |
194 | 194 | jsi::Value(static_cast<int>(vec[i]->scalarType))); |
195 | 195 |
|
196 | | - jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->data); |
197 | | - tensorObj.setProperty(runtime, "data", arrayBuffer); |
| 196 | + jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr); |
| 197 | + tensorObj.setProperty(runtime, "dataPtr", arrayBuffer); |
198 | 198 |
|
199 | 199 | array.setValueAtIndex(runtime, i, tensorObj); |
200 | 200 | } |
|
0 commit comments