Skip to content

Commit 029547b

Browse files
committed
wip
1 parent 0442aa3 commit 029547b

2 files changed

Lines changed: 65 additions & 0 deletions

File tree

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#include "ExecutorchModule.h"
2+
3+
#include <fmt/core.h>
4+
#include <rnexecutorch/Log.h>
5+
6+
namespace rnexecutorch {
7+
8+
using ::executorch::extension::Module;
9+
using ::executorch::runtime::Error;
10+
11+
ExecutorchModule::ExecutorchModule(const std::string &modelSource,
12+
facebook::jsi::Runtime *runtime)
13+
: module(std::make_unique<Module>(
14+
modelSource, Module::LoadMode::MmapUseMlockIgnoreErrors)),
15+
runtime(runtime) {
16+
Error loadError = module->load();
17+
if (loadError != Error::Ok) {
18+
throw std::runtime_error("Couldn't load the model, error: " +
19+
std::to_string(static_cast<uint32_t>(loadError)));
20+
}
21+
}
22+
23+
std::vector<int32_t> ExecutorchModule::getInputShape(std::string method_name,
24+
int index) {
25+
auto method_meta = module->method_meta(method_name);
26+
if (!method_meta.ok()) {
27+
throw std::runtime_error(
28+
fmt::format("Failed to load method with name {}", method_name));
29+
}
30+
31+
std::vector<int32_t> input_shape;
32+
auto input_meta = method_meta->input_tensor_meta(index);
33+
if (!input_meta.ok()) {
34+
throw std::runtime_error(
35+
fmt::format("Failed to load forward input {}", index));
36+
}
37+
38+
for (auto size : input_meta->sizes()) {
39+
input_shape.push_back(size);
40+
}
41+
return input_shape;
42+
}
43+
} // namespace rnexecutorch
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#pragma once
2+
3+
#include <string>
4+
5+
#include <executorch/extension/module/module.h>
6+
#include <fmt/core.h>
7+
#include <jsi/jsi.h>
8+
9+
namespace rnexecutorch {
10+
11+
class ExecutorchModule {
12+
public:
13+
ExecutorchModule(const std::string &modelSource,
14+
facebook::jsi::Runtime *runtime);
15+
std::vector<int32_t> getInputShape(std::string method_name, int index);
16+
17+
protected:
18+
std::unique_ptr<executorch::extension::Module> module;
19+
facebook::jsi::Runtime *runtime;
20+
};
21+
22+
} // namespace rnexecutorch

0 commit comments

Comments
 (0)