Skip to content

Commit a39b623

Browse files
committed
Support receiving binary messages
1 parent f7dde13 commit a39b623

1 file changed

Lines changed: 21 additions & 8 deletions

File tree

webgpu/link/base.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -236,12 +236,10 @@ def _dump_data(self, data, buffer=None):
236236
self._objects[id_] = data
237237
return {"__is_crosslink_type__": True, "type": "proxy", "id": id_}
238238

239-
def _load_data(self, data):
239+
def _load_data(self, data, buffers = None):
240240
"""Parse the result of a message from the remote environment"""
241241
from .proxy import Proxy
242242

243-
# print("load data", data, type(data))
244-
245243
if isinstance(data, list):
246244
return [self._load_data(v) for v in data]
247245

@@ -261,6 +259,9 @@ def _load_data(self, data):
261259
if data["type"] == "bytes":
262260
return base64.b64decode(data["value"])
263261

262+
if data["type"] == "buffer":
263+
return buffers[data["index"]]
264+
264265
raise Exception(f"Unknown result type: {data}")
265266

266267
def expose(self, name: str, obj):
@@ -301,8 +302,20 @@ def _get_obj(self, data):
301302
obj = obj[data["key"]]
302303
return obj
303304

304-
async def _on_message_async(self, message: str):
305-
data = json.loads(message)
305+
async def _on_message_async(self, message: str | memoryview):
306+
if isinstance(message, memoryview):
307+
prefix_size = 4 + int.from_bytes(message[:4], byteorder="little")
308+
data = json.loads(bytes(message[4:prefix_size]).decode("utf-8"))
309+
buffer_offsets = data['buffer_offsets']
310+
buffers = []
311+
312+
for i in range(len(buffer_offsets)-1):
313+
offset = buffer_offsets[i]
314+
size = buffer_offsets[i+1] - offset
315+
buffers.append(bytes(message[prefix_size+offset:prefix_size+offset+size]))
316+
else:
317+
data = json.loads(message)
318+
buffers = []
306319
obj = None
307320
try:
308321
msg_type = data.get("type", None)
@@ -313,7 +326,7 @@ async def _on_message_async(self, message: str):
313326
match msg_type:
314327
case "response":
315328
event, key = self._requests[request_id]
316-
self._requests[request_id] = self._load_data(data.get("value", None))
329+
self._requests[request_id] = self._load_data(data.get("value", None), buffers)
317330
if key and data.get("cache", False):
318331
self._cache[key] = self._requests[request_id]
319332

@@ -325,7 +338,7 @@ async def _on_message_async(self, message: str):
325338

326339
case "call":
327340
func = obj = self._get_obj(data)
328-
args = self._load_data(data["args"])
341+
args = self._load_data(data["args"], buffers)
329342
response = func(*args)
330343
try:
331344
response = await response
@@ -347,7 +360,7 @@ async def _on_message_async(self, message: str):
347360
if prop is not None:
348361
obj.__setattr__(prop, data["value"])
349362
elif key is not None:
350-
obj[key] = self._load_data(data["value"])
363+
obj[key] = self._load_data(data["value"], buffers)
351364

352365
case _:
353366
print("unknown message type", msg_type)

0 commit comments

Comments
 (0)