Skip to content

Commit 6762278

Browse files
committed
chore: improve error handling, make getAllInputshapes accept a methodName param
1 parent e6eba4b commit 6762278

2 files changed

Lines changed: 18 additions & 14 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,22 +28,25 @@ BaseModel::BaseModel(const std::string &modelSource,
2828
memorySizeLowerBound = std::filesystem::file_size(modelPath);
2929
}
3030

31-
std::vector<std::vector<int32_t>> BaseModel::getAllInputShapes() {
32-
if (!module) {
31+
std::vector<std::vector<int32_t>>
32+
BaseModel::getAllInputShapes(std::string methodName) {
33+
if (!module) {
3334
throw std::runtime_error("getInputShape called on unloaded model");
3435
}
35-
auto method_meta = module->method_meta("forward");
36+
auto method_meta = module->method_meta(methodName);
3637

3738
if (!method_meta.ok()) {
38-
throw std::runtime_error("Failed to load forward");
39+
throw std::runtime_error("Failed to load method: " + methodName);
3940
}
4041
std::vector<std::vector<int32_t>> output;
4142
std::size_t numInputs = method_meta->num_inputs();
4243
output.reserve(numInputs);
4344
for (std::size_t input = 0; input < numInputs; ++input) {
4445
auto input_meta = method_meta->input_tensor_meta(input);
4546
if (!input_meta.ok()) {
46-
throw std::runtime_error("Failed to load forward input");
47+
throw std::runtime_error(
48+
"Failed to load input no: " + std::to_string(input) + " for method " +
49+
methodName);
4750
}
4851
auto shape = input_meta->sizes();
4952
output.emplace_back(std::vector<int32_t>(shape.begin(), shape.end()));
@@ -55,11 +58,4 @@ std::size_t BaseModel::getMemoryLowerBound() { return memorySizeLowerBound; }
5558

5659
void BaseModel::unload() { module.reset(nullptr); }
5760

58-
Result<std::vector<EValue>> BaseModel::forwardET(const EValue &input_value) {
59-
if (!module) {
60-
throw std::runtime_error("Forward called on unloaded model");
61-
}
62-
return module->forward(input_value);
63-
}
64-
6561
} // namespace rnexecutorch

packages/react-native-executorch/common/rnexecutorch/models/BaseModel.h

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
#include <ReactCommon/CallInvoker.h>
77
#include <executorch/extension/module/module.h>
88
#include <jsi/jsi.h>
9+
#include <rnexecutorch/jsi/OwningArrayBuffer.h>
10+
#include <rnexecutorch/utils/JsiTensorView.h>
911

1012
namespace rnexecutorch {
1113
using namespace facebook;
@@ -17,10 +19,16 @@ class BaseModel {
1719
std::shared_ptr<react::CallInvoker> callInvoker);
1820
std::size_t getMemoryLowerBound();
1921
void unload();
20-
std::vector<std::vector<int32_t>> getAllInputShapes();
22+
std::vector<std::vector<int32_t>>
23+
getAllInputShapes(std::string methodName = "forward");
24+
std::vector<std::shared_ptr<OwningArrayBuffer>>
25+
forward(std::vector<JsiTensorView> tensorViewVec);
2126

2227
protected:
23-
Result<std::vector<EValue>> forwardET(const EValue &input_value);
28+
// TODO: NEED TO CHANGE THE CONCEPT TO MATCH THE EXACT SIGNATURE OF THE SECOND
29+
// FORWARD SO ITS NOT AMBIGUOUS
30+
Result<std::vector<EValue>> forward(const EValue &input_value);
31+
Result<std::vector<EValue>> forward(const std::vector<EValue> &input_value);
2432
// If possible, models should not use the JS runtime to keep JSI internals
2533
// away from logic, however, sometimes this would incur too big of a penalty
2634
// (unnecessary copies instead of working on JS memory). In this case

0 commit comments

Comments
 (0)