11# -*- coding: utf-8 -*-
22from __future__ import annotations
33
4- from typing import Any , List , Optional , Sequence , Type , cast
4+ import select
5+ import time
6+ from threading import Thread
7+ from typing import Any , Callable , List , Optional , Sequence , Type , cast
58from uuid import UUID
69
710from eventsourcing .persistence import (
811 AggregateRecorder ,
912 ApplicationRecorder ,
1013 IntegrityError ,
14+ ListenNotifySubscription ,
1115 Notification ,
1216 ProcessRecorder ,
17+ ProgrammingError ,
1318 StoredEvent ,
1419 Subscription ,
1520 Tracking ,
2934class SQLAlchemyRecorder :
3035 """Base class for recorders that use SQLAlchemy."""
3136
37+ POSTGRES_MAX_IDENTIFIER_LEN = 63
38+
3239 def __init__ (
3340 self ,
3441 datastore : SQLAlchemyDatastore ,
@@ -38,6 +45,13 @@ def __init__(
3845 self .schema_name = schema_name
3946 self .tables : List [Table ] = []
4047
48+ def check_identifier_length (self , table_name : str ) -> None :
49+ assert self .datastore .engine is not None
50+ if self .datastore .engine .dialect .name == "postgresql" :
51+ if len (table_name ) > SQLAlchemyRecorder .POSTGRES_MAX_IDENTIFIER_LEN :
52+ msg = f"Identifier too long: { table_name } "
53+ raise ProgrammingError (msg )
54+
4155 def create_table (self ) -> None :
4256 assert self .datastore .engine is not None
4357 for table in self .tables :
@@ -234,12 +248,12 @@ def _insert_stored_events(
234248 session .add (record )
235249 if self ._has_autoincrementing_ids :
236250 session .flush () # We want the autoincremented IDs now.
251+ self ._notify_channel (session )
237252 return [cast (StoredEventRecord , r ).id for r in records ]
238253
239254 def max_notification_id (self ) -> int | None :
240255 try :
241256 with self .transaction (commit = False ) as session :
242- # record_class = cast(Type[StoredEventRecord], self.events_record_cls)
243257 record_class = self .events_record_cls
244258 q = session .query (record_class )
245259 q = q .order_by (record_class .id .desc ())
@@ -258,7 +272,6 @@ def select_notifications(
258272 inclusive_of_start : bool = True ,
259273 ) -> list [Notification ]:
260274 with self .transaction (commit = False ) as session :
261- # record_class = cast(Type[StoredEventRecord], self.events_record_cls)
262275 record_class = self .events_record_cls
263276 q = session .query (record_class )
264277 if start is not None :
@@ -290,8 +303,86 @@ def select_notifications(
290303 def subscribe (
291304 self , gt : int | None = None , topics : Sequence [str ] = ()
292305 ) -> Subscription [ApplicationRecorder ]:
293- msg = "SQLAlchemyApplicationRecorder.subscribe() is not implemented"
294- raise NotImplementedError (msg )
306+ assert self .datastore .engine
307+ if self .datastore .engine .dialect .name == "postgresql" :
308+ return SQLAlchemySubscription (recorder = self , gt = gt , topics = topics )
309+ else :
310+ msg = "SQLAlchemyApplicationRecorder.subscribe() is not implemented for"
311+ msg += f"{ self .datastore .engine .dialect } "
312+ raise NotImplementedError (msg )
313+
314+ def _notify_channel (self , session : Session ) -> None :
315+ """
316+ Send a NOTIFY on the channel using a SQLAlchemy connection.
317+ """
318+ assert self .datastore .engine
319+ if self .datastore .engine .dialect .name == "postgresql" :
320+ # Get the raw psycopg connection
321+ cursor = session .connection ().connection .cursor ()
322+ cursor .execute (f"NOTIFY { self .channel_name } ;" )
323+
324+
325+ class SQLAlchemySubscription (ListenNotifySubscription [SQLAlchemyApplicationRecorder ]):
326+ def __init__ (
327+ self ,
328+ recorder : SQLAlchemyApplicationRecorder ,
329+ gt : int | None = None ,
330+ topics : Sequence [str ] = (),
331+ ) -> None :
332+ assert isinstance (recorder , SQLAlchemyApplicationRecorder )
333+ super ().__init__ (recorder = recorder , gt = gt , topics = topics )
334+ self ._listen_thread = Thread (target = self ._listen )
335+ self ._listen_thread .start ()
336+
337+ def __exit__ (self , * args : object , ** kwargs : Any ) -> None :
338+ super ().__exit__ (* args , ** kwargs )
339+ self ._listen_thread .join ()
340+
341+ def _listen (self ) -> None :
342+ assert self ._recorder .datastore .engine
343+ assert self ._recorder .datastore .engine .dialect .name == "postgresql"
344+ notification_handler = self .__get_notification_handler ()
345+
346+ try :
347+ with self ._recorder .datastore .get_connection () as sa_conn :
348+ sa_conn .execution_options (isolation_level = "AUTOCOMMIT" )
349+ raw_conn = sa_conn .connection
350+
351+ cur = raw_conn .cursor ()
352+ cur .execute (f"LISTEN { self ._recorder .channel_name } ;" )
353+
354+ while not self ._has_been_stopped and not self ._thread_error :
355+ if select .select ([raw_conn ], [], [], 0.1 )[0 ]:
356+ notification_handler (raw_conn )
357+ else :
358+ time .sleep (0.1 )
359+
360+ except BaseException as e : # noqa: B036
361+ if self ._thread_error is None :
362+ self ._thread_error = e
363+ self .stop ()
364+
365+ def __get_notification_handler (self ) -> Callable [[Any ], None ]:
366+ assert self ._recorder .datastore .engine
367+ driver_name = self ._recorder .datastore .engine .dialect .driver
368+ handlers = {
369+ "psycopg" : self .__handle_psycopg_notification ,
370+ "psycopg2" : self .__handle_psycopg2_notification ,
371+ }
372+ try :
373+ return handlers [driver_name ]
374+ except KeyError as e :
375+ raise NotImplementedError (f"Unsupported driver: { driver_name } " ) from e
376+
377+ def __handle_psycopg_notification (self , raw_conn : Any ) -> None :
378+ next (raw_conn .notifies ())
379+ self ._has_been_notified .set ()
380+
381+ def __handle_psycopg2_notification (self , raw_conn : Any ) -> None :
382+ raw_conn .poll ()
383+ if raw_conn .notifies :
384+ raw_conn .notifies .pop (0 )
385+ self ._has_been_notified .set ()
295386
296387
297388class SQLAlchemyTrackingRecorder (SQLAlchemyRecorder , TrackingRecorder ):
0 commit comments