@@ -13,13 +13,22 @@ class SchemaData:
1313 _type_string : ClassVar [str ] = "Any"
1414
1515 def get_type_string (self , include_constraints : bool = True ) -> str :
16- """Get schema typing string in any place"""
16+ """Get schema typing string in any place.
17+
18+ Args:
19+ include_constraints (bool):
20+ whether to include field constraints by Annotated.
21+ """
1722 if include_constraints and (args := self ._get_field_args ()):
1823 return f"Annotated[{ self ._type_string } , { self ._get_field_string (args )} ]"
1924 return self ._type_string
2025
2126 def get_param_type_string (self ) -> str :
22- """Get type string used by client codegen"""
27+ """Get type string used by client request codegen"""
28+ return self ._type_string
29+
30+ def get_response_type_string (self ) -> str :
31+ """Get type string used by client response codegen"""
2332 return self ._type_string
2433
2534 def get_model_imports (self ) -> set [str ]:
@@ -40,7 +49,11 @@ def get_param_imports(self) -> set[str]:
4049 return set ()
4150
4251 def get_using_imports (self ) -> set [str ]:
43- """Get schema needed imports for client request codegen"""
52+ """Get schema needed imports for client request body codegen"""
53+ return set ()
54+
55+ def get_response_imports (self ) -> set [str ]:
56+ """Get schema needed imports for client response codegen"""
4457 return set ()
4558
4659 def _get_field_string (self , args : dict [str , str ]) -> str :
@@ -66,7 +79,7 @@ class Property:
6679 required : bool
6780 schema_data : SchemaData
6881
69- def get_type_string (self , include_constraints : bool = True ) -> str :
82+ def get_type_string (self , include_constraints : bool = False ) -> str :
7083 """Get schema typing string in any place"""
7184 type_string = self .schema_data .get_type_string (
7285 include_constraints = include_constraints
@@ -78,24 +91,35 @@ def get_param_type_string(self) -> str:
7891 type_string = self .schema_data .get_param_type_string ()
7992 return type_string if self .required else f"Missing[{ type_string } ]"
8093
81- def get_model_defination (self ) -> str :
82- """Get defination used by model codegen"""
94+ def get_response_type_string (self ) -> str :
95+ type_string = self .schema_data .get_response_type_string ()
96+ return type_string if self .required else f"Missing[{ type_string } ]"
97+
98+ def get_model_definition (self ) -> str :
99+ """Get definition used by model codegen"""
83100 # extract the outermost type constraints to the field
84101 type_ = self .get_type_string (include_constraints = False )
85102 args = self .schema_data ._get_field_args ()
86103 args .update (self ._get_field_args ())
87104 default = self ._get_field_string (args )
88105 return f"{ self .prop_name } : { type_ } = { default } "
89106
90- def get_type_defination (self ) -> str :
91- """Get defination used by types codegen"""
107+ def get_type_definition (self ) -> str :
108+ """Get definition used by types codegen"""
92109 type_ = self .schema_data .get_param_type_string ()
93110 return (
94111 f"{ self .prop_name } : { type_ if self .required else f'NotRequired[{ type_ } ]' } "
95112 )
96113
97- def get_param_defination (self ) -> str :
98- """Get defination used by client codegen"""
114+ def get_response_type_definition (self ) -> str :
115+ """Get definition usede by response types codegen"""
116+ type_ = self .schema_data .get_response_type_string ()
117+ return (
118+ f"{ self .prop_name } : { type_ if self .required else f'NotRequired[{ type_ } ]' } "
119+ )
120+
121+ def get_param_definition (self ) -> str :
122+ """Get definition used by client codegen"""
99123 type_ = self .get_param_type_string ()
100124 return (
101125 (
@@ -177,6 +201,12 @@ def get_using_imports(self) -> set[str]:
177201 imports .add ("from typing import Any" )
178202 return imports
179203
204+ @override
205+ def get_response_imports (self ) -> set [str ]:
206+ imports = super ().get_response_imports ()
207+ imports .add ("from typing import Any" )
208+ return imports
209+
180210
181211@dataclass (kw_only = True )
182212class NoneSchema (SchemaData ):
@@ -264,6 +294,12 @@ def _get_field_args(self) -> dict[str, str]:
264294class DateTimeSchema (SchemaData ):
265295 _type_string : ClassVar [str ] = "datetime"
266296
297+ @override
298+ def get_response_type_string (self ) -> str :
299+ # datetime field is ISO string in response
300+ # https://github.com/yanyongyu/githubkit/issues/246
301+ return "str"
302+
267303 @override
268304 def get_model_imports (self ) -> set [str ]:
269305 imports = super ().get_model_imports ()
@@ -288,11 +324,23 @@ def get_using_imports(self) -> set[str]:
288324 imports .add ("from datetime import datetime" )
289325 return imports
290326
327+ @override
328+ def get_response_imports (self ) -> set [str ]:
329+ imports = super ().get_response_imports ()
330+ imports .add ("from datetime import datetime" )
331+ return imports
332+
291333
292334@dataclass (kw_only = True )
293335class DateSchema (SchemaData ):
294336 _type_string : ClassVar [str ] = "date"
295337
338+ @override
339+ def get_response_type_string (self ) -> str :
340+ # date field is ISO string in response
341+ # https://github.com/yanyongyu/githubkit/issues/246
342+ return "str"
343+
296344 @override
297345 def get_model_imports (self ) -> set [str ]:
298346 imports = super ().get_model_imports ()
@@ -317,6 +365,12 @@ def get_using_imports(self) -> set[str]:
317365 imports .add ("from datetime import date" )
318366 return imports
319367
368+ @override
369+ def get_response_imports (self ) -> set [str ]:
370+ imports = super ().get_response_imports ()
371+ imports .add ("from datetime import date" )
372+ return imports
373+
320374
321375@dataclass (kw_only = True )
322376class FileSchema (SchemaData ):
@@ -346,6 +400,12 @@ def get_using_imports(self) -> set[str]:
346400 imports .add ("from githubkit.typing import FileTypes" )
347401 return imports
348402
403+ @override
404+ def get_response_imports (self ) -> set [str ]:
405+ imports = super ().get_response_imports ()
406+ imports .add ("from githubkit.typing import FileTypes" )
407+ return imports
408+
349409
350410@dataclass (kw_only = True )
351411class ListSchema (SchemaData ):
@@ -366,6 +426,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
366426 def get_param_type_string (self ) -> str :
367427 return f"list[{ self .item_schema .get_param_type_string ()} ]"
368428
429+ @override
430+ def get_response_type_string (self ) -> str :
431+ return f"list[{ self .item_schema .get_response_type_string ()} ]"
432+
369433 @override
370434 def get_model_imports (self ) -> set [str ]:
371435 imports = super ().get_model_imports ()
@@ -392,6 +456,13 @@ def get_using_imports(self) -> set[str]:
392456 imports .update (self .item_schema .get_using_imports ())
393457 return imports
394458
459+ @override
460+ def get_response_imports (self ) -> set [str ]:
461+ imports = super ().get_response_imports ()
462+ imports .add ("from githubkit.compat import PYDANTIC_V2" )
463+ imports .update (self .item_schema .get_response_imports ())
464+ return imports
465+
395466 @override
396467 def _get_field_args (self ) -> dict [str , str ]:
397468 args = super ()._get_field_args ()
@@ -433,6 +504,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
433504 def get_param_type_string (self ) -> str :
434505 return f"UniqueList[{ self .item_schema .get_param_type_string ()} ]"
435506
507+ @override
508+ def get_response_type_string (self ) -> str :
509+ return f"UniqueList[{ self .item_schema .get_response_type_string ()} ]"
510+
436511 @override
437512 def get_model_imports (self ) -> set [str ]:
438513 imports = super ().get_model_imports ()
@@ -462,6 +537,13 @@ def get_using_imports(self) -> set[str]:
462537 imports .update (self .item_schema .get_using_imports ())
463538 return imports
464539
540+ @override
541+ def get_response_imports (self ) -> set [str ]:
542+ # imports = super().get_response_imports()
543+ imports = {"from githubkit.typing import UniqueList" }
544+ imports .update (self .item_schema .get_response_imports ())
545+ return imports
546+
465547 @override
466548 def _get_field_args (self ) -> dict [str , str ]:
467549 args = super ()._get_field_args ()
@@ -511,6 +593,10 @@ def get_type_string(self, include_constraints: bool = True) -> str:
511593 def get_param_type_string (self ) -> str :
512594 return f"Literal[{ ', ' .join (repr (value ) for value in self .values )} ]"
513595
596+ @override
597+ def get_response_type_string (self ) -> str :
598+ return self .get_param_type_string ()
599+
514600 @override
515601 def get_model_imports (self ) -> set [str ]:
516602 imports = super ().get_model_imports ()
@@ -535,6 +621,12 @@ def get_using_imports(self) -> set[str]:
535621 imports .add ("from typing import Literal" )
536622 return imports
537623
624+ @override
625+ def get_response_imports (self ) -> set [str ]:
626+ imports = super ().get_response_imports ()
627+ imports .add ("from typing import Literal" )
628+ return imports
629+
538630
539631@dataclass (kw_only = True )
540632class ModelSchema (SchemaData ):
@@ -552,8 +644,46 @@ def get_type_string(self, include_constraints: bool = True) -> str:
552644
553645 @override
554646 def get_param_type_string (self ) -> str :
647+ """Get type string used by model type class name and client request codegen.
648+
649+ Example:
650+
651+ ```python
652+ class ModelType(TypedDict):
653+ ...
654+
655+ class Client:
656+ def create_xxx(
657+ *,
658+ data: ModelType,
659+ ) -> Response[Model, ModelResponseType]:
660+ ...
661+ ```
662+ """
555663 return f"{ self .class_name } Type"
556664
665+ @override
666+ def get_response_type_string (self ) -> str :
667+ """Get type string used by model resposne type class name
668+ and client response codegen.
669+
670+ Example:
671+
672+ ```python
673+ class ModelResponseType(TypedDict):
674+ ...
675+
676+ class Client:
677+ def create_xxx(
678+ *,
679+ data: ModelType,
680+ ) -> Response[Model, ModelResponseType]:
681+ ...
682+ ```
683+ """
684+ # `XXXResponseType` has name conflicts in definition
685+ return f"{ self .class_name } TypeForResponse"
686+
557687 @override
558688 def get_model_imports (self ) -> set [str ]:
559689 imports = super ().get_model_imports ()
@@ -583,6 +713,10 @@ def get_param_imports(self) -> set[str]:
583713 def get_using_imports (self ) -> set [str ]:
584714 return {f"from ..models import { self .class_name } " }
585715
716+ @override
717+ def get_response_imports (self ) -> set [str ]:
718+ return {f"from ..types import { self .get_response_type_string ()} " }
719+
586720 @override
587721 def get_model_dependencies (self ) -> list ["ModelSchema" ]:
588722 result : list [ModelSchema ] = []
@@ -624,6 +758,15 @@ def get_param_type_string(self) -> str:
624758 types = ", " .join (schema .get_param_type_string () for schema in self .schemas )
625759 return f"Union[{ types } ]"
626760
761+ @override
762+ def get_response_type_string (self ) -> str :
763+ if len (self .schemas ) == 0 :
764+ return "Any"
765+ elif len (self .schemas ) == 1 :
766+ return self .schemas [0 ].get_response_type_string ()
767+ types = ", " .join (schema .get_response_type_string () for schema in self .schemas )
768+ return f"Union[{ types } ]"
769+
627770 @override
628771 def get_model_imports (self ) -> set [str ]:
629772 imports = super ().get_model_imports ()
@@ -656,6 +799,15 @@ def get_using_imports(self) -> set[str]:
656799 imports .update (schema .get_using_imports ())
657800 return imports
658801
802+ @override
803+ def get_response_imports (self ) -> set [str ]:
804+ imports = super ().get_response_imports ()
805+ imports .add ("from typing import Union" )
806+ for schema in self .schemas :
807+ imports .update (schema .get_response_imports ())
808+ return imports
809+
810+ @override
659811 def _get_field_args (self ) -> dict [str , str ]:
660812 args = super ()._get_field_args ()
661813 if self .discriminator :
0 commit comments