1313 StoredEvent ,
1414 Subscription ,
1515 Tracking ,
16+ TrackingRecorder ,
1617)
1718from sqlalchemy import Column , Table , text
1819from sqlalchemy .orm import Session
20+ from typing_extensions import TypeVar
1921
2022from eventsourcing_sqlalchemy .datastore import SQLAlchemyDatastore , Transaction
2123from eventsourcing_sqlalchemy .models import ( # type: ignore
2426)
2527
2628
27- class SQLAlchemyAggregateRecorder (AggregateRecorder ):
29+ class SQLAlchemyRecorder :
30+ """Base class for recorders that use SQLAlchemy."""
31+
2832 def __init__ (
2933 self ,
3034 datastore : SQLAlchemyDatastore ,
35+ schema_name : str | None = None ,
36+ ):
37+ self .datastore = datastore
38+ self .schema_name = schema_name
39+ self .tables : List [Table ] = []
40+
41+ def create_table (self ) -> None :
42+ assert self .datastore .engine is not None
43+ for table in self .tables :
44+ table .create (self .datastore .engine , checkfirst = True )
45+
46+ def transaction (self , commit : bool = True ) -> Transaction :
47+ return self .datastore .transaction (commit = commit )
48+
49+
50+ class SQLAlchemyAggregateRecorder (SQLAlchemyRecorder , AggregateRecorder ):
51+ def __init__ (
52+ self ,
53+ datastore : SQLAlchemyDatastore ,
54+ * ,
3155 events_table_name : str ,
3256 schema_name : str | None = None ,
3357 for_snapshots : bool = False ,
3458 ):
35- super ().__init__ ()
36- self .datastore = datastore
59+ super ().__init__ (
60+ datastore ,
61+ schema_name = schema_name ,
62+ )
3763 self .events_table_name = events_table_name
38- self .schema_name = schema_name
3964 record_cls_name = "" .join (
4065 [
4166 s .capitalize ()
@@ -55,27 +80,18 @@ def __init__(
5580 schema_name = self .schema_name ,
5681 base_cls = base_cls ,
5782 )
58- self .stored_events_table = self .events_record_cls .__table__
59-
60- def transaction (self , commit : bool = True ) -> Transaction :
61- return self .datastore .transaction (commit = commit )
62-
63- def create_table (self ) -> None :
64- assert self .datastore .engine is not None
65- self .stored_events_table .create (self .datastore .engine , checkfirst = True )
83+ self .tables .append (self .events_record_cls .__table__ )
6684
6785 def insert_events (
6886 self , stored_events : Sequence [StoredEvent ], ** kwargs : Any
6987 ) -> Optional [Sequence [int ]]:
7088 with self .transaction (commit = True ) as session :
71- self ._insert_events (session , stored_events , ** kwargs )
89+ self ._insert_stored_events (session , stored_events , ** kwargs )
7290 return None
7391
74- def _insert_events (
92+ def _insert_stored_events (
7593 self , session : Session , stored_events : Sequence [StoredEvent ], ** kwargs : Any
7694 ) -> Optional [Sequence [int ]]:
77- if len (stored_events ) == 0 :
78- return []
7995 records = [
8096 self .events_record_cls (
8197 originator_id = e .originator_id ,
@@ -91,9 +107,7 @@ def _insert_events(
91107 session .add (record )
92108 if self ._has_autoincrementing_ids :
93109 session .flush () # We want the autoincremented IDs now.
94- return [cast (StoredEventRecord , r ).id for r in records ]
95- else :
96- return None
110+ return None
97111
98112 def _lock_table (self , session : Session ) -> None :
99113 assert self .datastore .engine is not None
@@ -160,6 +174,18 @@ def select_events(
160174
161175
162176class SQLAlchemyApplicationRecorder (SQLAlchemyAggregateRecorder , ApplicationRecorder ):
177+ def __init__ (
178+ self ,
179+ datastore : SQLAlchemyDatastore ,
180+ * ,
181+ events_table_name : str ,
182+ schema_name : str | None = None ,
183+ ):
184+ super ().__init__ (
185+ datastore , events_table_name = events_table_name , schema_name = schema_name
186+ )
187+ self .channel_name = self .events_table_name .replace ("." , "_" )
188+
163189 def insert_events (
164190 self ,
165191 stored_events : Sequence [StoredEvent ],
@@ -169,12 +195,46 @@ def insert_events(
169195 ) -> Optional [Sequence [int ]]:
170196 if session is not None :
171197 assert isinstance (session , Session ), type (session )
172- notification_ids = self ._insert_events (session , stored_events , ** kwargs )
198+ self ._insert_events (session , stored_events , ** kwargs )
199+ notification_ids = self ._insert_stored_events (
200+ session , stored_events , ** kwargs
201+ )
173202 else :
174203 with self .transaction (commit = True ) as session :
175- notification_ids = self ._insert_events (session , stored_events , ** kwargs )
204+ self ._insert_events (session , stored_events , ** kwargs )
205+ notification_ids = self ._insert_stored_events (
206+ session , stored_events , ** kwargs
207+ )
176208 return notification_ids
177209
210+ def _insert_events (
211+ self ,
212+ session : Session ,
213+ stored_events : Sequence [StoredEvent ],
214+ ** _ : Any ,
215+ ) -> Optional [Sequence [int ]]:
216+ pass
217+
218+ def _insert_stored_events (
219+ self , session : Session , stored_events : Sequence [StoredEvent ], ** kwargs : Any
220+ ) -> Sequence [int ]:
221+ records = [
222+ self .events_record_cls (
223+ originator_id = e .originator_id ,
224+ originator_version = e .originator_version ,
225+ topic = e .topic ,
226+ state = e .state ,
227+ )
228+ for e in stored_events
229+ ]
230+ if self ._has_autoincrementing_ids :
231+ self ._lock_table (session )
232+ for record in records :
233+ session .add (record )
234+ if self ._has_autoincrementing_ids :
235+ session .flush () # We want the autoincremented IDs now.
236+ return [cast (StoredEventRecord , r ).id for r in records ]
237+
178238 def max_notification_id (self ) -> int | None :
179239 try :
180240 with self .transaction (commit = False ) as session :
@@ -233,51 +293,30 @@ def subscribe(
233293 raise NotImplementedError (msg )
234294
235295
236- class SQLAlchemyProcessRecorder ( SQLAlchemyApplicationRecorder , ProcessRecorder ):
296+ class SQLAlchemyTrackingRecorder ( SQLAlchemyRecorder , TrackingRecorder ):
237297 def __init__ (
238298 self ,
239299 datastore : SQLAlchemyDatastore ,
240- events_table_name : str ,
241- tracking_table_name : str ,
300+ * ,
301+ tracking_table_name : str = "notification_tracking" ,
242302 schema_name : str | None = None ,
303+ ** kwargs : Any ,
243304 ):
244- super ().__init__ (
245- datastore = datastore ,
246- events_table_name = events_table_name ,
247- schema_name = schema_name ,
248- )
305+ super ().__init__ (datastore = datastore , ** kwargs )
249306 self .tracking_table_name = tracking_table_name
250307 self .tracking_record_cls = self .datastore .define_record_class (
251308 cls_name = "NotificationTrackingRecord" ,
252309 table_name = self .tracking_table_name ,
253- schema_name = self . schema_name ,
310+ schema_name = schema_name ,
254311 base_cls = datastore .base_notification_tracking_record_cls ,
255312 )
256313 self .tracking_table : Table = self .tracking_record_cls .__table__
257314
258315 def create_table (self ) -> None :
259316 super ().create_table ()
317+ assert self .datastore .engine is not None
260318 self .tracking_table .create (self .datastore .engine , checkfirst = True )
261319
262- def _insert_events (
263- self , session : Session , stored_events : Sequence [StoredEvent ], ** kwargs : Any
264- ) -> Optional [Sequence [int ]]:
265- notification_ids = super (SQLAlchemyProcessRecorder , self )._insert_events (
266- session , stored_events , ** kwargs
267- )
268- tracking : Optional [Tracking ] = kwargs .get ("tracking" , None )
269- if tracking is not None :
270- if self .has_tracking_id (
271- tracking .application_name , tracking .notification_id
272- ):
273- raise IntegrityError
274- record = self .tracking_record_cls (
275- application_name = tracking .application_name ,
276- notification_id = tracking .notification_id ,
277- )
278- session .add (record )
279- return notification_ids
280-
281320 def max_tracking_id (self , application_name : str ) -> int | None :
282321 with self .transaction (commit = False ) as session :
283322 q = session .query (self .tracking_record_cls )
@@ -290,4 +329,51 @@ def max_tracking_id(self, application_name: str) -> int | None:
290329 return max_id
291330
292331 def insert_tracking (self , tracking : Tracking ) -> None :
293- raise NotImplementedError
332+ with self .transaction (commit = True ) as session :
333+ self ._insert_tracking (session = session , tracking = tracking )
334+
335+ def _insert_tracking (self , session : Session , tracking : Tracking ) -> None :
336+ if tracking is not None :
337+ if self .has_tracking_id (
338+ tracking .application_name , tracking .notification_id
339+ ):
340+ raise IntegrityError
341+ record = self .tracking_record_cls (
342+ application_name = tracking .application_name ,
343+ notification_id = tracking .notification_id ,
344+ )
345+ session .add (record )
346+
347+
348+ TSQLAlchemyTrackingRecorder = TypeVar (
349+ "TSQLAlchemyTrackingRecorder" ,
350+ bound = SQLAlchemyTrackingRecorder ,
351+ default = SQLAlchemyTrackingRecorder ,
352+ )
353+
354+
355+ class SQLAlchemyProcessRecorder (
356+ SQLAlchemyTrackingRecorder , SQLAlchemyApplicationRecorder , ProcessRecorder
357+ ):
358+ def __init__ (
359+ self ,
360+ datastore : SQLAlchemyDatastore ,
361+ * ,
362+ events_table_name : str ,
363+ tracking_table_name : str ,
364+ schema_name : str | None = None ,
365+ ):
366+ super ().__init__ (
367+ datastore = datastore ,
368+ tracking_table_name = tracking_table_name ,
369+ events_table_name = events_table_name ,
370+ schema_name = schema_name ,
371+ )
372+
373+ def _insert_events (
374+ self , session : Session , stored_events : Sequence [StoredEvent ], ** kwargs : Any
375+ ) -> None :
376+ tracking : Optional [Tracking ] = kwargs .get ("tracking" , None )
377+ if tracking is not None :
378+ self ._insert_tracking (session , tracking )
379+ super ()._insert_events (session = session , stored_events = stored_events , ** kwargs )
0 commit comments