Skip to content

Commit d38a1df

Browse files
committed
feat: port style transfer implementation to C++ (#229)
Port native implementation of style transfer to C++ #227. This is a part of a larger effort to merge native code from Kotlin and Objective C to a single implementation in C++. Old style transfer modules have been replaced by new, stateful implementation. For this purpose non-static versions of useModule and BaseModule have been introduced. Old implementations of style transfer have been removed. Added methods for handling fetching data via https that are passed on as function objects to C++. - Ported image processing capabilities equivalent to the ObjC/Kotlin implementations. - `ModelHostObject` serves as an automatic interface between model implementations and JSI. This is done by template metaprogramming; defining methods in `JsiConversions.h` for types used by the model is necessary for it to work. - A factory method for loading a style transfer model is registered in `RnExecutorchInstaller`. - `common/rnexecutorch/jsi` originates from react-native-audio-api - `base64.*` is used for converting from base64 due to https://renenyffenegger.ch/notes/development/Base64/Encoding-and-decoding-base-64-with-cpp/ - `ada` is a header-only library for parsing urls due to https://github.com/ada-url/ada - headers for Executorch 0.6 have been updated - OpenCV 4.11.0 dependency is introduced for Android/C++. For iOS the version is bumped via Cocoapods to 4.11.0. - C10 headers are the dependency of OpenCV. - Exceptions in native C++ do not result in promise getting rejected when forwarding yet. - Garbage collection can malfunction for host objects. C++ host objects need to notify JS runtime about their size via external memory pressure for the garbage collector to know to free them. This is not yet done.
1 parent 45a390b commit d38a1df

File tree

254 files changed

+129993
-2030
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

254 files changed

+129993
-2030
lines changed

common/ada/ada.cpp

Lines changed: 17406 additions & 0 deletions
Large diffs are not rendered by default.

common/ada/ada.h

Lines changed: 10274 additions & 0 deletions
Large diffs are not rendered by default.

common/Log.h renamed to common/rnexecutorch/Log.cpp

Lines changed: 9 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
#pragma once
1+
#include "Log.h"
22

33
#include <cstdarg>
4-
#include <string>
4+
#include <cstdio>
55

66
#ifdef __ANDROID__
77
#include <android/log.h>
@@ -12,23 +12,20 @@
1212

1313
namespace rnexecutorch {
1414

15-
enum class LOG_LEVEL { INFO, ERROR, DEBUG };
16-
1715
#ifdef __ANDROID__
1816
android_LogPriority androidLogLevel(LOG_LEVEL logLevel) {
1917
switch (logLevel) {
20-
case LOG_LEVEL::INFO:
18+
case LOG_LEVEL::Info:
2119
default:
2220
return ANDROID_LOG_INFO;
23-
case LOG_LEVEL::ERROR:
21+
case LOG_LEVEL::Error:
2422
return ANDROID_LOG_ERROR;
25-
case LOG_LEVEL::DEBUG:
23+
case LOG_LEVEL::Debug:
2624
return ANDROID_LOG_DEBUG;
2725
}
2826
}
2927
#endif
3028

31-
// const char* instead of const std::string& as va_start doesn't take references
3229
void log(LOG_LEVEL logLevel, const char *fmt, ...) {
3330
va_list args;
3431
va_start(args, fmt);
@@ -52,14 +49,14 @@ void log(LOG_LEVEL logLevel, const char *fmt, ...) {
5249
#ifdef __APPLE__
5350

5451
switch (logLevel) {
55-
case LOG_LEVEL::INFO:
52+
case LOG_LEVEL::Info:
5653
default:
5754
os_log_info(OS_LOG_DEFAULT, "%s", buf);
5855
break;
59-
case LOG_LEVEL::ERROR:
56+
case LOG_LEVEL::Error:
6057
os_log_error(OS_LOG_DEFAULT, "%s", buf);
6158
break;
62-
case LOG_LEVEL::DEBUG:
59+
case LOG_LEVEL::Debug:
6360
os_log_debug(OS_LOG_DEFAULT, "%s", buf);
6461
break;
6562
}
@@ -68,4 +65,4 @@ void log(LOG_LEVEL logLevel, const char *fmt, ...) {
6865
va_end(args);
6966
}
7067

71-
} // namespace rnexecutorch
68+
} // namespace rnexecutorch

common/rnexecutorch/Log.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
#pragma once
2+
3+
namespace rnexecutorch {
4+
5+
enum class LOG_LEVEL { Info, Error, Debug };
6+
7+
// const char* instead of const std::string& as va_start doesn't take references
8+
void log(LOG_LEVEL logLevel, const char *fmt, ...);
9+
10+
} // namespace rnexecutorch
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#include "RnExecutorchInstaller.h"
2+
3+
#include <rnexecutorch/host_objects/JsiConversions.h>
4+
#include <rnexecutorch/host_objects/ModelHostObject.h>
5+
#include <rnexecutorch/jsi/JsiPromise.h>
6+
#include <rnexecutorch/models/StyleTransfer.h>
7+
8+
namespace rnexecutorch {
9+
10+
// This function fetches data from a url address. It is implemented in
11+
// Kotlin/ObjectiveC++ and then bound to this variable. It's done to not handle
12+
// SSL intricacies manually, as it is done automagically in ObjC++/Kotlin.
13+
FetchUrlFunc_t fetchUrlFunc;
14+
15+
jsi::Function RnExecutorchInstaller::loadStyleTransfer(
16+
jsi::Runtime *jsiRuntime,
17+
const std::shared_ptr<react::CallInvoker> &jsCallInvoker) {
18+
return jsi::Function::createFromHostFunction(
19+
*jsiRuntime, jsi::PropNameID::forAscii(*jsiRuntime, "loadStyleTransfer"),
20+
0,
21+
[jsCallInvoker](jsi::Runtime &runtime, const jsi::Value &thisValue,
22+
const jsi::Value *args, size_t count) -> jsi::Value {
23+
assert(count == 1);
24+
try {
25+
auto source = jsiconversion::getValue<std::string>(args[0], runtime);
26+
27+
auto styleTransferPtr =
28+
std::make_shared<StyleTransfer>(source, &runtime);
29+
auto styleTransferHostObject =
30+
std::make_shared<ModelHostObject<StyleTransfer>>(
31+
styleTransferPtr, &runtime, jsCallInvoker);
32+
33+
return jsi::Object::createFromHostObject(runtime,
34+
styleTransferHostObject);
35+
} catch (const std::runtime_error &e) {
36+
// This catch should be merged with the next one
37+
// (std::runtime_error inherits from std::exception) HOWEVER react
38+
// native has broken RTTI which breaks proper exception type
39+
// checking. Remove when the following change is present in our
40+
// version:
41+
// https://github.com/facebook/react-native/commit/3132cc88dd46f95898a756456bebeeb6c248f20e
42+
throw jsi::JSError(runtime, e.what());
43+
} catch (const std::exception &e) {
44+
throw jsi::JSError(runtime, e.what());
45+
} catch (...) {
46+
throw jsi::JSError(runtime, "Unknown error");
47+
}
48+
});
49+
}
50+
51+
void RnExecutorchInstaller::injectJSIBindings(
52+
jsi::Runtime *jsiRuntime,
53+
const std::shared_ptr<react::CallInvoker> &jsCallInvoker,
54+
FetchUrlFunc_t fetchDataFromUrl) {
55+
fetchUrlFunc = fetchDataFromUrl;
56+
57+
jsiRuntime->global().setProperty(
58+
*jsiRuntime, "loadStyleTransfer",
59+
loadStyleTransfer(jsiRuntime, jsCallInvoker));
60+
}
61+
} // namespace rnexecutorch

common/RnExecutorchInstaller.h renamed to common/rnexecutorch/RnExecutorchInstaller.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,23 +4,26 @@
44
#include <string>
55
#include <thread>
66

7+
#include <ReactCommon/CallInvoker.h>
78
#include <jsi/jsi.h>
89

9-
#include "jsi/JsiPromise.h"
10-
1110
namespace rnexecutorch {
1211

12+
using FetchUrlFunc_t = std::function<std::vector<std::byte>(std::string)>;
13+
1314
using namespace facebook;
1415

1516
class RnExecutorchInstaller {
1617
public:
1718
static void
1819
injectJSIBindings(jsi::Runtime *jsiRuntime,
19-
const std::shared_ptr<react::CallInvoker> &jsCallInvoker) {
20-
// Install JSI methods here
21-
}
20+
const std::shared_ptr<react::CallInvoker> &jsCallInvoker,
21+
FetchUrlFunc_t fetchDataFromUrl);
2222

2323
private:
24+
static jsi::Function
25+
loadStyleTransfer(jsi::Runtime *jsiRuntime,
26+
const std::shared_ptr<react::CallInvoker> &jsCallInvoker);
2427
};
2528

2629
} // namespace rnexecutorch
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
#pragma once
2+
3+
#include <chrono>
4+
#include <string>
5+
6+
namespace rnexecutorch::fileutils {
7+
8+
inline std::string getTimeID() {
9+
return std::to_string(std::chrono::duration_cast<std::chrono::milliseconds>(
10+
std::chrono::system_clock::now().time_since_epoch())
11+
.count());
12+
}
13+
14+
} // namespace rnexecutorch::fileutils
Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
#include "ImageProcessing.h"
2+
3+
#include <chrono>
4+
#include <filesystem>
5+
6+
#include <ada/ada.h>
7+
8+
#include <rnexecutorch/Log.h>
9+
#include <rnexecutorch/RnExecutorchInstaller.h>
10+
#include <rnexecutorch/data_processing/FileUtils.h>
11+
#include <rnexecutorch/data_processing/base64.h>
12+
13+
namespace rnexecutorch {
14+
// This is defined in RnExecutorchInstaller.cpp. This function fetches data
15+
// from a url address. It is implemented in Kotlin/ObjectiveC++ and then bound
16+
// to this variable. It's done to not handle SSL intricacies manually, as it is
17+
// done automagically in ObjC++/Kotlin.
18+
extern FetchUrlFunc_t fetchUrlFunc;
19+
namespace imageprocessing {
20+
std::vector<float> colorMatToVector(const cv::Mat &mat) {
21+
return colorMatToVector(mat, cv::Scalar(0.0, 0.0, 0.0),
22+
cv::Scalar(1.0, 1.0, 1.0));
23+
}
24+
25+
std::vector<float> colorMatToVector(const cv::Mat &mat, cv::Scalar mean,
26+
cv::Scalar variance) {
27+
int pixelCount = mat.cols * mat.rows;
28+
std::vector<float> v(pixelCount * 3);
29+
30+
for (int i = 0; i < pixelCount; i++) {
31+
int row = i / mat.cols;
32+
int col = i % mat.cols;
33+
cv::Vec3b pixel = mat.at<cv::Vec3b>(row, col);
34+
v[0 * pixelCount + i] =
35+
(pixel[0] - mean[0] * 255.0) / (variance[0] * 255.0);
36+
v[1 * pixelCount + i] =
37+
(pixel[1] - mean[1] * 255.0) / (variance[1] * 255.0);
38+
v[2 * pixelCount + i] =
39+
(pixel[2] - mean[2] * 255.0) / (variance[2] * 255.0);
40+
}
41+
42+
return v;
43+
}
44+
45+
cv::Mat bufferToColorMat(const std::span<const float> &buffer,
46+
cv::Size matSize) {
47+
cv::Mat mat(matSize, CV_8UC3);
48+
49+
int pixelCount = matSize.width * matSize.height;
50+
for (int i = 0; i < pixelCount; i++) {
51+
int row = i / matSize.width;
52+
int col = i % matSize.width;
53+
54+
float r = buffer[0 * pixelCount + i];
55+
float g = buffer[1 * pixelCount + i];
56+
float b = buffer[2 * pixelCount + i];
57+
58+
cv::Vec3b color(static_cast<uchar>(b * 255), static_cast<uchar>(g * 255),
59+
static_cast<uchar>(r * 255));
60+
mat.at<cv::Vec3b>(row, col) = color;
61+
}
62+
63+
return mat;
64+
}
65+
66+
std::string saveToTempFile(const cv::Mat &image) {
67+
std::string filename = "rn_executorch_" + fileutils::getTimeID() + ".png";
68+
69+
std::filesystem::path tempDir = std::filesystem::temp_directory_path();
70+
std::filesystem::path filePath = tempDir / filename;
71+
72+
if (!cv::imwrite(filePath.string(), image)) {
73+
throw std::runtime_error("Failed to save the image: " + filePath.string());
74+
}
75+
76+
return "file://" + filePath.string();
77+
}
78+
79+
cv::Mat readImage(const std::string &imageURI) {
80+
cv::Mat image;
81+
82+
if (imageURI.starts_with("data")) {
83+
// base64
84+
std::stringstream uriStream(imageURI);
85+
std::string stringData;
86+
std::size_t segmentIndex{0};
87+
while (std::getline(uriStream, stringData, ',')) {
88+
++segmentIndex;
89+
}
90+
if (segmentIndex != 1) {
91+
throw std::runtime_error("Read image error: invalid base64 URI");
92+
}
93+
auto data = base64_decode(stringData);
94+
cv::Mat encodedData(1, data.size(), CV_8UC1, (void *)data.data());
95+
image = cv::imdecode(encodedData, cv::IMREAD_COLOR);
96+
} else if (imageURI.starts_with("file")) {
97+
// local file
98+
auto url = ada::parse(imageURI);
99+
image = cv::imread(std::string{url->get_pathname()}, cv::IMREAD_COLOR);
100+
} else {
101+
// remote file
102+
std::vector<std::byte> imageData = fetchUrlFunc(imageURI);
103+
image = cv::imdecode(
104+
cv::Mat(1, imageData.size(), CV_8UC1, (void *)imageData.data()),
105+
cv::IMREAD_COLOR);
106+
}
107+
108+
if (image.empty()) {
109+
throw std::runtime_error("Read image error: invalid argument");
110+
}
111+
112+
return image;
113+
}
114+
115+
TensorPtr getTensorFromMatrix(const std::vector<int32_t> &sizes,
116+
const cv::Mat &matrix) {
117+
std::vector<float> inputVector = colorMatToVector(matrix);
118+
return executorch::extension::make_tensor_ptr(sizes, inputVector);
119+
}
120+
121+
cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor) {
122+
auto resultData = static_cast<const float *>(tensor.const_data_ptr());
123+
return bufferToColorMat(std::span<const float>(resultData, tensor.numel()),
124+
size);
125+
}
126+
} // namespace imageprocessing
127+
} // namespace rnexecutorch
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
#pragma once
2+
3+
#include <executorch/extension/tensor/tensor.h>
4+
#include <executorch/extension/tensor/tensor_ptr.h>
5+
#include <opencv2/opencv.hpp>
6+
#include <span>
7+
#include <string>
8+
#include <vector>
9+
10+
namespace rnexecutorch::imageprocessing {
11+
using executorch::aten::Tensor;
12+
using executorch::extension::TensorPtr;
13+
14+
/// @brief Convert a OpenCV matrix to channel-first vector representation
15+
std::vector<float> colorMatToVector(const cv::Mat &mat, cv::Scalar mean,
16+
cv::Scalar variance);
17+
/// @brief Convert a OpenCV matrix to channel-first vector representation
18+
std::vector<float> colorMatToVector(const cv::Mat &mat);
19+
/// @brief Convert a channel-first representation of an RGB image to OpenCV
20+
/// matrix
21+
cv::Mat bufferToColorMat(const std::span<const float> &buffer,
22+
cv::Size matSize);
23+
std::string saveToTempFile(const cv::Mat &image);
24+
cv::Mat readImage(const std::string &imageURI);
25+
TensorPtr getTensorFromMatrix(const std::vector<int32_t> &sizes,
26+
const cv::Mat &mat);
27+
cv::Mat getMatrixFromTensor(cv::Size size, const Tensor &tensor);
28+
29+
} // namespace rnexecutorch::imageprocessing

0 commit comments

Comments
 (0)