Skip to content

Commit 6ce3bdf

Browse files
committed
Fix incoming binary messages
1 parent b0560ab commit 6ce3bdf

1 file changed

Lines changed: 26 additions & 22 deletions

File tree

webgpu/link/base.py

Lines changed: 26 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ def _pack_message(data: dict) -> tuple[dict, bytes | str]:
3232
return data, bdata
3333

3434

35+
def _unpack_message(msg: str | memoryview | bytes) -> tuple[dict, list]:
36+
if isinstance(msg, (memoryview, bytes)):
37+
prefix_size = 4 + int.from_bytes(msg[:4], byteorder="little")
38+
data = json.loads(bytes(msg[4:prefix_size]).decode("utf-8"))
39+
buffer_offsets = data["buffer_offsets"]
40+
buffers = []
41+
42+
for i in range(len(buffer_offsets) - 1):
43+
offset = buffer_offsets[i]
44+
size = buffer_offsets[i + 1] - offset
45+
buffers.append(bytes(msg[prefix_size + offset : prefix_size + offset + size]))
46+
return data, buffers
47+
return json.loads(msg), []
48+
49+
3550
class LinkBase:
3651
_request_id: itertools.count
3752
_requests: dict
@@ -97,7 +112,7 @@ def __init__(self):
97112
self._objects = {}
98113
self._cache = {}
99114

100-
next(self._request_id) # make sure first id is 1, in case 0 is interpreted as None
115+
next(self._request_id) # make sure first id is 1, in case 0 is interpreted as None
101116

102117
def _call_data(self, id, prop, args, ignore_result=False):
103118
buffer = []
@@ -236,7 +251,7 @@ def _dump_data(self, data, buffer=None):
236251
self._objects[id_] = data
237252
return {"__is_crosslink_type__": True, "type": "proxy", "id": id_}
238253

239-
def _load_data(self, data, buffers = None):
254+
def _load_data(self, data, buffers=None):
240255
"""Parse the result of a message from the remote environment"""
241256
from .proxy import Proxy
242257

@@ -302,20 +317,8 @@ def _get_obj(self, data):
302317
obj = obj[data["key"]]
303318
return obj
304319

305-
async def _on_message_async(self, message: str | memoryview):
306-
if isinstance(message, memoryview):
307-
prefix_size = 4 + int.from_bytes(message[:4], byteorder="little")
308-
data = json.loads(bytes(message[4:prefix_size]).decode("utf-8"))
309-
buffer_offsets = data['buffer_offsets']
310-
buffers = []
311-
312-
for i in range(len(buffer_offsets)-1):
313-
offset = buffer_offsets[i]
314-
size = buffer_offsets[i+1] - offset
315-
buffers.append(bytes(message[prefix_size+offset:prefix_size+offset+size]))
316-
else:
317-
data = json.loads(message)
318-
buffers = []
320+
async def _on_message_async(self, message: str | memoryview | bytes):
321+
data, buffers = _unpack_message(message)
319322
obj = None
320323
try:
321324
msg_type = data.get("type", None)
@@ -329,7 +332,7 @@ async def _on_message_async(self, message: str | memoryview):
329332
self._requests[request_id] = self._load_data(data.get("value", None), buffers)
330333
if key and data.get("cache", False):
331334
self._cache[key] = self._requests[request_id]
332-
335+
333336
if isinstance(event, asyncio.Future):
334337
event.set_result(self._requests[request_id])
335338
else:
@@ -376,7 +379,7 @@ async def _on_message_async(self, message: str | memoryview):
376379
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
377380

378381
def _on_message(self, message: str):
379-
data = json.loads(message)
382+
data, buffers = _unpack_message(message)
380383
try:
381384
msg_type = data.get("type", None)
382385
request_id = data.get("request_id", None)
@@ -386,7 +389,7 @@ def _on_message(self, message: str):
386389
match msg_type:
387390
case "response":
388391
event, key = self._requests[request_id]
389-
response = self._load_data(data.get("value", None))
392+
response = self._load_data(data.get("value", None), buffers)
390393
self._requests[request_id] = response
391394
if data.get("cache", False):
392395
self._cache[key] = response
@@ -395,7 +398,7 @@ def _on_message(self, message: str):
395398

396399
case "call":
397400
func = self._get_obj(data)
398-
args = self._load_data(data["args"])
401+
args = self._load_data(data["args"], buffers)
399402
# print("call", func, args)
400403
response = func(*args)
401404

@@ -412,10 +415,10 @@ def _on_message(self, message: str):
412415
if prop is not None:
413416
obj.__setattr__(prop, data["value"])
414417
elif key is not None:
415-
obj[key] = self._load_data(data["value"])
418+
obj[key] = self._load_data(data["value"], buffers)
416419

417420
case _:
418-
print("unknown message type", msg_type)
421+
print("unknown message type", msg_type, data, type(message))
419422

420423
if request_id is not None:
421424
self._send_response(request_id, response)
@@ -452,6 +455,7 @@ def _send_data(self, metadata, data, key=None):
452455
if type != "response" and request_id is not None:
453456
# from pyodide.ffi import run_sync
454457
import asyncio
458+
455459
event = asyncio.Future()
456460
self._requests[request_id] = event, key
457461
# todo: this shouldn't be necessary

0 commit comments

Comments
 (0)