Skip to content

Commit e3e7ee8

Browse files
authored
Merge pull request #27 from mobuild-io/SQLAlchemyTrackingRecorder
SQLAlchemyTrackingRecorder implementationn
2 parents 0c330bc + 6c2d9c4 commit e3e7ee8

11 files changed

Lines changed: 219 additions & 133 deletions

File tree

.github/workflows/github-actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ jobs:
66
strategy:
77
fail-fast: false
88
matrix:
9-
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
9+
python-version: ["3.10", "3.11", "3.12", "3.13"]
1010
sqlalchemy-version: ["<2.0", "default"]
1111
runs-on: ubuntu-latest
1212
services:

eventsourcing_sqlalchemy/factory.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,13 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
from typing import Optional
4+
from typing import Optional, cast
55

66
from eventsourcing.persistence import (
77
AggregateRecorder,
88
ApplicationRecorder,
99
InfrastructureFactory,
1010
ProcessRecorder,
11-
TrackingRecorder,
1211
)
1312
from eventsourcing.utils import Environment, resolve_topic, strtobool
1413
from sqlalchemy.orm import scoped_session
@@ -18,10 +17,12 @@
1817
SQLAlchemyAggregateRecorder,
1918
SQLAlchemyApplicationRecorder,
2019
SQLAlchemyProcessRecorder,
20+
SQLAlchemyTrackingRecorder,
21+
TSQLAlchemyTrackingRecorder,
2122
)
2223

2324

24-
class SQLAlchemyFactory(InfrastructureFactory[TrackingRecorder]):
25+
class SQLAlchemyFactory(InfrastructureFactory[SQLAlchemyTrackingRecorder]):
2526
SQLALCHEMY_URL = "SQLALCHEMY_URL"
2627
SQLALCHEMY_AUTOFLUSH = "SQLALCHEMY_AUTOFLUSH"
2728
SQLALCHEMY_CONNECTION_CREATOR_TOPIC = "SQLALCHEMY_CONNECTION_CREATOR_TOPIC"
@@ -32,6 +33,7 @@ class SQLAlchemyFactory(InfrastructureFactory[TrackingRecorder]):
3233
datastore_class = SQLAlchemyDatastore
3334
aggregate_recorder_class = SQLAlchemyAggregateRecorder
3435
application_recorder_class = SQLAlchemyApplicationRecorder
36+
tracking_recorder_class = SQLAlchemyTrackingRecorder
3537
process_recorder_class = SQLAlchemyProcessRecorder
3638

3739
def __init__(self, env: Environment):
@@ -90,6 +92,31 @@ def application_recorder(self) -> ApplicationRecorder:
9092
recorder.create_table()
9193
return recorder
9294

95+
def tracking_recorder(
96+
self, tracking_recorder_class: type[TSQLAlchemyTrackingRecorder] | None = None
97+
) -> TSQLAlchemyTrackingRecorder:
98+
prefix = self.env.name.lower() or "notification"
99+
tracking_table_name = prefix + "_tracking"
100+
if tracking_recorder_class is None:
101+
tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
102+
if tracking_recorder_topic:
103+
tracking_recorder_class = resolve_topic(tracking_recorder_topic)
104+
else:
105+
tracking_recorder_class = cast(
106+
"type[TSQLAlchemyTrackingRecorder]",
107+
type(self).tracking_recorder_class,
108+
)
109+
assert tracking_recorder_class is not None
110+
assert issubclass(tracking_recorder_class, SQLAlchemyTrackingRecorder)
111+
recorder = tracking_recorder_class(
112+
datastore=self.datastore,
113+
tracking_table_name=tracking_table_name,
114+
schema_name=self._schema_name,
115+
)
116+
if self.env_create_table():
117+
recorder.create_table()
118+
return recorder
119+
93120
def process_recorder(self) -> ProcessRecorder:
94121
prefix = self.env.name.lower() or "stored"
95122
events_table_name = prefix + "_events"
@@ -105,11 +132,6 @@ def process_recorder(self) -> ProcessRecorder:
105132
recorder.create_table()
106133
return recorder
107134

108-
def tracking_recorder(
109-
self, tracking_recorder_class: type[TrackingRecorder] | None = None
110-
) -> TrackingRecorder:
111-
raise NotImplementedError
112-
113135
def env_create_table(self) -> bool:
114136
default = "yes"
115137
return bool(strtobool(self.env.get(self.CREATE_TABLE) or default))

eventsourcing_sqlalchemy/recorders.py

Lines changed: 136 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,11 @@
1313
StoredEvent,
1414
Subscription,
1515
Tracking,
16+
TrackingRecorder,
1617
)
1718
from sqlalchemy import Column, Table, text
1819
from sqlalchemy.orm import Session
20+
from typing_extensions import TypeVar
1921

2022
from eventsourcing_sqlalchemy.datastore import SQLAlchemyDatastore, Transaction
2123
from eventsourcing_sqlalchemy.models import ( # type: ignore
@@ -24,18 +26,41 @@
2426
)
2527

