Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/github-actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
python-version: ["3.10", "3.11", "3.12", "3.13"]
sqlalchemy-version: ["<2.0", "default"]
runs-on: ubuntu-latest
services:
Expand Down
38 changes: 30 additions & 8 deletions eventsourcing_sqlalchemy/factory.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
# -*- coding: utf-8 -*-
from __future__ import annotations

from typing import Optional
from typing import Optional, cast

from eventsourcing.persistence import (
AggregateRecorder,
ApplicationRecorder,
InfrastructureFactory,
ProcessRecorder,
TrackingRecorder,
)
from eventsourcing.utils import Environment, resolve_topic, strtobool
from sqlalchemy.orm import scoped_session
Expand All @@ -18,10 +17,12 @@
SQLAlchemyAggregateRecorder,
SQLAlchemyApplicationRecorder,
SQLAlchemyProcessRecorder,
SQLAlchemyTrackingRecorder,
TSQLAlchemyTrackingRecorder,
)


class SQLAlchemyFactory(InfrastructureFactory[TrackingRecorder]):
class SQLAlchemyFactory(InfrastructureFactory[SQLAlchemyTrackingRecorder]):
SQLALCHEMY_URL = "SQLALCHEMY_URL"
SQLALCHEMY_AUTOFLUSH = "SQLALCHEMY_AUTOFLUSH"
SQLALCHEMY_CONNECTION_CREATOR_TOPIC = "SQLALCHEMY_CONNECTION_CREATOR_TOPIC"
Expand All @@ -32,6 +33,7 @@ class SQLAlchemyFactory(InfrastructureFactory[TrackingRecorder]):
datastore_class = SQLAlchemyDatastore
aggregate_recorder_class = SQLAlchemyAggregateRecorder
application_recorder_class = SQLAlchemyApplicationRecorder
tracking_recorder_class = SQLAlchemyTrackingRecorder
process_recorder_class = SQLAlchemyProcessRecorder

def __init__(self, env: Environment):
Expand Down Expand Up @@ -90,6 +92,31 @@ def application_recorder(self) -> ApplicationRecorder:
recorder.create_table()
return recorder

def tracking_recorder(
self, tracking_recorder_class: type[TSQLAlchemyTrackingRecorder] | None = None
) -> TSQLAlchemyTrackingRecorder:
prefix = self.env.name.lower() or "notification"
tracking_table_name = prefix + "_tracking"
if tracking_recorder_class is None:
tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC)
if tracking_recorder_topic:
tracking_recorder_class = resolve_topic(tracking_recorder_topic)
else:
tracking_recorder_class = cast(
"type[TSQLAlchemyTrackingRecorder]",
type(self).tracking_recorder_class,
)
assert tracking_recorder_class is not None
assert issubclass(tracking_recorder_class, SQLAlchemyTrackingRecorder)
recorder = tracking_recorder_class(
datastore=self.datastore,
tracking_table_name=tracking_table_name,
schema_name=self._schema_name,
)
if self.env_create_table():
recorder.create_table()
return recorder

def process_recorder(self) -> ProcessRecorder:
prefix = self.env.name.lower() or "stored"
events_table_name = prefix + "_events"
Expand All @@ -105,11 +132,6 @@ def process_recorder(self) -> ProcessRecorder:
recorder.create_table()
return recorder

def tracking_recorder(
self, tracking_recorder_class: type[TrackingRecorder] | None = None
) -> TrackingRecorder:
raise NotImplementedError

def env_create_table(self) -> bool:
default = "yes"
return bool(strtobool(self.env.get(self.CREATE_TABLE) or default))
Expand Down
186 changes: 136 additions & 50 deletions eventsourcing_sqlalchemy/recorders.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
StoredEvent,
Subscription,
Tracking,
TrackingRecorder,
)
from sqlalchemy import Column, Table, text
from sqlalchemy.orm import Session
from typing_extensions import TypeVar

from eventsourcing_sqlalchemy.datastore import SQLAlchemyDatastore, Transaction
from eventsourcing_sqlalchemy.models import ( # type: ignore
Expand All @@ -24,18 +26,41 @@
)


