diff --git a/src/prisma/generator/models.py b/src/prisma/generator/models.py index c445ab01f..99a0d9878 100644 --- a/src/prisma/generator/models.py +++ b/src/prisma/generator/models.py @@ -353,6 +353,7 @@ def to_params(self) -> Dict[str, Any]: """Get the parameters that should be sent to Jinja templates""" params = vars(self) params['type_schema'] = Schema.from_data(self) + params['client_types'] = ClientTypes.from_data(self) # add utility functions for func in [ @@ -628,11 +629,22 @@ def engine_type_validator(cls, value: EngineType) -> EngineType: assert_never(value) +class DMMFEnumType(BaseModel): + name: str + values: List[object] + + +class DMMFEnumTypes(BaseModel): + prisma: List[DMMFEnumType] + + +class PrismaSchema(BaseModel): + enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes') + + class DMMF(BaseModel): datamodel: 'Datamodel' - - # TODO - prisma_schema: Any = FieldInfo(alias='schema') + prisma_schema: PrismaSchema = FieldInfo(alias='schema') class Datamodel(BaseModel): @@ -1182,4 +1194,4 @@ class DefaultData(GenericData[_EmptyModel]): TemplateError, PartialTypeGeneratorError, ) -from .schema import Schema +from .schema import Schema, ClientTypes diff --git a/src/prisma/generator/schema.py b/src/prisma/generator/schema.py index d04a572e0..0dd6a6b6a 100644 --- a/src/prisma/generator/schema.py +++ b/src/prisma/generator/schema.py @@ -1,10 +1,11 @@ from enum import Enum -from typing import Any, Dict, List, Type, Tuple, Union +from typing import Any, Dict, List, Type, Tuple, Union, Optional from typing_extensions import ClassVar from pydantic import BaseModel -from .models import Model as ModelInfo, AnyData, PrimaryKey +from .utils import to_constant_case +from .models import Model as ModelInfo, AnyData, PrimaryKey, DMMFEnumType from .._compat import ( PYDANTIC_V2, ConfigDict, @@ -18,6 +19,7 @@ class Kind(str, Enum): alias = 'alias' union = 'union' typeddict = 'typeddict' + enum = 'enum' class PrismaType(BaseModel): @@ -45,6 +47,11 @@ class PrismaUnion(PrismaType): subtypes: List[PrismaType] +class PrismaEnum(PrismaType): + kind: Kind = Kind.enum + members: List[Tuple[str, str]] + + class PrismaAlias(PrismaType): kind: Kind = Kind.alias to: str @@ -143,6 +150,29 @@ def order_by(self) -> PrismaType: return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput') +class ClientTypes(BaseModel): + transaction_isolation_level: Optional[PrismaEnum] + + @classmethod + def from_data(cls, data: AnyData) -> 'ClientTypes': + enum_types = data.dmmf.prisma_schema.enum_types.prisma + + return cls( + transaction_isolation_level=construct_enum_type(enum_types, name='TransactionIsolationLevel'), + ) + + +def construct_enum_type(dmmf_enum_types: List[DMMFEnumType], *, name: str) -> Optional[PrismaEnum]: + enum_type = next((t for t in dmmf_enum_types if t.name == name), None) + if not enum_type: + return None + + return PrismaEnum( + name=name, + members=[(to_constant_case(str(value)), str(value)) for value in enum_type.values], + ) + + model_rebuild(Schema) model_rebuild(PrismaType) model_rebuild(PrismaDict) diff --git a/src/prisma/generator/templates/types.py.jinja b/src/prisma/generator/templates/types.py.jinja index 666bbc84f..cc9ac46a9 100644 --- a/src/prisma/generator/templates/types.py.jinja +++ b/src/prisma/generator/templates/types.py.jinja @@ -65,6 +65,11 @@ from .utils import _NoneType }, total={{ type.total }} ) +{% elif type.kind == 'enum' %} +class {{ type.name }}(StrEnum): + {% for name, value in type.members %} + {{ name }} = "{{ value }}" + {% endfor %} {% else %} {{ raise_err('Unhandled type kind: %s' % type.kind) }} {% endif %}