Skip to content

Commit 36f3fdb

Browse files
committed
feat: finish forward of ExecutorchModule
1 parent 51549bb commit 36f3fdb

1 file changed

Lines changed: 39 additions & 9 deletions

File tree

packages/react-native-executorch/common/rnexecutorch/bindings/ExecutorchModule.cpp

Lines changed: 39 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,47 @@ ExecutorchModule::ExecutorchModule(
2525
}
2626
}
2727

28-
int ExecutorchModule::forward(std::vector<JsiTensorView> tensorViewVec) {
29-
auto currTensor = tensorViewVec[0];
30-
auto myTensor =
31-
make_tensor_ptr(currTensor.shape, currTensor.dataPtr, ScalarType::Float);
32-
auto result = module->forward(myTensor);
28+
std::vector<std::shared_ptr<OwningArrayBuffer>>
29+
ExecutorchModule::forward(std::vector<JsiTensorView> tensorViewVec) {
30+
std::vector<executorch::runtime::EValue> evalues;
31+
evalues.reserve(tensorViewVec.size());
32+
// Because EValue doesn't hold to the dynamic data and metadata from
33+
// TensorPtr, we need to make sure that the TensorPtr for each EValue is valid
34+
// as long as that EValue is in use. Therefore we create a vec solely for
35+
// keeping references to the TensorPtr
36+
std::vector<TensorPtr> tensorPtrs;
37+
tensorPtrs.reserve(evalues.size());
38+
39+
for (size_t i = 0; i < tensorViewVec.size(); i++) {
40+
const auto &currTensorView = tensorViewVec[i];
41+
auto tensorPtr = make_tensor_ptr(currTensorView.shape,
42+
currTensorView.dataPtr, ScalarType::Float);
43+
tensorPtrs.emplace_back(tensorPtr);
44+
evalues.emplace_back(*tensorPtr); // Dereference TensorPtr to get Tensor,
45+
// which implicitly converts to EValue
46+
}
47+
48+
auto result = module->forward(evalues);
3349
if (!result.ok()) {
34-
std::string errorStr = std::to_string(static_cast<int>(result.error()));
35-
log(LOG_LEVEL::Debug, errorStr.c_str());
36-
throw std::runtime_error("Failed to run forward! Error: " + errorStr);
50+
throw std::runtime_error("Forward error: " +
51+
std::to_string(static_cast<int>(result.error())));
52+
}
53+
54+
auto &outputs = result.get();
55+
std::vector<std::shared_ptr<OwningArrayBuffer>> output;
56+
output.reserve(outputs.size());
57+
58+
// Convert ET outputs to a vector of ArrayBuffers which are later
59+
// converted to JSI array via JsiConversions.h
60+
for (size_t i = 0; i < outputs.size(); i++) {
61+
auto &outputTensor = outputs[i].toTensor();
62+
63+
size_t bufferSize = outputTensor.numel() * outputTensor.element_size();
64+
auto buffer = std::make_shared<OwningArrayBuffer>(bufferSize);
65+
std::memcpy(buffer->data(), outputTensor.const_data_ptr(), bufferSize);
66+
output.emplace_back(buffer);
3767
}
38-
return 1;
68+
return output;
3969
}
4070

4171
std::vector<int32_t> ExecutorchModule::getInputShape(std::string method_name,

0 commit comments

Comments
 (0)