@@ -32,6 +32,21 @@ def _pack_message(data: dict) -> tuple[dict, bytes | str]:
3232 return data , bdata
3333
3434
35+ def _unpack_message (msg : str | memoryview | bytes ) -> tuple [dict , list ]:
36+ if isinstance (msg , (memoryview , bytes )):
37+ prefix_size = 4 + int .from_bytes (msg [:4 ], byteorder = "little" )
38+ data = json .loads (bytes (msg [4 :prefix_size ]).decode ("utf-8" ))
39+ buffer_offsets = data ["buffer_offsets" ]
40+ buffers = []
41+
42+ for i in range (len (buffer_offsets ) - 1 ):
43+ offset = buffer_offsets [i ]
44+ size = buffer_offsets [i + 1 ] - offset
45+ buffers .append (bytes (msg [prefix_size + offset : prefix_size + offset + size ]))
46+ return data , buffers
47+ return json .loads (msg ), []
48+
49+
3550class LinkBase :
3651 _request_id : itertools .count
3752 _requests : dict
@@ -97,7 +112,7 @@ def __init__(self):
97112 self ._objects = {}
98113 self ._cache = {}
99114
100- next (self ._request_id ) # make sure first id is 1, in case 0 is interpreted as None
115+ next (self ._request_id ) # make sure first id is 1, in case 0 is interpreted as None
101116
102117 def _call_data (self , id , prop , args , ignore_result = False ):
103118 buffer = []
@@ -236,7 +251,7 @@ def _dump_data(self, data, buffer=None):
236251 self ._objects [id_ ] = data
237252 return {"__is_crosslink_type__" : True , "type" : "proxy" , "id" : id_ }
238253
239- def _load_data (self , data , buffers = None ):
254+ def _load_data (self , data , buffers = None ):
240255 """Parse the result of a message from the remote environment"""
241256 from .proxy import Proxy
242257
@@ -302,20 +317,8 @@ def _get_obj(self, data):
302317 obj = obj [data ["key" ]]
303318 return obj
304319
305- async def _on_message_async (self , message : str | memoryview ):
306- if isinstance (message , memoryview ):
307- prefix_size = 4 + int .from_bytes (message [:4 ], byteorder = "little" )
308- data = json .loads (bytes (message [4 :prefix_size ]).decode ("utf-8" ))
309- buffer_offsets = data ['buffer_offsets' ]
310- buffers = []
311-
312- for i in range (len (buffer_offsets )- 1 ):
313- offset = buffer_offsets [i ]
314- size = buffer_offsets [i + 1 ] - offset
315- buffers .append (bytes (message [prefix_size + offset :prefix_size + offset + size ]))
316- else :
317- data = json .loads (message )
318- buffers = []
320+ async def _on_message_async (self , message : str | memoryview | bytes ):
321+ data , buffers = _unpack_message (message )
319322 obj = None
320323 try :
321324 msg_type = data .get ("type" , None )
@@ -329,7 +332,7 @@ async def _on_message_async(self, message: str | memoryview):
329332 self ._requests [request_id ] = self ._load_data (data .get ("value" , None ), buffers )
330333 if key and data .get ("cache" , False ):
331334 self ._cache [key ] = self ._requests [request_id ]
332-
335+
333336 if isinstance (event , asyncio .Future ):
334337 event .set_result (self ._requests [request_id ])
335338 else :
@@ -376,7 +379,7 @@ async def _on_message_async(self, message: str | memoryview):
376379 traceback .print_exception (* sys .exc_info (), file = sys .stderr )
377380
378381 def _on_message (self , message : str ):
379- data = json . loads (message )
382+ data , buffers = _unpack_message (message )
380383 try :
381384 msg_type = data .get ("type" , None )
382385 request_id = data .get ("request_id" , None )
@@ -386,7 +389,7 @@ def _on_message(self, message: str):
386389 match msg_type :
387390 case "response" :
388391 event , key = self ._requests [request_id ]
389- response = self ._load_data (data .get ("value" , None ))
392+ response = self ._load_data (data .get ("value" , None ), buffers )
390393 self ._requests [request_id ] = response
391394 if data .get ("cache" , False ):
392395 self ._cache [key ] = response
@@ -395,7 +398,7 @@ def _on_message(self, message: str):
395398
396399 case "call" :
397400 func = self ._get_obj (data )
398- args = self ._load_data (data ["args" ])
401+ args = self ._load_data (data ["args" ], buffers )
399402 # print("call", func, args)
400403 response = func (* args )
401404
@@ -412,10 +415,10 @@ def _on_message(self, message: str):
412415 if prop is not None :
413416 obj .__setattr__ (prop , data ["value" ])
414417 elif key is not None :
415- obj [key ] = self ._load_data (data ["value" ])
418+ obj [key ] = self ._load_data (data ["value" ], buffers )
416419
417420 case _:
418- print ("unknown message type" , msg_type )
421+ print ("unknown message type" , msg_type , data , type ( message ) )
419422
420423 if request_id is not None :
421424 self ._send_response (request_id , response )
@@ -452,6 +455,7 @@ def _send_data(self, metadata, data, key=None):
452455 if type != "response" and request_id is not None :
453456 # from pyodide.ffi import run_sync
454457 import asyncio
458+
455459 event = asyncio .Future ()
456460 self ._requests [request_id ] = event , key
457461 # todo: this shouldn't be necessary
0 commit comments