Skip to content

Commit f65969f

Browse files
committed
Add generic..
1 parent 8791b5d commit f65969f

1 file changed

Lines changed: 12 additions & 12 deletions

File tree

pymfdata/rdb/repository.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from abc import ABC, abstractmethod
22
from contextlib import AbstractAsyncContextManager, AbstractContextManager
3-
from typing import AsyncIterator, Callable, final, Iterator, Type, TypeVar
3+
from typing import AsyncIterator, Callable, final, Generic, Iterator, TypeVar
44
from sqlalchemy.ext.asyncio import AsyncSession
55
from sqlalchemy.future import select
66
from sqlalchemy.orm import Session, Query
@@ -9,20 +9,20 @@
99
from pymfdata.rdb.connection import Base
1010
from pymfdata.common.errors import NotFoundException
1111

12-
ModelType = TypeVar("ModelType", bound=Base)
12+
MT = TypeVar("MT", bound=Base)
1313

1414

15-
class AsyncRepository(ABC):
16-
def __init__(self, model: Type[ModelType], session_factory: Callable[..., AbstractAsyncContextManager]) -> None:
15+
class AsyncRepository(ABC, Generic[MT]):
16+
def __init__(self, model: MT, session_factory: Callable[..., AbstractAsyncContextManager]) -> None:
1717
self.model = model
1818
self._session_factory = session_factory
1919

2020
@abstractmethod
21-
async def find_by_pk(self, pk) -> Type[ModelType]:
21+
async def find_by_pk(self, pk) -> MT:
2222
raise NotImplementedError("Required implementation")
2323

2424
@final
25-
async def find_by_col(self, **kwargs) -> Type[ModelType]:
25+
async def find_by_col(self, **kwargs) -> MT:
2626
if not await self.is_exists(**kwargs):
2727
raise NotFoundException()
2828

@@ -40,7 +40,7 @@ def _gen_stmt_for_param(self, **kwargs) -> Select:
4040
return stmt
4141

4242
@final
43-
async def find_all(self, **kwargs) -> AsyncIterator[Type[ModelType]]:
43+
async def find_all(self, **kwargs) -> AsyncIterator[MT]:
4444
session: AsyncSession
4545
async with self._session_factory() as session:
4646
stmt = self._gen_stmt_for_param(**kwargs)
@@ -62,17 +62,17 @@ async def save(self, item: Base):
6262
await session.commit()
6363

6464

65-
class SyncRepository(ABC):
66-
def __init__(self, model: Type[ModelType], session_factory: Callable[..., AbstractContextManager]) -> None:
65+
class SyncRepository(ABC, Generic[MT]):
66+
def __init__(self, model: MT, session_factory: Callable[..., AbstractContextManager]) -> None:
6767
self.model = model
6868
self._session_factory = session_factory
6969

7070
@abstractmethod
71-
def find_by_pk(self, pk) -> Type[ModelType]:
71+
def find_by_pk(self, pk) -> MT:
7272
raise NotImplementedError("Required implementation")
7373

7474
@final
75-
def find_by_col(self, **kwargs) -> Type[ModelType]:
75+
def find_by_col(self, **kwargs) -> MT:
7676
if not self.is_exists(**kwargs):
7777
raise NotFoundException()
7878

@@ -89,7 +89,7 @@ def _gen_query_for_param(self, session: Session, **kwargs) -> Query:
8989
return query
9090

9191
@final
92-
def find_all(self, **kwargs) -> Iterator[Type[ModelType]]:
92+
def find_all(self, **kwargs) -> Iterator[MT]:
9393
session: Session
9494
with self._session_factory() as session:
9595
query = self._gen_query_for_param(session, **kwargs)

0 commit comments

Comments
 (0)