Skip to content

Commit d9f2a7a

Browse files
committed
fix: address signature regression in expand_dims
In #354, a regression was introduced which reverted a change to the signature of `expand_dims`. Namely, the `axis` argument should not have been made optional and should not have had a default value. Ref: #331 Ref: #354
1 parent 3a801d3 commit d9f2a7a

5 files changed

Lines changed: 6 additions & 6 deletions

File tree

src/array_api_stubs/_2021_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def concat(
2424
"""
2525

2626

27-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
27+
def expand_dims(x: array, /, axis: int) -> array:
2828
"""
2929
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
3030

src/array_api_stubs/_2022_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def concat(
5858
"""
5959

6060

61-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
61+
def expand_dims(x: array, /, axis: int) -> array:
6262
"""
6363
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
6464

src/array_api_stubs/_2023_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def concat(
7676
"""
7777

7878

79-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
79+
def expand_dims(x: array, /, axis: int) -> array:
8080
"""
8181
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
8282

src/array_api_stubs/_2024_12/manipulation_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def concat(
7979
"""
8080

8181

82-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
82+
def expand_dims(x: array, /, axis: int) -> array:
8383
"""
8484
Expands the shape of an array by inserting a new axis (dimension) of size one at the position specified by ``axis``.
8585

src/array_api_stubs/_draft/manipulation_functions.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def concat(
7979
"""
8080

8181

82-
def expand_dims(x: array, /, *, axis: int = 0) -> array:
82+
def expand_dims(x: array, /, axis: int) -> array:
8383
"""
8484
Expands the shape of an array by inserting a new axis of size one at the position specified by ``axis``.
8585
@@ -88,7 +88,7 @@ def expand_dims(x: array, /, *, axis: int = 0) -> array:
8888
x: array
8989
input array.
9090
axis: int
91-
axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception. Default: ``0``.
91+
axis position (zero-based). A valid ``axis`` **must** reside on the closed-interval ``[-N-1, N]``, where ``N`` is the number of axes in ``x``. If an axis is specified as a negative integer, the axis position at which to insert a singleton dimension **must** be computed as ``N + axis + 1``. Hence, if provided ``-1``, the resolved axis position **must** be ``N`` (i.e., a singleton dimension **must** be appended to the input array ``x``). If provided ``-N-1``, the resolved axis position **must** be ``0`` (i.e., a singleton dimension **must** be prepended to the input array ``x``). If provided an invalid axis, the function **must** raise an exception.
9292
9393
Returns
9494
-------

0 commit comments

Comments
 (0)