|
1 | 1 | import logging |
| 2 | +from collections import OrderedDict |
| 3 | +from dataclasses import dataclass |
| 4 | +from datetime import date, datetime, time, timezone |
| 5 | +from decimal import Decimal |
| 6 | +from enum import Enum, IntEnum |
| 7 | +from uuid import UUID |
2 | 8 |
|
3 | 9 | import pytest |
4 | 10 |
|
|
16 | 22 | not _AVRO_AVAILABLE, reason="avro package not installed" |
17 | 23 | ) |
18 | 24 |
|
| 25 | +try: |
| 26 | + from enum import StrEnum |
| 27 | +except ImportError: # pragma: no cover - Python < 3.11 compatibility |
| 28 | + StrEnum = None # type: ignore[assignment,misc] |
| 29 | + |
| 30 | + |
| 31 | +@dataclass |
| 32 | +class SerializerDataclass: |
| 33 | + name: str |
| 34 | + count: int |
| 35 | + |
| 36 | + |
| 37 | +class SerializerEnum(Enum): |
| 38 | + PENDING = "pending" |
| 39 | + |
| 40 | + |
| 41 | +class SerializerIntEnum(IntEnum): |
| 42 | + LOW = 1 |
| 43 | + |
| 44 | + |
| 45 | +if StrEnum is not None: |
| 46 | + |
| 47 | + class SerializerStrEnum(StrEnum): |
| 48 | + HIGH = "high" |
| 49 | +else: |
| 50 | + SerializerStrEnum = None |
| 51 | + |
19 | 52 |
|
20 | 53 | class TestEncode: |
21 | 54 | def test_list(self) -> None: |
@@ -170,6 +203,40 @@ def test_round_trip_containers(self) -> None: |
170 | 203 | blob = serializer.encode(value, codec="avro") |
171 | 204 | assert serializer.decode(blob, codec="avro") == value |
172 | 205 |
|
| 206 | + @pytest.mark.parametrize( |
| 207 | + "value", |
| 208 | + [ |
| 209 | + SerializerDataclass(name="Ada", count=2), |
| 210 | + datetime(2026, 4, 21, 10, 30, tzinfo=timezone.utc), |
| 211 | + date(2026, 4, 21), |
| 212 | + time(10, 30, tzinfo=timezone.utc), |
| 213 | + UUID("12345678-1234-5678-1234-567812345678"), |
| 214 | + Decimal("10.25"), |
| 215 | + SerializerEnum.PENDING, |
| 216 | + ], |
| 217 | + ) |
| 218 | + def test_json_unsupported_user_types_fail_encode(self, value: object) -> None: |
| 219 | + with pytest.raises(TypeError, match="not JSON serializable"): |
| 220 | + serializer.encode(value, codec="avro") |
| 221 | + |
| 222 | + def test_ordered_dict_decodes_as_plain_dict(self) -> None: |
| 223 | + value = OrderedDict([("first", 1), ("second", 2)]) |
| 224 | + decoded = serializer.decode(serializer.encode(value, codec="avro"), codec="avro") |
| 225 | + assert decoded == {"first": 1, "second": 2} |
| 226 | + assert type(decoded) is dict |
| 227 | + |
| 228 | + def test_int_enum_decodes_as_int(self) -> None: |
| 229 | + decoded = serializer.decode(serializer.encode(SerializerIntEnum.LOW, codec="avro"), codec="avro") |
| 230 | + assert decoded == 1 |
| 231 | + assert type(decoded) is int |
| 232 | + |
| 233 | + @pytest.mark.skipif(StrEnum is None, reason="StrEnum requires Python 3.11+") |
| 234 | + def test_str_enum_decodes_as_str(self) -> None: |
| 235 | + assert SerializerStrEnum is not None |
| 236 | + decoded = serializer.decode(serializer.encode(SerializerStrEnum.HIGH, codec="avro"), codec="avro") |
| 237 | + assert decoded == "high" |
| 238 | + assert type(decoded) is str |
| 239 | + |
173 | 240 | @pytest.mark.parametrize( |
174 | 241 | "name,blob,expected", |
175 | 242 | [(name, blob, expected) for name, (blob, expected) in _PHP_AVRO_FIXTURES.items()], |
|
0 commit comments