diff --git a/src/py_avro_schema/_schemas.py b/src/py_avro_schema/_schemas.py index 808baea..e6a579f 100644 --- a/src/py_avro_schema/_schemas.py +++ b/src/py_avro_schema/_schemas.py @@ -1156,8 +1156,13 @@ def __init__(self, py_type: Type, namespace: Optional[str] = None, options: Opti super().__init__(py_type, namespace=namespace, options=options) py_type = _type_from_annotated(py_type) + # Try to get resolved type hints, but fall back to raw annotations if there are unresolved forward refs + try: + type_hints = get_type_hints(py_type, include_extras=True) + except NameError: + type_hints = py_type.__annotations__ self.py_fields: list[tuple[str, type]] = [] - for k, v in py_type.__annotations__.items(): + for k, v in type_hints.items(): self.py_fields.append((k, v)) # We store __init__ parameters with default values. They can be used as defaults for the record. self.signature_fields = { diff --git a/tests/models/__init__.py b/tests/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/models/forward.py b/tests/models/forward.py new file mode 100644 index 0000000..7a8cba9 --- /dev/null +++ b/tests/models/forward.py @@ -0,0 +1,9 @@ +from __future__ import annotations + +Name = str + + +class PyClass: + """For testing imports with future annotations""" + + name: Name diff --git a/tests/test_plain_class.py b/tests/test_plain_class.py index c8d21de..4f9b32f 100644 --- a/tests/test_plain_class.py +++ b/tests/test_plain_class.py @@ -143,3 +143,20 @@ class PyType: expected = {"fields": [{"name": "details", "type": "string"}], "name": "PyType", "type": "record"} assert_schema(PyType, expected) + + +def test_type_aliases(): + Name = str + + class PyClass: + name: Name + + expected = {"fields": [{"name": "name", "type": "string"}], "name": "PyClass", "type": "record"} + assert_schema(PyClass, expected) + + +def test_type_aliases_future(): + from tests.models.forward import PyClass + + expected = {"fields": [{"name": "name", "type": "string"}], "name": "PyClass", "type": "record"} + assert_schema(PyClass, expected)