forked from wcandillon/react-native-webgpu
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathAndroidPlatformContext.h
More file actions
206 lines (173 loc) · 6.78 KB
/
Copy pathAndroidPlatformContext.h
File metadata and controls
206 lines (173 loc) · 6.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
#pragma once
#include <android/bitmap.h>
#include <jni.h>
#include <functional>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include "webgpu/webgpu_cpp.h"
#include "PlatformContext.h"
#include "RNWebGPUManager.h"
namespace rnwgpu {
namespace jsi = facebook::jsi;
namespace jni = facebook::jni;
class AndroidPlatformContext : public PlatformContext {
private:
jobject _blobModule;
std::vector<uint8_t> resolveBlob(JNIEnv *env, const std::string &blobId,
double offset, double size) {
if (!_blobModule) {
throw std::runtime_error("BlobModule instance is null");
}
jclass blobModuleClass = env->GetObjectClass(_blobModule);
if (!blobModuleClass) {
throw std::runtime_error("Couldn't find BlobModule class");
}
jmethodID resolveMethod = env->GetMethodID(blobModuleClass, "resolve",
"(Ljava/lang/String;II)[B");
env->DeleteLocalRef(blobModuleClass);
if (!resolveMethod) {
throw std::runtime_error("Couldn't find resolve method in BlobModule");
}
jstring jBlobId = env->NewStringUTF(blobId.c_str());
jbyteArray blobData = (jbyteArray)env->CallObjectMethod(
_blobModule, resolveMethod, jBlobId, static_cast<jint>(offset),
static_cast<jint>(size));
env->DeleteLocalRef(jBlobId);
if (!blobData) {
throw std::runtime_error("Couldn't retrieve blob data");
}
jsize len = env->GetArrayLength(blobData);
std::vector<uint8_t> data(len);
env->GetByteArrayRegion(blobData, 0, len,
reinterpret_cast<jbyte *>(data.data()));
env->DeleteLocalRef(blobData);
return data;
}
public:
explicit AndroidPlatformContext(jobject blobModule)
: _blobModule(blobModule) {}
~AndroidPlatformContext() {
if (_blobModule) {
JNIEnv *env = facebook::jni::Environment::current();
env->DeleteGlobalRef(_blobModule);
_blobModule = nullptr;
}
}
wgpu::Surface makeSurface(wgpu::Instance instance, void *window, int width,
int height) override {
wgpu::SurfaceSourceAndroidNativeWindow androidSurfaceDesc;
androidSurfaceDesc.window = reinterpret_cast<ANativeWindow *>(window);
wgpu::SurfaceDescriptor surfaceDescriptor;
surfaceDescriptor.nextInChain = &androidSurfaceDesc;
return instance.CreateSurface(&surfaceDescriptor);
}
ImageData createImageBitmap(std::string blobId, double offset,
double size) override {
jni::Environment::ensureCurrentThreadIsAttached();
JNIEnv *env = facebook::jni::Environment::current();
if (!env) {
throw std::runtime_error("Couldn't get JNI environment");
}
auto data = resolveBlob(env, blobId, offset, size);
return createImageBitmapFromData(data);
}
void createImageBitmapAsync(
std::string blobId, double offset, double size,
std::function<void(ImageData)> onSuccess,
std::function<void(std::string)> onError) override {
std::thread([this, blobId = std::move(blobId), offset, size,
onSuccess = std::move(onSuccess),
onError = std::move(onError)]() {
jni::Environment::ensureCurrentThreadIsAttached();
try {
JNIEnv *env = facebook::jni::Environment::current();
if (!env) {
throw std::runtime_error("Couldn't get JNI environment");
}
auto data = resolveBlob(env, blobId, offset, size);
auto result = createImageBitmapFromData(data);
onSuccess(std::move(result));
} catch (const std::exception &e) {
onError(e.what());
}
}).detach();
}
ImageData createImageBitmapFromData(std::span<const uint8_t> data) override {
jni::Environment::ensureCurrentThreadIsAttached();
JNIEnv *env = facebook::jni::Environment::current();
if (!env) {
throw std::runtime_error("Couldn't get JNI environment");
}
// Create jbyteArray from the raw bytes
jbyteArray byteArray = env->NewByteArray(static_cast<jsize>(data.size()));
if (!byteArray) {
throw std::runtime_error("Couldn't allocate byte array");
}
env->SetByteArrayRegion(byteArray, 0, static_cast<jsize>(data.size()),
reinterpret_cast<const jbyte *>(data.data()));
// Decode via BitmapFactory
jclass bitmapFactoryClass =
env->FindClass("android/graphics/BitmapFactory");
if (!bitmapFactoryClass) {
env->DeleteLocalRef(byteArray);
throw std::runtime_error("Couldn't find BitmapFactory class");
}
jmethodID decodeByteArrayMethod =
env->GetStaticMethodID(bitmapFactoryClass, "decodeByteArray",
"([BII)Landroid/graphics/Bitmap;");
if (!decodeByteArrayMethod) {
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmapFactoryClass);
throw std::runtime_error("Couldn't find decodeByteArray method");
}
jint length = static_cast<jint>(data.size());
jobject bitmap = env->CallStaticObjectMethod(
bitmapFactoryClass, decodeByteArrayMethod, byteArray, 0, length);
env->DeleteLocalRef(bitmapFactoryClass);
if (!bitmap) {
env->DeleteLocalRef(byteArray);
throw std::runtime_error("Couldn't decode image");
}
AndroidBitmapInfo bitmapInfo;
if (AndroidBitmap_getInfo(env, bitmap, &bitmapInfo) !=
ANDROID_BITMAP_RESULT_SUCCESS) {
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmap);
throw std::runtime_error("Couldn't get bitmap info");
}
void *bitmapPixels;
if (AndroidBitmap_lockPixels(env, bitmap, &bitmapPixels) !=
ANDROID_BITMAP_RESULT_SUCCESS) {
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmap);
throw std::runtime_error("Couldn't lock bitmap pixels");
}
ImageData result;
result.width = static_cast<int>(bitmapInfo.width);
result.height = static_cast<int>(bitmapInfo.height);
result.data.resize(bitmapInfo.height * bitmapInfo.stride);
memcpy(result.data.data(), bitmapPixels, result.data.size());
AndroidBitmap_unlockPixels(env, bitmap);
env->DeleteLocalRef(byteArray);
env->DeleteLocalRef(bitmap);
return result;
}
void createImageBitmapFromDataAsync(
std::span<const uint8_t> data, std::function<void(ImageData)> onSuccess,
std::function<void(std::string)> onError) override {
std::thread([this, ownedData = std::vector<uint8_t>(data.begin(), data.end()),
onSuccess = std::move(onSuccess),
onError = std::move(onError)]() mutable {
jni::Environment::ensureCurrentThreadIsAttached();
try {
auto result = createImageBitmapFromData(ownedData);
onSuccess(std::move(result));
} catch (const std::exception &e) {
onError(e.what());
}
}).detach();
}
};
} // namespace rnwgpu