Skip to content

Commit 23ab899

Browse files
committed
DEP: deprecate broadcast_shapes, expand_dims
1 parent 3c50efc commit 23ab899

4 files changed

Lines changed: 49 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,8 @@ reportUnusedParameter = false
142142
reportImportCycles = false
143143
# PyRight can't trace types in lambdas
144144
reportUnknownLambdaType = false
145+
# conflicts with https://docs.astral.sh/ruff/rules/explicit-string-concatenation/
146+
reportImplicitStringConcatenation = false
145147

146148
executionEnvironments = [
147149
{ root = "tests", reportPrivateUsage = false, reportUnknownArgumentType = false },

src/array_api_extra/_delegation.py

Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
is_torch_namespace,
1616
)
1717
from ._lib._utils._compat import device as get_device
18-
from ._lib._utils._helpers import asarrays, eager_shape
18+
from ._lib._utils._helpers import asarrays, deprecated, eager_shape
1919
from ._lib._utils._typing import Array, DType
2020

2121
__all__ = [
@@ -83,19 +83,23 @@ def atleast_nd(x: Array, /, *, ndim: int, xp: ModuleType | None = None) -> Array
8383
return _funcs.atleast_nd(x, ndim=ndim, xp=xp)
8484

8585

86+
@deprecated(
87+
"`xpx.broadcast_shapes` is deprecated and will be removed in v1.0.0. "
88+
"`xp.broadcast_shapes` exists in the standard as of v2025.12."
89+
)
8690
def broadcast_shapes(
8791
*shapes: tuple[float | None, ...], xp: ModuleType | None = None
8892
) -> tuple[int | None, ...]:
8993
"""
9094
Compute the shape of the broadcasted arrays.
9195
96+
.. deprecated:: 0.11.0
97+
:func:`broadcast_shapes` is deprecated and will be removed in v1.0.0.
98+
:func:`array_api.broadcast_shapes` exists in the standard as of v2025.12.
99+
92100
Duplicates :func:`numpy.broadcast_shapes`, with additional support for
93101
None and NaN sizes.
94102
95-
This is equivalent to ``xp.broadcast_arrays(arr1, arr2, ...)[0].shape``
96-
without needing to worry about the backend potentially deep copying
97-
the arrays.
98-
99103
Parameters
100104
----------
101105
*shapes : tuple[int | None, ...]
@@ -300,18 +304,25 @@ def create_diagonal(
300304
return _funcs.create_diagonal(x, offset=offset, xp=xp)
301305

302306

307+
@deprecated(
308+
"`xpx.expand_dims` is deprecated and will be removed in v1.0.0. "
309+
"`xp.expand_dims` with support for a tuple of ints in `axis` "
310+
"exists in the standard as of v2025.12."
311+
)
303312
def expand_dims(
304313
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType | None = None
305314
) -> Array:
306315
"""
307316
Expand the shape of an array.
308317
318+
.. deprecated:: 0.11.0
319+
:func:`expand_dims` is deprecated and will be removed in v1.0.0.
320+
:func:`array_api.expand_dims` with support for a tuple of ints in `axis`
321+
exists in the standard as of v2025.12.
322+
309323
Insert (a) new axis/axes that will appear at the position(s) specified by
310324
`axis` in the expanded array shape.
311325
312-
This is ``xp.expand_dims`` for `axis` an int *or a tuple of ints*.
313-
Roughly equivalent to ``numpy.expand_dims`` for NumPy arrays.
314-
315326
Parameters
316327
----------
317328
a : array
@@ -804,7 +815,7 @@ def searchsorted(
804815
Find the indices into a sorted array ``x1`` such that if the elements in ``x2``
805816
were inserted before the indices, the resulting array would remain sorted.
806817
807-
The behavior of this function is similar to that of `array_api.searchsorted`,
818+
The behavior of this function is similar to that of :func:`array_api.searchsorted`,
808819
but it relaxes the requirement that `x1` must be one-dimensional.
809820
This function is vectorized, treating slices along the last axis
810821
as elements and preceding axes as batch (or "loop") dimensions.
@@ -1220,8 +1231,8 @@ def isin(
12201231
"""
12211232
Determine whether each element in `a` is present in `b`.
12221233
1223-
Return a boolean array of the same shape as `a` that is True for elements
1224-
that are in `b` and False otherwise.
1234+
This is :func:`array_api.isin`, with additional `assume_unique`
1235+
and `kind` parameters.
12251236
12261237
Parameters
12271238
----------

src/array_api_extra/_lib/_utils/_helpers.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
from __future__ import annotations
44

5+
import functools
56
import io
67
import math
78
import pickle
89
import types
10+
import warnings
911
from collections.abc import Callable, Generator, Iterable, Iterator
1012
from functools import wraps
1113
from types import ModuleType
@@ -48,6 +50,7 @@ def override(func):
4850
__all__ = [
4951
"asarrays",
5052
"capabilities",
53+
"deprecated",
5154
"eager_shape",
5255
"in1d",
5356
"is_python_scalar",
@@ -58,6 +61,26 @@ def override(func):
5861
]
5962

6063

64+
def deprecated(
65+
msg: str, stacklevel: int = 2
66+
) -> Callable[[Callable[P, T]], Callable[P, T]]: # numpydoc ignore=PR01,RT01
67+
"""Deprecate a function by emitting a warning on use."""
68+
69+
def decorate(func: Callable[P, T]) -> Callable[P, T]: # numpydoc ignore=GL08
70+
@functools.wraps(func)
71+
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: # numpydoc ignore=GL08
72+
warnings.warn(
73+
msg,
74+
category=DeprecationWarning,
75+
stacklevel=stacklevel,
76+
)
77+
return func(*args, **kwargs)
78+
79+
return wrapper
80+
81+
return decorate
82+
83+
6184
def in1d(
6285
x1: Array,
6386
x2: Array,

tests/test_funcs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -488,6 +488,7 @@ def test_5D_values(self, xp: ModuleType):
488488
assert_equal(y, xp.asarray([[[[[[[[[3.0]], [[2.0]]]]]]]]]))
489489

490490

491+
@pytest.mark.filterwarnings("ignore:.*removed in v1.0.0.*:DeprecationWarning")
491492
class TestBroadcastShapes:
492493
def test_delegates_known_integer_shapes(self, monkeypatch: pytest.MonkeyPatch):
493494
calls = []
@@ -828,6 +829,7 @@ def test_torch(self, torch: ModuleType):
828829
assert default_dtype(xp, "complex floating") == xp.complex64
829830

830831

832+
@pytest.mark.filterwarnings(r"ignore:.*removed in v1.0.0.*:DeprecationWarning")
831833
class TestExpandDims:
832834
def test_single_axis(self, xp: ModuleType):
833835
"""Trivial case where xpx.expand_dims doesn't add anything to xp.expand_dims"""

0 commit comments

Comments
 (0)