Skip to content

Commit 97b9257

Browse files
Apply suggestions from code review
Co-authored-by: Lucas Colley <lucas.colley8@gmail.com>
1 parent cade0f1 commit 97b9257

3 files changed

Lines changed: 12 additions & 12 deletions

File tree

src/array_api_extra/_delegation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def diag_indices(
264264
Returns
265265
-------
266266
tuple of array
267-
``ndim`` 1-D integer arrays of length ``n`` that together index
267+
1-D integer arrays of length ``n`` that together index
268268
the main diagonal of an array of shape ``(n,) * ndim``.
269269
270270
Examples

src/array_api_extra/_lib/_funcs.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -350,7 +350,7 @@ def create_diagonal(
350350

351351

352352
def diag_indices(
353-
n: int, /, *, ndim: int = 2, device: Device | None = None, xp: ModuleType
353+
n: int, /, *, ndim: int, device: Device | None, xp: ModuleType
354354
) -> tuple[Array, ...]: # numpydoc ignore=PR01,RT01
355355
"""See docstring in array_api_extra._delegation."""
356356
idx = xp.arange(n, device=device)
@@ -368,8 +368,8 @@ def _tri_indices(
368368
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
369369
"""Shared implementation for `tril_indices` and `triu_indices`."""
370370
cols = n if m is None else m
371-
rows = xp.arange(n, device=device)[:, None]
372-
cols_a = xp.arange(cols, device=device)[None, :]
371+
rows = xp.arange(n, device=device)[:, xp.newaxis]
372+
cols_a = xp.arange(cols, device=device)[xp.newaxis, :]
373373
delta = cols_a - rows
374374
mask = delta >= offset if upper else delta <= offset
375375
r, c = xp.nonzero(mask)
@@ -380,9 +380,9 @@ def tril_indices(
380380
n: int,
381381
/,
382382
*,
383-
offset: int = 0,
384-
m: int | None = None,
385-
device: Device | None = None,
383+
offset: int,
384+
m: int | None,
385+
device: Device | None,
386386
xp: ModuleType,
387387
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
388388
"""See docstring in array_api_extra._delegation."""
@@ -393,9 +393,9 @@ def triu_indices(
393393
n: int,
394394
/,
395395
*,
396-
offset: int = 0,
397-
m: int | None = None,
398-
device: Device | None = None,
396+
offset: int,
397+
m: int | None,
398+
device: Device | None,
399399
xp: ModuleType,
400400
) -> tuple[Array, Array]: # numpydoc ignore=PR01,RT01
401401
"""See docstring in array_api_extra._delegation."""

tests/test_funcs.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -812,7 +812,7 @@ def test_torch(self, torch: ModuleType):
812812
@pytest.mark.xfail_xp_backend(Backend.SPARSE, reason="no arange", strict=False)
813813
class TestDiagIndices:
814814
def test_basic(self, xp: ModuleType):
815-
rows, cols = diag_indices(5, xp=xp)
815+
rows, cols = diag_indices(5)
816816
ref_rows, ref_cols = np.diag_indices(5)
817817
xp_assert_equal(rows, xp.asarray(ref_rows))
818818
xp_assert_equal(cols, xp.asarray(ref_cols))
@@ -865,7 +865,7 @@ def test_basic(
865865
xpx_fn: Callable[..., tuple[Array, Array]],
866866
np_fn: Callable[..., tuple[Array, Array]],
867867
):
868-
rows, cols = xpx_fn(4, xp=xp)
868+
rows, cols = xpx_fn(4)
869869
ref_rows, ref_cols = np_fn(4)
870870
xp_assert_equal(rows, xp.asarray(ref_rows))
871871
xp_assert_equal(cols, xp.asarray(ref_cols))

0 commit comments

Comments
 (0)