|
4 | 4 | from typing import Generic |
5 | 5 | from typing import Optional |
6 | 6 | from typing import TypeVar |
7 | | -from typing import Union |
8 | 7 | from typing import get_args |
9 | 8 | from typing import get_origin |
10 | 9 |
|
|
24 | 23 | from ..base import Uniqueness |
25 | 24 | from ..base import URIReference |
26 | 25 | from ..base import is_complex_attribute |
| 26 | +from ..utils import UNION_TYPES |
27 | 27 | from ..utils import normalize_attribute_name |
28 | 28 |
|
29 | 29 |
|
@@ -117,7 +117,7 @@ def __new__(cls, name, bases, attrs, **kwargs): |
117 | 117 | extensions = kwargs["__pydantic_generic_metadata__"]["args"][0] |
118 | 118 | extensions = ( |
119 | 119 | get_args(extensions) |
120 | | - if get_origin(extensions) == Union |
| 120 | + if get_origin(extensions) in UNION_TYPES |
121 | 121 | else [extensions] |
122 | 122 | ) |
123 | 123 | for extension in extensions: |
@@ -183,7 +183,8 @@ def get_extension_models(cls) -> dict[str, type[Extension]]: |
183 | 183 | extension_models = cls.__pydantic_generic_metadata__.get("args", []) |
184 | 184 | extension_models = ( |
185 | 185 | get_args(extension_models[0]) |
186 | | - if len(extension_models) == 1 and get_origin(extension_models[0]) == Union |
| 186 | + if len(extension_models) == 1 |
| 187 | + and get_origin(extension_models[0]) in UNION_TYPES |
187 | 188 | else extension_models |
188 | 189 | ) |
189 | 190 |
|
@@ -301,7 +302,7 @@ def model_to_schema(model: type[BaseModel]): |
301 | 302 |
|
302 | 303 | def get_reference_types(type) -> list[str]: |
303 | 304 | first_arg = get_args(type)[0] |
304 | | - types = get_args(first_arg) if get_origin(first_arg) == Union else [first_arg] |
| 305 | + types = get_args(first_arg) if get_origin(first_arg) in UNION_TYPES else [first_arg] |
305 | 306 |
|
306 | 307 | def serialize_ref_type(ref_type): |
307 | 308 | if ref_type == URIReference: |
|
0 commit comments