2628

27-
class SQLAlchemyAggregateRecorder(AggregateRecorder):
29+
class SQLAlchemyRecorder:
30+
"""Base class for recorders that use SQLAlchemy."""
31+
2832
def __init__(
2933
self,
3034
datastore: SQLAlchemyDatastore,
35+
schema_name: str | None = None,
36+
):
37+
self.datastore = datastore
38+
self.schema_name = schema_name
39+
self.tables: List[Table] = []
40+
41+
def create_table(self) -> None:
42+
assert self.datastore.engine is not None
43+
for table in self.tables:
44+
table.create(self.datastore.engine, checkfirst=True)
45+
46+
def transaction(self, commit: bool = True) -> Transaction:
47+
return self.datastore.transaction(commit=commit)
48+
49+
50+
class SQLAlchemyAggregateRecorder(SQLAlchemyRecorder, AggregateRecorder):
51+
def __init__(
52+
self,
53+
datastore: SQLAlchemyDatastore,
54+
*,
3155
events_table_name: str,
3256
schema_name: str | None = None,
3357
for_snapshots: bool = False,
3458
):
35-
super().__init__()
36-
self.datastore = datastore
59+
super().__init__(
60+
datastore,
61+
schema_name=schema_name,
62+
)
3763
self.events_table_name = events_table_name
38-
self.schema_name = schema_name
3964
record_cls_name = "".join(
4065
[
4166
s.capitalize()
@@ -55,27 +80,18 @@ def __init__(
5580
schema_name=self.schema_name,
5681
base_cls=base_cls,
5782
)
58-
self.stored_events_table = self.events_record_cls.__table__
59-
60-
def transaction(self, commit: bool = True) -> Transaction:
61-
return self.datastore.transaction(commit=commit)
62-
63-
def create_table(self) -> None:
64-
assert self.datastore.engine is not None
65-
self.stored_events_table.create(self.datastore.engine, checkfirst=True)
83+
self.tables.append(self.events_record_cls.__table__)
6684

6785
def insert_events(
6886
self, stored_events: Sequence[StoredEvent], **kwargs: Any
6987
) -> Optional[Sequence[int]]:
7088
with self.transaction(commit=True) as session:
71-
self._insert_events(session, stored_events, **kwargs)
89+
self._insert_stored_events(session, stored_events, **kwargs)
7290
return None
7391

74-
def _insert_events(
92+
def _insert_stored_events(
7593
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
7694
) -> Optional[Sequence[int]]:
77-
if len(stored_events) == 0:
78-
return []
7995
records = [
8096
self.events_record_cls(
8197
originator_id=e.originator_id,
@@ -91,9 +107,7 @@ def _insert_events(
91107
session.add(record)
92108
if self._has_autoincrementing_ids:
93109
session.flush() # We want the autoincremented IDs now.
94-
return [cast(StoredEventRecord, r).id for r in records]
95-
else:
96-
return None
110+
return None
97111

98112
def _lock_table(self, session: Session) -> None:
99113
assert self.datastore.engine is not None
@@ -160,6 +174,18 @@ def select_events(
160174

161175

162176
class SQLAlchemyApplicationRecorder(SQLAlchemyAggregateRecorder, ApplicationRecorder):
177+
def __init__(
178+
self,
179+
datastore: SQLAlchemyDatastore,
180+
*,
181+
events_table_name: str,
182+
schema_name: str | None = None,
183+
):
184+
super().__init__(
185+
datastore, events_table_name=events_table_name, schema_name=schema_name
186+
)
187+
self.channel_name = self.events_table_name.replace(".", "_")
188+
163189
def insert_events(
164190
self,
165191
stored_events: Sequence[StoredEvent],
@@ -169,12 +195,46 @@ def insert_events(
169195
) -> Optional[Sequence[int]]:
170196
if session is not None:
171197
assert isinstance(session, Session), type(session)
172-
notification_ids = self._insert_events(session, stored_events, **kwargs)
198+
self._insert_events(session, stored_events, **kwargs)
199+
notification_ids = self._insert_stored_events(
200+
session, stored_events, **kwargs
201+
)
173202
else:
174203
with self.transaction(commit=True) as session:
175-
notification_ids = self._insert_events(session, stored_events, **kwargs)
204+
self._insert_events(session, stored_events, **kwargs)
205+
notification_ids = self._insert_stored_events(
206+
session, stored_events, **kwargs
207+
)
176208
return notification_ids
177209

210+
def _insert_events(
211+
self,
212+
session: Session,
213+
stored_events: Sequence[StoredEvent],
214+
**_: Any,
215+
) -> Optional[Sequence[int]]:
216+
pass
217+
218+
def _insert_stored_events(
219+
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
220+
) -> Sequence[int]:
221+
records = [
222+
self.events_record_cls(
223+
originator_id=e.originator_id,
224+
originator_version=e.originator_version,
225+
topic=e.topic,
226+
state=e.state,
227+
)
228+
for e in stored_events
229+
]
230+
if self._has_autoincrementing_ids:
231+
self._lock_table(session)
232+
for record in records:
233+
session.add(record)
234+
if self._has_autoincrementing_ids:
235+
session.flush() # We want the autoincremented IDs now.
236+
return [cast(StoredEventRecord, r).id for r in records]
237+
178238
def max_notification_id(self) -> int | None:
179239
try:
180240
with self.transaction(commit=False) as session:
@@ -233,51 +293,30 @@ def subscribe(
233293
raise NotImplementedError(msg)
234294

235295

236-
class SQLAlchemyProcessRecorder(SQLAlchemyApplicationRecorder, ProcessRecorder):
296+
class SQLAlchemyTrackingRecorder(SQLAlchemyRecorder, TrackingRecorder):
237297
def __init__(
238298
self,
239299
datastore: SQLAlchemyDatastore,
240-
events_table_name: str,
241-
tracking_table_name: str,
300+
*,
301+
tracking_table_name: str = "notification_tracking",
242302
schema_name: str | None = None,
303+
**kwargs: Any,
243304
):
244-
super().__init__(
245-
datastore=datastore,
246-
events_table_name=events_table_name,
247-
schema_name=schema_name,
248-
)
305+
super().__init__(datastore=datastore, **kwargs)
249306
self.tracking_table_name = tracking_table_name
250307
self.tracking_record_cls = self.datastore.define_record_class(
251308
cls_name="NotificationTrackingRecord",
252309
table_name=self.tracking_table_name,
253-
schema_name=self.schema_name,
310+
schema_name=schema_name,
254311
base_cls=datastore.base_notification_tracking_record_cls,
255312
)
256313
self.tracking_table: Table = self.tracking_record_cls.__table__
257314

258315
def create_table(self) -> None:
259316
super().create_table()
317+
assert self.datastore.engine is not None
260318
self.tracking_table.create(self.datastore.engine, checkfirst=True)
261319

262-
def _insert_events(
263-
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
264-
) -> Optional[Sequence[int]]:
265-
notification_ids = super(SQLAlchemyProcessRecorder, self)._insert_events(
266-
session, stored_events, **kwargs
267-
)
268-
tracking: Optional[Tracking] = kwargs.get("tracking", None)
269-
if tracking is not None:
270-
if self.has_tracking_id(
271-
tracking.application_name, tracking.notification_id
272-
):
273-
raise IntegrityError
274-
record = self.tracking_record_cls(
275-
application_name=tracking.application_name,
276-
notification_id=tracking.notification_id,
277-
)
278-
session.add(record)
279-
return notification_ids
280-
281320
def max_tracking_id(self, application_name: str) -> int | None:
282321
with self.transaction(commit=False) as session:
283322
q = session.query(self.tracking_record_cls)
@@ -290,4 +329,51 @@ def max_tracking_id(self, application_name: str) -> int | None:
290329
return max_id
291330

292331
def insert_tracking(self, tracking: Tracking) -> None:
293-
raise NotImplementedError
332+
with self.transaction(commit=True) as session:
333+
self._insert_tracking(session=session, tracking=tracking)
334+
335+
def _insert_tracking(self, session: Session, tracking: Tracking) -> None:
336+
if tracking is not None:
337+
if self.has_tracking_id(
338+
tracking.application_name, tracking.notification_id
339+
):
340+
raise IntegrityError
341+
record = self.tracking_record_cls(
342+
application_name=tracking.application_name,
343+
notification_id=tracking.notification_id,
344+
)
345+
session.add(record)
346+
347+
348+
TSQLAlchemyTrackingRecorder = TypeVar(
349+
"TSQLAlchemyTrackingRecorder",
350+
bound=SQLAlchemyTrackingRecorder,
351+
default=SQLAlchemyTrackingRecorder,
352+
)
353+
354+
355+
class SQLAlchemyProcessRecorder(
356+
SQLAlchemyTrackingRecorder, SQLAlchemyApplicationRecorder, ProcessRecorder
357+
):
358+
def __init__(
359+
self,
360+
datastore: SQLAlchemyDatastore,
361+
*,
362+
events_table_name: str,
363+
tracking_table_name: str,
364+
schema_name: str | None = None,
365+
):
366+
super().__init__(
367+
datastore=datastore,
368+
tracking_table_name=tracking_table_name,
369+
events_table_name=events_table_name,
370+
schema_name=schema_name,
371+
)
372+
373+
def _insert_events(
374+
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
375+
) -> None:
376+
tracking: Optional[Tracking] = kwargs.get("tracking", None)
377+
if tracking is not None:
378+
self._insert_tracking(session, tracking)
379+
super()._insert_events(session=session, stored_events=stored_events, **kwargs)

mypy.ini

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[mypy]
2-
python_version = 3.9
2+
python_version = 3.10
33
files = eventsourcing_sqlalchemy,tests
44

55
check_untyped_defs = True

0 commit comments

Comments
 (0)