Skip to content

Commit 2767a43

Browse files
authored
🐛 Fix: datetime type hint in response json (#256)
1 parent a4adedc commit 2767a43

File tree

2,794 files changed

+201257
-21010
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

2,794 files changed

+201257
-21010
lines changed

codegen/parser/data.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,7 @@ def get_raw_definition(self) -> str:
8686
required=self.required,
8787
schema_data=self.body_schema,
8888
)
89-
return prop.get_param_defination()
89+
return prop.get_param_definition()
9090

9191
def get_endpoint_definition(self) -> str:
9292
prop = Property(
@@ -95,7 +95,7 @@ def get_endpoint_definition(self) -> str:
9595
required=not bool(self.allowed_models),
9696
schema_data=self.body_schema,
9797
)
98-
return prop.get_param_defination()
98+
return prop.get_param_definition()
9999

100100

101101
@dataclass(kw_only=True)

codegen/parser/schemas/schema.py

Lines changed: 162 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,22 @@ class SchemaData:
1313
_type_string: ClassVar[str] = "Any"
1414

1515
def get_type_string(self, include_constraints: bool = True) -> str:
16-
"""Get schema typing string in any place"""
16+
"""Get schema typing string in any place.
17+
18+
Args:
19+
include_constraints (bool):
20+
whether to include field constraints by Annotated.
21+
"""
1722
if include_constraints and (args := self._get_field_args()):
1823
return f"Annotated[{self._type_string}, {self._get_field_string(args)}]"
1924
return self._type_string
2025

2126
def get_param_type_string(self) -> str:
22-
"""Get type string used by client codegen"""
27+
"""Get type string used by client request codegen"""
28+
return self._type_string
29+
30+
def get_response_type_string(self) -> str:
31+
"""Get type string used by client response codegen"""
2332
return self._type_string
2433

2534
def get_model_imports(self) -> set[str]:
@@ -40,7 +49,11 @@ def get_param_imports(self) -> set[str]:
4049
return set()
4150

4251
def get_using_imports(self) -> set[str]:
43-
"""Get schema needed imports for client request codegen"""
52+
"""Get schema needed imports for client request body codegen"""
53+
return set()
54+
55+
def get_response_imports(self) -> set[str]:
56+
"""Get schema needed imports for client response codegen"""
4457
return set()
4558

4659
def _get_field_string(self, args: dict[str, str]) -> str:
@@ -66,7 +79,7 @@ class Property:
6679
required: bool
6780
schema_data: SchemaData
6881

