Skip to content

Commit c7a0b9f

Browse files
committed
- Properly detect ArrayBuffer vs TypedArray
- Add unit tests for ArrayBuffer and TypedArray
1 parent 3cf200e commit c7a0b9f

2 files changed

Lines changed: 102 additions & 8 deletions

File tree

packages/webgpu/cpp/rnwgpu/api/RNWebGPU.h

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55

66
#include "NativeObject.h"
77

8-
#include "ArrayBuffer.h"
98
#include "Canvas.h"
109
#include "GPU.h"
1110
#include "GPUCanvasContext.h"
@@ -88,17 +87,39 @@ class RNWebGPU : public NativeObject<RNWebGPU> {
8887
auto platformContext = _platformContext;
8988
auto callInvoker = _callInvoker;
9089

91-
// Check if the argument is an ArrayBuffer or TypedArray
90+
// Check if the argument is an ArrayBuffer or ArrayBufferView
91+
// (TypedArray / DataView)
9292
if (args[0].isObject()) {
9393
auto obj = args[0].getObject(runtime);
94-
if (obj.isArrayBuffer(runtime) || obj.hasProperty(runtime, "buffer")) {
95-
auto arrayBuffer =
96-
JSIConverter<std::shared_ptr<ArrayBuffer>>::fromJSI(
97-
runtime, args[0], false);
94+
95+
const uint8_t *dataPtr = nullptr;
96+
size_t dataSize = 0;
97+
98+
if (obj.isArrayBuffer(runtime)) {
99+
// Plain ArrayBuffer — use the full buffer
100+
auto &ab = obj.getArrayBuffer(runtime);
101+
dataPtr = ab.data(runtime);
102+
dataSize = ab.size(runtime);
103+
} else if (obj.hasProperty(runtime, "buffer")) {
104+
// TypedArray or DataView — respect byteOffset/byteLength
105+
auto bufferVal = obj.getProperty(runtime, "buffer");
106+
if (bufferVal.isObject() &&
107+
bufferVal.getObject(runtime).isArrayBuffer(runtime)) {
108+
auto &ab =
109+
bufferVal.getObject(runtime).getArrayBuffer(runtime);
110+
auto byteOffset = static_cast<size_t>(
111+
obj.getProperty(runtime, "byteOffset").asNumber());
112+
auto byteLength = static_cast<size_t>(
113+
obj.getProperty(runtime, "byteLength").asNumber());
114+
dataPtr = ab.data(runtime) + byteOffset;
115+
dataSize = byteLength;
116+
}
117+
}
118+
119+
if (dataPtr != nullptr) {
98120
// Copy bytes on the JS thread — the ArrayBuffer pointer is into
99121
// JS-owned memory that can be GC'd
100-
std::vector<uint8_t> dataCopy(arrayBuffer->data(),
101-
arrayBuffer->data() + arrayBuffer->size());
122+
std::vector<uint8_t> dataCopy(dataPtr, dataPtr + dataSize);
102123

103124
return Promise::createPromise(
104125
runtime,

packages/webgpu/src/__tests__/ImageData.spec.ts

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import fs from "fs";
12
import path from "path";
23

34
import { checkImage, client, encodeImage, decodeImage } from "./setup";
@@ -23,4 +24,76 @@ describe("Image Bitmap", () => {
2324
const image = encodeImage(result);
2425
checkImage(image, "snapshots/ref.png");
2526
});
27+
// The following tests exercise the React Native ArrayBuffer/TypedArray
28+
// overload of createImageBitmap, which is not part of the standard web API.
29+
it("createImageBitmap from ArrayBuffer", async () => {
30+
if (client.OS === "web") {
31+
return;
32+
}
33+
const pngBytes = Array.from(
34+
fs.readFileSync(path.join(__dirname, "./assets/Di-3d.png")),
35+
);
36+
const expected = decodeImage(path.join(__dirname, "./assets/Di-3d.png"));
37+
const result = await client.eval(
38+
async ({ pngData }) => {
39+
const bytes = new Uint8Array(pngData);
40+
const bmp = await createImageBitmap(bytes.buffer);
41+
return { width: bmp.width, height: bmp.height };
42+
},
43+
{ pngData: pngBytes },
44+
);
45+
expect(result.width).toBe(expected.width);
46+
expect(result.height).toBe(expected.height);
47+
});
48+
it("createImageBitmap from Uint8Array", async () => {
49+
if (client.OS === "web") {
50+
return;
51+
}
52+
const pngBytes = Array.from(
53+
fs.readFileSync(path.join(__dirname, "./assets/Di-3d.png")),
54+
);
55+
const expected = decodeImage(path.join(__dirname, "./assets/Di-3d.png"));
56+
const result = await client.eval(
57+
async ({ pngData }) => {
58+
const bytes = new Uint8Array(pngData);
59+
const bmp = await createImageBitmap(bytes);
60+
return { width: bmp.width, height: bmp.height };
61+
},
62+
{ pngData: pngBytes },
63+
);
64+
expect(result.width).toBe(expected.width);
65+
expect(result.height).toBe(expected.height);
66+
});
67+
it("createImageBitmap from Uint8Array subarray (byteOffset/byteLength)", async () => {
68+
if (client.OS === "web") {
69+
return;
70+
}
71+
const pngBytes = Array.from(
72+
fs.readFileSync(path.join(__dirname, "./assets/Di-3d.png")),
73+
);
74+
const expected = decodeImage(path.join(__dirname, "./assets/Di-3d.png"));
75+
const result = await client.eval(
76+
async ({ pngData }) => {
77+
// Embed PNG bytes at an offset within a larger buffer
78+
const padding = 128;
79+
const totalLength = padding + pngData.length + padding;
80+
const largeBuffer = new ArrayBuffer(totalLength);
81+
const fullView = new Uint8Array(largeBuffer);
82+
// Fill with garbage bytes
83+
fullView.fill(0xff);
84+
// Copy PNG bytes into the middle
85+
const pngView = new Uint8Array(largeBuffer, padding, pngData.length);
86+
for (let i = 0; i < pngData.length; i++) {
87+
pngView[i] = pngData[i];
88+
}
89+
// createImageBitmap must respect byteOffset/byteLength of the view,
90+
// not use the full underlying ArrayBuffer (which has garbage padding)
91+
const bmp = await createImageBitmap(pngView);
92+
return { width: bmp.width, height: bmp.height };
93+
},
94+
{ pngData: pngBytes },
95+
);
96+
expect(result.width).toBe(expected.width);
97+
expect(result.height).toBe(expected.height);
98+
});
2699
});

0 commit comments

Comments
 (0)