Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 86 additions & 48 deletions src/pycrdt/websocket/yroom.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class YRoom:
ystore: BaseYStore | None
ready_event: Event
_on_message: Callable[[bytes], Awaitable[bool] | bool] | None
_on_message_error: Callable[[Exception, bytes, Channel], Awaitable[bool] | bool] | None
_update_send_stream: MemoryObjectSendStream
_update_receive_stream: MemoryObjectReceiveStream
_task_group: TaskGroup | None = None
Expand Down Expand Up @@ -90,7 +91,10 @@ def __init__(
provider_factory: An optional provider factory used to synchronize the room with
an external document.
exception_handler: An optional callback to call when an exception is raised, that
returns True if the exception was handled.
returns True if the exception was handled. Handling exceptions does not prevent
the room from stopping, but it prevents the exception from propagating.
Use on_message_error to handle exceptions raised while processing messages
without stopping the room.
log: An optional logger.
ydoc: An optional document for the room (a new one is created otherwise).
"""
Expand All @@ -104,6 +108,7 @@ def __init__(
self.awareness.observe(self.send_server_awareness)
self.clients = set()
self._on_message = None
self._on_message_error = None
self.exception_handler = exception_handler
self._stopped = Event()
self._provider_stop_event = Event()
Expand Down Expand Up @@ -158,6 +163,32 @@ def on_message(self, value: Callable[[bytes], Awaitable[bool] | bool] | None):
"""
self._on_message = value

@property
def on_message_error(
self,
) -> Callable[[Exception, bytes, Channel], Awaitable[bool] | bool] | None:
"""
Returns:
The optional callback to call when an exception is raised while processing
a message. The callback receives the exception, the raw message bytes,
and the channel. If it returns True the error is considered handled and
processing continues with the next message; otherwise the exception propagates.
"""
return self._on_message_error

@on_message_error.setter
def on_message_error(
self,
value: Callable[[Exception, bytes, Channel], Awaitable[bool] | bool] | None,
):
"""
Arguments:
value: An optional callback to call when an exception is raised while
processing a message. If the callback returns True, the error is
handled and the message is skipped.
"""
self._on_message_error = value

async def _broadcast_updates(self):
if self.ystore is not None:
async with self.ystore.start_lock:
Expand Down Expand Up @@ -294,59 +325,66 @@ async def serve(self, channel: Channel):
)
await channel.send(sync_message)
async for message in channel:
# filter messages (e.g. awareness)
skip = False
if self.on_message:
_skip = self.on_message(message)
skip = await _skip if isawaitable(_skip) else _skip
if skip:
continue
message_type = message[0]
if message_type == YMessageType.SYNC:
# update our internal state in the background
# changes to the internal state are then forwarded to all clients
# and stored in the YStore (if any)
self.log.debug(
"Received %s message from endpoint: %s",
YSyncMessageType(message[1]).name,
channel.path,
)
reply = handle_sync_message(message[1:], self.ydoc)
if reply is not None:
try:
# filter messages (e.g. awareness)
skip = False
if self.on_message:
_skip = self.on_message(message)
skip = await _skip if isawaitable(_skip) else _skip
if skip:
continue
message_type = message[0]
if message_type == YMessageType.SYNC:
# update our internal state in the background
# changes to the internal state are then forwarded to all clients
# and stored in the YStore (if any)
self.log.debug(
"Sending %s message to endpoint: %s",
YSyncMessageType.SYNC_STEP2.name,
"Received %s message from endpoint: %s",
YSyncMessageType(message[1]).name,
channel.path,
)
tg.start_soon(channel.send, reply)
elif message_type == YMessageType.AWARENESS:
# forward awareness messages from this client to all clients,
# including itself, because it's used to keep the connection alive
self.log.debug(
"Received %s message from endpoint: %s",
YMessageType.AWARENESS.name,
channel.path,
)

# Check if the message is a client awareness disconnect.
disconnection = is_awareness_disconnect_message(message[1:])

# Propagate the message to all clients except itself if it is a
# disconnection from the client. This avoid an error when trying
# to send the message to the disconnected client.
for client in self.clients:
if disconnection and client == channel:
continue

reply = handle_sync_message(message[1:], self.ydoc)
if reply is not None:
self.log.debug(
"Sending %s message to endpoint: %s",
YSyncMessageType.SYNC_STEP2.name,
channel.path,
)
tg.start_soon(channel.send, reply)
elif message_type == YMessageType.AWARENESS:
# forward awareness messages from this client to all clients,
# including itself, because it's used to keep the connection alive
self.log.debug(
"Sending Y awareness from client with endpoint "
"%s to client with endpoint: %s",
"Received %s message from endpoint: %s",
YMessageType.AWARENESS.name,
channel.path,
client.path,
)
tg.start_soon(client.send, message)
# apply awareness update to the server's awareness
self.awareness.apply_awareness_update(read_message(message[1:]), self)

# Check if the message is a client awareness disconnect.
disconnection = is_awareness_disconnect_message(message[1:])

# Propagate the message to all clients except itself if it is a
# disconnection from the client. This avoid an error when trying
# to send the message to the disconnected client.
for client in self.clients:
if disconnection and client == channel:
continue
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It looks like we can have this logic once for all message types just after async for message in channel:, instead of duplicating it for all message types?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like 8c0a645?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this looks good.


self.log.debug(
"Sending Y awareness from client with endpoint "
"%s to client with endpoint: %s",
channel.path,
client.path,
)
tg.start_soon(client.send, message)
# apply awareness update to the server's awareness
self.awareness.apply_awareness_update(read_message(message[1:]), self)
except Exception as exc:
if self._on_message_error is not None:
_handled = self._on_message_error(exc, message, channel)
handled = await _handled if isawaitable(_handled) else _handled
if handled:
continue
except Exception as exception:
self._handle_exception(exception)
finally:
Expand Down
Loading