diff --git a/.github/workflows/github-actions.yml b/.github/workflows/github-actions.yml index 01b4794..f9d0593 100644 --- a/.github/workflows/github-actions.yml +++ b/.github/workflows/github-actions.yml @@ -6,7 +6,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] + python-version: ["3.10", "3.11", "3.12", "3.13"] sqlalchemy-version: ["<2.0", "default"] runs-on: ubuntu-latest services: diff --git a/eventsourcing_sqlalchemy/factory.py b/eventsourcing_sqlalchemy/factory.py index 24b92eb..6a6d2d9 100644 --- a/eventsourcing_sqlalchemy/factory.py +++ b/eventsourcing_sqlalchemy/factory.py @@ -1,14 +1,13 @@ # -*- coding: utf-8 -*- from __future__ import annotations -from typing import Optional +from typing import Optional, cast from eventsourcing.persistence import ( AggregateRecorder, ApplicationRecorder, InfrastructureFactory, ProcessRecorder, - TrackingRecorder, ) from eventsourcing.utils import Environment, resolve_topic, strtobool from sqlalchemy.orm import scoped_session @@ -18,10 +17,12 @@ SQLAlchemyAggregateRecorder, SQLAlchemyApplicationRecorder, SQLAlchemyProcessRecorder, + SQLAlchemyTrackingRecorder, + TSQLAlchemyTrackingRecorder, ) -class SQLAlchemyFactory(InfrastructureFactory[TrackingRecorder]): +class SQLAlchemyFactory(InfrastructureFactory[SQLAlchemyTrackingRecorder]): SQLALCHEMY_URL = "SQLALCHEMY_URL" SQLALCHEMY_AUTOFLUSH = "SQLALCHEMY_AUTOFLUSH" SQLALCHEMY_CONNECTION_CREATOR_TOPIC = "SQLALCHEMY_CONNECTION_CREATOR_TOPIC" @@ -32,6 +33,7 @@ class SQLAlchemyFactory(InfrastructureFactory[TrackingRecorder]): datastore_class = SQLAlchemyDatastore aggregate_recorder_class = SQLAlchemyAggregateRecorder application_recorder_class = SQLAlchemyApplicationRecorder + tracking_recorder_class = SQLAlchemyTrackingRecorder process_recorder_class = SQLAlchemyProcessRecorder def __init__(self, env: Environment): @@ -90,6 +92,31 @@ def application_recorder(self) -> ApplicationRecorder: recorder.create_table() return recorder + def tracking_recorder( + self, tracking_recorder_class: type[TSQLAlchemyTrackingRecorder] | None = None + ) -> TSQLAlchemyTrackingRecorder: + prefix = self.env.name.lower() or "notification" + tracking_table_name = prefix + "_tracking" + if tracking_recorder_class is None: + tracking_recorder_topic = self.env.get(self.TRACKING_RECORDER_TOPIC) + if tracking_recorder_topic: + tracking_recorder_class = resolve_topic(tracking_recorder_topic) + else: + tracking_recorder_class = cast( + "type[TSQLAlchemyTrackingRecorder]", + type(self).tracking_recorder_class, + ) + assert tracking_recorder_class is not None + assert issubclass(tracking_recorder_class, SQLAlchemyTrackingRecorder) + recorder = tracking_recorder_class( + datastore=self.datastore, + tracking_table_name=tracking_table_name, + schema_name=self._schema_name, + ) + if self.env_create_table(): + recorder.create_table() + return recorder + def process_recorder(self) -> ProcessRecorder: prefix = self.env.name.lower() or "stored" events_table_name = prefix + "_events" @@ -105,11 +132,6 @@ def process_recorder(self) -> ProcessRecorder: recorder.create_table() return recorder - def tracking_recorder( - self, tracking_recorder_class: type[TrackingRecorder] | None = None - ) -> TrackingRecorder: - raise NotImplementedError - def env_create_table(self) -> bool: default = "yes" return bool(strtobool(self.env.get(self.CREATE_TABLE) or default)) diff --git a/eventsourcing_sqlalchemy/recorders.py b/eventsourcing_sqlalchemy/recorders.py index 9567417..0816de4 100644 --- a/eventsourcing_sqlalchemy/recorders.py +++ b/eventsourcing_sqlalchemy/recorders.py @@ -13,9 +13,11 @@ StoredEvent, Subscription, Tracking, + TrackingRecorder, ) from sqlalchemy import Column, Table, text from sqlalchemy.orm import Session +from typing_extensions import TypeVar from eventsourcing_sqlalchemy.datastore import SQLAlchemyDatastore, Transaction from eventsourcing_sqlalchemy.models import ( # type: ignore @@ -24,18 +26,41 @@ ) -class SQLAlchemyAggregateRecorder(AggregateRecorder): +class SQLAlchemyRecorder: + """Base class for recorders that use SQLAlchemy.""" + def __init__( self, datastore: SQLAlchemyDatastore, + schema_name: str | None = None, + ): + self.datastore = datastore + self.schema_name = schema_name + self.tables: List[Table] = [] + + def create_table(self) -> None: + assert self.datastore.engine is not None + for table in self.tables: + table.create(self.datastore.engine, checkfirst=True) + + def transaction(self, commit: bool = True) -> Transaction: + return self.datastore.transaction(commit=commit) + + +class SQLAlchemyAggregateRecorder(SQLAlchemyRecorder, AggregateRecorder): + def __init__( + self, + datastore: SQLAlchemyDatastore, + *, events_table_name: str, schema_name: str | None = None, for_snapshots: bool = False, ): - super().__init__() - self.datastore = datastore + super().__init__( + datastore, + schema_name=schema_name, + ) self.events_table_name = events_table_name - self.schema_name = schema_name record_cls_name = "".join( [ s.capitalize() @@ -55,27 +80,18 @@ def __init__( schema_name=self.schema_name, base_cls=base_cls, ) - self.stored_events_table = self.events_record_cls.__table__ - - def transaction(self, commit: bool = True) -> Transaction: - return self.datastore.transaction(commit=commit) - - def create_table(self) -> None: - assert self.datastore.engine is not None - self.stored_events_table.create(self.datastore.engine, checkfirst=True) + self.tables.append(self.events_record_cls.__table__) def insert_events( self, stored_events: Sequence[StoredEvent], **kwargs: Any ) -> Optional[Sequence[int]]: with self.transaction(commit=True) as session: - self._insert_events(session, stored_events, **kwargs) + self._insert_stored_events(session, stored_events, **kwargs) return None - def _insert_events( + def _insert_stored_events( self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any ) -> Optional[Sequence[int]]: - if len(stored_events) == 0: - return [] records = [ self.events_record_cls( originator_id=e.originator_id, @@ -91,9 +107,7 @@ def _insert_events( session.add(record) if self._has_autoincrementing_ids: session.flush() # We want the autoincremented IDs now. - return [cast(StoredEventRecord, r).id for r in records] - else: - return None + return None def _lock_table(self, session: Session) -> None: assert self.datastore.engine is not None @@ -160,6 +174,18 @@ def select_events( class SQLAlchemyApplicationRecorder(SQLAlchemyAggregateRecorder, ApplicationRecorder): + def __init__( + self, + datastore: SQLAlchemyDatastore, + *, + events_table_name: str, + schema_name: str | None = None, + ): + super().__init__( + datastore, events_table_name=events_table_name, schema_name=schema_name + ) + self.channel_name = self.events_table_name.replace(".", "_") + def insert_events( self, stored_events: Sequence[StoredEvent], @@ -169,12 +195,46 @@ def insert_events( ) -> Optional[Sequence[int]]: if session is not None: assert isinstance(session, Session), type(session) - notification_ids = self._insert_events(session, stored_events, **kwargs) + self._insert_events(session, stored_events, **kwargs) + notification_ids = self._insert_stored_events( + session, stored_events, **kwargs + ) else: with self.transaction(commit=True) as session: - notification_ids = self._insert_events(session, stored_events, **kwargs) + self._insert_events(session, stored_events, **kwargs) + notification_ids = self._insert_stored_events( + session, stored_events, **kwargs + ) return notification_ids + def _insert_events( + self, + session: Session, + stored_events: Sequence[StoredEvent], + **_: Any, + ) -> Optional[Sequence[int]]: + pass + + def _insert_stored_events( + self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any + ) -> Sequence[int]: + records = [ + self.events_record_cls( + originator_id=e.originator_id, + originator_version=e.originator_version, + topic=e.topic, + state=e.state, + ) + for e in stored_events + ] + if self._has_autoincrementing_ids: + self._lock_table(session) + for record in records: + session.add(record) + if self._has_autoincrementing_ids: + session.flush() # We want the autoincremented IDs now. + return [cast(StoredEventRecord, r).id for r in records] + def max_notification_id(self) -> int | None: try: with self.transaction(commit=False) as session: @@ -233,51 +293,30 @@ def subscribe( raise NotImplementedError(msg) -class SQLAlchemyProcessRecorder(SQLAlchemyApplicationRecorder, ProcessRecorder): +class SQLAlchemyTrackingRecorder(SQLAlchemyRecorder, TrackingRecorder): def __init__( self, datastore: SQLAlchemyDatastore, - events_table_name: str, - tracking_table_name: str, + *, + tracking_table_name: str = "notification_tracking", schema_name: str | None = None, + **kwargs: Any, ): - super().__init__( - datastore=datastore, - events_table_name=events_table_name, - schema_name=schema_name, - ) + super().__init__(datastore=datastore, **kwargs) self.tracking_table_name = tracking_table_name self.tracking_record_cls = self.datastore.define_record_class( cls_name="NotificationTrackingRecord", table_name=self.tracking_table_name, - schema_name=self.schema_name, + schema_name=schema_name, base_cls=datastore.base_notification_tracking_record_cls, ) self.tracking_table: Table = self.tracking_record_cls.__table__ def create_table(self) -> None: super().create_table() + assert self.datastore.engine is not None self.tracking_table.create(self.datastore.engine, checkfirst=True) - def _insert_events( - self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any - ) -> Optional[Sequence[int]]: - notification_ids = super(SQLAlchemyProcessRecorder, self)._insert_events( - session, stored_events, **kwargs - ) - tracking: Optional[Tracking] = kwargs.get("tracking", None) - if tracking is not None: - if self.has_tracking_id( - tracking.application_name, tracking.notification_id - ): - raise IntegrityError - record = self.tracking_record_cls( - application_name=tracking.application_name, - notification_id=tracking.notification_id, - ) - session.add(record) - return notification_ids - def max_tracking_id(self, application_name: str) -> int | None: with self.transaction(commit=False) as session: q = session.query(self.tracking_record_cls) @@ -290,4 +329,51 @@ def max_tracking_id(self, application_name: str) -> int | None: return max_id def insert_tracking(self, tracking: Tracking) -> None: - raise NotImplementedError + with self.transaction(commit=True) as session: + self._insert_tracking(session=session, tracking=tracking) + + def _insert_tracking(self, session: Session, tracking: Tracking) -> None: + if tracking is not None: + if self.has_tracking_id( + tracking.application_name, tracking.notification_id + ): + raise IntegrityError + record = self.tracking_record_cls( + application_name=tracking.application_name, + notification_id=tracking.notification_id, + ) + session.add(record) + + +TSQLAlchemyTrackingRecorder = TypeVar( + "TSQLAlchemyTrackingRecorder", + bound=SQLAlchemyTrackingRecorder, + default=SQLAlchemyTrackingRecorder, +) + + +class SQLAlchemyProcessRecorder( + SQLAlchemyTrackingRecorder, SQLAlchemyApplicationRecorder, ProcessRecorder +): + def __init__( + self, + datastore: SQLAlchemyDatastore, + *, + events_table_name: str, + tracking_table_name: str, + schema_name: str | None = None, + ): + super().__init__( + datastore=datastore, + tracking_table_name=tracking_table_name, + events_table_name=events_table_name, + schema_name=schema_name, + ) + + def _insert_events( + self, session: Session, stored_events: Sequence[StoredEvent], **kwargs: Any + ) -> None: + tracking: Optional[Tracking] = kwargs.get("tracking", None) + if tracking is not None: + self._insert_tracking(session, tracking) + super()._insert_events(session=session, stored_events=stored_events, **kwargs) diff --git a/mypy.ini b/mypy.ini index 46a1b3f..a60333c 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,5 @@ [mypy] -python_version = 3.9 +python_version = 3.10 files = eventsourcing_sqlalchemy,tests check_untyped_defs = True diff --git a/poetry.lock b/poetry.lock index 605921e..a899bfa 100644 --- a/poetry.lock +++ b/poetry.lock @@ -765,14 +765,14 @@ wmi = ["wmi (>=1.5.1)"] [[package]] name = "eventsourcing" -version = "9.4.6" +version = "9.5.2" description = "Event sourcing in Python" optional = false -python-versions = ">=3.9.2" +python-versions = ">=3.10.0" groups = ["main", "dev"] files = [ - {file = "eventsourcing-9.4.6-py3-none-any.whl", hash = "sha256:5663f940156f55133eb407bb962349a631942fddb69064ab1040c2010cd6a761"}, - {file = "eventsourcing-9.4.6.tar.gz", hash = "sha256:093742d5c2e4fcbe835c4b38c8d2193e1c38810c9d27eb76516f55947b096baa"}, + {file = "eventsourcing-9.5.2-py3-none-any.whl", hash = "sha256:a1343af2cf7aacf3e2af41d90f9dbf0cca4a6475532bc7632b10d730e35fcade"}, + {file = "eventsourcing-9.5.2.tar.gz", hash = "sha256:11dd225de6ae7a5598de950685d380ba395c106c89cacd731ba914dbeae2fd28"}, ] [package.dependencies] @@ -973,7 +973,6 @@ files = [ [package.dependencies] blinker = ">=1.9.0" click = ">=8.1.3" -importlib-metadata = {version = ">=3.6.0", markers = "python_version < \"3.10\""} itsdangerous = ">=2.2.0" jinja2 = ">=3.1.2" markupsafe = ">=2.1.1" @@ -1239,31 +1238,6 @@ files = [ [package.extras] all = ["flake8 (>=7.1.1)", "mypy (>=1.11.2)", "pytest (>=8.3.2)", "ruff (>=0.6.2)"] -[[package]] -name = "importlib-metadata" -version = "8.7.0" -description = "Read metadata from Python packages" -optional = false -python-versions = ">=3.9" -groups = ["dev"] -markers = "python_version < \"3.10\"" -files = [ - {file = "importlib_metadata-8.7.0-py3-none-any.whl", hash = "sha256:e5dd1551894c77868a30651cef00984d50e1002d06942a7101d34870c5f02afd"}, - {file = "importlib_metadata-8.7.0.tar.gz", hash = "sha256:d13b81ad223b890aa16c5471f2ac3056cf76c5f10f82d6f9292f0b415f389000"}, -] - -[package.dependencies] -zipp = ">=3.20" - -[package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] -cover = ["pytest-cov"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -enabler = ["pytest-enabler (>=2.2)"] -perf = ["ipython"] -test = ["flufl.flake8", "importlib_resources (>=1.3) ; python_version < \"3.9\"", "jaraco.test (>=5.4)", "packaging", "pyfakefs", "pytest (>=6,!=8.1.*)", "pytest-perf (>=0.9.2)"] -type = ["pytest-mypy"] - [[package]] name = "iniconfig" version = "2.1.0" @@ -2851,7 +2825,6 @@ files = [ [package.dependencies] anyio = ">=3.6.2,<5" -typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} [package.extras] full = ["httpx (>=0.27.0,<0.29.0)", "itsdangerous", "jinja2", "python-multipart (>=0.0.18)", "pyyaml"] @@ -3130,28 +3103,7 @@ idna = ">=2.0" multidict = ">=4.0" propcache = ">=0.2.1" -[[package]] -name = "zipp" -version = "3.23.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.9" -groups = ["dev"] -markers = "python_version < \"3.10\"" -files = [ - {file = "zipp-3.23.0-py3-none-any.whl", hash = "sha256:071652d6115ed432f5ce1d34c336c0adfd6a884660d1e9712a256d3d3bd4b14e"}, - {file = "zipp-3.23.0.tar.gz", hash = "sha256:a07157588a12518c9d4034df3fbbee09c814741a33ff63c05fa29d26a2404166"}, -] - -[package.extras] -check = ["pytest-checkdocs (>=2.4)", "pytest-ruff (>=0.2.1) ; sys_platform != \"cygwin\""] -cover = ["pytest-cov"] -doc = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -enabler = ["pytest-enabler (>=2.2)"] -test = ["big-O", "jaraco.functools", "jaraco.itertools", "jaraco.test", "more_itertools", "pytest (>=6,!=8.1.*)", "pytest-ignore-flaky"] -type = ["pytest-mypy"] - [metadata] lock-version = "2.1" -python-versions = "^3.9.2" -content-hash = "df76ef7c072c695e387efef10fa08f268297c03804c1da67378a7582b3fb7ee1" +python-versions = "^3.10.0" +content-hash = "ac61046780c5ba039cbc209c586d2e2ab271fa92e87346af17f9273b117595e9" diff --git a/pyproject.toml b/pyproject.toml index 77ac382..bd1da06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -14,7 +14,6 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", - "Programming Language :: Python :: 3.9", "Programming Language :: Python", ] readme = "README.md" @@ -23,12 +22,12 @@ repository = "https://github.com/pyeventsourcing/eventsourcing-sqlalchemy" include = ["eventsourcing_sqlalchemy/py.typed"] [tool.poetry.dependencies] -python = "^3.9.2" +python = "^3.10.0" #eventsourcing = { path = "../eventsourcing/", extras = ["crypto"] } #eventsourcing = { path = "../eventsourcing/", extras = ["crypto"], develop = true } #eventsourcing = { git = "https://github.com/pyeventsourcing/eventsourcing.git", branch = "main", extras = ["crypto"]} SQLAlchemy-Utils = ">=0.38.2" -eventsourcing = "^9.4.6" +eventsourcing = "^9.5.2" sqlalchemy = ">=1.4.26, <2.1" [tool.poetry.group.dev.dependencies] @@ -44,7 +43,7 @@ isort = "*" mypy = "*" pre-commit = "*" pre-commit-hooks = "*" -eventsourcing = { version = "^9.4.6", extras = ["crypto"] } +eventsourcing = { version = "^9.5.2", extras = ["crypto"] } psycopg = { version = "*", extras = ["binary", "pool"] } psycopg2-binary = "*" pytest = "*" diff --git a/tests/test_application.py b/tests/test_application.py index 354dff3..935deff 100644 --- a/tests/test_application.py +++ b/tests/test_application.py @@ -6,14 +6,13 @@ from eventsourcing.application import AggregateNotFoundError, Application from eventsourcing.domain import Aggregate from eventsourcing.tests.application import ExampleApplicationTestCase -from eventsourcing.tests.postgres_utils import drop_tables from eventsourcing.utils import clear_topic_cache, get_topic from fastapi_sqlalchemy import DBSessionMiddleware from sqlalchemy.engine.url import URL from sqlalchemy.orm import scoped_session from eventsourcing_sqlalchemy.factory import SQLAlchemyFactory -from tests.utils import drop_mssql_table +from tests.utils import drop_mssql_table, drop_pg_tables try: from sqlalchemy.orm import declarative_base # type: ignore @@ -280,7 +279,7 @@ def tearDown(self) -> None: super().tearDown() def drop_tables(self) -> None: - drop_tables() + drop_pg_tables() def test_example_application(self) -> None: super().test_example_application() @@ -297,7 +296,7 @@ def tearDown(self) -> None: del os.environ["SQLALCHEMY_SCHEMA"] def drop_tables(self) -> None: - drop_tables() + drop_pg_tables() @skip("SQL Server not supported yet") diff --git a/tests/test_factory.py b/tests/test_factory.py index c05e030..2d421e1 100644 --- a/tests/test_factory.py +++ b/tests/test_factory.py @@ -1,7 +1,6 @@ # -*- coding: utf-8 -*- import os from typing import Type -from unittest import skip from eventsourcing.persistence import ( AggregateRecorder, @@ -18,6 +17,7 @@ SQLAlchemyAggregateRecorder, SQLAlchemyApplicationRecorder, SQLAlchemyProcessRecorder, + SQLAlchemyTrackingRecorder, ) @@ -38,10 +38,13 @@ def expected_process_recorder_class(self) -> Type[ProcessRecorder]: return SQLAlchemyProcessRecorder def expected_tracking_recorder_class(self) -> type[TrackingRecorder]: - raise NotImplementedError + return SQLAlchemyTrackingRecorder + + class SQLAchemyTrackingRecorderSubclass(SQLAlchemyTrackingRecorder): + pass def tracking_recorder_subclass(self) -> type[TrackingRecorder]: - raise NotImplementedError + return self.SQLAchemyTrackingRecorderSubclass def setUp(self) -> None: self.env = Environment("TestCase") @@ -56,8 +59,5 @@ def tearDown(self) -> None: del os.environ[SQLAlchemyFactory.SQLALCHEMY_URL] super().tearDown() - def test_create_tracking_recorder(self) -> None: - skip("SQLAlchemyFactory doesn't implement tracking recorders yet") - del InfrastructureFactoryTestCase diff --git a/tests/test_noninterleaving_notification_ids.py b/tests/test_noninterleaving_notification_ids.py index 5981da4..b32366e 100644 --- a/tests/test_noninterleaving_notification_ids.py +++ b/tests/test_noninterleaving_notification_ids.py @@ -7,12 +7,11 @@ NonInterleavingNotificationIDsBaseCase, tmpfile_uris, ) -from eventsourcing.tests.postgres_utils import drop_tables from sqlalchemy.engine.url import URL from eventsourcing_sqlalchemy.datastore import SQLAlchemyDatastore from eventsourcing_sqlalchemy.recorders import SQLAlchemyApplicationRecorder -from tests.utils import drop_mssql_table +from tests.utils import drop_mssql_table, drop_pg_tables class TestNonInterleaving(NonInterleavingNotificationIDsBaseCase): @@ -68,7 +67,7 @@ def tearDown(self) -> None: super().tearDown() def drop_tables(self) -> None: - drop_tables() + drop_pg_tables() @skip("SQL Server not supported yet") diff --git a/tests/test_recorders.py b/tests/test_recorders.py index 540d61f..da76657 100644 --- a/tests/test_recorders.py +++ b/tests/test_recorders.py @@ -8,11 +8,13 @@ ProcessRecorder, StoredEvent, Tracking, + TrackingRecorder, ) from eventsourcing.tests.persistence import ( AggregateRecorderTestCase, ApplicationRecorderTestCase, ProcessRecorderTestCase, + TrackingRecorderTestCase, tmpfile_uris, ) from sqlalchemy.future import create_engine @@ -23,6 +25,7 @@ SQLAlchemyAggregateRecorder, SQLAlchemyApplicationRecorder, SQLAlchemyProcessRecorder, + SQLAlchemyTrackingRecorder, ) @@ -87,12 +90,12 @@ def test_insert_select(self) -> None: self.assertFalse(self.datastore.is_sqlite_wal_mode) super().test_insert_select() - def test_concurrent_no_conflicts(self) -> None: + def test_concurrent_no_conflicts(self, initial_position: int = 0) -> None: self.assertFalse(self.datastore.is_sqlite_wal_mode) self.assertTrue(self.datastore.access_lock) self.assertFalse(self.datastore.write_lock) self.assertIsInstance(self.datastore.access_lock, Semaphore) - super().test_concurrent_no_conflicts() + super().test_concurrent_no_conflicts(initial_position=initial_position) def test_concurrent_no_conflicts_sqlite_filedb(self) -> None: uris = tmpfile_uris() @@ -107,6 +110,19 @@ def test_concurrent_no_conflicts_sqlite_filedb(self) -> None: self.assertTrue(self.datastore.is_sqlite_wal_mode) +class TestSQLAlchemyTrackingRecorder(TrackingRecorderTestCase): + def setUp(self) -> None: + self.datastore = SQLAlchemyDatastore(url="sqlite:///:memory:") + + def create_recorder(self) -> TrackingRecorder: + recorder = SQLAlchemyTrackingRecorder( + datastore=self.datastore, + tracking_table_name="tracking", + ) + recorder.create_table() + return recorder + + class TestSQLAlchemyProcessRecorder(ProcessRecorderTestCase): def setUp(self) -> None: self.datastore = SQLAlchemyDatastore(url="sqlite:///:memory:") @@ -170,3 +186,4 @@ def test_max_tracking_id_query_should_be_filtered_by_application_name(self) -> N del AggregateRecorderTestCase del ApplicationRecorderTestCase del ProcessRecorderTestCase +del TrackingRecorderTestCase diff --git a/tests/utils.py b/tests/utils.py index c17de49..63a0e95 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -2,6 +2,8 @@ import subprocess from pathlib import Path +from sqlalchemy import MetaData, create_engine + BASE_DIR = Path(__file__).parents[1] @@ -9,3 +11,13 @@ def drop_mssql_table(table_name: str) -> None: subprocess.run( ["make", "drop-mssql-table", f"name={table_name}"], check=True, cwd=BASE_DIR ) + + +def drop_pg_tables() -> None: + url = "postgresql://eventsourcing:eventsourcing@localhost:5432/eventsourcing_sqlalchemy" + for schema in ["public", "myschema"]: + engine = create_engine(url=url) + meta = MetaData(schema=schema) + meta.reflect(bind=engine) + with engine.begin() as conn: + meta.drop_all(bind=conn)