Skip to content

Commit 3a00e23

Browse files
committed
up
1 parent 19fe631 commit 3a00e23

1 file changed

Lines changed: 73 additions & 10 deletions

File tree

src/diffusers/pipelines/pipeline_utils.py

Lines changed: 73 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import types
2323
from dataclasses import dataclass
2424
from pathlib import Path
25-
from typing import Any, Callable, Optional, Union, get_args, get_origin
25+
from typing import Any, Callable, Optional, Union, get_args, get_origin, get_type_hints
2626

2727
import httpx
2828
import numpy as np
@@ -1815,15 +1815,78 @@ def _get_signature_keys(cls, obj):
18151815
@classmethod
18161816
def _get_signature_types(cls):
18171817
signature_types = {}
1818-
for k, v in inspect.signature(cls.__init__).parameters.items():
1819-
if inspect.isclass(v.annotation):
1820-
signature_types[k] = (v.annotation,)
1821-
elif get_origin(v.annotation) in [Union, types.UnionType]:
1822-
signature_types[k] = get_args(v.annotation)
1823-
elif get_origin(v.annotation) in [list, dict]:
1824-
signature_types[k] = (v.annotation,)
1825-
else:
1826-
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
1818+
module_globals = sys.modules.get(cls.__module__, {}).__dict__ if cls.__module__ in sys.modules else {}
1819+
localns = dict(vars(cls))
1820+
1821+
try:
1822+
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns, include_extras=True)
1823+
except TypeError:
1824+
type_hints = get_type_hints(cls.__init__, globalns=module_globals, localns=localns)
1825+
except Exception as exc:
1826+
logger.debug("Failed to resolve type hints for %s.__init__: %s", cls.__name__, exc)
1827+
type_hints = {}
1828+
1829+
def _is_union(annotation: Any) -> bool:
1830+
origin = get_origin(annotation)
1831+
union_type = getattr(types, "UnionType", None)
1832+
if origin in (Union, union_type):
1833+
return True
1834+
return union_type is not None and isinstance(annotation, union_type)
1835+
1836+
def _normalize_annotation(annotation: Any) -> tuple[type, ...]:
1837+
if annotation in (inspect._empty, None) or annotation is Any:
1838+
return ()
1839+
1840+
if inspect.isclass(annotation):
1841+
return (annotation,)
1842+
1843+
if _is_union(annotation):
1844+
collected: list[type] = []
1845+
for arg in get_args(annotation):
1846+
collected.extend(_normalize_annotation(arg))
1847+
# preserve order while removing duplicates
1848+
unique: list[type] = []
1849+
seen: set[type] = set()
1850+
for item in collected:
1851+
if item not in seen:
1852+
seen.add(item)
1853+
unique.append(item)
1854+
return tuple(unique)
1855+
1856+
origin = get_origin(annotation)
1857+
if origin is not None:
1858+
if getattr(origin, "__qualname__", "") == "Annotated":
1859+
args = get_args(annotation)
1860+
return _normalize_annotation(args[0]) if args else ()
1861+
if getattr(origin, "__qualname__", "") == "Literal":
1862+
return ()
1863+
if inspect.isclass(origin):
1864+
return (origin,)
1865+
1866+
return ()
1867+
1868+
for name, parameter in inspect.signature(cls.__init__).parameters.items():
1869+
if name == "self":
1870+
continue
1871+
1872+
annotation = type_hints.get(name, parameter.annotation)
1873+
1874+
if isinstance(annotation, str):
1875+
try:
1876+
annotation = eval(annotation, module_globals, localns) # noqa: S307
1877+
except Exception as exc: # noqa: BLE001
1878+
logger.debug(
1879+
"Failed to evaluate forward reference %r on %s.%s: %s", annotation, cls.__name__, name, exc
1880+
)
1881+
annotation = inspect._empty
1882+
1883+
normalized = _normalize_annotation(annotation)
1884+
1885+
if normalized:
1886+
signature_types[name] = normalized
1887+
elif annotation not in (inspect._empty, None, Any):
1888+
logger.warning(f"cannot get type annotation for Parameter {name} of {cls}.")
1889+
18271890
return signature_types
18281891

18291892
@property

0 commit comments

Comments
 (0)