Skip to content

Commit f1f8f2f

Browse files
committed
implement SQLAlchemySubscription
1 parent 2e945c8 commit f1f8f2f

7 files changed

Lines changed: 637 additions & 8 deletions

File tree

eventsourcing_sqlalchemy/datastore.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22
from __future__ import annotations
33

44
import sqlite3
5+
from contextlib import contextmanager
56
from contextvars import ContextVar, Token
67
from threading import Lock, Semaphore
7-
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union, cast
8+
from typing import Any, Dict, Iterator, List, Optional, Tuple, Type, Union, cast
89

10+
import psycopg
11+
import psycopg2
912
import sqlalchemy.exc
1013
from eventsourcing.persistence import (
1114
DatabaseError,
@@ -23,6 +26,7 @@
2326
from sqlalchemy.future import create_engine
2427
from sqlalchemy.orm import Session, scoped_session, sessionmaker
2528
from sqlalchemy.pool import StaticPool
29+
from typing_extensions import TypeVar
2630

2731
from eventsourcing_sqlalchemy.models import ( # type: ignore
2832
EventRecord,
@@ -263,3 +267,30 @@ def define_record_class(
263267
)
264268
cls.record_classes[record_classes_key] = (record_class, base_cls)
265269
return cast(Type[TEventRecord], record_class)
270+
271+
@contextmanager
272+
def get_connection(self) -> Iterator[Connection]:
273+
try:
274+
assert self.engine
275+
conn = self.engine.connect()
276+
yield conn
277+
except (psycopg.InterfaceError, psycopg2.InterfaceError) as e:
278+
raise InterfaceError(str(e)) from e
279+
except (psycopg.OperationalError, psycopg2.OperationalError) as e:
280+
raise OperationalError(str(e)) from e
281+
except (psycopg.DataError, psycopg2.DataError) as e:
282+
raise DataError(str(e)) from e
283+
except (psycopg.IntegrityError, psycopg2.IntegrityError) as e:
284+
raise IntegrityError(str(e)) from e
285+
except (psycopg.InternalError, psycopg2.InternalError) as e:
286+
raise InternalError(str(e)) from e
287+
except (psycopg.ProgrammingError, psycopg2.ProgrammingError) as e:
288+
raise ProgrammingError(str(e)) from e
289+
except (psycopg.NotSupportedError, psycopg2.NotSupportedError) as e:
290+
raise NotSupportedError(str(e)) from e
291+
except (psycopg.DatabaseError, psycopg2.DatabaseError) as e:
292+
raise DatabaseError(str(e)) from e
293+
except (psycopg.Error, psycopg2.Error) as e:
294+
raise PersistenceError(str(e)) from e
295+
except Exception:
296+
raise

eventsourcing_sqlalchemy/recorders.py

Lines changed: 96 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# -*- coding: utf-8 -*-
22
from __future__ import annotations
33

4-
from typing import Any, List, Optional, Sequence, Type, cast
4+
import select
5+
import time
6+
from threading import Thread
7+
from typing import Any, Callable, List, Optional, Sequence, Type, cast
58
from uuid import UUID
69

710
from eventsourcing.persistence import (
811
AggregateRecorder,
912
ApplicationRecorder,
1013
IntegrityError,
14+
ListenNotifySubscription,
1115
Notification,
1216
ProcessRecorder,
17+
ProgrammingError,
1318
StoredEvent,
1419
Subscription,
1520
Tracking,
@@ -29,6 +34,8 @@
2934
class SQLAlchemyRecorder:
3035
"""Base class for recorders that use SQLAlchemy."""
3136

37+
POSTGRES_MAX_IDENTIFIER_LEN = 63
38+
3239
def __init__(
3340
self,
3441
datastore: SQLAlchemyDatastore,
@@ -38,6 +45,13 @@ def __init__(
3845
self.schema_name = schema_name
3946
self.tables: List[Table] = []
4047

48+
def check_identifier_length(self, table_name: str) -> None:
49+
assert self.datastore.engine is not None
50+
if self.datastore.engine.dialect.name == "postgresql":
51+
if len(table_name) > SQLAlchemyRecorder.POSTGRES_MAX_IDENTIFIER_LEN:
52+
msg = f"Identifier too long: {table_name}"
53+
raise ProgrammingError(msg)
54+
4155
def create_table(self) -> None:
4256
assert self.datastore.engine is not None
4357
for table in self.tables:
@@ -233,12 +247,12 @@ def _insert_stored_events(
233247
session.add(record)
234248
if self._has_autoincrementing_ids:
235249
session.flush() # We want the autoincremented IDs now.
250+
self._notify_channel(session)
236251
return [cast(StoredEventRecord, r).id for r in records]
237252

238253
def max_notification_id(self) -> int | None:
239254
try:
240255
with self.transaction(commit=False) as session:
241-
# record_class = cast(Type[StoredEventRecord], self.events_record_cls)
242256
record_class = self.events_record_cls
243257
q = session.query(record_class)
244258
q = q.order_by(record_class.id.desc())
@@ -257,7 +271,6 @@ def select_notifications(
257271
inclusive_of_start: bool = True,
258272
) -> list[Notification]:
259273
with self.transaction(commit=False) as session:
260-
# record_class = cast(Type[StoredEventRecord], self.events_record_cls)
261274
record_class = self.events_record_cls
262275
q = session.query(record_class)
263276
if start is not None:
@@ -289,8 +302,86 @@ def select_notifications(
289302
def subscribe(
290303
self, gt: int | None = None, topics: Sequence[str] = ()
291304
) -> Subscription[ApplicationRecorder]:
292-
msg = "SQLAlchemyApplicationRecorder.subscribe() is not implemented"
293-
raise NotImplementedError(msg)
305+
assert self.datastore.engine
306+
if self.datastore.engine.dialect.name == "postgresql":
307+
return SQLAlchemySubscription(recorder=self, gt=gt, topics=topics)
308+
else:
309+
msg = "SQLAlchemyApplicationRecorder.subscribe() is not implemented for"
310+
msg += f"{self.datastore.engine.dialect}"
311+
raise NotImplementedError(msg)
312+
313+
def _notify_channel(self, session: Session) -> None:
314+
"""
315+
Send a NOTIFY on the channel using a SQLAlchemy connection.
316+
"""
317+
assert self.datastore.engine
318+
if self.datastore.engine.dialect.name == "postgresql":
319+
# Get the raw psycopg connection
320+
cursor = session.connection().connection.cursor()
321+
cursor.execute(f"NOTIFY {self.channel_name};")
322+
323+
324+
class SQLAlchemySubscription(ListenNotifySubscription[SQLAlchemyApplicationRecorder]):
325+
def __init__(
326+
self,
327+
recorder: SQLAlchemyApplicationRecorder,
328+
gt: int | None = None,
329+
topics: Sequence[str] = (),
330+
) -> None:
331+
assert isinstance(recorder, SQLAlchemyApplicationRecorder)
332+
super().__init__(recorder=recorder, gt=gt, topics=topics)
333+
self._listen_thread = Thread(target=self._listen)
334+
self._listen_thread.start()
335+
336+
def __exit__(self, *args: object, **kwargs: Any) -> None:
337+
super().__exit__(*args, **kwargs)
338+
self._listen_thread.join()
339+
340+
def _listen(self) -> None:
341+
assert self._recorder.datastore.engine
342+
assert self._recorder.datastore.engine.dialect.name == "postgresql"
343+
notification_handler = self.__get_notification_handler()
344+
345+
try:
346+
with self._recorder.datastore.get_connection() as sa_conn:
347+
sa_conn.execution_options(isolation_level="AUTOCOMMIT")
348+
raw_conn = sa_conn.connection
349+
350+
cur = raw_conn.cursor()
351+
cur.execute(f"LISTEN {self._recorder.channel_name};")
352+
353+
while not self._has_been_stopped and not self._thread_error:
354+
if select.select([raw_conn], [], [], 0.1)[0]:
355+
notification_handler(raw_conn)
356+
else:
357+
time.sleep(0.1)
358+
359+
except BaseException as e: # noqa: B036
360+
if self._thread_error is None:
361+
self._thread_error = e
362+
self.stop()
363+
364+
def __get_notification_handler(self) -> Callable[[Any], None]:
365+
assert self._recorder.datastore.engine
366+
driver_name = self._recorder.datastore.engine.dialect.driver
367+
handlers = {
368+
"psycopg": self.__handle_psycopg_notification,
369+
"psycopg2": self.__handle_psycopg2_notification,
370+
}
371+
try:
372+
return handlers[driver_name]
373+
except KeyError as e:
374+
raise NotImplementedError(f"Unsupported driver: {driver_name}") from e
375+
376+
def __handle_psycopg_notification(self, raw_conn: Any) -> None:
377+
next(raw_conn.notifies())
378+
self._has_been_notified.set()
379+
380+
def __handle_psycopg2_notification(self, raw_conn: Any) -> None:
381+
raw_conn.poll()
382+
if raw_conn.notifies:
383+
raw_conn.notifies.pop(0)
384+
self._has_been_notified.set()
294385

295386

296387
class SQLAlchemyTrackingRecorder(SQLAlchemyRecorder, TrackingRecorder):

poetry.lock

Lines changed: 72 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ flake8-coding = "*"
4040
flake8-isort = "*"
4141
flake8-tidy-imports = "*"
4242
isort = "*"
43+
msgspec = "*"
4344
mypy = "*"
4445
pre-commit = "*"
4546
pre-commit-hooks = "*"

0 commit comments

Comments
 (0)