Skip to content

Commit 393dc86

Browse files
committed
ENH: add broadcast_shapes
1 parent c303adc commit 393dc86

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

array_api_strict/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,11 @@
9292
"result_type",
9393
]
9494

95+
# TODO: only add for 2025.12
96+
from ._data_type_functions import broadcast_shapes
97+
__all__ += ["broadcast_shapes"]
98+
99+
95100
from ._dtypes import (
96101
int8,
97102
int16,

array_api_strict/_data_type_functions.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,16 @@ def broadcast_arrays(*arrays: Array) -> list[Array]:
6262
]
6363

6464

65+
def broadcast_shapes(*shapes: tuple[int, ...]) -> tuple[int, ...]:
66+
"""
67+
Array API compatible wrapper for :py:func:`np.broadcast_shapes <numpy.broadcast_shapes>`.
68+
69+
See its docstring for more information.
70+
"""
71+
# TODO: only define for 2025.12
72+
return np.broadcast_shapes(*shapes)
73+
74+
6575
def broadcast_to(x: Array, /, shape: tuple[int, ...]) -> Array:
6676
"""
6777
Array API compatible wrapper for :py:func:`np.broadcast_to <numpy.broadcast_to>`.

0 commit comments

Comments
 (0)