Skip to content

Commit c7f00c0

Browse files
author
Vahid Tavanashad
committed
set a default value for axis parameter in dpnp.take_along_axis
1 parent a37fadb commit c7f00c0

File tree

3 files changed

+7
-4
lines changed

3 files changed

+7
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ This release achieves 100% compliance with Python Array API specification (revis
2828
* Updated Python Array API specification version supported to `2024.12` [#2416](https://github.com/IntelPython/dpnp/pull/2416)
2929
* Removed `einsum_call` keyword from `dpnp.einsum_path` signature [#2421](https://github.com/IntelPython/dpnp/pull/2421)
3030
* Changed `"max dimensions"` to `None` in array API capabilities [#2432](https://github.com/IntelPython/dpnp/pull/2432)
31+
* The parameter `axis` in `dpnp.take_along_axis` function has now a default value of `-1` []()
3132

3233
### Fixed
3334

dpnp/dpnp_iface_indexing.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2205,7 +2205,7 @@ def take(a, indices, /, *, axis=None, out=None, mode="wrap"):
22052205
return dpnp.get_result_array(usm_res, out=out)
22062206

22072207

2208-
def take_along_axis(a, indices, axis, mode="wrap"):
2208+
def take_along_axis(a, indices, axis=-1, mode="wrap"):
22092209
"""
22102210
Take values from the input array by matching 1d index and data slices.
22112211
@@ -2230,6 +2230,8 @@ def take_along_axis(a, indices, axis, mode="wrap"):
22302230
The axis to take 1d slices along. If axis is ``None``, the input
22312231
array is treated as if it had first been flattened to 1d,
22322232
for consistency with :obj:`dpnp.sort` and :obj:`dpnp.argsort`.
2233+
2234+
Default: ``-1``.
22332235
mode : {"wrap", "clip"}, optional
22342236
Specifies how out-of-bounds indices will be handled. Possible values
22352237
are:
@@ -2274,8 +2276,8 @@ def take_along_axis(a, indices, axis, mode="wrap"):
22742276
array([[10, 20, 30],
22752277
[40, 50, 60]])
22762278
2277-
The same works for max and min, if you maintain the trivial dimension
2278-
with ``keepdims``:
2279+
The same works for :obj:`dpnp.max` and :obj:`dpnp.min`, if you maintain
2280+
the trivial dimension with ``keepdims``:
22792281
22802282
>>> np.max(a, axis=1, keepdims=True)
22812283
array([[30],

dpnp/tests/test_indexing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -804,7 +804,7 @@ def test_argequivalent(self, func, argfunc, kwargs):
804804
# a = dpnp.random.random(size=(3, 4, 5))
805805
a = dpnp.asarray(numpy.random.random(size=(3, 4, 5)))
806806

807-
for axis in list(range(a.ndim)) + [None]:
807+
for axis in list(range(a.ndim)) + [None, -1]:
808808
a_func = func(a, axis=axis, **kwargs)
809809
ai_func = argfunc(a, axis=axis, **kwargs)
810810
assert_array_equal(

0 commit comments

Comments
 (0)