|
22 | 22 | import types |
23 | 23 | from dataclasses import dataclass |
24 | 24 | from pathlib import Path |
25 | | -from typing import Any, Callable, Optional, Union, get_args, get_origin |
| 25 | +from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints |
26 | 26 |
|
27 | 27 | import httpx |
28 | 28 | import numpy as np |
@@ -1815,15 +1815,78 @@ def _get_signature_keys(cls, obj): |
1815 | 1815 | @classmethod |
1816 | 1816 | def _get_signature_types(cls): |
1817 | 1817 | signature_types = {} |
1818 | | - for k, v in inspect.signature(cls.__init__).parameters.items(): |
1819 | | - if inspect.isclass(v.annotation): |
1820 | | - signature_types[k] = (v.annotation,) |
1821 | | - elif get_origin(v.annotation) in [Union, types.UnionType]: |
1822 | | - signature_types[k] = get_args(v.annotation) |
1823 | | - elif get_origin(v.annotation) in [list, dict]: |
1824 | | - signature_types[k] = (v.annotation,) |
1825 | | - else: |
1826 | | - logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.") |
| 1818 | + module_globals = sys.modules.get(cls.__module__, {}).__dict__ if cls.__module__ in sys.modules else {} |
| 1819 | + localns = dict(vars(cls)) |
| 1820 | + |
| 1821 | + try: |
| 1822 | + type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns, include_extras=True) |
| 1823 | + except TypeError: |
| 1824 | + type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns) |
| 1825 | + except Exception as exc: |
| 1826 | + logger.debug("Failed to resolve type hints for %s.__init__: %s", cls.__name__, exc) |
| 1827 | + type_hints = {} |
| 1828 | + |
| 1829 | + def _is_union(annotation: Any) -> bool: |
| 1830 | + origin = get_origin(annotation) |
| 1831 | + union_type = getattr(types, "UnionType", None) |
| 1832 | + if origin in (Union, union_type): |
| 1833 | + return True |
| 1834 | + return union_type is not None and isinstance(annotation, union_type) |
| 1835 | + |
| 1836 | + def _normalize_annotation(annotation: Any) -> tuple[type, ...]: |
| 1837 | + if annotation in (inspect._empty, None) or annotation is Any: |
| 1838 | + return () |
| 1839 | + |
| 1840 | + if inspect.isclass(annotation): |
| 1841 | + return (annotation,) |
| 1842 | + |
| 1843 | + if _is_union(annotation): |
| 1844 | + collected: list[type] = [] |
| 1845 | + for arg in get_args(annotation): |
| 1846 | + collected.extend(_normalize_annotation(arg)) |
| 1847 | + # preserve order while removing duplicates |
| 1848 | + unique: list[type] = [] |
| 1849 | + seen: set[type] = set() |
| 1850 | + for item in collected: |
| 1851 | + if item not in seen: |
| 1852 | + seen.add(item) |
| 1853 | + unique.append(item) |
| 1854 | + return tuple(unique) |
| 1855 | + |
| 1856 | + origin = get_origin(annotation) |
| 1857 | + if origin is not None: |
| 1858 | + if getattr(origin, "__qualname__", "") == "Annotated": |
| 1859 | + args = get_args(annotation) |
| 1860 | + return _normalize_annotation(args[0]) if args else () |
| 1861 | + if getattr(origin, "__qualname__", "") == "Literal": |
| 1862 | + return () |
| 1863 | + if inspect.isclass(origin): |
| 1864 | + return (origin,) |
| 1865 | + |
| 1866 | + return () |
| 1867 | + |
| 1868 | + for name, parameter in inspect.signature(cls.__init__).parameters.items(): |
| 1869 | + if name == "self": |
| 1870 | + continue |
| 1871 | + |
| 1872 | + annotation = type_hints.get(name, parameter.annotation) |
| 1873 | + |
| 1874 | + if isinstance(annotation, str): |
| 1875 | + try: |
| 1876 | + annotation = eval(annotation, module_globals, localns) # noqa: S307 |
| 1877 | + except Exception as exc: # noqa: BLE001 |
| 1878 | + logger.debug( |
| 1879 | + "Failed to evaluate forward reference %r on %s.%s: %s", annotation, cls.__name__, name, exc |
| 1880 | + ) |
| 1881 | + annotation = inspect._empty |
| 1882 | + |
| 1883 | + normalized = _normalize_annotation(annotation) |
| 1884 | + |
| 1885 | + if normalized: |
| 1886 | + signature_types[name] = normalized |
| 1887 | + elif annotation not in (inspect._empty, None, Any): |
| 1888 | + logger.warning(f"cannot get type annotation for Parameter {name} of {cls}.") |
| 1889 | + |
1827 | 1890 | return signature_types |
1828 | 1891 |
|
1829 | 1892 | @property |
|
0 commit comments