Skip to content

Commit ec8f4b4

Browse files
committed
add Forward Ref optional typing
1 parent b97265f commit ec8f4b4

2 files changed

Lines changed: 64 additions & 2 deletions

File tree

src/py_avro_schema/_testing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@
1616

1717
import dataclasses
1818
import difflib
19-
from typing import Dict, Type, Union
19+
from typing import Dict, List, Type, Union
2020

2121
import avro.schema # type: ignore
2222
import orjson
2323

2424
import py_avro_schema._schemas
2525

2626

27-
def assert_schema(py_type: Type, expected_schema: Union[str, Dict[str, str]], **kwargs) -> None:
27+
def assert_schema(py_type: Type, expected_schema: Union[str, Dict[str, str], List[str]], **kwargs) -> None:
2828
"""Test that the given Python type results in the correct Avro schema"""
2929
if not kwargs.pop("do_auto_namespace", False):
3030
kwargs["options"] = kwargs.get("options", py_avro_schema.Option(0)) | py_avro_schema.Option.NO_AUTO_NAMESPACE

tests/test_primitives.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import py_avro_schema as pas
3131
import py_avro_schema._schemas
32+
from py_avro_schema._testing import PyType as TestPyType
3233
from py_avro_schema._testing import assert_schema
3334

3435

@@ -418,6 +419,67 @@ def test_optional_str_py310():
418419
assert_schema(py_type, expected)
419420

420421

422+
@pytest.mark.skipif(sys.version_info < (3, 10), reason="Requires Python 3.10+")
423+
def test_optional_forward_ref_py310():
424+
class PyType:
425+
forward_ref: "TestPyType | None"
426+
427+
expected = {
428+
"fields": [
429+
{
430+
"name": "forward_ref",
431+
"type": [
432+
"PyType",
433+
"null",
434+
],
435+
},
436+
],
437+
"name": "PyType",
438+
"type": "record",
439+
}
440+
assert_schema(PyType, expected)
441+
442+
443+
def test_optional_forward_ref_with_union():
444+
class PyType:
445+
forward_ref: Union["TestPyType", None]
446+
447+
expected = {
448+
"fields": [
449+
{
450+
"name": "forward_ref",
451+
"type": [
452+
"PyType",
453+
"null",
454+
],
455+
},
456+
],
457+
"name": "PyType",
458+
"type": "record",
459+
}
460+
assert_schema(PyType, expected)
461+
462+
463+
def test_optional_forward_ref():
464+
class PyType:
465+
forward_ref: Optional["TestPyType"]
466+
467+
expected = {
468+
"fields": [
469+
{
470+
"name": "forward_ref",
471+
"type": [
472+
"PyType",
473+
"null",
474+
],
475+
},
476+
],
477+
"name": "PyType",
478+
"type": "record",
479+
}
480+
assert_schema(PyType, expected)
481+
482+
421483
def test_enum():
422484
class PyType(enum.Enum):
423485
RED = "RED"

0 commit comments

Comments
 (0)