class SQLAlchemyAggregateRecorder(AggregateRecorder):
class SQLAlchemyRecorder:
"""Base class for recorders that use SQLAlchemy."""

def __init__(
self,
datastore: SQLAlchemyDatastore,
schema_name: str | None = None,
):
self.datastore = datastore
self.schema_name = schema_name
self.tables: List[Table] = []

def create_table(self) -> None:
assert self.datastore.engine is not None
for table in self.tables:
table.create(self.datastore.engine, checkfirst=True)

def transaction(self, commit: bool = True) -> Transaction:
return self.datastore.transaction(commit=commit)


class SQLAlchemyAggregateRecorder(SQLAlchemyRecorder, AggregateRecorder):
def __init__(
self,
datastore: SQLAlchemyDatastore,
*,
events_table_name: str,
schema_name: str | None = None,
for_snapshots: bool = False,
):
super().__init__()
self.datastore = datastore
super().__init__(
datastore,
schema_name=schema_name,
)
self.events_table_name = events_table_name
self.schema_name = schema_name
record_cls_name = "".join(
[
s.capitalize()
Expand All @@ -55,27 +80,18 @@ def __init__(
schema_name=self.schema_name,
base_cls=base_cls,
)
self.stored_events_table = self.events_record_cls.__table__

def transaction(self, commit: bool = True) -> Transaction:
return self.datastore.transaction(commit=commit)

def create_table(self) -> None:
assert self.datastore.engine is not None
self.stored_events_table.create(self.datastore.engine, checkfirst=True)
self.tables.append(self.events_record_cls.__table__)

def insert_events(
self, stored_events: Sequence[StoredEvent], **kwargs: Any
) -> Optional[Sequence[int]]:
with self.transaction(commit=True) as session:
self._insert_events(session, stored_events, **kwargs)
self._insert_stored_events(session, stored_events, **kwargs)
return None

def _insert_events(
def _insert_stored_events(
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
) -> Optional[Sequence[int]]:
if len(stored_events) == 0:
return []
records = [
self.events_record_cls(
originator_id=e.originator_id,
Expand All @@ -91,9 +107,7 @@ def _insert_events(
session.add(record)
if self._has_autoincrementing_ids:
session.flush() # We want the autoincremented IDs now.
return [cast(StoredEventRecord, r).id for r in records]
else:
return None
return None

def _lock_table(self, session: Session) -> None:
assert self.datastore.engine is not None
Expand Down Expand Up @@ -160,6 +174,18 @@ def select_events(


class SQLAlchemyApplicationRecorder(SQLAlchemyAggregateRecorder, ApplicationRecorder):
def __init__(
self,
datastore: SQLAlchemyDatastore,
*,
events_table_name: str,
schema_name: str | None = None,
):
super().__init__(
datastore, events_table_name=events_table_name, schema_name=schema_name
)
self.channel_name = self.events_table_name.replace(".", "_")

def insert_events(
self,
stored_events: Sequence[StoredEvent],
Expand All @@ -169,12 +195,46 @@ def insert_events(
) -> Optional[Sequence[int]]:
if session is not None:
assert isinstance(session, Session), type(session)
notification_ids = self._insert_events(session, stored_events, **kwargs)
self._insert_events(session, stored_events, **kwargs)
notification_ids = self._insert_stored_events(
session, stored_events, **kwargs
)
else:
with self.transaction(commit=True) as session:
notification_ids = self._insert_events(session, stored_events, **kwargs)
self._insert_events(session, stored_events, **kwargs)
notification_ids = self._insert_stored_events(
session, stored_events, **kwargs
)
return notification_ids

def _insert_events(
self,
session: Session,
stored_events: Sequence[StoredEvent],
**_: Any,
) -> Optional[Sequence[int]]:
pass

def _insert_stored_events(
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
) -> Sequence[int]:
records = [
self.events_record_cls(
originator_id=e.originator_id,
originator_version=e.originator_version,
topic=e.topic,
state=e.state,
)
for e in stored_events
]
if self._has_autoincrementing_ids:
self._lock_table(session)
for record in records:
session.add(record)
if self._has_autoincrementing_ids:
session.flush() # We want the autoincremented IDs now.
return [cast(StoredEventRecord, r).id for r in records]

def max_notification_id(self) -> int | None:
try:
with self.transaction(commit=False) as session:
Expand Down Expand Up @@ -233,51 +293,30 @@ def subscribe(
raise NotImplementedError(msg)


class SQLAlchemyProcessRecorder(SQLAlchemyApplicationRecorder, ProcessRecorder):
class SQLAlchemyTrackingRecorder(SQLAlchemyRecorder, TrackingRecorder):
def __init__(
self,
datastore: SQLAlchemyDatastore,
events_table_name: str,
tracking_table_name: str,
*,
tracking_table_name: str = "notification_tracking",
schema_name: str | None = None,
**kwargs: Any,
):
super().__init__(
datastore=datastore,
events_table_name=events_table_name,
schema_name=schema_name,
)
super().__init__(datastore=datastore, **kwargs)
self.tracking_table_name = tracking_table_name
self.tracking_record_cls = self.datastore.define_record_class(
cls_name="NotificationTrackingRecord",
table_name=self.tracking_table_name,
schema_name=self.schema_name,
schema_name=schema_name,
base_cls=datastore.base_notification_tracking_record_cls,
)
self.tracking_table: Table = self.tracking_record_cls.__table__

def create_table(self) -> None:
super().create_table()
assert self.datastore.engine is not None
self.tracking_table.create(self.datastore.engine, checkfirst=True)

def _insert_events(
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
) -> Optional[Sequence[int]]:
notification_ids = super(SQLAlchemyProcessRecorder, self)._insert_events(
session, stored_events, **kwargs
)
tracking: Optional[Tracking] = kwargs.get("tracking", None)
if tracking is not None:
if self.has_tracking_id(
tracking.application_name, tracking.notification_id
):
raise IntegrityError
record = self.tracking_record_cls(
application_name=tracking.application_name,
notification_id=tracking.notification_id,
)
session.add(record)
return notification_ids

def max_tracking_id(self, application_name: str) -> int | None:
with self.transaction(commit=False) as session:
q = session.query(self.tracking_record_cls)
Expand All @@ -290,4 +329,51 @@ def max_tracking_id(self, application_name: str) -> int | None:
return max_id

def insert_tracking(self, tracking: Tracking) -> None:
raise NotImplementedError
with self.transaction(commit=True) as session:
self._insert_tracking(session=session, tracking=tracking)

def _insert_tracking(self, session: Session, tracking: Tracking) -> None:
if tracking is not None:
if self.has_tracking_id(
tracking.application_name, tracking.notification_id
):
raise IntegrityError
record = self.tracking_record_cls(
application_name=tracking.application_name,
notification_id=tracking.notification_id,
)
session.add(record)


TSQLAlchemyTrackingRecorder = TypeVar(
"TSQLAlchemyTrackingRecorder",
bound=SQLAlchemyTrackingRecorder,
default=SQLAlchemyTrackingRecorder,
)


class SQLAlchemyProcessRecorder(
SQLAlchemyTrackingRecorder, SQLAlchemyApplicationRecorder, ProcessRecorder
):
def __init__(
self,
datastore: SQLAlchemyDatastore,
*,
events_table_name: str,
tracking_table_name: str,
schema_name: str | None = None,
):
super().__init__(
datastore=datastore,
tracking_table_name=tracking_table_name,
events_table_name=events_table_name,
schema_name=schema_name,
)

def _insert_events(
self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any
) -> None:
tracking: Optional[Tracking] = kwargs.get("tracking", None)
if tracking is not None:
self._insert_tracking(session, tracking)
super()._insert_events(session=session, stored_events=stored_events, **kwargs)
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
[mypy]
python_version = 3.9
python_version = 3.10
files = eventsourcing_sqlalchemy,tests

check_untyped_defs = True
Expand Down
Loading