Skip to content

Commit d30f515

Browse files
committed
Add broadcast
1 parent c37e812 commit d30f515

2 files changed

Lines changed: 159 additions & 1 deletion

File tree

violetear/app.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,13 +107,27 @@ def get_service_worker(scope_hash: str):
107107
headers=headers,
108108
)
109109

110+
# In App.__init__
111+
self.socket_manager = SocketManager()
112+
113+
@self.api.websocket("/_violetear/ws")
114+
async def websocket_endpoint(websocket: WebSocket):
115+
await self.socket_manager.connect(websocket)
116+
try:
117+
while True:
118+
# Keep the connection alive.
119+
# We can also listen for client-to-server messages here if needed later.
120+
await websocket.receive()
121+
except WebSocketDisconnect:
122+
self.socket_manager.disconnect(websocket)
123+
110124
def client(self, func: Callable):
111125
"""Decorator to mark a function to be compiled to the client."""
112126
if not inspect.iscoroutinefunction(func):
113127
raise ValueError("func must be async")
114128

115129
self.client_functions[func.__name__] = func
116-
return func
130+
return ClientFunctionWrapper(self, func)
117131

118132
def startup(self, func: Callable):
119133
"""
@@ -304,6 +318,25 @@ def _generate_bundle(self) -> str:
304318
code = [c for c in code if not c.startswith("@")] # remove decorators
305319
user_code.append("\n".join(code))
306320

321+
# --- SAFETY INJECTION START ---
322+
# We attach a dummy .broadcast() method to client functions running in the browser.
323+
# This prevents confusion if a user tries to call await my_func.broadcast() in client code.
324+
safety_checks = []
325+
safety_checks.append(
326+
dedent(
327+
"""
328+
def _server_only_broadcast(*args, **kwargs):
329+
raise RuntimeError("❌ .broadcast() cannot be called from the Client (Browser).\\nIt must be called from the Server to trigger client updates.")
330+
"""
331+
)
332+
)
333+
334+
for name in self.client_functions.keys():
335+
safety_checks.append(f"{name}.broadcast = _server_only_broadcast")
336+
337+
safety_code = "\n".join(safety_checks)
338+
# --- SAFETY INJECTION END ---
339+
307340
# 5. Generate Server Stubs
308341
server_stubs = self._generate_server_stubs()
309342

@@ -321,6 +354,7 @@ def _generate_bundle(self) -> str:
321354
storage_injection,
322355
runtime_code,
323356
"\n\n".join(user_code),
357+
safety_code,
324358
server_stubs,
325359
init_code,
326360
]
@@ -537,3 +571,85 @@ def mount_static(self, directory: str, path: str = "/static"):
537571
def run(self, host="0.0.0.0", port=8000, **kwargs):
538572
"""Helper to run via uvicorn programmatically."""
539573
uvicorn.run(self.api, host=host, port=port, **kwargs)
574+
575+
576+
class ClientFunctionWrapper:
577+
"""
578+
Wraps a client-side function to provide Server-Side RPC capabilities.
579+
580+
When you define:
581+
@app.client
582+
async def my_func(...): ...
583+
584+
This wrapper ensures that:
585+
1. Calling `await my_func(...)` on the server runs the function (or warns).
586+
2. Calling `await my_func.broadcast(...)` triggers the WebSocket dispatcher.
587+
"""
588+
589+
def __init__(self, app: "App", func: Callable):
590+
self.app = app
591+
self.func = func
592+
# Mimic the original function's identity
593+
self.__name__ = func.__name__
594+
self.__doc__ = func.__doc__
595+
596+
def __call__(self, *args, **kwargs):
597+
"""
598+
Standard call: await my_func(...)
599+
Executes the function logic locally (useful for testing or shared logic).
600+
"""
601+
raise RuntimeError(
602+
f"Cannot call client-side functions in the server! Did you meant {self.func.__name__}.broadcast(...)?"
603+
)
604+
605+
async def broadcast(self, *args, **kwargs):
606+
"""
607+
RPC Call: await my_func.broadcast(...)
608+
609+
Tells the server to instructing ALL connected clients to run this function.
610+
"""
611+
# FUTURE: This will hook into the WebSocket manager
612+
if hasattr(self.app, "socket_manager"):
613+
# print(f"Broadcasting {self.__name__} to all clients...")
614+
await self.app.socket_manager.broadcast(
615+
func_name=self.__name__, args=args, kwargs=kwargs
616+
)
617+
else:
618+
print(
619+
f"[Violetear] Warning: Broadcast called on '{self.__name__}' but no SocketManager is active."
620+
)
621+
622+
623+
# Add this class to violetear/app.py
624+
625+
from fastapi import WebSocket, WebSocketDisconnect
626+
627+
628+
class SocketManager:
629+
def __init__(self):
630+
# Keep track of active connections
631+
self.active_connections: List[WebSocket] = []
632+
633+
async def connect(self, websocket: WebSocket):
634+
await websocket.accept()
635+
self.active_connections.append(websocket)
636+
637+
def disconnect(self, websocket: WebSocket):
638+
self.active_connections.remove(websocket)
639+
640+
async def broadcast(self, func_name: str, args: tuple, kwargs: dict):
641+
"""
642+
Sends a command to all connected clients to run a specific function.
643+
"""
644+
payload = json.dumps(
645+
{"type": "rpc", "func": func_name, "args": args, "kwargs": kwargs}
646+
)
647+
648+
# Iterate over all connections and send the message
649+
# We use a copy of the list to avoid modification errors during iteration
650+
for connection in self.active_connections[:]:
651+
try:
652+
await connection.send_text(payload)
653+
except Exception:
654+
# If sending fails (e.g. client disconnected), remove it
655+
self.disconnect(connection)

violetear/client.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
"""
55

66
import sys
7+
import json
8+
import asyncio
79

810
# We define IS_BROWSER to avoid import errors if this is imported on the server
911
IS_BROWSER = "pyodide" in sys.modules or "emscripten" in sys.platform
@@ -72,3 +74,43 @@ def hydrate(namespace: dict):
7274
)
7375

7476
print(f"[Violetear] Hydrated {bound_count} interactive elements.")
77+
setup_socket_listener(namespace)
78+
79+
80+
def setup_socket_listener(namespace):
81+
"""
82+
Connects to the server and listens for RPC commands.
83+
"""
84+
from js import WebSocket, window
85+
86+
# Calculate the WebSocket URL (ws:// or wss://)
87+
protocol = "wss" if window.location.protocol == "https:" else "ws"
88+
ws_url = f"{protocol}://{window.location.host}/_violetear/ws"
89+
90+
socket = WebSocket.new(ws_url)
91+
92+
def on_message(event):
93+
data = json.loads(event.data)
94+
95+
if data.get("type") == "rpc":
96+
func_name = data["func"]
97+
args = data["args"]
98+
kwargs = data["kwargs"]
99+
100+
# 1. Look up the function in the global scope
101+
if func_name in namespace:
102+
func = namespace[func_name]
103+
104+
# 2. Schedule the async function to run on the event loop
105+
# We use asyncio.create_task because we are inside a sync callback
106+
asyncio.create_task(func(*args, **kwargs))
107+
else:
108+
print(f"[Violetear] Received RPC for unknown function: {func_name}")
109+
110+
# Attach the callback (converting Python function to JS proxy not strictly needed for simple events in recent Pyodide, but good practice)
111+
socket.onmessage = on_message
112+
113+
# Keep a reference so it doesn't get garbage collected
114+
window.violetear_socket = socket
115+
116+
print("[Violetear] Attached socket connection")

0 commit comments

Comments
 (0)