Skip to content

Commit 4ac2047

Browse files
authored
Fix indexing zero-sized dimension (#761)
1 parent 713d67b commit 4ac2047

2 files changed

Lines changed: 17 additions & 1 deletion

File tree

cubed/core/indexing.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,8 @@ def merged_chunk_len_for_indexer(ia, c):
8787
# empty output case
8888
from cubed.array_api.creation_functions import empty
8989

90-
out = empty(shape, dtype=x.dtype, chunks=x.chunksize, spec=x.spec)
90+
chunks = tuple(c for c in x.chunksize if c > 0)
91+
out = empty(shape, dtype=x.dtype, chunks=chunks, spec=x.spec)
9192
else:
9293
dtype = x.dtype
9394
chunks = tuple(

cubed/tests/test_array_api.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -357,6 +357,21 @@ def test_index_2d_step(spec, shape, chunks, ind, new_chunks_expected):
357357
assert b.chunks == new_chunks_expected
358358

359359

360+
@pytest.mark.parametrize(
361+
("shape", "chunks", "ind"),
362+
[
363+
((0, 2), (0, 2), (slice(None), 0)),
364+
((0, 2), (1, 2), (slice(None), 0)),
365+
((2, 0), (2, 0), (0, slice(None))),
366+
((2, 0), (2, 1), (0, slice(None))),
367+
],
368+
)
369+
def test_index_zero_dim(shape, chunks, ind):
370+
a = xp.ones(shape, chunks=chunks)
371+
b = a[ind]
372+
assert_array_equal(b.compute(), np.ones(shape)[ind])
373+
374+
360375
def test_index_slice_unsupported_step(spec):
361376
a = xp.arange(12, chunks=(4,), spec=spec)
362377
with pytest.raises(NotImplementedError):

0 commit comments

Comments
 (0)