diff --git a/app/auth/api/endpoints.py b/app/auth/api/endpoints.py index 78a0f88..b1792b7 100644 --- a/app/auth/api/endpoints.py +++ b/app/auth/api/endpoints.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, Header, status, Response, Request from app.auth.use_cases.reset_password_use_case import ResetPasswordUseCase -from app.common.api.dependencies.get_session import SessionDependency +from app.common.api.dependencies import SessionDependency from app.auth.schemas.auth_schema import PasswordResetRequest, UserLogin from app.auth.use_cases.auth_user_use_case import AuthUserUseCase from app.common.exceptions.model_not_found_exception import ( diff --git a/app/celery/tasks/emails.py b/app/celery/tasks/emails.py index 8edc3d0..42b3d2f 100644 --- a/app/celery/tasks/emails.py +++ b/app/celery/tasks/emails.py @@ -1,29 +1,24 @@ -from uuid import UUID +from app.common.domain import PaginationCriteria +from app.core.config import settings +from app.db.session import SessionLocal from app.emails.exceptions.email_client_exception import EmailClientException -from app.common.schemas.pagination_schema import ListFilter from app.emails.services.emails_service import EmailService -from app.db.session import SessionLocal from app.main import celery - - -from app.core.config import get_settings -from app.users.schemas.user_schema import UserInDB -from app.users.services.users_service import UsersService - -settings = get_settings() +from app.users.domain import UserId +from app.users.infrastructure import SQLAlchemyUserRepository @celery.task def send_reminder_email() -> None: - session = SessionLocal() - try: - users = UsersService(session).list(ListFilter(page=1, page_size=100)) - for user in users.data: - EmailService().send_user_remind_email( - UserInDB.model_validate(user) - ) - finally: - session.close() + email_service = EmailService() + + with SessionLocal() as session: + user_repository = SQLAlchemyUserRepository(session) + users = user_repository.where( + PaginationCriteria(page=1, page_size=100) + ) + for user in users: + email_service.send_user_remind_email(user) @celery.task( @@ -32,11 +27,7 @@ def send_reminder_email() -> None: max_retries=settings.SEND_WELCOME_EMAIL_MAX_RETRIES, retry_jitter=False, ) -def send_welcome_email(user_id: UUID) -> None: - session = SessionLocal() - try: - user = UsersService(session).get_by_id(user_id) - if user: - EmailService().send_new_user_email(UserInDB.model_validate(user)) - finally: - session.close() +def send_welcome_email(user_id: UserId) -> None: + with SessionLocal() as session: + user = SQLAlchemyUserRepository(session).find_or_fail(user_id) + EmailService().send_new_user_email(user) diff --git a/app/common/api/dependencies/__init__.py b/app/common/api/dependencies/__init__.py index e69de29..c4aab35 100644 --- a/app/common/api/dependencies/__init__.py +++ b/app/common/api/dependencies/__init__.py @@ -0,0 +1 @@ +from .session_dependency import * diff --git a/app/common/api/dependencies/get_session.py b/app/common/api/dependencies/session_dependency.py similarity index 52% rename from app/common/api/dependencies/get_session.py rename to app/common/api/dependencies/session_dependency.py index 8f04089..7e879e7 100644 --- a/app/common/api/dependencies/get_session.py +++ b/app/common/api/dependencies/session_dependency.py @@ -1,12 +1,17 @@ -from typing import Annotated, Generator +__all__ = ( + "get_session", + "SessionDependency", +) -from fastapi import Depends +import typing as t + +import fastapi from sqlalchemy.orm import Session from app.db.session import SessionLocal -def get_session() -> Generator: +def get_session() -> t.Generator[Session, None, None]: session = SessionLocal() try: yield session @@ -18,4 +23,7 @@ def get_session() -> Generator: session.close() -SessionDependency = Annotated[Session, Depends(get_session)] +SessionDependency = t.Annotated[ + Session, + fastapi.Depends(get_session), +] diff --git a/app/common/domain/__init__.py b/app/common/domain/__init__.py new file mode 100644 index 0000000..deee356 --- /dev/null +++ b/app/common/domain/__init__.py @@ -0,0 +1,4 @@ +from .entities import * +from .criteria import * +from .pagination_criteria import * +from .repository import * diff --git a/app/common/domain/criteria.py b/app/common/domain/criteria.py new file mode 100644 index 0000000..2204183 --- /dev/null +++ b/app/common/domain/criteria.py @@ -0,0 +1,5 @@ +__all__ = ("Criteria",) + + +class Criteria[TEntity]: + pass diff --git a/app/common/domain/entities/__init__.py b/app/common/domain/entities/__init__.py new file mode 100644 index 0000000..ceda3b0 --- /dev/null +++ b/app/common/domain/entities/__init__.py @@ -0,0 +1,3 @@ +__all__ = ("Entity",) + +from .entity import Entity diff --git a/app/common/domain/entities/_exc.py b/app/common/domain/entities/_exc.py new file mode 100644 index 0000000..2de682b --- /dev/null +++ b/app/common/domain/entities/_exc.py @@ -0,0 +1,77 @@ +from __future__ import annotations + +__all__ = ( + "EntityDefinitionError", + "MissingDtoError", + "InvalidDtoError", +) + +import typing as t + +from app.common.exceptions import BaseApplicationError + +from . import _utils + +if t.TYPE_CHECKING: + from .entity import Entity + from .entity import _DtoConstraint + + +class EntityDefinitionError(BaseApplicationError): + pass + + +class MissingDtoError(EntityDefinitionError): + def __init__( + self, + entity_cls: type[Entity], + missing_dto_constraint: _DtoConstraint, + ) -> None: + dto_name, dto_cls = missing_dto_constraint + dto_bases = (dto_cls.__name__,) + model_name, model_bases = _utils.extract_class_info(entity_cls) + + message_lines = ( + f"Missing {dto_name} definition from {model_name} declaration.", + "", + "Example:", + _utils.build_nested_class_definition_example( + outer_class_name=model_name, + outer_class_bases=model_bases, + inner_class_name=dto_name, + inner_class_bases=dto_bases, + inner_class_statements=( + "# fields required to perform operation via repository", + ), + ), + ) + message = "\n".join(message_lines) + super().__init__(message) + + +class InvalidDtoError(EntityDefinitionError): + def __init__( + self, + entity_cls: type[Entity], + invalid_dto_constraint: _DtoConstraint, + ) -> None: + dto_name, dto_cls = invalid_dto_constraint + dto_bases = (dto_cls.__name__,) + model_name, model_bases = _utils.extract_class_info(entity_cls) + + message_lines = ( + f"Invalid {dto_name} schema definition, must be child of {dto_cls.__name__}", + "", + "Example:", + _utils.build_nested_class_definition_example( + outer_class_name=model_name, + outer_class_bases=model_bases, + inner_class_name=dto_name, + inner_class_bases=dto_bases, + inner_class_statements=( + "# fields required to perform operation via repository", + ), + ), + ) + message = "\n".join(message_lines) + super().__init__(message) diff --git a/app/common/domain/entities/_utils.py b/app/common/domain/entities/_utils.py new file mode 100644 index 0000000..1fabfac --- /dev/null +++ b/app/common/domain/entities/_utils.py @@ -0,0 +1,54 @@ +__all__ = ( + "build_nested_class_definition_example", + "extract_class_info", +) + +import typing as t + + +def build_nested_class_definition_example( + *, + outer_class_name: str = "Outer", + outer_class_bases: t.Sequence[str] | None = None, + inner_class_name: str = "Inner", + inner_class_bases: t.Sequence[str] | None = None, + inner_class_statements: t.Sequence[str] = ("# ...",), +) -> str: + output_lines = ( + _build_class_header(outer_class_name, outer_class_bases), + _indent(1, _build_class_header(inner_class_name, inner_class_bases)), + *_indent(2, inner_class_statements), + ) + + return "\n".join(output_lines) + + +def _build_class_header(name: str, bases: t.Sequence[str] | None) -> str: + return f"class {name}{f'({", ".join(bases)})' if bases else ''}:" + + +@t.overload +def _indent(level: int, val: str) -> str: ... +@t.overload +def _indent(level: int, val: t.Sequence[str]) -> t.Sequence[str]: ... + + +def _indent(level: int, val: str | t.Sequence[str]) -> str | t.Sequence[str]: + space = " " * level + + if isinstance(val, str): + return f"{space}{val}" + + return tuple(f"{space}{line}" for line in val) + + +class _ClassInfo(t.NamedTuple): + class_name: str + class_bases: t.Sequence[str] + + +def extract_class_info(cls: type) -> _ClassInfo: + return _ClassInfo( + class_name=cls.__name__, + class_bases=tuple(base.__name__ for base in cls.__bases__), + ) diff --git a/app/common/domain/entities/entity.py b/app/common/domain/entities/entity.py new file mode 100644 index 0000000..51c544b --- /dev/null +++ b/app/common/domain/entities/entity.py @@ -0,0 +1,98 @@ +__all__ = ("Entity",) + +import typing as t + +import pydantic + +from . import _exc + + +class _DtoConstraint(t.NamedTuple): + name: str + parent_cls: type + + +type _DtoConstraints = t.Collection[_DtoConstraint] +type _DtoConstraintEvaluation = _exc.EntityDefinitionError | None +type _EntityDefinitionErrors = t.Sequence[_exc.EntityDefinitionError] + + +# TODO: schema registry and automatic model rebuilding (fix circular imports for good) +# TODO: sentinel for unloaded relationships (fix shape of data problem) +class Entity(pydantic.BaseModel): + """ + Base model of the application, representing an object in the db. + + :var CreateDto: (required) Subschema for creating an object + :var UpdateDto: (required) Subschema for updating an object + + ## Example: + ``` + class User(Entity): + id: UserId + name: str + last_name: str + + class CreateDto(pydantic.BaseModel): + name: str + last_name: str + + class UpdateDto(pydantic.BaseModel): + name: str | None = None + last_name: str | None = None + ``` + """ + + model_config = pydantic.ConfigDict( + from_attributes=True, + frozen=True, + ) + + CreateDto: t.ClassVar[type[pydantic.BaseModel]] + UpdateDto: t.ClassVar[type[pydantic.BaseModel]] + + @classmethod + def __pydantic_init_subclass__(cls, **kw: t.Any) -> None: + super().__pydantic_init_subclass__(**kw) + cls._validate_entity_definition() + + __required_dtos__: t.Final[_DtoConstraints] = ( + _DtoConstraint("CreateDto", pydantic.BaseModel), + _DtoConstraint("UpdateDto", pydantic.BaseModel), + ) + + @classmethod + def _validate_entity_definition(cls) -> None: + errors = cls._collect_entity_definition_errors() + + if errors: + msg = f"Can't create {cls.__name__} entity class" + raise ExceptionGroup(msg, errors) + + @classmethod + def _collect_entity_definition_errors(cls) -> _EntityDefinitionErrors: + errors = ( + cls._evaluate_dto_constraint(dto_constraint) + for dto_constraint in cls.__required_dtos__ + ) + + return tuple(error for error in errors if error is not None) + + @classmethod + def _evaluate_dto_constraint( + cls, + dto_constraint: _DtoConstraint, + ) -> _DtoConstraintEvaluation: + dto_name, dto_parent_cls = dto_constraint + + if not hasattr(cls, dto_name): + return _exc.MissingDtoError(cls, dto_constraint) + + dto = getattr(cls, dto_name) + is_type = isinstance(dto, type) + is_subclass = is_type and issubclass(dto, dto_parent_cls) + + if not is_subclass: + return _exc.InvalidDtoError(cls, dto_constraint) + + return None diff --git a/app/common/domain/pagination_criteria.py b/app/common/domain/pagination_criteria.py new file mode 100644 index 0000000..b197a7f --- /dev/null +++ b/app/common/domain/pagination_criteria.py @@ -0,0 +1,23 @@ +__all__ = ("PaginationCriteria",) + +import typing as t + +import pydantic + +from .criteria import Criteria + +_PageField = t.Annotated[ + int, + pydantic.Field(ge=1), +] +_PageSizeField = t.Annotated[ + int, + pydantic.Field(ge=1, le=100), +] + + +class PaginationCriteria(pydantic.BaseModel, Criteria[t.Any]): + model_config = pydantic.ConfigDict(frozen=True) + + page: _PageField = 1 + page_size: _PageSizeField = 10 diff --git a/app/common/domain/repository.py b/app/common/domain/repository.py new file mode 100644 index 0000000..a04d9a3 --- /dev/null +++ b/app/common/domain/repository.py @@ -0,0 +1,71 @@ +__all__ = ("Repository",) + +import abc +import uuid + +import pydantic + +from .criteria import Criteria +from .entities import Entity + + +class Repository[ + TEntity: Entity, + TEntityId: uuid.UUID, + TCreate: pydantic.BaseModel, + TUpdate: pydantic.BaseModel, +](abc.ABC): + # Single entity operations + @abc.abstractmethod + def find(self, entity_id: TEntityId) -> TEntity | None: + raise NotImplementedError() + + @abc.abstractmethod + def find_or_fail(self, entity_id: TEntityId) -> TEntity: + raise NotImplementedError() + + @abc.abstractmethod + def create(self, dto: TCreate) -> TEntity: + raise NotImplementedError() + + @abc.abstractmethod + def update(self, entity_id: TEntityId, dto: TUpdate) -> TEntity: + raise NotImplementedError() + + @abc.abstractmethod + def delete(self, entity_id: TEntityId) -> None: + raise NotImplementedError() + + # Multiple entity operations + @abc.abstractmethod + def all(self) -> list[TEntity]: + raise NotImplementedError() + + @abc.abstractmethod + def where(self, *criteria: Criteria[TEntity]) -> list[TEntity]: + raise NotImplementedError() + + @abc.abstractmethod + def first(self, *criteria: Criteria[TEntity]) -> TEntity | None: + raise NotImplementedError() + + @abc.abstractmethod + def count(self, *criteria: Criteria[TEntity]) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def exists(self, *criteria: Criteria[TEntity]) -> bool: + raise NotImplementedError() + + # Bulk operations + @abc.abstractmethod + def insert_many(self, dtos: list[TCreate]) -> list[TEntity]: + raise NotImplementedError() + + @abc.abstractmethod + def update_where(self, dto: TUpdate, *criteria: Criteria[TEntity]) -> int: + raise NotImplementedError() + + @abc.abstractmethod + def delete_where(self, *criteria: Criteria[TEntity]) -> int: + raise NotImplementedError() diff --git a/app/common/exceptions/__init__.py b/app/common/exceptions/__init__.py index 55a834e..df56879 100644 --- a/app/common/exceptions/__init__.py +++ b/app/common/exceptions/__init__.py @@ -1,3 +1,9 @@ -from .external_provider_exception import ExternalProviderException +# TODO: rename all exceptions as errors to follow python convention + +from .base_application_error import * +from .external_provider_error import * +from .repository_errors import * + +# TODO: deprecate from .model_not_created_exception import ModelNotCreatedException from .model_not_found_exception import ModelNotFoundException diff --git a/app/common/exceptions/base_application_error.py b/app/common/exceptions/base_application_error.py new file mode 100644 index 0000000..4afe6e6 --- /dev/null +++ b/app/common/exceptions/base_application_error.py @@ -0,0 +1,9 @@ +__all__ = ("BaseApplicationError",) + + +class BaseApplicationError(Exception): + """Base exception for all custom exceptions of the application""" + + def __init__(self, message: str): + self.message = message + super().__init__(message) diff --git a/app/common/exceptions/external_provider_exception.py b/app/common/exceptions/external_provider_error.py similarity index 68% rename from app/common/exceptions/external_provider_exception.py rename to app/common/exceptions/external_provider_error.py index 6f484f6..25de66c 100644 --- a/app/common/exceptions/external_provider_exception.py +++ b/app/common/exceptions/external_provider_error.py @@ -1,4 +1,7 @@ -class ExternalProviderException(Exception): +__all__ = ("ExternalProviderError",) + + +class ExternalProviderError(Exception): def __init__( self, message: str = "Connection with external provider failed." ): diff --git a/app/common/exceptions/repository_errors.py b/app/common/exceptions/repository_errors.py new file mode 100644 index 0000000..6f0e144 --- /dev/null +++ b/app/common/exceptions/repository_errors.py @@ -0,0 +1,60 @@ +__all__ = ( + "RepositoryError", + "EntityNotCreatedError", + "EntityNotDeletedError", + "EntityNotFoundError", + "EntityNotUpdatedError", +) + +import re +import typing as t + +from .base_application_error import BaseApplicationError + + +class RepositoryError(BaseApplicationError): + pass + + +class EntityNotFoundError(RepositoryError): + def __init__(self, entity: type | str): + if isinstance(entity, str): + entity_name = entity + else: + entity_name = _separate_camel_case(entity.__name__) + super().__init__(f"{entity_name} not found") + + +class EntityNotCreatedError(RepositoryError): + def __init__(self, entity: type | str): + if isinstance(entity, str): + entity_name = entity + else: + entity_name = _separate_camel_case(entity.__name__) + super().__init__(f"{entity_name} not created") + + +class EntityNotUpdatedError(RepositoryError): + def __init__(self, entity: type | str): + if isinstance(entity, str): + entity_name = entity + else: + entity_name = _separate_camel_case(entity.__name__) + super().__init__(f"{entity_name} not updated") + + +class EntityNotDeletedError(RepositoryError): + def __init__(self, entity: type | str): + if isinstance(entity, str): + entity_name = entity + else: + entity_name = _separate_camel_case(entity.__name__) + super().__init__(f"{entity_name} not deleted") + + +# helpers +_CAMEL_CASE_PATTERN: t.Final[re.Pattern] = re.compile(r"(?<=[a-z])(?=[A-Z])") + + +def _separate_camel_case(s: str) -> str: + return _CAMEL_CASE_PATTERN.sub(" ", s) diff --git a/app/common/infrastructure/__init__.py b/app/common/infrastructure/__init__.py new file mode 100644 index 0000000..253824c --- /dev/null +++ b/app/common/infrastructure/__init__.py @@ -0,0 +1,3 @@ +from .base_sqlalchemy_model import * +from .base_sqlalchemy_repository import * +from .sqlalchemy_criteria_translator import * diff --git a/app/common/infrastructure/base_sqlalchemy_model.py b/app/common/infrastructure/base_sqlalchemy_model.py new file mode 100644 index 0000000..1fe2be7 --- /dev/null +++ b/app/common/infrastructure/base_sqlalchemy_model.py @@ -0,0 +1,36 @@ +import uuid +from datetime import datetime + +import sqlalchemy as sa +from sqlalchemy import orm + + +class BaseSQLAlchemyModel(orm.DeclarativeBase): + __abstract__ = True + + metadata = sa.MetaData( + naming_convention={ + "ix": "ix_%(column_0_label)s", + "uq": "uq_%(table_name)s_%(column_0_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", + } + ) + + type_annotation_map = { + uuid.UUID: sa.UUID, + datetime: sa.DateTime(timezone=True), + } + + id: orm.Mapped[uuid.UUID] = orm.mapped_column( + default=uuid.uuid4, + primary_key=True, + ) + created_at: orm.Mapped[datetime] = orm.mapped_column( + server_default=sa.func.now(), + ) + updated_at: orm.Mapped[datetime] = orm.mapped_column( + server_default=sa.func.now(), + server_onupdate=sa.func.now(), + ) diff --git a/app/common/infrastructure/base_sqlalchemy_repository.py b/app/common/infrastructure/base_sqlalchemy_repository.py new file mode 100644 index 0000000..dde1474 --- /dev/null +++ b/app/common/infrastructure/base_sqlalchemy_repository.py @@ -0,0 +1,234 @@ +__all__ = ("BaseSQLAlchemyRepository",) + +import abc +import typing as t +import uuid + +import pydantic +import sqlalchemy as sa +import sqlalchemy.exc as sa_exc +from sqlalchemy import orm + +from app.common.domain import Criteria +from app.common.domain import Entity +from app.common.domain import Repository +from app.common.exceptions import repository_errors + +from .base_sqlalchemy_model import BaseSQLAlchemyModel +from .sqlalchemy_criteria_translator import SqlalchemyCriteriaTranslator + + +class BaseSQLAlchemyRepository[ + TEntity: Entity, + TEntityId: uuid.UUID, + TCreate: pydantic.BaseModel, + TUpdate: pydantic.BaseModel, + TModel: BaseSQLAlchemyModel, +](Repository[TEntity, TEntityId, TCreate, TUpdate], abc.ABC): + @property + @abc.abstractmethod + def entity(self) -> type[TEntity]: + raise NotImplementedError() + + @property + @abc.abstractmethod + def model(self) -> type[TModel]: + raise NotImplementedError() + + def __init__(self, session: orm.Session): + self._session = session + + # -------- Public interface -------- + # Single entity operations + def find(self, entity_id: TEntityId) -> TEntity | None: + model = self._get_by_id(entity_id) + if model is None: + return None + + return self._create_entity_from_model(model) + + def find_or_fail(self, entity_id: TEntityId) -> TEntity: + entity = self.find(entity_id) + if entity is None: + self._raise_not_found() + + return entity + + def create(self, dto: TCreate) -> TEntity: + stmt = ( + sa.insert(self.model) + .values(**dto.model_dump()) + .returning(self.model) + ) + + try: + model = self._session.execute(stmt).scalar_one() + except (sa_exc.NoResultFound, sa_exc.IntegrityError): + self._raise_not_created() + + return self._create_entity_from_model(model) + + def update(self, entity_id: TEntityId, dto: TUpdate) -> TEntity: + stmt = ( + sa.update(self.model) + .where(self.model.id == entity_id) + .values(**dto.model_dump(exclude_unset=True)) + .returning(self.model) + ) + + try: + model = self._session.execute(stmt).scalar_one() + except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound): + self._raise_not_found() + except sa_exc.IntegrityError: + self._raise_not_updated() + + return self._create_entity_from_model(model) + + def delete(self, entity_id: TEntityId) -> None: + stmt = ( + sa.delete(self.model) + .where(self.model.id == entity_id) + .returning(self.model.id) + ) + + try: + self._session.execute(stmt).scalar_one() + except (sa_exc.NoResultFound, sa_exc.MultipleResultsFound): + self._raise_not_found() + except sa_exc.IntegrityError: + self._raise_not_deleted() + + # Multiple entity operations + def all(self) -> list[TEntity]: + stmt = sa.select(self.model) + models = self._session.scalars(stmt) + + return self._create_entities_from_models(models) + + def where(self, *criteria: Criteria[TEntity]) -> list[TEntity]: + stmt = sa.select(self.model) + stmt = self._apply_criteria(stmt, criteria) + models = self._session.scalars(stmt) + + return self._create_entities_from_models(models) + + def first(self, *criteria: Criteria[TEntity]) -> TEntity | None: + stmt = sa.select(self.model).limit(1) + stmt = self._apply_criteria(stmt, criteria) + model = self._session.execute(stmt).scalar_one_or_none() + if model is None: + return None + + return self._create_entity_from_model(model) + + def count(self, *criteria: Criteria[TEntity]) -> int: + stmt = sa.select(self.model) + stmt = self._apply_criteria(stmt, criteria) + count_stmt = sa.select(sa.func.count()).select_from(stmt.subquery()) + + return self._session.execute(count_stmt).scalar_one() + + def exists(self, *criteria: Criteria[TEntity]) -> bool: + stmt = sa.select(self.model) + stmt = self._apply_criteria(stmt, criteria) + exists_stmt = sa.select(sa.exists(stmt.subquery())) + + return self._session.execute(exists_stmt).scalar_one() + + # Bulk operations + def insert_many(self, dtos: list[TCreate]) -> list[TEntity]: + if not dtos: + return [] + + stmt = ( + sa.insert(self.model) + .values([dto.model_dump() for dto in dtos]) + .returning(self.model) + ) + + models = self._session.scalars(stmt) + entities = self._create_entities_from_models(models) + + if len(entities) != len(dtos): + self._raise_not_created() + + return entities + + def update_where(self, dto: TUpdate, *criteria: Criteria[TEntity]) -> int: + stmt = sa.update(self.model).values( + **dto.model_dump(exclude_unset=True) + ) + stmt = self._apply_criteria(stmt, criteria) + + result = self._session.execute(stmt) + return self._get_row_count(result) + + def delete_where(self, *criteria: Criteria[TEntity]) -> int: + stmt = sa.delete(self.model) + stmt = self._apply_criteria(stmt, criteria) + + result = self._session.execute(stmt) + return self._get_row_count(result) + + # ------- Private Helpers -------- + def _create_entity_from_model(self, model: TModel) -> TEntity: + # TODO: rework with sentinel values to prevent eager loading on relationships + try: + return self.entity.model_validate(model) + except pydantic.ValidationError as e: + msg = f"Couldn't create entity instance.\n{e}" + raise repository_errors.RepositoryError(msg) + + def _create_entities_from_models( + self, models: t.Iterable[TModel] + ) -> list[TEntity]: + # TODO: rework to aggregate errors into an exception group? + return [self._create_entity_from_model(model) for model in models] + + def _get_by_id(self, entity_id: TEntityId) -> TModel | None: + return self._session.get(self.model, entity_id) + + def _apply_criteria[ + TStatement: (sa.Select[tuple[TModel]], sa.Update, sa.Delete) + ]( + self, + statement: TStatement, + criteria: tuple[Criteria[TEntity], ...], + ) -> TStatement: + for c in criteria: + translator = SqlalchemyCriteriaTranslator.get_for_criteria(c) + statement = translator.translate(statement) + + return statement + + def _raise_not_found(self) -> t.NoReturn: + """Override if a better message is required""" + raise repository_errors.EntityNotFoundError(self.entity) + + def _raise_not_created(self) -> t.NoReturn: + """Override if a better message is required""" + raise repository_errors.EntityNotCreatedError(self.entity) + + def _raise_not_updated(self) -> t.NoReturn: + """Override if a better message is required""" + raise repository_errors.EntityNotUpdatedError(self.entity) + + def _raise_not_deleted(self) -> t.NoReturn: + """Override if a better message is required""" + raise repository_errors.EntityNotDeletedError(self.entity) + + def _get_row_count(self, result: sa.Result) -> int: + # NOTE: in runtime, the Result is a CursorResult which has this prop + # but the type stubs do not reflect this + match result: + case sa.CursorResult(rowcount=row_count): + return row_count + case _: + msg = ( + "Can't get result row count: " + f"expected result to be a CursorResult, got {type(result)} instead. " + "This could indicate a breaking change with sqlalchemy, " + "review sqlalchemy repository implementation." + ) + raise repository_errors.RepositoryError(msg) diff --git a/app/common/infrastructure/sqlalchemy_criteria_translator.py b/app/common/infrastructure/sqlalchemy_criteria_translator.py new file mode 100644 index 0000000..9bdf638 --- /dev/null +++ b/app/common/infrastructure/sqlalchemy_criteria_translator.py @@ -0,0 +1,56 @@ +__all__ = ( + "SqlalchemyCriteriaTranslator", + "Statement", +) + +import abc +import typing as t + +import sqlalchemy as sa + +from app.common.domain import Criteria + +type Statement = sa.Select[t.Any] | sa.Update | sa.Delete +type _Criteria = Criteria[t.Any] +type _Translator = "SqlalchemyCriteriaTranslator[t.Any]" +type _Registry = dict[type[_Criteria], type[_Translator]] + + +class SqlalchemyCriteriaTranslator[TCriteria: _Criteria](abc.ABC): + _registry: t.ClassVar[_Registry] = {} + + def __init_subclass__( + cls, + *, + criteria: type[_Criteria] | None = None, + **kw: t.Any, + ) -> None: + super().__init_subclass__(**kw) + if criteria is not None: + cls._registry[criteria] = cls + + def __init__(self, criteria: TCriteria) -> None: + self._criteria = criteria + + @classmethod + def get_for_criteria(cls, criteria: _Criteria) -> _Translator: + translator_cls = cls._registry.get(type(criteria), None) + if translator_cls is None: + raise NoTranslatorForCriteriaError(criteria) + + return translator_cls(criteria) + + @abc.abstractmethod + def translate(self, stmt: Statement) -> Statement: + raise NotImplementedError() + + +class SqlalchemyCriteriaTranslatorError(Exception): + pass + + +class NoTranslatorForCriteriaError(SqlalchemyCriteriaTranslatorError): + def __init__(self, criteria: _Criteria): + criteria_name = type(criteria).__name__ + msg = f"No translator for {criteria_name}" + super().__init__(msg) diff --git a/app/common/infrastructure/sqlalchemy_paginator_translator.py b/app/common/infrastructure/sqlalchemy_paginator_translator.py new file mode 100644 index 0000000..e6c1d3b --- /dev/null +++ b/app/common/infrastructure/sqlalchemy_paginator_translator.py @@ -0,0 +1,16 @@ +from app.common.domain import PaginationCriteria + +from .sqlalchemy_criteria_translator import SqlalchemyCriteriaTranslator +from .sqlalchemy_criteria_translator import Statement + + +class SQLAlchemyPaginationTranslator( + SqlalchemyCriteriaTranslator[PaginationCriteria], + criteria=PaginationCriteria, +): + def translate(self, stmt: Statement) -> Statement: + page_size = self._criteria.page_size + page = self._criteria.page + offset = (page - 1) * page_size + + return stmt.offset(offset).limit(page_size) diff --git a/app/db/base.py b/app/db/base.py deleted file mode 100644 index 2d41219..0000000 --- a/app/db/base.py +++ /dev/null @@ -1,4 +0,0 @@ -# Import all the models, so that Base has them before being -# imported by Alembic -from app.common.models.base_class import Base # noqa -from app.users.models.user import User # noqa diff --git a/app/emails/clients/mailpit_email_client/_mailpit_email_client.py b/app/emails/clients/mailpit_email_client/_mailpit_email_client.py index edbf745..d13a18c 100644 --- a/app/emails/clients/mailpit_email_client/_mailpit_email_client.py +++ b/app/emails/clients/mailpit_email_client/_mailpit_email_client.py @@ -1,9 +1,8 @@ from typing import ClassVar -from app.common.exceptions import ExternalProviderException from app.common.clients.base_request_client import BaseRequestClient +from app.common.exceptions import ExternalProviderError from app.core.config import settings - from app.emails.clients.base import BaseEmailClient from app.emails.schema.email import Email @@ -36,4 +35,4 @@ def send_email( if not response: message = "Email not sent, see logs for details." - raise ExternalProviderException(message) + raise ExternalProviderError(message) diff --git a/app/emails/services/emails_service.py b/app/emails/services/emails_service.py index 5130552..a9787f8 100644 --- a/app/emails/services/emails_service.py +++ b/app/emails/services/emails_service.py @@ -1,10 +1,10 @@ from enum import Enum from string import Template -from app.users.schemas.user_schema import UserInDB +from app.emails._global_state import get_client from app.emails.clients.base import BaseEmailClient from app.emails.schema.email import Email -from app.emails._global_state import get_client +from app.users.domain import User class Paths(Enum): @@ -36,7 +36,7 @@ def _get_email( def send_new_user_email( self, - user: UserInDB, + user: User, ) -> None: email = self._get_email( user.email, @@ -48,7 +48,7 @@ def send_new_user_email( def send_user_remind_email( self, - user: UserInDB, + user: User, ) -> None: email = self._get_email( user.email, diff --git a/app/two_factor_authentication/api/endpoints.py b/app/two_factor_authentication/api/endpoints.py index f83b036..4f2b6d2 100644 --- a/app/two_factor_authentication/api/endpoints.py +++ b/app/two_factor_authentication/api/endpoints.py @@ -2,7 +2,7 @@ from fastapi import APIRouter, HTTPException, status from slowapi import Limiter -from app.common.api.dependencies.get_session import SessionDependency +from app.common.api.dependencies.session_dependency import SessionDependency from app.common.exceptions.model_not_created_exception import ( ModelNotCreatedException, ) diff --git a/app/users/api/dependencies/__init__.py b/app/users/api/dependencies/__init__.py index e69de29..ae90816 100644 --- a/app/users/api/dependencies/__init__.py +++ b/app/users/api/dependencies/__init__.py @@ -0,0 +1,2 @@ +from .current_user_dependency import * +from .user_repository_dependency import * diff --git a/app/users/api/dependencies/current_user_dependency.py b/app/users/api/dependencies/current_user_dependency.py new file mode 100644 index 0000000..a5cb7d6 --- /dev/null +++ b/app/users/api/dependencies/current_user_dependency.py @@ -0,0 +1,53 @@ +__all__ = ( + "get_current_user", + "CurrentUserDependency", +) + +import typing as t +import uuid + +import fastapi + +from app.auth.api.dependencies.get_token import TokenDep +from app.auth.enums.claims_enum import ClaimsEnum +from app.auth.exceptions.invalid_credentials_exception import ( + InvalidCredentialsException, +) +from app.auth.schemas.token_schema import TokenPayload +from app.auth.utils.security import validate_token +from app.common.exceptions import ModelNotFoundException +from app.users.domain import User +from app.users.domain import UserId + +from .user_repository_dependency import UserRepositoryDependency + + +def get_current_user( + user_repository: UserRepositoryDependency, + token: TokenDep, +) -> User: + try: + token_data = t.cast( + TokenPayload, validate_token(token, ClaimsEnum.USER_ID) + ) + except InvalidCredentialsException as e: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_401_UNAUTHORIZED, detail=e.message + ) + + user_id = UserId(uuid.UUID(token_data.user_id)) + try: + user = user_repository.find_or_fail(user_id) + except ModelNotFoundException as e: + raise fastapi.HTTPException( + status_code=404, + detail=e.message, + ) + + return user + + +CurrentUserDependency = t.Annotated[ + User, + fastapi.Depends(get_current_user), +] diff --git a/app/users/api/dependencies/get_current_user.py b/app/users/api/dependencies/get_current_user.py deleted file mode 100644 index 40ee0bc..0000000 --- a/app/users/api/dependencies/get_current_user.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import Annotated, cast -from uuid import UUID - -from fastapi import Depends, HTTPException -from starlette import status - -from app.auth.enums.claims_enum import ClaimsEnum -from app.auth.schemas.token_schema import TokenPayload -from app.common.api.dependencies.get_session import SessionDependency -from app.auth.api.dependencies.get_token import TokenDep - -from app.auth.utils.security import validate_token -from app.auth.exceptions.invalid_credentials_exception import ( - InvalidCredentialsException, -) -from app.users.schemas.user_schema import UserInDB -from app.users.services.users_service import UsersService - - -def get_current_user(session: SessionDependency, token: TokenDep) -> UserInDB: - try: - token_data = cast( - TokenPayload, validate_token(token, ClaimsEnum.USER_ID) - ) - except InvalidCredentialsException as e: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, detail=e.message - ) - provider = UsersService(session).get_by_id(UUID(token_data.user_id)) - if not provider: - raise HTTPException(status_code=404, detail="Provider not found") - return provider - - -CurrentUser = Annotated[UserInDB, Depends(get_current_user)] diff --git a/app/users/api/dependencies/user_repository_dependency.py b/app/users/api/dependencies/user_repository_dependency.py new file mode 100644 index 0000000..5b93941 --- /dev/null +++ b/app/users/api/dependencies/user_repository_dependency.py @@ -0,0 +1,22 @@ +__all__ = ( + "get_user_repository", + "UserRepositoryDependency", +) + +import typing as t + +import fastapi + +from app.common.api.dependencies import SessionDependency +from app.users.domain import UserRepository +from app.users.infrastructure import SQLAlchemyUserRepository + + +def get_user_repository(session: SessionDependency) -> UserRepository: + return SQLAlchemyUserRepository(session) + + +UserRepositoryDependency = t.Annotated[ + UserRepository, + fastapi.Depends(get_user_repository), +] diff --git a/app/users/api/endpoints.py b/app/users/api/endpoints.py index 630ab77..fa42cda 100644 --- a/app/users/api/endpoints.py +++ b/app/users/api/endpoints.py @@ -1,30 +1,32 @@ -from fastapi import APIRouter, status +import fastapi -from app.core.config import get_settings -from app.users.schemas.user_schema import CreateUserRequest, UserResponse -from app.users.use_cases.create_user_use_case import CreateUserUseCase -from app.users.api.dependencies.get_current_user import CurrentUser -from app.common.api.dependencies.get_session import SessionDependency +import app.users.domain as user_domain +import app.users.errors as user_errors +import app.users.use_cases as use_cases +from app.users.api.dependencies import CurrentUserDependency +from app.users.api.dependencies import UserRepositoryDependency -from slowapi import Limiter -from slowapi.util import get_remote_address +router = fastapi.APIRouter() -router = APIRouter() -settings = get_settings() - -limiter = Limiter(key_func=get_remote_address) +@router.post("", status_code=fastapi.status.HTTP_201_CREATED) +def create_user( + user_repository: UserRepositoryDependency, + create_user_request: user_domain.CreateUserRequest, +) -> user_domain.UserResponse: + use_case = use_cases.CreateUserUseCase(user_repository) -@router.get("/current", status_code=status.HTTP_200_OK) -def get_current_user( - current_user: CurrentUser, -) -> UserResponse: - return UserResponse.model_validate(current_user) + try: + return use_case.execute(create_user_request) + except user_errors.UserEmailCollisionError as e: + raise fastapi.HTTPException( + status_code=fastapi.status.HTTP_409_CONFLICT, + detail=e.message, + ) -@router.post("", status_code=status.HTTP_201_CREATED) -def create_user( - session: SessionDependency, - create_user_request: CreateUserRequest, -) -> UserResponse: - return CreateUserUseCase(session).execute(create_user_request) +@router.get("/current", status_code=fastapi.status.HTTP_200_OK) +def get_current_user( + current_user: CurrentUserDependency, +) -> user_domain.UserResponse: + return user_domain.UserResponse.model_validate(current_user) diff --git a/app/users/domain/__init__.py b/app/users/domain/__init__.py new file mode 100644 index 0000000..5418498 --- /dev/null +++ b/app/users/domain/__init__.py @@ -0,0 +1,6 @@ +from .user import * +from .user_constants import * +from .user_criteria import * +from .user_dtos import * +from .user_repository import * +from .user_types import * diff --git a/app/users/domain/user.py b/app/users/domain/user.py new file mode 100644 index 0000000..9ca2554 --- /dev/null +++ b/app/users/domain/user.py @@ -0,0 +1,27 @@ +__all__ = ("User",) + +import datetime as dt + +import pydantic + +from app.common.domain import Entity + +from .user_types import UserEmailField +from .user_types import UserId + + +class User(Entity): + id: UserId + created_at: dt.datetime + updated_at: dt.datetime + + email: UserEmailField + hashed_password: str + + class CreateDto(pydantic.BaseModel): + email: UserEmailField + hashed_password: str + + class UpdateDto(pydantic.BaseModel): + email: UserEmailField | None = None + hashed_password: str | None = None diff --git a/app/users/domain/user_constants.py b/app/users/domain/user_constants.py new file mode 100644 index 0000000..1ef8a15 --- /dev/null +++ b/app/users/domain/user_constants.py @@ -0,0 +1,5 @@ +__all__ = ("USER_EMAIL_MAX_LENGTH",) + +import typing as t + +USER_EMAIL_MAX_LENGTH: t.Final = 100 diff --git a/app/users/domain/user_criteria.py b/app/users/domain/user_criteria.py new file mode 100644 index 0000000..b8e93dc --- /dev/null +++ b/app/users/domain/user_criteria.py @@ -0,0 +1,19 @@ +__all__ = ( + "UserEmailFilterCriteria", + "UserCriteria", +) + +import dataclasses + +from app.common.domain import Criteria + +from .user import User + + +class UserCriteria(Criteria[User]): + pass + + +@dataclasses.dataclass(frozen=True) +class UserEmailFilterCriteria(UserCriteria): + email: str diff --git a/app/users/domain/user_dtos.py b/app/users/domain/user_dtos.py new file mode 100644 index 0000000..52d3b87 --- /dev/null +++ b/app/users/domain/user_dtos.py @@ -0,0 +1,29 @@ +__all__ = ( + "CreateUserRequest", + "UserAuth", + "UserResponse", +) + +import uuid + +from pydantic import BaseModel +from pydantic import ConfigDict + +from .user_types import UserEmailField +from .user_types import UserId + + +class UserResponse(BaseModel): + id: UserId + email: UserEmailField + + +class CreateUserRequest(BaseModel): + email: UserEmailField + password: str + + +class UserAuth(BaseModel): + model_config = ConfigDict(from_attributes=True) + id: uuid.UUID + hashed_password: str diff --git a/app/users/domain/user_repository.py b/app/users/domain/user_repository.py new file mode 100644 index 0000000..cf384dd --- /dev/null +++ b/app/users/domain/user_repository.py @@ -0,0 +1,17 @@ +__all__ = ("UserRepository",) + +from app.common.domain import Repository + +from .user import User +from .user_types import UserId + + +class UserRepository( + Repository[ + User, + UserId, + User.CreateDto, + User.UpdateDto, + ] +): + pass diff --git a/app/users/domain/user_types.py b/app/users/domain/user_types.py new file mode 100644 index 0000000..f79c26b --- /dev/null +++ b/app/users/domain/user_types.py @@ -0,0 +1,19 @@ +__all__ = ( + "UserEmailField", + "UserId", +) + +import typing as t +import uuid + +import pydantic + +from .user_constants import USER_EMAIL_MAX_LENGTH + +UserId = t.NewType("UserId", uuid.UUID) + +UserEmailField = t.Annotated[ + pydantic.EmailStr, + pydantic.Field(max_length=USER_EMAIL_MAX_LENGTH), + pydantic.AfterValidator(lambda s: s.lower()), +] diff --git a/app/users/errors/__init__.py b/app/users/errors/__init__.py new file mode 100644 index 0000000..1b1e4de --- /dev/null +++ b/app/users/errors/__init__.py @@ -0,0 +1 @@ +from .user_email_collision_error import * diff --git a/app/users/errors/user_email_collision_error.py b/app/users/errors/user_email_collision_error.py new file mode 100644 index 0000000..14bf01c --- /dev/null +++ b/app/users/errors/user_email_collision_error.py @@ -0,0 +1,10 @@ +__all__ = ("UserEmailCollisionError",) + +from app.common.exceptions import BaseApplicationError + + +class UserEmailCollisionError(BaseApplicationError): + def __init__(self, email: str): + self.email = email + msg = "User with that email already registered." + super().__init__(msg) diff --git a/app/users/infrastructure/__init__.py b/app/users/infrastructure/__init__.py new file mode 100644 index 0000000..8abff04 --- /dev/null +++ b/app/users/infrastructure/__init__.py @@ -0,0 +1,3 @@ +from .sqlalchemy_user_model import * +from .sqlalchemy_user_translators import * +from .sqlalchemy_user_repository import * diff --git a/app/users/infrastructure/sqlalchemy_user_model.py b/app/users/infrastructure/sqlalchemy_user_model.py new file mode 100644 index 0000000..e3d4a5e --- /dev/null +++ b/app/users/infrastructure/sqlalchemy_user_model.py @@ -0,0 +1,18 @@ +__all__ = ("SQLAlchemyUserModel",) + +from sqlalchemy.orm import Mapped +from sqlalchemy.orm import mapped_column +from sqlalchemy.types import String + +import app.users.domain as user_domain +from app.common.infrastructure import BaseSQLAlchemyModel + + +class SQLAlchemyUserModel(BaseSQLAlchemyModel): + __tablename__ = "users" + + email: Mapped[str] = mapped_column( + String(user_domain.USER_EMAIL_MAX_LENGTH), + unique=True, + ) + hashed_password: Mapped[str] diff --git a/app/users/infrastructure/sqlalchemy_user_repository.py b/app/users/infrastructure/sqlalchemy_user_repository.py new file mode 100644 index 0000000..f35b985 --- /dev/null +++ b/app/users/infrastructure/sqlalchemy_user_repository.py @@ -0,0 +1,22 @@ +__all__ = ("SQLAlchemyUserRepository",) + +from app.common.infrastructure import BaseSQLAlchemyRepository +from app.users.domain import User +from app.users.domain import UserRepository +from app.users.domain.user_types import UserId + +from .sqlalchemy_user_model import SQLAlchemyUserModel + + +class SQLAlchemyUserRepository( + BaseSQLAlchemyRepository[ + User, + UserId, + User.CreateDto, + User.UpdateDto, + SQLAlchemyUserModel, + ], + UserRepository, +): + entity = User + model = SQLAlchemyUserModel diff --git a/app/users/infrastructure/sqlalchemy_user_translators.py b/app/users/infrastructure/sqlalchemy_user_translators.py new file mode 100644 index 0000000..37fb4a6 --- /dev/null +++ b/app/users/infrastructure/sqlalchemy_user_translators.py @@ -0,0 +1,15 @@ +__all__ = ("SQLAlchemyUserEmailFilterTranslator",) + +from app.common.infrastructure import SqlalchemyCriteriaTranslator +from app.common.infrastructure import Statement +from app.users.domain import UserEmailFilterCriteria + +from .sqlalchemy_user_model import SQLAlchemyUserModel + + +class SQLAlchemyUserEmailFilterTranslator( + SqlalchemyCriteriaTranslator[UserEmailFilterCriteria], + criteria=UserEmailFilterCriteria, +): + def translate(self, stmt: Statement) -> Statement: + return stmt.where(SQLAlchemyUserModel.email == self._criteria.email) diff --git a/app/users/use_cases/__init__.py b/app/users/use_cases/__init__.py new file mode 100644 index 0000000..6c84403 --- /dev/null +++ b/app/users/use_cases/__init__.py @@ -0,0 +1 @@ +from .create_user_use_case import * diff --git a/app/users/use_cases/create_user_use_case.py b/app/users/use_cases/create_user_use_case.py index 41afec2..4f9bbaf 100644 --- a/app/users/use_cases/create_user_use_case.py +++ b/app/users/use_cases/create_user_use_case.py @@ -1,42 +1,37 @@ -from fastapi.exceptions import HTTPException -from sqlalchemy.orm import Session -from fastapi import status +__all__ = ("CreateUserUseCase",) +import app.users.domain as user_domain from app.auth.utils import security -from app.users.schemas.user_schema import ( - CreateUserRequest, - UserCreate, - UserResponse, -) -from app.users.services.users_service import UsersService +from app.users.errors import UserEmailCollisionError class CreateUserUseCase: - def __init__(self, session: Session): - self.session = session + def __init__(self, user_repository: user_domain.UserRepository): + self.user_repository = user_repository - def execute(self, create_user_request: CreateUserRequest) -> UserResponse: + def execute( + self, + create_user_request: user_domain.CreateUserRequest, + ) -> user_domain.UserResponse: from app.celery.tasks.emails import send_welcome_email - users_service = UsersService(self.session) - if users_service.get_by_email(create_user_request.email): - raise HTTPException( - status_code=status.HTTP_409_CONFLICT, - detail="User with that email already registered.", - ) + if self.user_repository.exists( + user_domain.UserEmailFilterCriteria(create_user_request.email), + ): + raise UserEmailCollisionError(create_user_request.email) - created_user = users_service.create_user( - UserCreate( - email=create_user_request.email.lower(), + created_user = self.user_repository.create( + user_domain.User.CreateDto( + email=create_user_request.email, hashed_password=security.get_password_hash( create_user_request.password ), ) ) - send_welcome_email.delay(created_user.id) # type: ignore + send_welcome_email.delay(created_user.id) - return UserResponse( + return user_domain.UserResponse( id=created_user.id, email=created_user.email, ) diff --git a/tests/conftest.py b/tests/conftest.py index b94c9d4..f192b51 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,16 @@ from typing import Generator + import pytest -from sqlalchemy import RootTransaction, event from fastapi.testclient import TestClient +from sqlalchemy import RootTransaction +from sqlalchemy import event +from sqlalchemy.orm import Session -from app.common.api.dependencies.get_session import get_session +from app.common.api.dependencies.session_dependency import get_session from app.core.config import get_settings -from app.db.session import engine, SessionLocal +from app.db.session import SessionLocal +from app.db.session import engine from app.main import app -from sqlalchemy.orm import Session settings = get_settings() TEST_DATABASE_URI = settings.SQLALCHEMY_DATABASE_URI