Skip to content

Commit 41f9880

Browse files
authored
feat: add support for specifying a tuple of axis positions in expand_dims
PR-URL: #988 Closes: #760 Reviewed-by: Lucas Colley Reviewed-by: Evgeni Burovski Reviewed-by: Jake Vanderplas
1 parent ffaf7fa commit 41f9880

File tree

5 files changed

+25
-9
lines changed

5 files changed

+25
-9
lines changed

src/array_api_stubs/_2021_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def concat(
2424
"""
2525

2626

27-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
27+
def expand_dims(x: array, /, axis: int) -> array:
2828
"""
2929
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
3030

src/array_api_stubs/_2022_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def concat(
5858
"""
5959

6060

61-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
61+
def expand_dims(x: array, /, axis: int) -> array:
6262
"""
6363
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
6464

src/array_api_stubs/_2023_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def concat(
7676
"""
7777

7878

79-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
79+
def expand_dims(x: array, /, axis: int) -> array:
8080
"""
8181
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
8282

src/array_api_stubs/_2024_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def concat(
7979
"""
8080

8181

82-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
82+
def expand_dims(x: array, /, axis: int) -> array:
8383
"""
8484
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
8585

src/array_api_stubs/_draft/manipulation_functions.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -79,26 +79,42 @@ def concat(
7979
"""
8080

8181

82-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
82+
def expand_dims(x: array, /, axis: Union[int, Tuple[int, ...]]) -> array:
8383
"""
84-
Expands the shape of an array by inserting a new axis of size one at the position specified by ``axis``.
84+
Expands the shape of an array by inserting a new axis of size one at the position (or positions) specified by ``axis``.
8585
8686
Parameters
8787
----------
8888
x: array
8989
input array.
90-
axis: int
91-
axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception. Default: ``0``.
90+
axis: Union[int, Tuple[int, ...]]
91+
axis position(s) (zero-based). If ``axis`` is an integer, ``axis`` **must** be equivalent to the tuple ``(axis,)``. If ``axis`` is a tuple,
92+
93+
- a valid axis position **must** reside on the half-open interval ``[-M, M)``, where ``M = N + len(axis)`` and ``N`` is the number of dimensions in ``x``.
94+
- if the i-th entry is a negative integer, the axis position of the inserted singleton dimension in the output array **must** be computed as ``M + axis[i]``.
95+
- each entry of ``axis`` must resolve to a unique positive axis position.
96+
- for each entry of ``axis``, the corresponding dimension in the expanded output array **must** be a singleton dimension.
97+
- for the remaining dimensions of the expanded output array, the output array dimensions **must** correspond to the dimensions of ``x`` in order.
98+
- if provided an invalid axis position, the function **must** raise an exception.
9299
93100
Returns
94101
-------
95102
out: array
96-
an expanded output array. **Must** have the same data type as ``x``.
103+
an expanded output array. **Must** have the same data type as ``x``. If ``axis`` is an integer, the output array must have ``N + 1`` dimensions. If ``axis`` is a tuple, the output array must have ``N + len(axis)`` dimensions.
97104
98105
Raises
99106
------
100107
IndexError
101108
If provided an invalid ``axis``, an ``IndexError`` **should** be raised.
109+
110+
Notes
111+
-----
112+
113+
- Calling this function with a tuple of axis positions **must** be semantically equivalent to calling this function repeatedly with a single axis position only when the following three conditions are met:
114+
115+
- each entry of the tuple is normalized to positive axis positions according to the number of dimensions in the expanded output array.
116+
- the normalized positive axis positions are sorted in ascending order.
117+
- the normalized positive axis positions are unique.
102118
"""
103119

104120

0 commit comments

Comments
 (0)