Skip to content

Commit a64bfd3

Browse files
authored
SQLAlchemy 모델을 제네릭에서 지정하여 사용할 수 있도록 구현 (#7)
1 parent c3dcec4 commit a64bfd3

2 files changed

Lines changed: 65 additions & 55 deletions

File tree

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pymfdata/rdb/repository.py

Lines changed: 62 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from contextlib import AbstractAsyncContextManager, AbstractContextManager
2-
from typing import Callable, final, Iterator, List, Protocol, TypeVar, Optional
2+
from typing import Callable, final, Iterator, get_args, List, Protocol, Optional, Type, TypeVar
33
from sqlalchemy.ext.asyncio import AsyncSession
44
from sqlalchemy.future import select
5+
from sqlalchemy.inspection import inspect
56
from sqlalchemy.orm import Session, Query
67
from sqlalchemy.sql.selectable import Select
78

@@ -12,33 +13,35 @@
1213

1314

1415
class AsyncRepository(Protocol[_MT, _T]):
15-
_model: _MT
1616
_session_factory: Callable[..., AbstractAsyncContextManager]
17-
_pk_column: str
17+
18+
@property
19+
def _model(self):
20+
return get_args(self.__orig_bases__[0])[0]
21+
22+
@property
23+
def _pk_column(self) -> str:
24+
return inspect(self._model).primary_key[0].name
1825

1926
async def delete_by_pk(self, pk: _T) -> bool:
20-
item = await self.find_by_pk(pk)
21-
if item is not None:
22-
session: AsyncSession
23-
async with self._session_factory() as session:
27+
session: AsyncSession
28+
async with self._session_factory() as session:
29+
item = await self.find_by_pk(session, pk)
30+
if item is not None:
2431
await session.delete(item)
2532
await session.commit()
2633

27-
return True
28-
return False
34+
return True
2935

30-
async def find_by_pk(self, pk: _T) -> Optional[_MT]:
31-
return await self.find_by_col(**{self._pk_column: pk})
36+
return False
3237

33-
@final
34-
async def find_by_col(self, **kwargs) -> Optional[_MT]:
35-
if not await self.is_exists(**kwargs):
36-
return None
38+
async def find_by_pk(self, session: AsyncSession, pk: _T) -> Optional[_MT]:
39+
return await self.find_by_col(session, **{self._pk_column: pk})
3740

38-
session: AsyncSession
39-
async with self._session_factory() as session:
40-
item = await session.execute(self._gen_stmt_for_param(**kwargs))
41-
return item.unique().scalars().one()
41+
@final
42+
async def find_by_col(self, session: AsyncSession, **kwargs) -> Optional[_MT]:
43+
item = await session.execute(self._gen_stmt_for_param(**kwargs))
44+
return item.unique().scalars().one_or_none()
4245

4346
@final
4447
def _gen_stmt_for_param(self, **kwargs) -> Select:
@@ -61,59 +64,66 @@ async def find_all(self, **kwargs) -> List[_MT]:
6164
async def is_exists(self, **kwargs) -> bool:
6265
session: AsyncSession
6366
async with self._session_factory() as session:
64-
return await session.execute(self._gen_stmt_for_param(**kwargs).exists().select())
67+
result = await session.execute(self._gen_stmt_for_param(**kwargs).exists().select())
68+
return result.scalar()
6569

6670
@final
67-
async def save(self, item: Base):
71+
async def save(self, item: _MT):
6872
session: AsyncSession
6973
async with self._session_factory() as session:
7074
session.add(item)
7175
await session.commit()
7276
await session.refresh(item)
7377

7478
async def update_by_pk(self, pk: _T, req: dict) -> bool:
75-
item = await self.find_by_pk(pk)
76-
if item is not None:
77-
session: AsyncSession
78-
async with self._session_factory() as session:
79+
session: AsyncSession
80+
async with self._session_factory() as session:
81+
item = await self.find_by_pk(session, pk)
82+
if item is not None:
7983
for k, v in req.items():
8084
if v is not None:
8185
setattr(item, k, v)
8286

8387
await session.commit()
8488
await session.refresh(item)
8589

86-
return True
87-
return False
90+
return True
91+
return False
8892

8993

9094
class SyncRepository(Protocol[_MT, _T]):
91-
_model: _MT
9295
_session_factory: Callable[..., AbstractContextManager]
93-
_pk_column: str
96+
97+
@property
98+
def _model(self):
99+
return get_args(self.__orig_bases__[0])[0]
100+
101+
@property
102+
def _pk_column(self) -> str:
103+
return inspect(self._model).primary_key[0].name
104+
105+
@final
106+
def count(self, **kwargs) -> int:
107+
return self._gen_query_for_param(**kwargs).count()
94108

95109
def delete_by_pk(self, pk: _T) -> bool:
96-
item = self.find_by_pk(pk)
97-
if item is not None:
98-
session: Session
99-
with self._session_factory() as session:
110+
session: Session
111+
with self._session_factory() as session:
112+
item = self.find_by_pk(session, pk)
113+
if item is not None:
100114
session.delete(item)
101115
session.commit()
102116

103-
return True
104-
return False
117+
return True
118+
return False
105119

106-
def find_by_pk(self, pk: _T) -> Optional[_MT]:
107-
return self.find_by_col(**{self._pk_column: pk})
120+
def find_by_pk(self, session: Session, pk: _T) -> Optional[_MT]:
121+
return self.find_by_col(session, **{self._pk_column: pk})
108122

109123
@final
110-
def find_by_col(self, **kwargs) -> Optional[_MT]:
111-
if not self.is_exists(**kwargs):
112-
return None
113-
114-
with self._session_factory() as session:
115-
query = self._gen_query_for_param(session, **kwargs)
116-
return query.one()
124+
def find_by_col(self, session: Session, **kwargs) -> Optional[_MT]:
125+
query = self._gen_query_for_param(session, **kwargs)
126+
return query.one_or_none()
117127

118128
@final
119129
def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
@@ -144,16 +154,16 @@ def save(self, item: Base):
144154
session.commit()
145155

146156
def update_by_pk(self, pk: _T, req: dict) -> bool:
147-
item = self.find_by_pk(pk)
148-
if item is not None:
149-
session: Session
150-
with self._session_factory() as session:
157+
session: Session
158+
with self._session_factory() as session:
159+
item = self.find_by_pk(session, pk)
160+
if item is not None:
151161
for k, v in req.items():
152162
if v is not None:
153163
setattr(item, k, v)
154164

155-
await session.commit()
156-
await session.refresh(item)
165+
session.commit()
166+
session.refresh(item)
157167

158-
return True
159-
return False
168+
return True
169+
return False

0 commit comments

Comments
 (0)