Skip to content

Commit 77770fb

Browse files
committed
Deprecate np.ndarray annotations in dataclass_array_container
1 parent da79de8 commit 77770fb

1 file changed

Lines changed: 15 additions & 7 deletions

File tree

arraycontext/container/dataclass.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,6 @@
1212
"""
1313
from __future__ import annotations
1414

15-
from pytools.obj_array import ObjectArray
16-
17-
from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar
18-
1915

2016
__copyright__ = """
2117
Copyright (C) 2020-1 University of Illinois Board of Trustees
@@ -54,10 +50,14 @@
5450
get_args,
5551
get_origin,
5652
)
53+
from warnings import warn
5754

5855
import numpy as np
5956

57+
from pytools.obj_array import ObjectArray
58+
6059
from arraycontext.container import ArrayContainer, is_array_container_type
60+
from arraycontext.context import ArrayOrContainer, ArrayOrContainerOrScalar
6161

6262

6363
if TYPE_CHECKING:
@@ -77,7 +77,15 @@ class _Field(NamedTuple):
7777
type: type
7878

7979

80-
def is_array_type(tp: type, /) -> bool:
80+
def _is_array_or_container_type(tp: type, /) -> bool:
81+
if tp is np.ndarray:
82+
warn("Encountered 'numpy.ndarray' in a dataclass_array_container. "
83+
"This is deprecated and will stop working in 2026. "
84+
"If you meant an object array, use pytools.obj_array.ObjectArray. "
85+
"For other uses, file an issue to discuss.",
86+
DeprecationWarning, stacklevel=3)
87+
return True
88+
8189
from arraycontext import Array
8290
return tp is Array or is_array_container_type(tp)
8391

@@ -151,7 +159,7 @@ def is_array_field(f: _Field) -> bool:
151159
if origin in (Union, UnionType): # pyright: ignore[reportDeprecated]
152160
for arg in get_args(field_type): # pyright: ignore[reportAny]
153161
if not (
154-
is_array_type(cast("type", arg))
162+
_is_array_or_container_type(cast("type", arg))
155163
or is_scalar_type(cast("type", arg))):
156164
raise TypeError(
157165
f"Field '{f.name}' union contains non-array container "
@@ -188,7 +196,7 @@ def is_array_field(f: _Field) -> bool:
188196
f"Field '{f.name}' not an instance of 'type': "
189197
f"'{field_type!r}'")
190198

191-
return is_array_type(field_type)
199+
return _is_array_or_container_type(field_type)
192200

193201
from pytools import partition
194202

0 commit comments

Comments
 (0)