Skip to content

Commit 3c51207

Browse files
committed
Change ABC -> Protocol..
1 parent f65969f commit 3c51207

1 file changed

Lines changed: 20 additions & 22 deletions

File tree

pymfdata/rdb/repository.py

Lines changed: 20 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
from abc import ABC, abstractmethod
1+
from abc import abstractmethod
22
from contextlib import AbstractAsyncContextManager, AbstractContextManager
3-
from typing import AsyncIterator, Callable, final, Generic, Iterator, TypeVar
3+
from typing import AsyncIterator, Callable, final, Iterator, Protocol, TypeVar
44
from sqlalchemy.ext.asyncio import AsyncSession
55
from sqlalchemy.future import select
66
from sqlalchemy.orm import Session, Query
@@ -9,20 +9,19 @@
99
from pymfdata.rdb.connection import Base
1010
from pymfdata.common.errors import NotFoundException
1111

12-
MT = TypeVar("MT", bound=Base)
12+
_MT = TypeVar("_MT", bound=Base)
1313

1414

15-
class AsyncRepository(ABC, Generic[MT]):
16-
def __init__(self, model: MT, session_factory: Callable[..., AbstractAsyncContextManager]) -> None:
17-
self.model = model
18-
self._session_factory = session_factory
15+
class AsyncRepository(Protocol[_MT]):
16+
_model: _MT
17+
_session_factory: Callable[..., AbstractAsyncContextManager]
1918

2019
@abstractmethod
21-
async def find_by_pk(self, pk) -> MT:
22-
raise NotImplementedError("Required implementation")
20+
async def find_by_pk(self, pk) -> _MT:
21+
raise NotImplementedError("Required implementation {}".format(self.find_by_pk.__name__))
2322

2423
@final
25-
async def find_by_col(self, **kwargs) -> MT:
24+
async def find_by_col(self, **kwargs) -> _MT:
2625
if not await self.is_exists(**kwargs):
2726
raise NotFoundException()
2827

@@ -33,14 +32,14 @@ async def find_by_col(self, **kwargs) -> MT:
3332

3433
@final
3534
def _gen_stmt_for_param(self, **kwargs) -> Select:
36-
stmt = select(self.model)
35+
stmt = select(self._model)
3736
if kwargs:
3837
for key, value in kwargs.items():
39-
stmt = stmt.where(getattr(self.model, key) == value)
38+
stmt = stmt.where(getattr(self._model, key) == value)
4039
return stmt
4140

4241
@final
43-
async def find_all(self, **kwargs) -> AsyncIterator[MT]:
42+
async def find_all(self, **kwargs) -> AsyncIterator[_MT]:
4443
session: AsyncSession
4544
async with self._session_factory() as session:
4645
stmt = self._gen_stmt_for_param(**kwargs)
@@ -62,17 +61,16 @@ async def save(self, item: Base):
6261
await session.commit()
6362

6463

65-
class SyncRepository(ABC, Generic[MT]):
66-
def __init__(self, model: MT, session_factory: Callable[..., AbstractContextManager]) -> None:
67-
self.model = model
68-
self._session_factory = session_factory
64+
class SyncRepository(Protocol[_MT]):
65+
_model: _MT
66+
_session_factory: Callable[..., AbstractContextManager]
6967

7068
@abstractmethod
71-
def find_by_pk(self, pk) -> MT:
69+
def find_by_pk(self, pk) -> _MT:
7270
raise NotImplementedError("Required implementation")
7371

7472
@final
75-
def find_by_col(self, **kwargs) -> MT:
73+
def find_by_col(self, **kwargs) -> _MT:
7674
if not self.is_exists(**kwargs):
7775
raise NotFoundException()
7876

@@ -82,14 +80,14 @@ def find_by_col(self, **kwargs) -> MT:
8280

8381
@final
8482
def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
85-
query = session.query(self.model)
83+
query = session.query(self._model)
8684
if kwargs:
8785
for key, value in kwargs.items():
88-
query = query.filter(getattr(self.model.key) == value)
86+
query = query.filter(getattr(self._model.key) == value)
8987
return query
9088

9189
@final
92-
def find_all(self, **kwargs) -> Iterator[MT]:
90+
def find_all(self, **kwargs) -> Iterator[_MT]:
9391
session: Session
9492
with self._session_factory() as session:
9593
query = self._gen_query_for_param(session, **kwargs)

0 commit comments

Comments
 (0)