Skip to content

Commit 6dd750d

Browse files
authored
Merge pull request #8 from CERBSim/sec_fix
Sec fix
2 parents 66333d4 + db9b1ec commit 6dd750d

6 files changed

Lines changed: 90 additions & 91 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ classifiers = [
1818
]
1919
readme = "README.md"
2020
requires-python = ">=3.8"
21-
dependencies = ["numpy", "websockets"]
21+
dependencies = ["numpy", "websockets>=14.0"]
2222

2323
[tool.setuptools_scm]
2424
version_file = "webgpu/_version.py"

webgpu/export/screenshot.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import sys
1717
import os
1818
import base64
19+
import shutil
1920
import subprocess
2021
import tempfile
2122
import threading
@@ -39,14 +40,16 @@ def main():
3940
os.environ['DISPLAY'] = f':{disp_num}'
4041
os.environ.pop('WAYLAND_DISPLAY', None)
4142

43+
tmpdir = Path(tempfile.mkdtemp(prefix="webgpu_ss_"))
4244
try:
43-
_run_worker()
45+
_run_worker(tmpdir)
4446
finally:
47+
shutil.rmtree(tmpdir, ignore_errors=True)
4548
xvfb_proc.terminate()
4649
xvfb_proc.wait()
4750

4851

49-
def _run_worker():
52+
def _run_worker(tmpdir):
5053
from playwright.sync_api import sync_playwright
5154

