diff --git a/src/pycrdt/websocket/yroom.py b/src/pycrdt/websocket/yroom.py index 1539fc3..84a3f92 100644 --- a/src/pycrdt/websocket/yroom.py +++ b/src/pycrdt/websocket/yroom.py @@ -52,6 +52,7 @@ class YRoom: ystore: BaseYStore | None ready_event: Event _on_message: Callable[[bytes], Awaitable[bool] | bool] | None + _on_message_error: Callable[[Exception, bytes, Channel], Awaitable[bool] | bool] | None _update_send_stream: MemoryObjectSendStream _update_receive_stream: MemoryObjectReceiveStream _task_group: TaskGroup | None = None @@ -90,7 +91,10 @@ def __init__( provider_factory: An optional provider factory used to synchronize the room with an external document. exception_handler: An optional callback to call when an exception is raised, that - returns True if the exception was handled. + returns True if the exception was handled. Handling exceptions does not prevent + the room from stopping, but it prevents the exception from propagating. + Use on_message_error to handle exceptions raised while processing messages + without stopping the room. log: An optional logger. ydoc: An optional document for the room (a new one is created otherwise). """ @@ -104,6 +108,7 @@ def __init__( self.awareness.observe(self.send_server_awareness) self.clients = set() self._on_message = None + self._on_message_error = None self.exception_handler = exception_handler self._stopped = Event() self._provider_stop_event = Event() @@ -158,6 +163,32 @@ def on_message(self, value: Callable[[bytes], Awaitable[bool] | bool] | None): """ self._on_message = value + @property + def on_message_error( + self, + ) -> Callable[[Exception, bytes, Channel], Awaitable[bool] | bool] | None: + """ + Returns: + The optional callback to call when an exception is raised while processing + a message. The callback receives the exception, the raw message bytes, + and the channel. If it returns True the error is considered handled and + processing continues with the next message; otherwise the exception propagates. + """ + return self._on_message_error + + @on_message_error.setter + def on_message_error( + self, + value: Callable[[Exception, bytes, Channel], Awaitable[bool] | bool] | None, + ): + """ + Arguments: + value: An optional callback to call when an exception is raised while + processing a message. If the callback returns True, the error is + handled and the message is skipped. + """ + self._on_message_error = value + async def _broadcast_updates(self): if self.ystore is not None: async with self.ystore.start_lock: @@ -294,59 +325,66 @@ async def serve(self, channel: Channel): ) await channel.send(sync_message) async for message in channel: - # filter messages (e.g. awareness) - skip = False - if self.on_message: - _skip = self.on_message(message) - skip = await _skip if isawaitable(_skip) else _skip - if skip: - continue - message_type = message[0] - if message_type == YMessageType.SYNC: - # update our internal state in the background - # changes to the internal state are then forwarded to all clients - # and stored in the YStore (if any) - self.log.debug( - "Received %s message from endpoint: %s", - YSyncMessageType(message[1]).name, - channel.path, - ) - reply = handle_sync_message(message[1:], self.ydoc) - if reply is not None: + try: + # filter messages (e.g. awareness) + skip = False + if self.on_message: + _skip = self.on_message(message) + skip = await _skip if isawaitable(_skip) else _skip + if skip: + continue + message_type = message[0] + if message_type == YMessageType.SYNC: + # update our internal state in the background + # changes to the internal state are then forwarded to all clients + # and stored in the YStore (if any) self.log.debug( - "Sending %s message to endpoint: %s", - YSyncMessageType.SYNC_STEP2.name, + "Received %s message from endpoint: %s", + YSyncMessageType(message[1]).name, channel.path, ) - tg.start_soon(channel.send, reply) - elif message_type == YMessageType.AWARENESS: - # forward awareness messages from this client to all clients, - # including itself, because it's used to keep the connection alive - self.log.debug( - "Received %s message from endpoint: %s", - YMessageType.AWARENESS.name, - channel.path, - ) - - # Check if the message is a client awareness disconnect. - disconnection = is_awareness_disconnect_message(message[1:]) - - # Propagate the message to all clients except itself if it is a - # disconnection from the client. This avoid an error when trying - # to send the message to the disconnected client. - for client in self.clients: - if disconnection and client == channel: - continue - + reply = handle_sync_message(message[1:], self.ydoc) + if reply is not None: + self.log.debug( + "Sending %s message to endpoint: %s", + YSyncMessageType.SYNC_STEP2.name, + channel.path, + ) + tg.start_soon(channel.send, reply) + elif message_type == YMessageType.AWARENESS: + # forward awareness messages from this client to all clients, + # including itself, because it's used to keep the connection alive self.log.debug( - "Sending Y awareness from client with endpoint " - "%s to client with endpoint: %s", + "Received %s message from endpoint: %s", + YMessageType.AWARENESS.name, channel.path, - client.path, ) - tg.start_soon(client.send, message) - # apply awareness update to the server's awareness - self.awareness.apply_awareness_update(read_message(message[1:]), self) + + # Check if the message is a client awareness disconnect. + disconnection = is_awareness_disconnect_message(message[1:]) + + # Propagate the message to all clients except itself if it is a + # disconnection from the client. This avoid an error when trying + # to send the message to the disconnected client. + for client in self.clients: + if disconnection and client == channel: + continue + + self.log.debug( + "Sending Y awareness from client with endpoint " + "%s to client with endpoint: %s", + channel.path, + client.path, + ) + tg.start_soon(client.send, message) + # apply awareness update to the server's awareness + self.awareness.apply_awareness_update(read_message(message[1:]), self) + except Exception as exc: + if self._on_message_error is not None: + _handled = self._on_message_error(exc, message, channel) + handled = await _handled if isawaitable(_handled) else _handled + if handled: + continue except Exception as exception: self._handle_exception(exception) finally: