|
89 | 89 | .. class:: SerializedContainer |
90 | 90 |
|
91 | 91 | :canonical: arraycontext.SerializedContainer |
| 92 | +
|
| 93 | +References |
| 94 | +---------- |
| 95 | +
|
| 96 | +.. class:: GenericAlias |
| 97 | +
|
| 98 | + See :class:`types.GenericAlias`. |
| 99 | +
|
| 100 | +.. class:: UnionType |
| 101 | +
|
| 102 | + See :class:`types.UnionType`. |
92 | 103 | """ |
93 | 104 |
|
94 | 105 | from __future__ import annotations |
|
120 | 131 |
|
121 | 132 | from collections.abc import Hashable, Sequence |
122 | 133 | from functools import singledispatch |
123 | | -from types import GenericAlias, UnionType |
124 | 134 | from typing import ( |
125 | 135 | TYPE_CHECKING, |
126 | 136 | TypeAlias, |
|
133 | 143 | import numpy as np |
134 | 144 | from typing_extensions import TypeIs |
135 | 145 |
|
136 | | -from pytools.obj_array import ObjectArrayND as ObjectArrayND |
| 146 | +from pytools.obj_array import ObjectArray, ObjectArrayND as ObjectArrayND |
137 | 147 |
|
138 | 148 | from arraycontext.typing import ( |
| 149 | + ArithArrayContainer, |
139 | 150 | ArrayContainer, |
140 | 151 | ArrayContainerT, |
141 | 152 | ArrayOrArithContainer, |
142 | 153 | ArrayOrArithContainerOrScalar as ArrayOrArithContainerOrScalar, |
143 | 154 | ArrayOrContainerOrScalar, |
| 155 | + _UserDefinedArithArrayContainer, |
| 156 | + _UserDefinedArrayContainer, |
144 | 157 | ) |
145 | 158 |
|
146 | 159 |
|
147 | 160 | if TYPE_CHECKING: |
| 161 | + from types import GenericAlias, UnionType |
| 162 | + |
148 | 163 | from pymbolic.geometric_algebra import CoeffT, MultiVector |
149 | 164 |
|
150 | 165 | from arraycontext.context import ArrayContext |
@@ -217,17 +232,21 @@ def is_array_container_type(cls: type | GenericAlias | UnionType) -> bool: |
217 | 232 | function will say that :class:`numpy.ndarray` is an array container |
218 | 233 | type, only object arrays *actually are* array containers. |
219 | 234 | """ |
220 | | - if cls is ArrayContainer: |
| 235 | + if cls is ArrayContainer or cls is ArithArrayContainer: |
221 | 236 | return True |
222 | 237 |
|
223 | | - while isinstance(cls, GenericAlias): |
224 | | - cls = get_origin(cls) |
| 238 | + origin = get_origin(cls) |
| 239 | + if origin is not None: |
| 240 | + cls = origin # pyright: ignore[reportAny] |
225 | 241 |
|
226 | 242 | assert isinstance(cls, type), ( |
227 | 243 | f"must pass a {type!r}, not a '{cls!r}'") |
228 | 244 |
|
229 | 245 | return ( |
230 | | - cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] |
| 246 | + cls is ObjectArray |
| 247 | + or cls is ArrayContainer # pyright: ignore[reportUnnecessaryComparison] |
| 248 | + or cls is _UserDefinedArrayContainer |
| 249 | + or cls is _UserDefinedArithArrayContainer |
231 | 250 | or (serialize_container.dispatch(cls) |
232 | 251 | is not serialize_container.__wrapped__)) # type:ignore[attr-defined] |
233 | 252 |
|
|
0 commit comments