|
6 | 6 |
|
7 | 7 | #include <executorch/runtime/core/exec_aten/util/scalar_type_util.h> |
8 | 8 | #include <jsi/jsi.h> |
9 | | -#include <rnexecutorch/host_objects/JSTensorView.h> |
| 9 | +#include <rnexecutorch/host_objects/JSTensorViewIn.h> |
10 | 10 | #include <rnexecutorch/host_objects/JSTensorViewOut.h> |
11 | 11 | #include <rnexecutorch/jsi/OwningArrayBuffer.h> |
12 | 12 |
|
@@ -43,27 +43,27 @@ inline std::string getValue<std::string>(const jsi::Value &val, |
43 | 43 | } |
44 | 44 |
|
45 | 45 | template <> |
46 | | -inline JSTensorView getValue<JSTensorView>(const jsi::Value &val, |
47 | | - jsi::Runtime &runtime) { |
| 46 | +inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val, |
| 47 | + jsi::Runtime &runtime) { |
48 | 48 | jsi::Object obj = val.asObject(runtime); |
49 | | - JSTensorView tensorView; |
| 49 | + JSTensorViewIn tensorView; |
50 | 50 |
|
51 | 51 | int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber(); |
52 | 52 | tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt); |
53 | 53 |
|
54 | | - jsi::Value shapeValue = obj.getProperty(runtime, "shape"); |
| 54 | + jsi::Value shapeValue = obj.getProperty(runtime, "sizes"); |
55 | 55 | jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime); |
56 | | - size_t shapeDims = shapeArray.size(runtime); |
57 | | - tensorView.shape.reserve(shapeDims); |
| 56 | + size_t numShapeDims = shapeArray.size(runtime); |
| 57 | + tensorView.sizes.reserve(numShapeDims); |
58 | 58 |
|
59 | | - for (size_t i = 0; i < shapeDims; ++i) { |
| 59 | + for (size_t i = 0; i < numShapeDims; ++i) { |
60 | 60 | int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime); |
61 | | - tensorView.shape.push_back(static_cast<int32_t>(dim)); |
| 61 | + tensorView.sizes.push_back(static_cast<int32_t>(dim)); |
62 | 62 | } |
63 | 63 |
|
64 | 64 | // On JS side, TensorPtr objects hold a 'data' property which should be either |
65 | 65 | // an ArrayBuffer or TypedArray |
66 | | - jsi::Value dataValue = obj.getProperty(runtime, "data"); |
| 66 | + jsi::Value dataValue = obj.getProperty(runtime, "dataPtr"); |
67 | 67 | jsi::Object dataObj = dataValue.asObject(runtime); |
68 | 68 |
|
69 | 69 | // Check if it's an ArrayBuffer or TypedArray |
@@ -99,17 +99,17 @@ inline JSTensorView getValue<JSTensorView>(const jsi::Value &val, |
99 | 99 | } |
100 | 100 |
|
101 | 101 | template <> |
102 | | -inline std::vector<JSTensorView> |
103 | | -getValue<std::vector<JSTensorView>>(const jsi::Value &val, |
104 | | - jsi::Runtime &runtime) { |
| 102 | +inline std::vector<JSTensorViewIn> |
| 103 | +getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val, |
| 104 | + jsi::Runtime &runtime) { |
105 | 105 | jsi::Array array = val.asObject(runtime).asArray(runtime); |
106 | 106 | size_t length = array.size(runtime); |
107 | | - std::vector<JSTensorView> result; |
| 107 | + std::vector<JSTensorViewIn> result; |
108 | 108 | result.reserve(length); |
109 | 109 |
|
110 | 110 | for (size_t i = 0; i < length; ++i) { |
111 | 111 | jsi::Value element = array.getValueAtIndex(runtime, i); |
112 | | - result.push_back(getValue<JSTensorView>(element, runtime)); |
| 112 | + result.push_back(getValue<JSTensorViewIn>(element, runtime)); |
113 | 113 | } |
114 | 114 | return result; |
115 | 115 | } |
@@ -190,14 +190,14 @@ getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec, |
190 | 190 | for (size_t i = 0; i < vec.size(); i++) { |
191 | 191 | jsi::Object tensorObj(runtime); |
192 | 192 |
|
193 | | - tensorObj.setProperty(runtime, "shape", |
| 193 | + tensorObj.setProperty(runtime, "sizes", |
194 | 194 | getJsiValue(vec[i]->sizes, runtime)); |
195 | 195 |
|
196 | 196 | tensorObj.setProperty(runtime, "scalarType", |
197 | 197 | jsi::Value(static_cast<int>(vec[i]->scalarType))); |
198 | 198 |
|
199 | | - jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->data); |
200 | | - tensorObj.setProperty(runtime, "data", arrayBuffer); |
| 199 | + jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr); |
| 200 | + tensorObj.setProperty(runtime, "dataPtr", arrayBuffer); |
201 | 201 |
|
202 | 202 | array.setValueAtIndex(runtime, i, tensorObj); |
203 | 203 | } |
|
0 commit comments