-
Notifications
You must be signed in to change notification settings - Fork 71
Expand file tree
/
Copy pathJsiConversions.h
More file actions
262 lines (217 loc) · 9.26 KB
/
JsiConversions.h
File metadata and controls
262 lines (217 loc) · 9.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
#pragma once
#include <set>
#include <type_traits>
#include <unordered_map>
#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <jsi/jsi.h>
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/models/object_detection/Constants.h>
#include <rnexecutorch/models/object_detection/Utils.h>
namespace rnexecutorch::jsiconversion {
using namespace facebook;
// Conversion from jsi to C++ types --------------------------------------------
template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);
template <typename T>
requires IsNumeric<T>
inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) {
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value,
"Only integral and floating-point types are supported");
return static_cast<T>(val.asNumber());
}
template <>
inline bool getValue<bool>(const jsi::Value &val, jsi::Runtime &runtime) {
return val.asBool();
}
template <>
inline std::string getValue<std::string>(const jsi::Value &val,
jsi::Runtime &runtime) {
return val.getString(runtime).utf8(runtime);
}
template <>
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
jsi::Runtime &runtime) {
jsi::Object obj = val.asObject(runtime);
JSTensorViewIn tensorView;
int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber();
tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt);
jsi::Value shapeValue = obj.getProperty(runtime, "sizes");
jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime);
size_t numShapeDims = shapeArray.size(runtime);
tensorView.sizes.reserve(numShapeDims);
for (size_t i = 0; i < numShapeDims; ++i) {
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
tensorView.sizes.push_back(static_cast<int32_t>(dim));
}
// On JS side, TensorPtr objects hold a 'data' property which should be either
// an ArrayBuffer or TypedArray
jsi::Value dataValue = obj.getProperty(runtime, "dataPtr");
jsi::Object dataObj = dataValue.asObject(runtime);
// Check if it's an ArrayBuffer or TypedArray
if (dataObj.isArrayBuffer(runtime)) {
jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime);
tensorView.dataPtr = arrayBuffer.data(runtime);
} else {
// Handle typed arrays (Float32Array, Int32Array, etc.)
const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") &&
dataObj.hasProperty(runtime, "byteOffset") &&
dataObj.hasProperty(runtime, "byteLength") &&
dataObj.hasProperty(runtime, "length");
if (!isValidTypedArray) {
throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray");
}
jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer");
if (!bufferValue.isObject() ||
!bufferValue.asObject(runtime).isArrayBuffer(runtime)) {
throw jsi::JSError(runtime,
"TypedArray buffer property must be an ArrayBuffer");
}
jsi::ArrayBuffer arrayBuffer =
bufferValue.asObject(runtime).getArrayBuffer(runtime);
size_t byteOffset =
getValue<int>(dataObj.getProperty(runtime, "byteOffset"), runtime);
tensorView.dataPtr =
static_cast<uint8_t *>(arrayBuffer.data(runtime)) + byteOffset;
}
return tensorView;
}
template <>
inline std::vector<JSTensorViewIn>
getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
jsi::Runtime &runtime) {
jsi::Array array = val.asObject(runtime).asArray(runtime);
size_t length = array.size(runtime);
std::vector<JSTensorViewIn> result;
result.reserve(length);
for (size_t i = 0; i < length; ++i) {
jsi::Value element = array.getValueAtIndex(runtime, i);
result.push_back(getValue<JSTensorViewIn>(element, runtime));
}
return result;
}
template <>
inline std::vector<std::string>
getValue<std::vector<std::string>>(const jsi::Value &val,
jsi::Runtime &runtime) {
jsi::Array array = val.asObject(runtime).asArray(runtime);
size_t length = array.size(runtime);
std::vector<std::string> result;
result.reserve(length);
for (size_t i = 0; i < length; ++i) {
jsi::Value element = array.getValueAtIndex(runtime, i);
result.push_back(getValue<std::string>(element, runtime));
}
return result;
}
// C++ set from JS array. Set with heterogenerous look-up (adding std::less<>
// enables querying with std::string_view).
template <>
inline std::set<std::string, std::less<>>
getValue<std::set<std::string, std::less<>>>(const jsi::Value &val,
jsi::Runtime &runtime) {
jsi::Array array = val.asObject(runtime).asArray(runtime);
size_t length = array.size(runtime);
std::set<std::string, std::less<>> result;
for (size_t i = 0; i < length; ++i) {
jsi::Value element = array.getValueAtIndex(runtime, i);
result.insert(getValue<std::string>(element, runtime));
}
return result;
}
// Conversion from C++ types to jsi --------------------------------------------
// Implementation functions might return any type, but in a promise we can only
// return jsi::Value or jsi::Object. For each type being returned
// we add a function here.
inline jsi::Value getJsiValue(std::shared_ptr<jsi::Object> valuePtr,
jsi::Runtime &runtime) {
return std::move(*valuePtr);
}
inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
jsi::Runtime &runtime) {
jsi::Array array(runtime, vec.size());
for (size_t i = 0; i < vec.size(); i++) {
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<int>(vec[i])));
}
return jsi::Value(runtime, array);
}
inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) {
return jsi::Value(runtime, val);
}
inline jsi::Value
getJsiValue(const std::vector<std::shared_ptr<OwningArrayBuffer>> &vec,
jsi::Runtime &runtime) {
jsi::Array array(runtime, vec.size());
for (size_t i = 0; i < vec.size(); i++) {
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]);
array.setValueAtIndex(runtime, i, jsi::Value(runtime, arrayBuffer));
}
return jsi::Value(runtime, array);
}
inline jsi::Value
getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
jsi::Runtime &runtime) {
jsi::Array array(runtime, vec.size());
for (size_t i = 0; i < vec.size(); i++) {
jsi::Object tensorObj(runtime);
tensorObj.setProperty(runtime, "sizes",
getJsiValue(vec[i]->sizes, runtime));
tensorObj.setProperty(runtime, "scalarType",
jsi::Value(static_cast<int>(vec[i]->scalarType)));
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr);
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);
array.setValueAtIndex(runtime, i, tensorObj);
}
return jsi::Value(runtime, array);
}
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
return jsi::String::createFromAscii(runtime, str);
}
inline jsi::Value
getJsiValue(const std::unordered_map<std::string_view, float> &map,
jsi::Runtime &runtime) {
jsi::Object mapObj{runtime};
for (auto &[k, v] : map) {
// The string_view keys must be null-terminated!
mapObj.setProperty(runtime, k.data(), v);
}
return mapObj;
}
inline jsi::Value getJsiValue(const std::vector<Detection> &detections,
jsi::Runtime &runtime) {
jsi::Array array(runtime, detections.size());
for (std::size_t i = 0; i < detections.size(); ++i) {
jsi::Object detection(runtime);
jsi::Object bbox(runtime);
bbox.setProperty(runtime, "x1", detections[i].x1);
bbox.setProperty(runtime, "y1", detections[i].y1);
bbox.setProperty(runtime, "x2", detections[i].x2);
bbox.setProperty(runtime, "y2", detections[i].y2);
detection.setProperty(runtime, "bbox", bbox);
detection.setProperty(runtime, "label",
jsi::String::createFromAscii(
runtime, cocoLabelsMap.at(detections[i].label)));
detection.setProperty(runtime, "score", detections[i].score);
array.setValueAtIndex(runtime, i, detection);
}
return array;
}
template <typename Model, typename R, typename... Types>
constexpr std::size_t getArgumentCount(R (Model::*f)(Types...)) {
return sizeof...(Types);
}
template <typename... Types, std::size_t... I>
std::tuple<Types...> fillTupleFromArgs(std::index_sequence<I...>,
const jsi::Value *args,
jsi::Runtime &runtime) {
return std::make_tuple(getValue<Types>(args[I], runtime)...);
}
template <typename Model, typename R, typename... Types>
std::tuple<Types...> createArgsTupleFromJsi(R (Model::*f)(Types...),
const jsi::Value *args,
jsi::Runtime &runtime) {
return fillTupleFromArgs<Types...>(std::index_sequence_for<Types...>{}, args,
runtime);
}
} // namespace rnexecutorch::jsiconversion