Skip to content
This repository was archived by the owner on Apr 15, 2025. It is now read-only.

Commit c74e7cd

Browse files
committed
chore(internal): support rendering enum types from the DMMF
This will be helpful for transaction isolation levels support #878
1 parent 1fa9331 commit c74e7cd

3 files changed

Lines changed: 53 additions & 6 deletions

File tree

src/prisma/generator/models.py

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -353,6 +353,7 @@ def to_params(self) -> Dict[str, Any]:
353353
"""Get the parameters that should be sent to Jinja templates"""
354354
params = vars(self)
355355
params['type_schema'] = Schema.from_data(self)
356+
params['client_types'] = ClientTypes.from_data(self)
356357

357358
# add utility functions
358359
for func in [
@@ -628,11 +629,22 @@ def engine_type_validator(cls, value: EngineType) -> EngineType:
628629
assert_never(value)
629630

630631

632+
class DMMFEnumType(BaseModel):
633+
name: str
634+
values: List[object]
635+
636+
637+
class DMMFEnumTypes(BaseModel):
638+
prisma: List[DMMFEnumType]
639+
640+
641+
class PrismaSchema(BaseModel):
642+
enum_types: DMMFEnumTypes = FieldInfo(alias='enumTypes')
643+
644+
631645
class DMMF(BaseModel):
632646
datamodel: 'Datamodel'
633-
634-
# TODO
635-
prisma_schema: Any = FieldInfo(alias='schema')
647+
prisma_schema: PrismaSchema = FieldInfo(alias='schema')
636648

637649

638650
class Datamodel(BaseModel):
@@ -1182,4 +1194,4 @@ class DefaultData(GenericData[_EmptyModel]):
11821194
TemplateError,
11831195
PartialTypeGeneratorError,
11841196
)
1185-
from .schema import Schema
1197+
from .schema import Schema, ClientTypes

src/prisma/generator/schema.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
from enum import Enum
2-
from typing import Any, Dict, List, Type, Tuple, Union
2+
from typing import Any, Dict, List, Type, Tuple, Union, Optional
33
from typing_extensions import ClassVar
44

55
from pydantic import BaseModel
66

7-
from .models import Model as ModelInfo, AnyData, PrimaryKey
7+
from .utils import to_constant_case
8+
from .models import Model as ModelInfo, AnyData, PrimaryKey, DMMFEnumType
89
from .._compat import (
910
PYDANTIC_V2,
1011
ConfigDict,
@@ -18,6 +19,7 @@ class Kind(str, Enum):
1819
alias = 'alias'
1920
union = 'union'
2021
typeddict = 'typeddict'
22+
enum = 'enum'
2123

2224

2325
class PrismaType(BaseModel):
@@ -45,6 +47,11 @@ class PrismaUnion(PrismaType):
4547
subtypes: List[PrismaType]
4648

4749

50+
class PrismaEnum(PrismaType):
51+
kind: Kind = Kind.enum
52+
members: List[Tuple[str, str]]
53+
54+
4855
class PrismaAlias(PrismaType):
4956
kind: Kind = Kind.alias
5057
to: str
@@ -143,6 +150,29 @@ def order_by(self) -> PrismaType:
143150
return PrismaType.from_subtypes(subtypes, name=f'{model}OrderByInput')
144151

145152

153+
class ClientTypes(BaseModel):
154+
transaction_isolation_level: Optional[PrismaEnum]
155+
156+
@classmethod
157+
def from_data(cls, data: AnyData) -> 'ClientTypes':
158+
enum_types = data.dmmf.prisma_schema.enum_types.prisma
159+
160+
return cls(
161+
transaction_isolation_level=construct_enum_type(enum_types, name='TransactionIsolationLevel'),
162+
)
163+
164+
165+
def construct_enum_type(dmmf_enum_types: List[DMMFEnumType], *, name: str) -> Optional[PrismaEnum]:
166+
enum_type = next((t for t in dmmf_enum_types if t.name == name), None)
167+
if not enum_type:
168+
return None
169+
170+
return PrismaEnum(
171+
name=name,
172+
members=[(to_constant_case(str(value)), str(value)) for value in enum_type.values],
173+
)
174+
175+
146176
model_rebuild(Schema)
147177
model_rebuild(PrismaType)
148178
model_rebuild(PrismaDict)

src/prisma/generator/templates/types.py.jinja

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,11 @@ from .utils import _NoneType
6565
},
6666
total={{ type.total }}
6767
)
68+
{% elif type.kind == 'enum' %}
69+
class {{ type.name }}(StrEnum):
70+
{% for name, value in type.members %}
71+
{{ name }} = "{{ value }}"
72+
{% endfor %}
6873
{% else %}
6974
{{ raise_err('Unhandled type kind: %s' % type.kind) }}
7075
{% endif %}

0 commit comments

Comments
 (0)