Skip to content

Commit 1cf016a

Browse files
authored
Refactor jsi conversion & NITs (#562)
## Description As in the title, some refactor of jsi. ### Introduces a breaking change? - [ ] Yes - [x] No ### Type of change - [ ] Bug fix (change which fixes an issue) - [ ] New feature (change which adds functionality) - [ ] Documentation update (improves or adds clarity to existing documentation) - [x] Other (chores, tests, code style improvements etc.) ### Tested on - [ ] iOS - [ ] Android ### Testing instructions Needs to be manually testes / unit test (TODO) ### Screenshots <!-- Add screenshots here, if applicable --> ### Related issues <!-- Link related issues here using #issue-number --> ### Checklist - [ ] I have performed a self-review of my code - [ ] I have commented my code, particularly in hard-to-understand areas - [ ] I have updated the documentation accordingly - [ ] My changes generate no new warnings ### Additional notes <!-- Include any additional information, assumptions, or context that reviewers might need to understand this PR. -->
1 parent f205414 commit 1cf016a

2 files changed

Lines changed: 33 additions & 64 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/host_objects/JsiConversions.h

Lines changed: 26 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,6 @@ template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);
2929
template <typename T>
3030
requires meta::IsNumeric<T>
3131
inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) {
32-
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value,
33-
"Only integral and floating-point types are supported");
3432
return static_cast<T>(val.asNumber());
3533
}
3634

@@ -53,21 +51,6 @@ getValue<std::shared_ptr<jsi::Function>>(const jsi::Value &val,
5351
val.asObject(runtime).asFunction(runtime));
5452
}
5553

56-
template <>
57-
inline std::vector<int32_t>
58-
getValue<std::vector<int32_t>>(const jsi::Value &val, jsi::Runtime &runtime) {
59-
jsi::Array array = val.asObject(runtime).asArray(runtime);
60-
size_t length = array.size(runtime);
61-
std::vector<int32_t> result;
62-
result.reserve(length);
63-
64-
for (size_t i = 0; i < length; ++i) {
65-
jsi::Value element = array.getValueAtIndex(runtime, i);
66-
result.push_back(getValue<int32_t>(element, runtime));
67-
}
68-
return result;
69-
}
70-
7154
template <>
7255
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
7356
jsi::Runtime &runtime) {
@@ -83,8 +66,8 @@ inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
8366
tensorView.sizes.reserve(numShapeDims);
8467

8568
for (size_t i = 0; i < numShapeDims; ++i) {
86-
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
87-
tensorView.sizes.push_back(static_cast<int32_t>(dim));
69+
int32_t dim = getValue<int32_t>(shapeArray.getValueAtIndex(runtime, i), runtime);
70+
tensorView.sizes.push_back(dim);
8871
}
8972

9073
// On JS side, TensorPtr objects hold a 'data' property which should be either
@@ -123,38 +106,6 @@ inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
123106
return tensorView;
124107
}
125108

126-
template <>
127-
inline std::vector<JSTensorViewIn>
128-
getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
129-
jsi::Runtime &runtime) {
130-
jsi::Array array = val.asObject(runtime).asArray(runtime);
131-
size_t length = array.size(runtime);
132-
std::vector<JSTensorViewIn> result;
133-
result.reserve(length);
134-
135-
for (size_t i = 0; i < length; ++i) {
136-
jsi::Value element = array.getValueAtIndex(runtime, i);
137-
result.push_back(getValue<JSTensorViewIn>(element, runtime));
138-
}
139-
return result;
140-
}
141-
142-
template <>
143-
inline std::vector<std::string>
144-
getValue<std::vector<std::string>>(const jsi::Value &val,
145-
jsi::Runtime &runtime) {
146-
jsi::Array array = val.asObject(runtime).asArray(runtime);
147-
size_t length = array.size(runtime);
148-
std::vector<std::string> result;
149-
result.reserve(length);
150-
151-
for (size_t i = 0; i < length; ++i) {
152-
jsi::Value element = array.getValueAtIndex(runtime, i);
153-
result.push_back(getValue<std::string>(element, runtime));
154-
}
155-
return result;
156-
}
157-
158109
// C++ set from JS array. Set with heterogenerous look-up (adding std::less<>
159110
// enables querying with std::string_view).
160111
template <>
@@ -222,7 +173,26 @@ inline std::vector<T> getArrayAsVector(const jsi::Value &val,
222173
return result;
223174
}
224175

176+
225177
// Template specializations for std::vector<T> types
178+
template <>
179+
inline std::vector<JSTensorViewIn> getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
180+
jsi::Runtime &runtime) {
181+
return getArrayAsVector<JSTensorViewIn>(val, runtime);
182+
}
183+
184+
template <>
185+
inline std::vector<std::string> getValue<std::vector<std::string>>(const jsi::Value &val,
186+
jsi::Runtime &runtime) {
187+
return getArrayAsVector<std::string>(val, runtime);
188+
}
189+
190+
template <>
191+
inline std::vector<int32_t> getValue<std::vector<int32_t>>(const jsi::Value &val,
192+
jsi::Runtime &runtime) {
193+
return getArrayAsVector<int32_t>(val, runtime);
194+
}
195+
226196
template <>
227197
inline std::vector<float> getValue<std::vector<float>>(const jsi::Value &val,
228198
jsi::Runtime &runtime) {
@@ -307,17 +277,17 @@ inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
307277
for (size_t i = 0; i < vec.size(); i++) {
308278
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<int>(vec[i])));
309279
}
310-
return jsi::Value(runtime, array);
280+
return {runtime, array};
311281
}
312282

313283
inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) {
314-
return jsi::Value(runtime, val);
284+
return {runtime, val};
315285
}
316286

317287
inline jsi::Value getJsiValue(const std::shared_ptr<OwningArrayBuffer> &buf,
318288
jsi::Runtime &runtime) {
319289
jsi::ArrayBuffer arrayBuffer(runtime, buf);
320-
return jsi::Value(runtime, arrayBuffer);
290+
return {runtime, arrayBuffer};
321291
}
322292

323293
inline jsi::Value
@@ -328,7 +298,7 @@ getJsiValue(const std::vector<std::shared_ptr<OwningArrayBuffer>> &vec,
328298
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]);
329299
array.setValueAtIndex(runtime, i, jsi::Value(runtime, arrayBuffer));
330300
}
331-
return jsi::Value(runtime, array);
301+
return {runtime, array};
332302
}
333303

334304
inline jsi::Value getJsiValue(const std::vector<JSTensorViewOut> &vec,
@@ -347,7 +317,7 @@ inline jsi::Value getJsiValue(const std::vector<JSTensorViewOut> &vec,
347317

348318
array.setValueAtIndex(runtime, i, tensorObj);
349319
}
350-
return jsi::Value(runtime, array);
320+
return {runtime, array};
351321
}
352322

353323
inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {

packages/react-native-executorch/common/rnexecutorch/host_objects/ModelHostObject.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#pragma once
22

3+
#include <format>
34
#include <ReactCommon/CallInvoker.h>
45
#include <string>
56
#include <thread>
@@ -109,10 +110,10 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
109110
template <auto FnPtr> JSI_HOST_FUNCTION(synchronousHostFunction) {
110111
constexpr std::size_t functionArgCount = meta::getArgumentCount(FnPtr);
111112
if (functionArgCount != count) {
112-
char errorMessage[100];
113-
std::snprintf(errorMessage, sizeof(errorMessage),
114-
"Argument count mismatch, was expecting: %zu but got: %zu",
115-
functionArgCount, count);
113+
const auto errorMessage = std::format(
114+
"Argument count mismatch, was expecting: {} but got: {}",
115+
functionArgCount, count
116+
);
116117
throw jsi::JSError(runtime, errorMessage);
117118
}
118119

@@ -155,10 +156,8 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
155156
constexpr std::size_t functionArgCount =
156157
meta::getArgumentCount(FnPtr);
157158
if (functionArgCount != count) {
158-
char errorMessage[100];
159-
std::snprintf(
160-
errorMessage, sizeof(errorMessage),
161-
"Argument count mismatch, was expecting: %zu but got: %zu",
159+
const auto errorMessage = std::format(
160+
"Argument count mismatch, was expecting: {} but got: {}",
162161
functionArgCount, count);
163162
promise->reject(errorMessage);
164163
return;

0 commit comments

Comments
 (0)