69-
def get_type_string(self, include_constraints: bool = True) -> str:
82+
def get_type_string(self, include_constraints: bool = False) -> str:
7083
"""Get schema typing string in any place"""
7184
type_string = self.schema_data.get_type_string(
7285
include_constraints=include_constraints
@@ -78,24 +91,35 @@ def get_param_type_string(self) -> str:
7891
type_string = self.schema_data.get_param_type_string()
7992
return type_string if self.required else f"Missing[{type_string}]"
8093

81-
def get_model_defination(self) -> str:
82-
"""Get defination used by model codegen"""
94+
def get_response_type_string(self) -> str:
95+
type_string = self.schema_data.get_response_type_string()
96+
return type_string if self.required else f"Missing[{type_string}]"
97+
98+
def get_model_definition(self) -> str:
99+
"""Get definition used by model codegen"""
83100
# extract the outermost type constraints to the field
84101
type_ = self.get_type_string(include_constraints=False)
85102
args = self.schema_data._get_field_args()
86103
args.update(self._get_field_args())
87104
default = self._get_field_string(args)
88105
return f"{self.prop_name}: {type_} = {default}"
89106

90-
def get_type_defination(self) -> str:
91-
"""Get defination used by types codegen"""
107+
def get_type_definition(self) -> str:
108+
"""Get definition used by types codegen"""
92109
type_ = self.schema_data.get_param_type_string()
93110
return (
94111
f"{self.prop_name}: {type_ if self.required else f'NotRequired[{type_}]'}"
95112
)
96113

97-
def get_param_defination(self) -> str:
98-
"""Get defination used by client codegen"""
114+
def get_response_type_definition(self) -> str:
115+
"""Get definition usede by response types codegen"""
116+
type_ = self.schema_data.get_response_type_string()
117+
return (
118+
f"{self.prop_name}: {type_ if self.required else f'NotRequired[{type_}]'}"
119+
)
120+
121+
def get_param_definition(self) -> str:
122+
"""Get definition used by client codegen"""
99123
type_ = self.get_param_type_string()
100124
return (
101125
(
@@ -177,6 +201,12 @@ def get_using_imports(self) -> set[str]:
177201
imports.add("from typing import Any")
178202
return imports
179203

204+
@override
205+
def get_response_imports(self) -> set[str]:
206+
imports = super().get_response_imports()
207+
imports.add("from typing import Any")
208+
return imports
209+
180210

181211
@dataclass(kw_only=True)
182212
class NoneSchema(SchemaData):
@@ -264,6 +294,12 @@ def _get_field_args(self) -> dict[str, str]:
264294
class DateTimeSchema(SchemaData):
265295
_type_string: ClassVar[str] = "datetime"
266296

297+
@override
298+
def get_response_type_string(self) -> str:
299+
# datetime field is ISO string in response
300+
# https://github.com/yanyongyu/githubkit/issues/246
301+
return "str"
302+
267303
@override
268304
def get_model_imports(self) -> set[str]:
269305
imports = super().get_model_imports()
@@ -288,11 +324,23 @@ def get_using_imports(self) -> set[str]:
288324
imports.add("from datetime import datetime")
289325
return imports
290326

327+
@override
328+
def get_response_imports(self) -> set[str]:
329+
imports = super().get_response_imports()
330+
imports.add("from datetime import datetime")
331+
return imports
332+
291333

292334
@dataclass(kw_only=True)
293335
class DateSchema(SchemaData):
294336
_type_string: ClassVar[str] = "date"
295337

338+
@override
339+
def get_response_type_string(self) -> str:
340+
# date field is ISO string in response
341+
# https://github.com/yanyongyu/githubkit/issues/246
342+
return "str"
343+
296344
@override
297345
def get_model_imports(self) -> set[str]:
298346
imports = super().get_model_imports()
@@ -317,6 +365,12 @@ def get_using_imports(self) -> set[str]:
317365
imports.add("from datetime import date")
318366
return imports
319367

368+
@override
369+
def get_response_imports(self) -> set[str]:
370+
imports = super().get_response_imports()
371+
imports.add("from datetime import date")
372+
return imports
373+
320374

321375
@dataclass(kw_only=True)
322376
class FileSchema(SchemaData):
@@ -346,6 +400,12 @@ def get_using_imports(self) -> set[str]:
346400
imports.add("from githubkit.typing import FileTypes")
347401
return imports
348402

403+
@override
404+
def get_response_imports(self) -> set[str]:
405+
imports = super().get_response_imports()
406+
imports.add("from githubkit.typing import FileTypes")
407+
return imports
408+
349409

350410
@dataclass(kw_only=True)
351411
class ListSchema(SchemaData):
@@ -366,6 +426,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
366426
def get_param_type_string(self) -> str:
367427
return f"list[{self.item_schema.get_param_type_string()}]"
368428

429+
@override
430+
def get_response_type_string(self) -> str:
431+
return f"list[{self.item_schema.get_response_type_string()}]"
432+
369433
@override
370434
def get_model_imports(self) -> set[str]:
371435
imports = super().get_model_imports()
@@ -392,6 +456,13 @@ def get_using_imports(self) -> set[str]:
392456
imports.update(self.item_schema.get_using_imports())
393457
return imports
394458

459+
@override
460+
def get_response_imports(self) -> set[str]:
461+
imports = super().get_response_imports()
462+
imports.add("from githubkit.compat import PYDANTIC_V2")
463+
imports.update(self.item_schema.get_response_imports())
464+
return imports
465+
395466
@override
396467
def _get_field_args(self) -> dict[str, str]:
397468
args = super()._get_field_args()
@@ -433,6 +504,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
433504
def get_param_type_string(self) -> str:
434505
return f"UniqueList[{self.item_schema.get_param_type_string()}]"
435506

507+
@override
508+
def get_response_type_string(self) -> str:
509+
return f"UniqueList[{self.item_schema.get_response_type_string()}]"
510+
436511
@override
437512
def get_model_imports(self) -> set[str]:
438513
imports = super().get_model_imports()
@@ -462,6 +537,13 @@ def get_using_imports(self) -> set[str]:
462537
imports.update(self.item_schema.get_using_imports())
463538
return imports
464539

540+
@override
541+
def get_response_imports(self) -> set[str]:
542+
# imports = super().get_response_imports()
543+
imports = {"from githubkit.typing import UniqueList"}
544+
imports.update(self.item_schema.get_response_imports())
545+
return imports
546+
465547
@override
466548
def _get_field_args(self) -> dict[str, str]:
467549
args = super()._get_field_args()
@@ -511,6 +593,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
511593
def get_param_type_string(self) -> str:
512594
return f"Literal[{', '.join(repr(value) for value in self.values)}]"
513595

596+
@override
597+
def get_response_type_string(self) -> str:
598+
return self.get_param_type_string()
599+
514600
@override
515601
def get_model_imports(self) -> set[str]:
516602
imports = super().get_model_imports()
@@ -535,6 +621,12 @@ def get_using_imports(self) -> set[str]:
535621
imports.add("from typing import Literal")
536622
return imports
537623

624+
@override
625+
def get_response_imports(self) -> set[str]:
626+
imports = super().get_response_imports()
627+
imports.add("from typing import Literal")
628+
return imports
629+
538630

539631
@dataclass(kw_only=True)
540632
class ModelSchema(SchemaData):
@@ -552,8 +644,46 @@ def get_type_string(self, include_constraints: bool = True) -> str:
552644

553645
@override
554646
def get_param_type_string(self) -> str:
647+
"""Get type string used by model type class name and client request codegen.
648+
649+
Example:
650+
651+
```python
652+
class ModelType(TypedDict):
653+
...
654+
655+
class Client:
656+
def create_xxx(
657+
*,
658+
data: ModelType,
659+
) -> Response[Model, ModelResponseType]:
660+
...
661+
```
662+
"""
555663
return f"{self.class_name}Type"
556664

665+
@override
666+
def get_response_type_string(self) -> str:
667+
"""Get type string used by model resposne type class name
668+
and client response codegen.
669+
670+
Example:
671+
672+
```python
673+
class ModelResponseType(TypedDict):
674+
...
675+
676+
class Client:
677+
def create_xxx(
678+
*,
679+
data: ModelType,
680+
) -> Response[Model, ModelResponseType]:
681+
...
682+
```
683+
"""
684+
# `XXXResponseType` has name conflicts in definition
685+
return f"{self.class_name}TypeForResponse"
686+
557687
@override
558688
def get_model_imports(self) -> set[str]:
559689
imports = super().get_model_imports()
@@ -583,6 +713,10 @@ def get_param_imports(self) -> set[str]:
583713
def get_using_imports(self) -> set[str]:
584714
return {f"from ..models import {self.class_name}"}
585715

716+
@override
717+
def get_response_imports(self) -> set[str]:
718+
return {f"from ..types import {self.get_response_type_string()}"}
719+
586720
@override
587721
def get_model_dependencies(self) -> list["ModelSchema"]:
588722
result: list[ModelSchema] = []
@@ -624,6 +758,15 @@ def get_param_type_string(self) -> str:
624758
types = ", ".join(schema.get_param_type_string() for schema in self.schemas)
625759
return f"Union[{types}]"
626760

761+
@override
762+
def get_response_type_string(self) -> str:
763+
if len(self.schemas) == 0:
764+
return "Any"
765+
elif len(self.schemas) == 1:
766+
return self.schemas[0].get_response_type_string()
767+
types = ", ".join(schema.get_response_type_string() for schema in self.schemas)
768+
return f"Union[{types}]"
769+
627770
@override
628771
def get_model_imports(self) -> set[str]:
629772
imports = super().get_model_imports()
@@ -656,6 +799,15 @@ def get_using_imports(self) -> set[str]:
656799
imports.update(schema.get_using_imports())
657800
return imports
658801

802+
@override
803+
def get_response_imports(self) -> set[str]:
804+
imports = super().get_response_imports()
805+
imports.add("from typing import Union")
806+
for schema in self.schemas:
807+
imports.update(schema.get_response_imports())
808+
return imports
809+
810+
@override
659811
def _get_field_args(self) -> dict[str, str]:
660812
args = super()._get_field_args()
661813
if self.discriminator:

codegen/templates/models/group.py.jinja

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class {{ model.class_name }}({{ "ExtraGitHubModel" if model.allow_extra else "Gi
2525
{{ build_model_docstring(model) | indent(4) }}
2626

2727
{% for prop in model.properties %}
28-
{{ prop.get_model_defination() }}
28+
{{ prop.get_model_definition() }}
2929
{% endfor %}
3030

3131
{% endfor %}

0 commit comments

Comments
 (0)