Skip to content

Commit adc8bcc

Browse files
authored
Add GenericEncoder and Pydantic support (#63)
* Add GenericEncoder and Pydantic support * Change default encoder clients * Update type_util.py * Add tests for pydantic models & increase coverage * Fix mypy issues * Add codegen for pydantic models * Add is_allowed_type to encoder * Add pydantic model to test_codegen * Create test_pydantic_v1.py * Add GenericEncoder test with pydantic v1 * Update Makefile
1 parent d19eaac commit adc8bcc

18 files changed

Lines changed: 217 additions & 41 deletions

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ format:
2929
install-lint:
3030
python -m pip install --upgrade pip
3131
pip install -r requirements.txt # needed for pytype
32-
pip install black isort flake8 pylint pytype mypy
32+
pip install black isort flake8 pylint pytype mypy pydantic>=2.0
3333

3434
lint:
3535
flake8 ./zero

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
* Zero uses messages for communication and traditional **client-server** or **request-reply** pattern is supported.
3232
* Support for both **async** and **sync**.
3333
* The base server (ZeroServer) **utilizes all cpu cores**.
34+
* Built-in support for Pydantic.
3435
* **Code generation**! See [example](https://github.com/Ananto30/zero#code-generation-) 👇
3536

3637
**Philosophy** behind Zero:

setup.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,7 @@
2222
python_requires=">=3.8",
2323
package_dir={"": "."},
2424
install_requires=["pyzmq", "msgspec"],
25+
extras_require={
26+
"pydantic": ["pydantic"], # Optional dependency
27+
},
2528
)

tests/functional/codegen/test_codegen.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
import msgspec
1313
from msgspec import Struct
14+
from pydantic import BaseModel
1415

1516
from zero.codegen.codegen import CodeGen
1617

@@ -63,6 +64,11 @@ class SimpleIntEnum(enum.IntEnum):
6364
TWO = 2
6465

6566

67+
class SimplePydanticModel(BaseModel):
68+
a: int
69+
b: str
70+
71+
6672
def func_none(arg: None) -> str:
6773
return "Received None"
6874

@@ -213,6 +219,14 @@ def func_take_optional_child_dataclass_return_optional_child_complex_struct(
213219
return None
214220

215221

222+
def func_pydantic_model(arg: SimplePydanticModel) -> str:
223+
return f"Received Pydantic model: {arg}"
224+
225+
226+
def func_return_pydantic_model() -> SimplePydanticModel:
227+
return SimplePydanticModel(a=1, b="hello")
228+
229+
216230
class TestCodegen(unittest.TestCase):
217231
def setUp(self) -> None:
218232
self.maxDiff = None
@@ -250,6 +264,8 @@ def setUp(self) -> None:
250264
"func_msgspec_struct_complex": (func_msgspec_struct_complex, False),
251265
"func_child_complex_struct": (func_child_complex_struct, False),
252266
"func_return_complex_struct": (func_return_complex_struct, False),
267+
"func_pydantic_model": (func_pydantic_model, False),
268+
"func_return_pydantic_model": (func_return_pydantic_model, False),
253269
}
254270
self._rpc_input_type_map = {
255271
"func_none": None,
@@ -285,6 +301,8 @@ def setUp(self) -> None:
285301
"func_msgspec_struct_complex": ComplexStruct,
286302
"func_child_complex_struct": ChildComplexStruct,
287303
"func_return_complex_struct": None,
304+
"func_pydantic_model": SimplePydanticModel,
305+
"func_return_pydantic_model": None,
288306
}
289307
self._rpc_return_type_map = {
290308
"func_none": str,
@@ -320,6 +338,8 @@ def setUp(self) -> None:
320338
"func_msgspec_struct_complex": str,
321339
"func_child_complex_struct": str,
322340
"func_return_complex_struct": ComplexStruct,
341+
"func_pydantic_model": str,
342+
"func_return_pydantic_model": SimplePydanticModel,
323343
}
324344

325345
def test_codegen(self):
@@ -335,6 +355,7 @@ def test_codegen(self):
335355
import enum
336356
import msgspec
337357
from msgspec import Struct
358+
from pydantic import BaseModel
338359
from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union
339360
import uuid
340361
@@ -385,6 +406,11 @@ class ChildComplexStruct(ComplexStruct):
385406
i: str
386407
387408
409+
class SimplePydanticModel(BaseModel):
410+
a: int
411+
b: str
412+
413+
388414
389415
class RpcClient:
390416
def __init__(self, zero_client: ZeroClient):
@@ -488,6 +514,12 @@ def func_child_complex_struct(self, arg: ChildComplexStruct) -> str:
488514
489515
def func_return_complex_struct(self) -> ComplexStruct:
490516
return self._zero_client.call("func_return_complex_struct", None)
517+
518+
def func_pydantic_model(self, arg: SimplePydanticModel) -> str:
519+
return self._zero_client.call("func_pydantic_model", arg)
520+
521+
def func_return_pydantic_model(self) -> SimplePydanticModel:
522+
return self._zero_client.call("func_return_pydantic_model", None)
491523
"""
492524
self.assertEqual(code, expected_code)
493525

tests/functional/single_server/client_generation_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def test_codegeneration():
2323
import decimal
2424
import enum
2525
import msgspec
26+
from pydantic import BaseModel
2627
from typing import Dict, FrozenSet, List, Optional, Set, Tuple, Union
2728
import uuid
2829
@@ -49,6 +50,11 @@ class Dataclass:
4950
age: int
5051
5152
53+
class PydanticModel(BaseModel):
54+
name: str
55+
age: int
56+
57+
5258
class Message(msgspec.Struct):
5359
msg: str
5460
start_time: datetime
@@ -116,6 +122,9 @@ def echo_enum_int(self, msg: ColorInt) -> ColorInt:
116122
def echo_dataclass(self, msg: Dataclass) -> Dataclass:
117123
return self._zero_client.call("echo_dataclass", msg)
118124
125+
def echo_pydantic(self, msg: PydanticModel) -> PydanticModel:
126+
return self._zero_client.call("echo_pydantic", msg)
127+
119128
def echo_typing_tuple(self, msg: Tuple[int, str]) -> Tuple[int, str]:
120129
return self._zero_client.call("echo_typing_tuple", msg)
121130

tests/functional/single_server/client_server_test.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,13 @@ def test_echo_dataclass(zero_client):
131131
assert result == data
132132

133133

134+
# pydantic input
135+
def test_echo_pydantic(zero_client):
136+
data = server.PydanticModel(name="John", age=30)
137+
result = zero_client.call("echo_pydantic", data, return_type=server.PydanticModel)
138+
assert result == data
139+
140+
134141
# typing.Tuple input
135142
def test_echo_typing_tuple(zero_client):
136143
assert zero_client.call(

tests/functional/single_server/server.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import jwt
1111
import msgspec
12+
from pydantic import BaseModel
1213

1314
from zero import ZeroServer
1415

@@ -155,6 +156,17 @@ def echo_dataclass(msg: Dataclass) -> Dataclass:
155156
return msg
156157

157158

159+
# pydantic input
160+
class PydanticModel(BaseModel):
161+
name: str
162+
age: int
163+
164+
165+
@app.register_rpc
166+
def echo_pydantic(msg: PydanticModel) -> PydanticModel:
167+
return msg
168+
169+
158170
# typing.Tuple input
159171
@app.register_rpc
160172
def echo_typing_tuple(msg: typing.Tuple[int, str]) -> typing.Tuple[int, str]:
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
import sys
2+
import importlib
3+
from typing import Iterator
4+
import pytest
5+
6+
7+
@pytest.fixture
8+
def patch_pydantic_to_v1(monkeypatch: pytest.MonkeyPatch) -> Iterator[None]:
9+
import pydantic.v1
10+
11+
# Patch sys.modules so any `import pydantic` gives you `pydantic.v1`
12+
monkeypatch.setitem(sys.modules, "pydantic", pydantic.v1)
13+
importlib.invalidate_caches()
14+
15+
yield
16+
17+
# Clean up after test
18+
importlib.invalidate_caches()
19+
20+
21+
def test_module_with_pydantic_v1(patch_pydantic_to_v1: None) -> None:
22+
# Re-import your module so it sees `pydantic` as v1
23+
from zero.encoder import generic
24+
25+
importlib.reload(generic)
26+
27+
# Now run assertions that rely on v1 behavior
28+
assert not generic.IS_PYDANTIC_V2
29+
30+
from pydantic import BaseModel
31+
32+
class TestModel(BaseModel):
33+
name: str
34+
age: int
35+
36+
encoder = generic.GenericEncoder()
37+
model_instance = TestModel(name="Alice", age=30)
38+
encoded_data = encoder.encode(model_instance)
39+
decoded_instance = encoder.decode_type(encoded_data, TestModel)
40+
assert decoded_instance.name == "Alice"
41+
assert decoded_instance.age == 30

tests/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
pyzmq
22
msgspec
3+
pydantic>=2.0
34
pytest
45
pytest-cov
56
PyJWT
67
pytest-asyncio
78
tornado>=6.1
89
requests
9-
pytest-timeout
10+
pytest-timeout

tests/unit/test_server.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
import unittest
3-
from typing import Any, Tuple
3+
from typing import Any, Tuple, Type
44
from unittest.mock import patch
55

66
# import pytest
@@ -71,9 +71,12 @@ def encode(self, message: Any) -> bytes:
7171
def decode(self, message: bytes) -> Any:
7272
return message
7373

74-
def decode_type(self, message: bytes, typ: Any) -> Any:
74+
def decode_type(self, message: bytes, typ: Type[Any]) -> Any:
7575
return message
7676

77+
def is_allowed_type(self, typ: Type) -> bool:
78+
return True
79+
7780
encoder = CustomEncoder()
7881

7982
server = ZeroServer(encoder=encoder)
@@ -94,9 +97,12 @@ def encode(self, message: Any) -> bytes:
9497
def decode(self, message: bytes) -> Any:
9598
return message
9699

97-
def decode_type(self, message: bytes, typ: Any) -> Any:
100+
def decode_type(self, message: bytes, typ: Type[Any]) -> Any:
98101
return message
99102

103+
def is_allowed_type(self, typ: Type) -> bool:
104+
return True
105+
100106
encoder = CustomEncoder()
101107
port = 5562
102108

@@ -118,9 +124,12 @@ def encode(self, message: Any) -> bytes:
118124
def decode(self, message: bytes) -> Any:
119125
return message
120126

121-
def decode_type(self, message: bytes, typ: Any) -> Any:
127+
def decode_type(self, message: bytes, typ: Type[Any]) -> Any:
122128
return message
123129

130+
def is_allowed_type(self, typ: Type) -> bool:
131+
return True
132+
124133
encoder = CustomEncoder()
125134
host = "123.0.0.123"
126135

@@ -142,9 +151,12 @@ def encode(self, message: Any) -> bytes:
142151
def decode(self, message: bytes) -> Any:
143152
return message
144153

145-
def decode_type(self, message: bytes, typ: Any) -> Any:
154+
def decode_type(self, message: bytes, typ: Type[Any]) -> Any:
146155
return message
147156

157+
def is_allowed_type(self, typ: Type) -> bool:
158+
return True
159+
148160
encoder = CustomEncoder()
149161
host = "123.0.0.123"
150162
port = 5563
@@ -236,7 +248,7 @@ def test_register_rpc_with_long_name(self):
236248

237249
@server.register_rpc
238250
def add_this_is_a_very_long_name_for_a_function_more_than_120_characters_ffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffff(
239-
msg: Tuple[int, int]
251+
msg: Tuple[int, int],
240252
) -> int:
241253
return msg[0] + msg[1]
242254

0 commit comments

Comments
 (0)