diff --git a/pyproject.toml b/pyproject.toml index 9da631b985..5f07203236 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,6 +54,8 @@ httpx = "0.24.1" dirty-equals = "^0.6.0" typer-cli = "^0.0.13" mkdocs-markdownextradata-plugin = ">=0.1.7,<0.3.0" +pytest-asyncio = "0.21.1" +aiosqlite = "0.19.0" [build-system] requires = ["poetry-core"] diff --git a/sqlmodel/ext/asyncio/async_model.py b/sqlmodel/ext/asyncio/async_model.py new file mode 100644 index 0000000000..88c8ee8db2 --- /dev/null +++ b/sqlmodel/ext/asyncio/async_model.py @@ -0,0 +1,68 @@ +from typing import Any, ClassVar, Coroutine, Dict, Tuple, Type + +from sqlalchemy.util.concurrency import greenlet_spawn + +from ... import SQLModel +from ..._compat import Representation, get_annotations +from ...main import SQLModelMetaclass + + +class AwaitableFieldInfo(Representation): + def __init__(self, *, field: str): + self.field = field + + +def AwaitableField(*, field: str) -> Any: + return AwaitableFieldInfo(field=field) + + +class AsyncSQLModelMetaclass(SQLModelMetaclass): + __async_sqlmodel_awaitable_fields__: Dict[str, AwaitableFieldInfo] + + def __new__( + cls, + name: str, + bases: Tuple[Type[Any], ...], + class_dict: Dict[str, Any], + **kwargs: Any, + ) -> Any: + awaitable_fields: Dict[str, AwaitableFieldInfo] = {} + dict_for_sqlmodel = {} + original_annotations = get_annotations(class_dict) + sqlmodel_annotations = {} + awaitable_fields_annotations = {} + for k, v in class_dict.items(): + if isinstance(v, AwaitableFieldInfo): + awaitable_fields[k] = v + else: + dict_for_sqlmodel[k] = v + for k, v in original_annotations.items(): + if k in awaitable_fields: + awaitable_fields_annotations[k] = v + else: + sqlmodel_annotations[k] = v + + dict_used = { + **dict_for_sqlmodel, + "__async_sqlmodel_awaitable_fields__": awaitable_fields, + "__annotations__": sqlmodel_annotations, + } + return super().__new__(cls, name, bases, dict_used, **kwargs) + + def __init__( + cls, classname: str, bases: Tuple[type, ...], dict_: Dict[str, Any], **kw: Any + ) -> None: + for field_name, field_info in cls.__async_sqlmodel_awaitable_fields__.items(): + + def get_awaitable_field( + self: "AsyncSQLModel", field: str = field_info.field + ) -> Coroutine[Any, Any, Any]: + return greenlet_spawn(getattr, self, field) + + setattr(cls, field_name, property(get_awaitable_field)) + + SQLModelMetaclass.__init__(cls, classname, bases, dict_, **kw) + + +class AsyncSQLModel(SQLModel, metaclass=AsyncSQLModelMetaclass): + __async_sqlmodel_awaitable_fields__: ClassVar[Dict[str, AwaitableFieldInfo]] diff --git a/tests/test_advanced/test_async/__init__.py b/tests/test_advanced/test_async/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/test_advanced/test_async/test_awaitable_field.py b/tests/test_advanced/test_async/test_awaitable_field.py new file mode 100644 index 0000000000..9c7fba94ff --- /dev/null +++ b/tests/test_advanced/test_async/test_awaitable_field.py @@ -0,0 +1,75 @@ +from typing import Awaitable, List, Optional + +import pytest +from sqlalchemy.exc import MissingGreenlet +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.util.concurrency import greenlet_spawn +from sqlmodel import Field, Relationship, SQLModel, select +from sqlmodel.ext.asyncio.async_model import AsyncSQLModel, AwaitableField +from sqlmodel.ext.asyncio.session import AsyncSession + + +@pytest.mark.asyncio +async def test_awaitable_nomral_field(clear_sqlmodel): + class Hero(AsyncSQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + secret_name: str + age: Optional[int] = None + awt_name: Awaitable[str] = AwaitableField(field="name") + awt_age: Awaitable[str] = AwaitableField(field="age") + + hero_deadpond = Hero(name="Deadpond", secret_name="Dive Wilson") + + engine = create_async_engine("sqlite+aiosqlite://") + await greenlet_spawn(SQLModel.metadata.create_all, engine.sync_engine) + + async with AsyncSession(engine) as session: + session.add(hero_deadpond) + await session.commit() + + # loading expired attribute will raise MissingGreenlet error + with pytest.raises(MissingGreenlet): + hero_deadpond.name # noqa: B018 + + name = await hero_deadpond.awt_name + assert name == "Deadpond" + + +@pytest.mark.asyncio +async def test_awaitable_relation_field(clear_sqlmodel): + class Team(AsyncSQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str = Field(index=True) + + heroes: List["Hero"] = Relationship() + awt_heroes: Awaitable[List["Hero"]] = AwaitableField(field="heroes") + + class Hero(AsyncSQLModel, table=True): + id: Optional[int] = Field(default=None, primary_key=True) + name: str + + team_id: Optional[int] = Field(default=None, foreign_key="team.id") + team: Optional[Team] = Relationship(back_populates="heroes") + awt_team: Awaitable[Optional[Team]] = AwaitableField(field="team") + + team_preventers = Team(name="Preventers") + hero_rusty_man = Hero(name="Rusty-Man", team=team_preventers) + + engine = create_async_engine("sqlite+aiosqlite://") + await greenlet_spawn(SQLModel.metadata.create_all, engine.sync_engine) + + async with AsyncSession(engine) as session: + session.add(hero_rusty_man) + await session.commit() + + async with AsyncSession(engine) as session: + hero = (await session.exec(select(Hero).where(Hero.name == "Rusty-Man"))).one() + + # loading lazy loading attribute will raise MissingGreenlet error + with pytest.raises(MissingGreenlet): + hero.team # noqa: B018 + + team = await hero.awt_team + assert team + assert team.name == "Preventers"