diff --git a/pyproject.toml b/pyproject.toml index de6988e..0b146d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,7 @@ classifiers = [ ] readme = "README.md" requires-python = ">=3.8" -dependencies = ["numpy", "websockets"] +dependencies = ["numpy", "websockets>=14.0"] [tool.setuptools_scm] version_file = "webgpu/_version.py" diff --git a/webgpu/export/screenshot.py b/webgpu/export/screenshot.py index fdf4b70..cad8cc8 100644 --- a/webgpu/export/screenshot.py +++ b/webgpu/export/screenshot.py @@ -16,6 +16,7 @@ import sys import os import base64 +import shutil import subprocess import tempfile import threading @@ -39,14 +40,16 @@ def main(): os.environ['DISPLAY'] = f':{disp_num}' os.environ.pop('WAYLAND_DISPLAY', None) + tmpdir = Path(tempfile.mkdtemp(prefix="webgpu_ss_")) try: - _run_worker() + _run_worker(tmpdir) finally: + shutil.rmtree(tmpdir, ignore_errors=True) xvfb_proc.terminate() xvfb_proc.wait() -def _run_worker(): +def _run_worker(tmpdir): from playwright.sync_api import sync_playwright ARGS = [ @@ -69,8 +72,6 @@ def _run_worker(): engine_js += "\nif (typeof window !== 'undefined') { window.RenderEngine = RenderEngine; }\n" # Start HTTP server for serving pages to Chrome - tmpdir = Path(tempfile.mkdtemp(prefix="webgpu_ss_")) - class Quiet(SimpleHTTPRequestHandler): def __init__(self, *a, **kw): super().__init__(*a, directory=str(tmpdir), **kw) diff --git a/webgpu/link/base.py b/webgpu/link/base.py index e163109..d50f81f 100644 --- a/webgpu/link/base.py +++ b/webgpu/link/base.py @@ -388,71 +388,6 @@ def _get_obj(self, data): obj = obj[data["key"]] return obj - async def _on_message_async(self, message: str | memoryview | bytes): - data, buffers = _unpack_message(message) - obj = None - try: - msg_type = data.get("type", None) - request_id = data.get("request_id", None) - - response = None - - match msg_type: - case "response": - event, key = self._requests[request_id] - self._requests[request_id] = self._load_data(data.get("value", None), buffers) - if key and data.get("cache", False): - self._cache_add(key, data.get("value", None)) - - if isinstance(event, asyncio.Future): - event.set_result(self._requests[request_id]) - else: - event.set() - return - - case "call": - func = obj = self._get_obj(data) - args = self._load_data(data["args"], buffers) - response = func(*args) - try: - response = await response - except TypeError: - pass - except Exception as e: - print("error in call", type(e), str(e)) - - case "get": - response = obj = self._get_obj(data) - - case "get_keys": - response = [] - - case "set": - prop = data.pop("prop", None) - key = data.pop("key", None) - obj = self._get_obj(data) - if prop is not None: - obj.__setattr__(prop, data["value"]) - elif key is not None: - obj[key] = self._load_data(data["value"], buffers) - - case "release_batch": - for id_ in data["ids"]: - self._objects.pop(id_, None) - - case _: - print("unknown message type", msg_type) - - if request_id is not None: - self._send_response(request_id, response) - except Exception as e: - import sys - import traceback - - print("error in on_message", data, obj, type(e), str(e), file=sys.stderr) - if not isinstance(e, str): - traceback.print_exception(*sys.exc_info(), file=sys.stderr) - def _on_message(self, message: str): data, buffers = _unpack_message(message) try: @@ -682,7 +617,6 @@ async def handle_callbacks(): print("error in callback", type(e), str(e)) try: - self._callback_loop = asyncio.new_event_loop() asyncio.set_event_loop(self._callback_loop) self._callback_task = self._callback_loop.create_task(handle_callbacks()) try: diff --git a/webgpu/link/link.js b/webgpu/link/link.js index 3b6bb32..488de93 100644 --- a/webgpu/link/link.js +++ b/webgpu/link/link.js @@ -1,5 +1,7 @@ /* eslint-disable */ +const MAX_MESSAGE_SIZE = 100 * 1024 * 1024; + function serializeEvent(event) { try { event.preventDefault(); @@ -403,11 +405,40 @@ class CrossLink { const prefixLen = 4 + jsonMsg.byteLength; const size = 4 + jsonMsg.byteLength + offset; var msg = new Uint8Array(size); - msg.set(new Uint32Array([jsonMsg.byteLength]), 0); + new DataView(msg.buffer).setUint32(0, jsonMsg.byteLength, true); msg.set(jsonMsg, 4); for (var bufferIndex = 0; bufferIndex < buffers.length; bufferIndex++) msg.set(buffers[bufferIndex], prefixLen + buffer_offsets[bufferIndex]); + this._sendFrame(msg.buffer, request_id); + } + } + + _sendFrame(frame, parent_request_id) { + const total = frame.byteLength; + if (total <= MAX_MESSAGE_SIZE) { + this.connection.send(frame); + return; + } + const n_chunks = Math.ceil(total / MAX_MESSAGE_SIZE); + for (let i = 0; i < n_chunks; i++) { + const offset = i * MAX_MESSAGE_SIZE; + const chunk = new Uint8Array(frame, offset, Math.min(MAX_MESSAGE_SIZE, total - offset)); + const meta = { + type: 'chunk', + parent_request_id, + chunk_id: i, + n_chunks, + offset, + size: chunk.byteLength, + total_size: total, + buffer_offsets: [0, chunk.byteLength], + }; + const json = new TextEncoder().encode(JSON.stringify(meta)); + const msg = new Uint8Array(4 + json.byteLength + chunk.byteLength); + new DataView(msg.buffer).setUint32(0, json.byteLength, true); + msg.set(json, 4); + msg.set(chunk, 4 + json.byteLength); this.connection.send(msg.buffer); } } diff --git a/webgpu/link/websocket.py b/webgpu/link/websocket.py index b875a72..f967d05 100644 --- a/webgpu/link/websocket.py +++ b/webgpu/link/websocket.py @@ -19,7 +19,7 @@ from websockets.http11 import Response from websockets.datastructures import Headers -from .base import LinkBaseAsync +from .base import LinkBaseAsync, _unpack_message class WebsocketLinkBase(LinkBaseAsync): @@ -35,7 +35,6 @@ def __init__(self): self._event_is_connected = threading.Event() self._event_is_running = threading.Event() self._start_handling_messages = threading.Event() - self._send_loop = asyncio.new_event_loop() self._websocket_thread = threading.Thread(target=self._connect, daemon=True) self._websocket_thread.start() @@ -65,6 +64,7 @@ def __init__(self): self._port = 8700 self._auth_token = secrets.token_urlsafe(32) self._executor = ThreadPoolExecutor(max_workers=8) + self._chunk_buffers = {} self._stop = None super().__init__() @@ -80,22 +80,50 @@ def _check_auth(self, connection, request): """Reject WebSocket connections that don't carry a valid token.""" params = parse_qs(urlparse(request.path).query) tokens = params.get("token", []) - if not tokens or tokens[0] != self._auth_token: + if not tokens or not secrets.compare_digest(tokens[0], self._auth_token): return Response(403, "Forbidden", Headers()) return None @staticmethod - def _is_response(message): - """Quick check if a message is a response (cheap, avoids full deserialization).""" - if isinstance(message, (memoryview, bytes)): - # Binary message: JSON metadata starts at byte 4 - try: + def _message_type(message): + """Return the top-level message type, parsing only the JSON header + (not buffer payloads). Returns None on malformed input.""" + try: + if isinstance(message, (memoryview, bytes)): prefix_size = 4 + int.from_bytes(message[:4], byteorder="little") - header = message[4:prefix_size] - return b'"type":"response"' in bytes(header) or b'"type": "response"' in bytes(header) - except Exception: - return False - return '"type":"response"' in message or '"type": "response"' in message + header = json.loads(bytes(message[4:prefix_size]).decode("utf-8")) + else: + header = json.loads(message) + return header.get("type") if isinstance(header, dict) else None + except Exception: + return None + + def _is_response(self, message): + return self._message_type(message) == "response" + + def _is_chunk(self, message): + return isinstance(message, (memoryview, bytes)) and self._message_type(message) == "chunk" + + def _reassemble_chunk(self, message): + data, buffers = _unpack_message(message) + pid = data["parent_request_id"] + buf = self._chunk_buffers.get(pid) + if buf is None: + buf = bytearray(data["total_size"]) + self._chunk_buffers[pid] = buf + chunk = buffers[0] + offset = data["offset"] + buf[offset : offset + len(chunk)] = chunk + if data["chunk_id"] + 1 == data["n_chunks"]: + del self._chunk_buffers[pid] + return bytes(buf) + return None + + def _dispatch(self, message): + if self._is_response(message): + self._on_message(message) + else: + self._executor.submit(self._on_message, message) async def _websocket_handler(self, websocket, path=""): if self._connection is not None: @@ -107,13 +135,17 @@ async def _websocket_handler(self, websocket, path=""): async for message in websocket: # Handle responses inline to avoid deadlock: if all executor # threads are blocked waiting for JS responses, queued response - # messages would never be processed. - if self._is_response(message): - self._on_message(message) + # messages would never be processed. Chunks are reassembled + # inline (single-threaded, ordered) then dispatched. + if self._is_chunk(message): + full = self._reassemble_chunk(message) + if full is not None: + self._dispatch(full) else: - self._executor.submit(self._on_message, message) + self._dispatch(message) finally: self._connection = None + self._chunk_buffers.clear() def _connect(self): async def start_websocket(): diff --git a/webgpu/scene.py b/webgpu/scene.py index 2f08ab4..1f73a8c 100644 --- a/webgpu/scene.py +++ b/webgpu/scene.py @@ -584,8 +584,9 @@ def _on_camera_changed(self): def _on_resize(self): """Called on canvas resize. Update camera uniforms (aspect ratio) and re-render.""" - self._select_buffer_valid = False - self.options.update_buffers() + with self._render_mutex: + self._select_buffer_valid = False + self.options.update_buffers() if self._js_engine is not None: try: self._js_engine.handleResize()