5255
ARGS = [
@@ -69,8 +72,6 @@ def _run_worker():
6972
engine_js += "\nif (typeof window !== 'undefined') { window.RenderEngine = RenderEngine; }\n"
7073

7174
# Start HTTP server for serving pages to Chrome
72-
tmpdir = Path(tempfile.mkdtemp(prefix="webgpu_ss_"))
73-
7475
class Quiet(SimpleHTTPRequestHandler):
7576
def __init__(self, *a, **kw):
7677
super().__init__(*a, directory=str(tmpdir), **kw)

webgpu/link/base.py

Lines changed: 0 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -388,71 +388,6 @@ def _get_obj(self, data):
388388
obj = obj[data["key"]]
389389
return obj
390390

391-
async def _on_message_async(self, message: str | memoryview | bytes):
392-
data, buffers = _unpack_message(message)
393-
obj = None
394-
try:
395-
msg_type = data.get("type", None)
396-
request_id = data.get("request_id", None)
397-
398-
response = None
399-
400-
match msg_type:
401-
case "response":
402-
event, key = self._requests[request_id]
403-
self._requests[request_id] = self._load_data(data.get("value", None), buffers)
404-
if key and data.get("cache", False):
405-
self._cache_add(key, data.get("value", None))
406-
407-
if isinstance(event, asyncio.Future):
408-
event.set_result(self._requests[request_id])
409-
else:
410-
event.set()
411-
return
412-
413-
case "call":
414-
func = obj = self._get_obj(data)
415-
args = self._load_data(data["args"], buffers)
416-
response = func(*args)
417-
try:
418-
response = await response
419-
except TypeError:
420-
pass
421-
except Exception as e:
422-
print("error in call", type(e), str(e))
423-
424-
case "get":
425-
response = obj = self._get_obj(data)
426-
427-
case "get_keys":
428-
response = []
429-
430-
case "set":
431-
prop = data.pop("prop", None)
432-
key = data.pop("key", None)
433-
obj = self._get_obj(data)
434-
if prop is not None:
435-
obj.__setattr__(prop, data["value"])
436-
elif key is not None:
437-
obj[key] = self._load_data(data["value"], buffers)
438-
439-
case "release_batch":
440-
for id_ in data["ids"]:
441-
self._objects.pop(id_, None)
442-
443-
case _:
444-
print("unknown message type", msg_type)
445-
446-
if request_id is not None:
447-
self._send_response(request_id, response)
448-
except Exception as e:
449-
import sys
450-
import traceback
451-
452-
print("error in on_message", data, obj, type(e), str(e), file=sys.stderr)
453-
if not isinstance(e, str):
454-
traceback.print_exception(*sys.exc_info(), file=sys.stderr)
455-
456391
def _on_message(self, message: str):
457392
data, buffers = _unpack_message(message)
458393
try:
@@ -682,7 +617,6 @@ async def handle_callbacks():
682617
print("error in callback", type(e), str(e))
683618

684619
try:
685-
self._callback_loop = asyncio.new_event_loop()
686620
asyncio.set_event_loop(self._callback_loop)
687621
self._callback_task = self._callback_loop.create_task(handle_callbacks())
688622
try:

webgpu/link/link.js

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
/* eslint-disable */
22

3+
const MAX_MESSAGE_SIZE = 100 * 1024 * 1024;
4+
35
function serializeEvent(event) {
46
try {
57
event.preventDefault();
@@ -403,11 +405,40 @@ class CrossLink {
403405
const prefixLen = 4 + jsonMsg.byteLength;
404406
const size = 4 + jsonMsg.byteLength + offset;
405407
var msg = new Uint8Array(size);
406-
msg.set(new Uint32Array([jsonMsg.byteLength]), 0);
408+
new DataView(msg.buffer).setUint32(0, jsonMsg.byteLength, true);
407409
msg.set(jsonMsg, 4);
408410

409411
for (var bufferIndex = 0; bufferIndex < buffers.length; bufferIndex++)
410412
msg.set(buffers[bufferIndex], prefixLen + buffer_offsets[bufferIndex]);
413+
this._sendFrame(msg.buffer, request_id);
414+
}
415+
}
416+
417+
_sendFrame(frame, parent_request_id) {
418+
const total = frame.byteLength;
419+
if (total <= MAX_MESSAGE_SIZE) {
420+
this.connection.send(frame);
421+
return;
422+
}
423+
const n_chunks = Math.ceil(total / MAX_MESSAGE_SIZE);
424+
for (let i = 0; i < n_chunks; i++) {
425+
const offset = i * MAX_MESSAGE_SIZE;
426+
const chunk = new Uint8Array(frame, offset, Math.min(MAX_MESSAGE_SIZE, total - offset));
427+
const meta = {
428+
type: 'chunk',
429+
parent_request_id,
430+
chunk_id: i,
431+
n_chunks,
432+
offset,
433+
size: chunk.byteLength,
434+
total_size: total,
435+
buffer_offsets: [0, chunk.byteLength],
436+
};
437+
const json = new TextEncoder().encode(JSON.stringify(meta));
438+
const msg = new Uint8Array(4 + json.byteLength + chunk.byteLength);
439+
new DataView(msg.buffer).setUint32(0, json.byteLength, true);
440+
msg.set(json, 4);
441+
msg.set(chunk, 4 + json.byteLength);
411442
this.connection.send(msg.buffer);
412443
}
413444
}

webgpu/link/websocket.py

Lines changed: 49 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from websockets.http11 import Response
2020
from websockets.datastructures import Headers
2121

22-
from .base import LinkBaseAsync
22+
from .base import LinkBaseAsync, _unpack_message
2323

2424

2525
class WebsocketLinkBase(LinkBaseAsync):
@@ -35,7 +35,6 @@ def __init__(self):
3535
self._event_is_connected = threading.Event()
3636
self._event_is_running = threading.Event()
3737
self._start_handling_messages = threading.Event()
38-
self._send_loop = asyncio.new_event_loop()
3938

4039
self._websocket_thread = threading.Thread(target=self._connect, daemon=True)
4140
self._websocket_thread.start()
@@ -65,6 +64,7 @@ def __init__(self):
6564
self._port = 8700
6665
self._auth_token = secrets.token_urlsafe(32)
6766
self._executor = ThreadPoolExecutor(max_workers=8)
67+
self._chunk_buffers = {}
6868
self._stop = None
6969
super().__init__()
7070

@@ -80,22 +80,50 @@ def _check_auth(self, connection, request):
8080
"""Reject WebSocket connections that don't carry a valid token."""
8181
params = parse_qs(urlparse(request.path).query)
8282
tokens = params.get("token", [])
83-
if not tokens or tokens[0] != self._auth_token:
83+
if not tokens or not secrets.compare_digest(tokens[0], self._auth_token):
8484
return Response(403, "Forbidden", Headers())
8585
return None
8686

8787
@staticmethod
88-
def _is_response(message):
89-
"""Quick check if a message is a response (cheap, avoids full deserialization)."""
90-
if isinstance(message, (memoryview, bytes)):
91-
# Binary message: JSON metadata starts at byte 4
92-
try:
88+
def _message_type(message):
89+
"""Return the top-level message type, parsing only the JSON header
90+
(not buffer payloads). Returns None on malformed input."""
91+
try:
92+
if isinstance(message, (memoryview, bytes)):
9393
prefix_size = 4 + int.from_bytes(message[:4], byteorder="little")
94-
header = message[4:prefix_size]
95-
return b'"type":"response"' in bytes(header) or b'"type": "response"' in bytes(header)
96-
except Exception:
97-
return False
98-
return '"type":"response"' in message or '"type": "response"' in message
94+
header = json.loads(bytes(message[4:prefix_size]).decode("utf-8"))
95+
else:
96+
header = json.loads(message)
97+
return header.get("type") if isinstance(header, dict) else None
98+
except Exception:
99+
return None
100+
101+
def _is_response(self, message):
102+
return self._message_type(message) == "response"
103+
104+
def _is_chunk(self, message):
105+
return isinstance(message, (memoryview, bytes)) and self._message_type(message) == "chunk"
106+
107+
def _reassemble_chunk(self, message):
108+
data, buffers = _unpack_message(message)
109+
pid = data["parent_request_id"]
110+
buf = self._chunk_buffers.get(pid)
111+
if buf is None:
112+
buf = bytearray(data["total_size"])
113+
self._chunk_buffers[pid] = buf
114+
chunk = buffers[0]
115+
offset = data["offset"]
116+
buf[offset : offset + len(chunk)] = chunk
117+
if data["chunk_id"] + 1 == data["n_chunks"]:
118+
del self._chunk_buffers[pid]
119+
return bytes(buf)
120+
return None
121+
122+
def _dispatch(self, message):
123+
if self._is_response(message):
124+
self._on_message(message)
125+
else:
126+
self._executor.submit(self._on_message, message)
99127

100128
async def _websocket_handler(self, websocket, path=""):
101129
if self._connection is not None:
@@ -107,13 +135,17 @@ async def _websocket_handler(self, websocket, path=""):
107135
async for message in websocket:
108136
# Handle responses inline to avoid deadlock: if all executor
109137
# threads are blocked waiting for JS responses, queued response
110-
# messages would never be processed.
111-
if self._is_response(message):
112-
self._on_message(message)
138+
# messages would never be processed. Chunks are reassembled
139+
# inline (single-threaded, ordered) then dispatched.
140+
if self._is_chunk(message):
141+
full = self._reassemble_chunk(message)
142+
if full is not None:
143+
self._dispatch(full)
113144
else:
114-
self._executor.submit(self._on_message, message)
145+
self._dispatch(message)
115146
finally:
116147
self._connection = None
148+
self._chunk_buffers.clear()
117149

118150
def _connect(self):
119151
async def start_websocket():

webgpu/scene.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -584,8 +584,9 @@ def _on_camera_changed(self):
584584

585585
def _on_resize(self):
586586
"""Called on canvas resize. Update camera uniforms (aspect ratio) and re-render."""
587-
self._select_buffer_valid = False
588-
self.options.update_buffers()
587+
with self._render_mutex:
588+
self._select_buffer_valid = False
589+
self.options.update_buffers()
589590
if self._js_engine is not None:
590591
try:
591592
self._js_engine.handleResize()

0 commit comments

Comments
 (0)