Skip to content

Commit c3e18bd

Browse files
committed
Detect primary key code..
1 parent 3c51207 commit c3e18bd

1 file changed

Lines changed: 32 additions & 6 deletions

File tree

pymfdata/rdb/repository.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
from abc import abstractmethod
21
from contextlib import AbstractAsyncContextManager, AbstractContextManager
32
from typing import AsyncIterator, Callable, final, Iterator, Protocol, TypeVar
43
from sqlalchemy.ext.asyncio import AsyncSession
@@ -15,10 +14,21 @@
1514
class AsyncRepository(Protocol[_MT]):
1615
_model: _MT
1716
_session_factory: Callable[..., AbstractAsyncContextManager]
17+
_pk_column: str
18+
19+
async def delete_by_pk(self, pk):
20+
if not await self.is_exists(**{self._pk_column: pk}):
21+
raise NotFoundException()
22+
23+
item = await self.find_by_pk(pk)
24+
25+
session: AsyncSession
26+
async with self._session_factory() as session:
27+
await session.delete(item)
28+
await session.commit()
1829

19-
@abstractmethod
2030
async def find_by_pk(self, pk) -> _MT:
21-
raise NotImplementedError("Required implementation {}".format(self.find_by_pk.__name__))
31+
return await self.find_by_col(**{self._pk_column: pk})
2232

2333
@final
2434
async def find_by_col(self, **kwargs) -> _MT:
@@ -59,15 +69,31 @@ async def save(self, item: Base):
5969
async with self._session_factory() as session:
6070
session.add(item)
6171
await session.commit()
72+
await session.refresh(item)
73+
74+
async def update_by_pk(self, pk, req: dict):
75+
if not await self.is_exists(**{self._pk_column: pk}):
76+
raise NotFoundException()
77+
78+
item = await self.find_by_pk(pk)
79+
80+
session: AsyncSession
81+
async with self._session_factory() as session:
82+
for k, v in req.items():
83+
if v is not None:
84+
setattr(item, k, v)
85+
86+
await session.commit()
87+
await session.refresh(item)
6288

6389

6490
class SyncRepository(Protocol[_MT]):
6591
_model: _MT
6692
_session_factory: Callable[..., AbstractContextManager]
93+
_pk_column: str
6794

68-
@abstractmethod
6995
def find_by_pk(self, pk) -> _MT:
70-
raise NotImplementedError("Required implementation")
96+
return self.find_by_col(**{self._pk_column: pk})
7197

7298
@final
7399
def find_by_col(self, **kwargs) -> _MT:
@@ -83,7 +109,7 @@ def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
83109
query = session.query(self._model)
84110
if kwargs:
85111
for key, value in kwargs.items():
86-
query = query.filter(getattr(self._model.key) == value)
112+
query = query.filter(getattr(self._model, key) == value)
87113
return query
88114

89115
@final

0 commit comments

Comments
 (0)