11# -*- coding: utf-8 -*-
22from __future__ import annotations
33
4- from typing import Any , List , Optional , Sequence , Type , cast
4+ import select
5+ import time
6+ from threading import Thread
7+ from typing import Any , Callable , List , Optional , Sequence , Type , cast
58from uuid import UUID
69
710from eventsourcing .persistence import (
811 AggregateRecorder ,
912 ApplicationRecorder ,
1013 IntegrityError ,
14+ ListenNotifySubscription ,
1115 Notification ,
1216 ProcessRecorder ,
17+ ProgrammingError ,
1318 StoredEvent ,
1419 Subscription ,
1520 Tracking ,
2934class SQLAlchemyRecorder :
3035 """Base class for recorders that use SQLAlchemy."""
3136
37+ POSTGRES_MAX_IDENTIFIER_LEN = 63
38+
3239 def __init__ (
3340 self ,
3441 datastore : SQLAlchemyDatastore ,
@@ -38,6 +45,13 @@ def __init__(
3845 self .schema_name = schema_name
3946 self .tables : List [Table ] = []
4047
48+ def check_identifier_length (self , table_name : str ) -> None :
49+ assert self .datastore .engine is not None
50+ if self .datastore .engine .dialect .name == "postgresql" :
51+ if len (table_name ) > SQLAlchemyRecorder .POSTGRES_MAX_IDENTIFIER_LEN :
52+ msg = f"Identifier too long: { table_name } "
53+ raise ProgrammingError (msg )
54+
4155 def create_table (self ) -> None :
4256 assert self .datastore .engine is not None
4357 for table in self .tables :
@@ -233,12 +247,12 @@ def _insert_stored_events(
233247 session .add (record )
234248 if self ._has_autoincrementing_ids :
235249 session .flush () # We want the autoincremented IDs now.
250+ self ._notify_channel (session )
236251 return [cast (StoredEventRecord , r ).id for r in records ]
237252
238253 def max_notification_id (self ) -> int | None :
239254 try :
240255 with self .transaction (commit = False ) as session :
241- # record_class = cast(Type[StoredEventRecord], self.events_record_cls)
242256 record_class = self .events_record_cls
243257 q = session .query (record_class )
244258 q = q .order_by (record_class .id .desc ())
@@ -257,7 +271,6 @@ def select_notifications(
257271 inclusive_of_start : bool = True ,
258272 ) -> list [Notification ]:
259273 with self .transaction (commit = False ) as session :
260- # record_class = cast(Type[StoredEventRecord], self.events_record_cls)
261274 record_class = self .events_record_cls
262275 q = session .query (record_class )
263276 if start is not None :
@@ -289,8 +302,86 @@ def select_notifications(
289302 def subscribe (
290303 self , gt : int | None = None , topics : Sequence [str ] = ()
291304 ) -> Subscription [ApplicationRecorder ]:
292- msg = "SQLAlchemyApplicationRecorder.subscribe() is not implemented"
293- raise NotImplementedError (msg )
305+ assert self .datastore .engine
306+ if self .datastore .engine .dialect .name == "postgresql" :
307+ return SQLAlchemySubscription (recorder = self , gt = gt , topics = topics )
308+ else :
309+ msg = "SQLAlchemyApplicationRecorder.subscribe() is not implemented for"
310+ msg += f"{ self .datastore .engine .dialect } "
311+ raise NotImplementedError (msg )
312+
313+ def _notify_channel (self , session : Session ) -> None :
314+ """
315+ Send a NOTIFY on the channel using a SQLAlchemy connection.
316+ """
317+ assert self .datastore .engine
318+ if self .datastore .engine .dialect .name == "postgresql" :
319+ # Get the raw psycopg connection
320+ cursor = session .connection ().connection .cursor ()
321+ cursor .execute (f"NOTIFY { self .channel_name } ;" )
322+
323+
324+ class SQLAlchemySubscription (ListenNotifySubscription [SQLAlchemyApplicationRecorder ]):
325+ def __init__ (
326+ self ,
327+ recorder : SQLAlchemyApplicationRecorder ,
328+ gt : int | None = None ,
329+ topics : Sequence [str ] = (),
330+ ) -> None :
331+ assert isinstance (recorder , SQLAlchemyApplicationRecorder )
332+ super ().__init__ (recorder = recorder , gt = gt , topics = topics )
333+ self ._listen_thread = Thread (target = self ._listen )
334+ self ._listen_thread .start ()
335+
336+ def __exit__ (self , * args : object , ** kwargs : Any ) -> None :
337+ super ().__exit__ (* args , ** kwargs )
338+ self ._listen_thread .join ()
339+
340+ def _listen (self ) -> None :
341+ assert self ._recorder .datastore .engine
342+ assert self ._recorder .datastore .engine .dialect .name == "postgresql"
343+ notification_handler = self .__get_notification_handler ()
344+
345+ try :
346+ with self ._recorder .datastore .get_connection () as sa_conn :
347+ sa_conn .execution_options (isolation_level = "AUTOCOMMIT" )
348+ raw_conn = sa_conn .connection
349+
350+ cur = raw_conn .cursor ()
351+ cur .execute (f"LISTEN { self ._recorder .channel_name } ;" )
352+
353+ while not self ._has_been_stopped and not self ._thread_error :
354+ if select .select ([raw_conn ], [], [], 0.1 )[0 ]:
355+ notification_handler (raw_conn )
356+ else :
357+ time .sleep (0.1 )
358+
359+ except BaseException as e : # noqa: B036
360+ if self ._thread_error is None :
361+ self ._thread_error = e
362+ self .stop ()
363+
364+ def __get_notification_handler (self ) -> Callable [[Any ], None ]:
365+ assert self ._recorder .datastore .engine
366+ driver_name = self ._recorder .datastore .engine .dialect .driver
367+ handlers = {
368+ "psycopg" : self .__handle_psycopg_notification ,
369+ "psycopg2" : self .__handle_psycopg2_notification ,
370+ }
371+ try :
372+ return handlers [driver_name ]
373+ except KeyError as e :
374+ raise NotImplementedError (f"Unsupported driver: { driver_name } " ) from e
375+
376+ def __handle_psycopg_notification (self , raw_conn : Any ) -> None :
377+ next (raw_conn .notifies ())
378+ self ._has_been_notified .set ()
379+
380+ def __handle_psycopg2_notification (self , raw_conn : Any ) -> None :
381+ raw_conn .poll ()
382+ if raw_conn .notifies :
383+ raw_conn .notifies .pop (0 )
384+ self ._has_been_notified .set ()
294385
295386
296387class SQLAlchemyTrackingRecorder (SQLAlchemyRecorder , TrackingRecorder ):
0 commit comments