Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
a29a7bd
wip
chmjkb Jun 4, 2025
4f4b0b0
wip: working ish brute force
chmjkb Jun 5, 2025
0c94769
feat: finish forward of ExecutorchModule
chmjkb Jun 6, 2025
a0e1438
feat: add a JSI conversion from OwningArrayBuffers to JS types
chmjkb Jun 6, 2025
08fed8d
fix: update ExecutorchModule header
chmjkb Jun 6, 2025
8df926f
fix: change the concept to check for external memory pressure properties
chmjkb Jun 6, 2025
a786646
fix: post-rebase concept fixes
chmjkb Jun 6, 2025
01f97ac
chore: 🧹
chmjkb Jun 6, 2025
feede80
chore: 🧹
chmjkb Jun 6, 2025
efddd91
chore: 🧹
chmjkb Jun 6, 2025
47540a9
fix: complete scalartype enum & move types to common.ts
chmjkb Jun 6, 2025
225dcac
chore: get rid of numel within JsiTensorView ✨
chmjkb Jun 6, 2025
4b96db0
chore: move using namespace to rnexecutorch namespace
chmjkb Jun 9, 2025
f0ba890
chore: improve error handling, make getAllInputshapes accept a method…
chmjkb Jun 9, 2025
6820c4f
chore: fix namespace mess
chmjkb Jun 9, 2025
305acac
fix: make NewExecutorchModule inherit from base
chmjkb Jun 9, 2025
8cd1602
fix: ensure loadClassification is installed
chmjkb Jun 9, 2025
54ab273
chore: use .asNumber() instead of getValue<int>
chmjkb Jun 9, 2025
8891ebe
chore: 🧹✨
chmjkb Jun 9, 2025
b2272a0
chore: 🧹✨
chmjkb Jun 9, 2025
303e732
chore: 🧹✨
chmjkb Jun 9, 2025
f525ab5
fix: pass proper ScalarType
chmjkb Jun 9, 2025
3208506
refactor: make BaseModel an ExecutorchModule
chmjkb Jun 9, 2025
58d9131
chore: get rid of ExecutorchModule, cleanup TypeConcepts
chmjkb Jun 9, 2025
b251c72
chore: improve error messages
chmjkb Jun 9, 2025
c411f28
chore: remove redundant log include
chmjkb Jun 9, 2025
af6cd0e
feat: make basenonstaticmodule able to call all the functions from Ba…
chmjkb Jun 9, 2025
08c2645
doesn't work :(
chmjkb Jun 10, 2025
6210bad
friendship with forward ended, now forwardjs is my best friend
chmjkb Jun 10, 2025
0e19cdd
chore: update return type of TS module
chmjkb Jun 10, 2025
19450bc
chore: remove redundant comments
chmjkb Jun 10, 2025
1d7535b
chore: remove log include
chmjkb Jun 10, 2025
19ea798
refactor: unify members of in/out tensors, adjust conversions & usage
chmjkb Jun 11, 2025
9d22549
chore: update TS types, refactor getTensorShape
chmjkb Jun 11, 2025
890bde3
chore: use .reserve() for input shapes
chmjkb Jun 11, 2025
0b0be90
chore: make memorySizeLowerBound a private member
chmjkb Jun 11, 2025
3776252
refactor: add a generic getValue for numeric types
chmjkb Jun 11, 2025
e3d63c1
chore: make creating input shapes prettier
chmjkb Jun 11, 2025
bc75e3a
fix: post rebase fixes
chmjkb Jun 11, 2025
45234d6
fix: forward->generate
chmjkb Jun 11, 2025
be50835
chore: make load within basenonstaticmodule abstract
chmjkb Jun 11, 2025
d1f31a1
fix: update index.tsx
chmjkb Jun 11, 2025
551c5c0
chore: change the usage of is_arithmetic<T>::value to is_arithmetic_v<T>
chmjkb Jun 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,11 @@ void RnExecutorchInstaller::injectJSIBindings(
*jsiRuntime, "loadObjectDetection",
RnExecutorchInstaller::loadModel<ObjectDetection>(
jsiRuntime, jsCallInvoker, "loadObjectDetection"));

jsiRuntime->global().setProperty(
*jsiRuntime, "loadExecutorchModule",
RnExecutorchInstaller::loadModel<BaseModel>(jsiRuntime, jsCallInvoker,
"loadExecutorchModule"));
}

} // namespace rnexecutorch
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <ReactCommon/CallInvoker.h>
#include <jsi/jsi.h>

#include <rnexecutorch/TypeConstraints.h>
#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/host_objects/JsiConversions.h>
#include <rnexecutorch/host_objects/ModelHostObject.h>

Expand All @@ -26,7 +26,7 @@ class RnExecutorchInstaller {
FetchUrlFunc_t fetchDataFromUrl);

private:
template <DerivedFromBaseModel ModelT>
template <DerivedFromOrSameAs<BaseModel> ModelT>
static jsi::Function
loadModel(jsi::Runtime *jsiRuntime,
std::shared_ptr<react::CallInvoker> jsCallInvoker,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#pragma once

#include <concepts>
#include <type_traits>

namespace rnexecutorch {

template <typename T, typename Base>
concept DerivedFromOrSameAs = std::is_base_of_v<Base, T>;

template <typename T>
concept HasGenerate = requires(T t) {
{ &T::generate };
};

template <typename T>
concept IsNumeric = std::is_arithmetic_v<T>;

} // namespace rnexecutorch

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
#pragma once

namespace rnexecutorch {

using executorch::aten::ScalarType;

struct JSTensorViewIn {
void *dataPtr;
std::vector<int32_t> sizes;
ScalarType scalarType;
};
} // namespace rnexecutorch
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#pragma once

#include <executorch/runtime/core/portable_type/scalar_type.h>
#include <memory>
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
#include <vector>

namespace rnexecutorch {

using executorch::runtime::etensor::ScalarType;

struct JSTensorViewOut {
std::shared_ptr<OwningArrayBuffer> dataPtr;
std::vector<int32_t> sizes;
ScalarType scalarType;

JSTensorViewOut(std::vector<int32_t> sizes, ScalarType scalarType,
std::shared_ptr<OwningArrayBuffer> dataPtr)
: sizes(std::move(sizes)), scalarType(scalarType),
dataPtr(std::move(dataPtr)) {}
};
} // namespace rnexecutorch
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
#include <type_traits>
#include <unordered_map>

#include <executorch/runtime/core/exec_aten/util/scalar_type_util.h>
#include <jsi/jsi.h>
#include <rnexecutorch/host_objects/JSTensorViewIn.h>
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/jsi/OwningArrayBuffer.h>

#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/models/object_detection/Constants.h>
#include <rnexecutorch/models/object_detection/Utils.h>

Expand All @@ -17,9 +22,12 @@ using namespace facebook;

template <typename T> T getValue(const jsi::Value &val, jsi::Runtime &runtime);

template <>
inline double getValue<double>(const jsi::Value &val, jsi::Runtime &runtime) {
return val.asNumber();
template <typename T>
requires IsNumeric<T>
inline T getValue(const jsi::Value &val, jsi::Runtime &runtime) {
static_assert(std::is_integral<T>::value || std::is_floating_point<T>::value,
"Only integral and floating-point types are supported");
return static_cast<T>(val.asNumber());
}

template <>
Expand All @@ -33,6 +41,78 @@ inline std::string getValue<std::string>(const jsi::Value &val,
return val.getString(runtime).utf8(runtime);
}

template <>
inline JSTensorViewIn getValue<JSTensorViewIn>(const jsi::Value &val,
jsi::Runtime &runtime) {
jsi::Object obj = val.asObject(runtime);
JSTensorViewIn tensorView;

int scalarTypeInt = obj.getProperty(runtime, "scalarType").asNumber();
tensorView.scalarType = static_cast<ScalarType>(scalarTypeInt);

jsi::Value shapeValue = obj.getProperty(runtime, "sizes");
jsi::Array shapeArray = shapeValue.asObject(runtime).asArray(runtime);
size_t numShapeDims = shapeArray.size(runtime);
tensorView.sizes.reserve(numShapeDims);

for (size_t i = 0; i < numShapeDims; ++i) {
int dim = getValue<int>(shapeArray.getValueAtIndex(runtime, i), runtime);
tensorView.sizes.push_back(static_cast<int32_t>(dim));
}

// On JS side, TensorPtr objects hold a 'data' property which should be either
// an ArrayBuffer or TypedArray
jsi::Value dataValue = obj.getProperty(runtime, "dataPtr");
jsi::Object dataObj = dataValue.asObject(runtime);

// Check if it's an ArrayBuffer or TypedArray
if (dataObj.isArrayBuffer(runtime)) {
jsi::ArrayBuffer arrayBuffer = dataObj.getArrayBuffer(runtime);
tensorView.dataPtr = arrayBuffer.data(runtime);

} else {
// Handle typed arrays (Float32Array, Int32Array, etc.)
const bool isValidTypedArray = dataObj.hasProperty(runtime, "buffer") &&
dataObj.hasProperty(runtime, "byteOffset") &&
dataObj.hasProperty(runtime, "byteLength") &&
dataObj.hasProperty(runtime, "length");
if (!isValidTypedArray) {
throw jsi::JSError(runtime, "Data must be an ArrayBuffer or TypedArray");
}
jsi::Value bufferValue = dataObj.getProperty(runtime, "buffer");
if (!bufferValue.isObject() ||
!bufferValue.asObject(runtime).isArrayBuffer(runtime)) {
throw jsi::JSError(runtime,
"TypedArray buffer property must be an ArrayBuffer");
}

jsi::ArrayBuffer arrayBuffer =
bufferValue.asObject(runtime).getArrayBuffer(runtime);
size_t byteOffset =
getValue<int>(dataObj.getProperty(runtime, "byteOffset"), runtime);

tensorView.dataPtr =
static_cast<uint8_t *>(arrayBuffer.data(runtime)) + byteOffset;
}
return tensorView;
}

template <>
inline std::vector<JSTensorViewIn>
getValue<std::vector<JSTensorViewIn>>(const jsi::Value &val,
jsi::Runtime &runtime) {
jsi::Array array = val.asObject(runtime).asArray(runtime);
size_t length = array.size(runtime);
std::vector<JSTensorViewIn> result;
result.reserve(length);

for (size_t i = 0; i < length; ++i) {
jsi::Value element = array.getValueAtIndex(runtime, i);
result.push_back(getValue<JSTensorViewIn>(element, runtime));
}
return result;
}

template <>
inline std::vector<std::string>
getValue<std::vector<std::string>>(const jsi::Value &val,
Expand Down Expand Up @@ -78,6 +158,51 @@ inline jsi::Value getJsiValue(std::shared_ptr<jsi::Object> valuePtr,
return std::move(*valuePtr);
}

inline jsi::Value getJsiValue(const std::vector<int32_t> &vec,
jsi::Runtime &runtime) {
jsi::Array array(runtime, vec.size());
for (size_t i = 0; i < vec.size(); i++) {
array.setValueAtIndex(runtime, i, jsi::Value(static_cast<int>(vec[i])));
}
return jsi::Value(runtime, array);
}

inline jsi::Value getJsiValue(int val, jsi::Runtime &runtime) {
return jsi::Value(runtime, val);
}

inline jsi::Value
getJsiValue(const std::vector<std::shared_ptr<OwningArrayBuffer>> &vec,
jsi::Runtime &runtime) {
jsi::Array array(runtime, vec.size());
for (size_t i = 0; i < vec.size(); i++) {
jsi::ArrayBuffer arrayBuffer(runtime, vec[i]);
array.setValueAtIndex(runtime, i, jsi::Value(runtime, arrayBuffer));
}
return jsi::Value(runtime, array);
}

inline jsi::Value
getJsiValue(const std::vector<std::shared_ptr<JSTensorViewOut>> &vec,
jsi::Runtime &runtime) {
jsi::Array array(runtime, vec.size());
for (size_t i = 0; i < vec.size(); i++) {
jsi::Object tensorObj(runtime);

tensorObj.setProperty(runtime, "sizes",
getJsiValue(vec[i]->sizes, runtime));

tensorObj.setProperty(runtime, "scalarType",
jsi::Value(static_cast<int>(vec[i]->scalarType)));

jsi::ArrayBuffer arrayBuffer(runtime, vec[i]->dataPtr);
tensorObj.setProperty(runtime, "dataPtr", arrayBuffer);

array.setValueAtIndex(runtime, i, tensorObj);
}
return jsi::Value(runtime, array);
}

inline jsi::Value getJsiValue(const std::string &str, jsi::Runtime &runtime) {
return jsi::String::createFromAscii(runtime, str);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
#include <ReactCommon/CallInvoker.h>

#include <rnexecutorch/Log.h>
#include <rnexecutorch/TypeConstraints.h>
#include <rnexecutorch/TypeConcepts.h>
#include <rnexecutorch/host_objects/JSTensorViewOut.h>
#include <rnexecutorch/host_objects/JsiConversions.h>
#include <rnexecutorch/jsi/JsiHostObject.h>
#include <rnexecutorch/jsi/Promise.h>
#include <rnexecutorch/models/BaseModel.h>

namespace rnexecutorch {

Expand All @@ -20,13 +22,28 @@ template <typename Model> class ModelHostObject : public JsiHostObject {
explicit ModelHostObject(const std::shared_ptr<Model> &model,
std::shared_ptr<react::CallInvoker> callInvoker)
: model(model), callInvoker(callInvoker) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::forward>,
"forward"));
if constexpr (DerivedFromBaseModel<Model>) {
if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
addFunctions(
JSI_EXPORT_FUNCTION(ModelHostObject<Model>, unload, "unload"));
}

if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::forwardJS>,
"forward"));
}

if constexpr (DerivedFromOrSameAs<Model, BaseModel>) {
addFunctions(JSI_EXPORT_FUNCTION(
ModelHostObject<Model>, promiseHostFunction<&Model::getInputShape>,
"getInputShape"));
}

if constexpr (HasGenerate<Model>) {
addFunctions(JSI_EXPORT_FUNCTION(ModelHostObject<Model>,
promiseHostFunction<&Model::generate>,
"generate"));
}
}

// A generic host function that resolves a promise with a result of a
Expand Down
Loading