1111
1212import msgspec
1313from msgspec import Struct
14+ from pydantic import BaseModel
1415
1516from zero .codegen .codegen import CodeGen
1617
@@ -63,6 +64,11 @@ class SimpleIntEnum(enum.IntEnum):
6364 TWO = 2
6465
6566
67+ class SimplePydanticModel (BaseModel ):
68+ a : int
69+ b : str
70+
71+
6672def func_none (arg : None ) -> str :
6773 return "Received None"
6874
@@ -213,6 +219,14 @@ def func_take_optional_child_dataclass_return_optional_child_complex_struct(
213219 return None
214220
215221
222+ def func_pydantic_model (arg : SimplePydanticModel ) -> str :
223+ return f"Received Pydantic model: { arg } "
224+
225+
226+ def func_return_pydantic_model () -> SimplePydanticModel :
227+ return SimplePydanticModel (a = 1 , b = "hello" )
228+
229+
216230class TestCodegen (unittest .TestCase ):
217231 def setUp (self ) -> None :
218232 self .maxDiff = None
@@ -250,6 +264,8 @@ def setUp(self) -> None:
250264 "func_msgspec_struct_complex" : (func_msgspec_struct_complex , False ),
251265 "func_child_complex_struct" : (func_child_complex_struct , False ),
252266 "func_return_complex_struct" : (func_return_complex_struct , False ),
267+ "func_pydantic_model" : (func_pydantic_model , False ),
268+ "func_return_pydantic_model" : (func_return_pydantic_model , False ),
253269 }
254270 self ._rpc_input_type_map = {
255271 "func_none" : None ,
@@ -285,6 +301,8 @@ def setUp(self) -> None:
285301 "func_msgspec_struct_complex" : ComplexStruct ,
286302 "func_child_complex_struct" : ChildComplexStruct ,
287303 "func_return_complex_struct" : None ,
304+ "func_pydantic_model" : SimplePydanticModel ,
305+ "func_return_pydantic_model" : None ,
288306 }
289307 self ._rpc_return_type_map = {
290308 "func_none" : str ,
@@ -320,6 +338,8 @@ def setUp(self) -> None:
320338 "func_msgspec_struct_complex" : str ,
321339 "func_child_complex_struct" : str ,
322340 "func_return_complex_struct" : ComplexStruct ,
341+ "func_pydantic_model" : str ,
342+ "func_return_pydantic_model" : SimplePydanticModel ,
323343 }
324344
325345 def test_codegen (self ):
@@ -335,6 +355,7 @@ def test_codegen(self):
335355import enum
336356import msgspec
337357from msgspec import Struct
358+ from pydantic import BaseModel
338359from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union
339360import uuid
340361
@@ -385,6 +406,11 @@ class ChildComplexStruct(ComplexStruct):
385406 i: str
386407
387408
409+ class SimplePydanticModel(BaseModel):
410+ a: int
411+ b: str
412+
413+
388414
389415class RpcClient:
390416 def __init__(self, zero_client: ZeroClient):
@@ -488,6 +514,12 @@ def func_child_complex_struct(self, arg: ChildComplexStruct) -> str:
488514
489515 def func_return_complex_struct(self) -> ComplexStruct:
490516 return self._zero_client.call("func_return_complex_struct", None)
517+
518+ def func_pydantic_model(self, arg: SimplePydanticModel) -> str:
519+ return self._zero_client.call("func_pydantic_model", arg)
520+
521+ def func_return_pydantic_model(self) -> SimplePydanticModel:
522+ return self._zero_client.call("func_return_pydantic_model", None)
491523"""
492524 self .assertEqual (code , expected_code )
493525
0 commit comments