-
Notifications
You must be signed in to change notification settings - Fork 1.1k
Expand file tree
/
Copy pathunit_of_work.py
More file actions
68 lines (50 loc) · 1.64 KB
/
unit_of_work.py
File metadata and controls
68 lines (50 loc) · 1.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# pylint: disable=attribute-defined-outside-init
from __future__ import annotations
import abc
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker
from sqlalchemy.orm.session import Session
from allocation import config
from allocation.adapters import repository
from . import messagebus
class AbstractUnitOfWork(abc.ABC):
products: repository.AbstractRepository
def __enter__(self) -> AbstractUnitOfWork:
return self
def __exit__(self, *args):
self.rollback()
def commit(self):
self._commit()
self.publish_events()
def publish_events(self):
for product in self.products.seen:
while product.events:
event = product.events.pop(0)
messagebus.handle(event)
@abc.abstractmethod
def _commit(self):
raise NotImplementedError
@abc.abstractmethod
def rollback(self):
raise NotImplementedError
DEFAULT_SESSION_FACTORY = sessionmaker(
bind=create_engine(
config.get_postgres_uri(),
)
)
class SqlAlchemyUnitOfWork(AbstractUnitOfWork):
def __init__(self, session_factory=DEFAULT_SESSION_FACTORY):
self.session_factory = session_factory
def __enter__(self):
self.session = self.session_factory() # type: Session
self.products = repository.TrackingRepository(
repository.SqlAlchemyRepository(self.session)
)
return super().__enter__()
def __exit__(self, *args):
super().__exit__(*args)
self.session.close()
def _commit(self):
self.session.commit()
def rollback(self):
self.session.rollback()