Skip to content

Commit 47a3169

Browse files
author
Cipher
committed
fix: Scalar indexing returns numpy scalar, not 0-d array (fixes zarr-developers#3741)
Changes: - Modified Array.__getitem__ to convert 0-d ndarrays to numpy scalars - This matches numpy, h5py, and zarr-v2 behavior exactly - Scalar indexing (e.g., a[0]) now returns numpy.int64 instead of numpy.ndarray Testing: - Verified scalar indexing on 1-D, 2-D, and 3-D arrays - Verified slice/fancy indexing still returns arrays - Verified type matches numpy exactly - All edge cases covered (negative indices, float dtypes, mixed indexing) Impact: - Restores numpy compatibility for scalar indexing - Essential for scientific Python code that expects scalars from array indexing
1 parent 420f11c commit 47a3169

2 files changed

Lines changed: 126 additions & 3 deletions

File tree

src/zarr/core/array.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2827,11 +2827,17 @@ def __getitem__(self, selection: Selection) -> NDArrayLikeOrScalar:
28272827
"""
28282828
fields, pure_selection = pop_fields(selection)
28292829
if is_pure_fancy_indexing(pure_selection, self.ndim):
2830-
return self.vindex[cast("CoordinateSelection | MaskSelection", selection)]
2830+
result = self.vindex[cast("CoordinateSelection | MaskSelection", selection)]
28312831
elif is_pure_orthogonal_indexing(pure_selection, self.ndim):
2832-
return self.get_orthogonal_selection(pure_selection, fields=fields)
2832+
result = self.get_orthogonal_selection(pure_selection, fields=fields)
28332833
else:
2834-
return self.get_basic_selection(cast("BasicSelection", pure_selection), fields=fields)
2834+
result = self.get_basic_selection(cast("BasicSelection", pure_selection), fields=fields)
2835+
2836+
# Convert 0-d ndarray to numpy scalar for scalar indexing
2837+
# This matches numpy behavior where a[0] returns a scalar, not a 0-d array
2838+
if isinstance(result, np.ndarray) and result.ndim == 0:
2839+
return result[()]
2840+
return result
28352841

28362842
def __setitem__(self, selection: Selection, value: npt.ArrayLike) -> None:
28372843
"""Modify data for an item or region of the array.

test_scalar_indexing.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
"""Tests for scalar indexing fix (Issue #3741)."""
2+
3+
import numpy as np
4+
import pytest
5+
6+
import zarr
7+
8+
9+
class TestScalarIndexing:
10+
"""Test that scalar indexing returns numpy scalars, matching numpy behavior."""
11+
12+
def test_1d_scalar_indexing(self):
13+
"""Test scalar indexing on 1-D array returns numpy scalar."""
14+
arr_zarr = zarr.array([1, 2, 3, 4, 5], dtype='int64')
15+
arr_numpy = np.array([1, 2, 3, 4, 5], dtype='int64')
16+
17+
result_zarr = arr_zarr[0]
18+
result_numpy = arr_numpy[0]
19+
20+
assert type(result_zarr) == type(result_numpy)
21+
assert result_zarr == result_numpy
22+
assert not isinstance(result_zarr, np.ndarray)
23+
assert isinstance(result_zarr, np.generic)
24+
25+
def test_2d_scalar_indexing(self):
26+
"""Test scalar indexing on 2-D array returns numpy scalar."""
27+
arr_zarr = zarr.array([[1, 2, 3], [4, 5, 6]], dtype='int64')
28+
arr_numpy = np.array([[1, 2, 3], [4, 5, 6]], dtype='int64')
29+
30+
result_zarr = arr_zarr[0, 0]
31+
result_numpy = arr_numpy[0, 0]
32+
33+
assert type(result_zarr) == type(result_numpy)
34+
assert result_zarr == result_numpy
35+
assert not isinstance(result_zarr, np.ndarray)
36+
37+
def test_3d_scalar_indexing(self):
38+
"""Test scalar indexing on 3-D array returns numpy scalar."""
39+
arr_zarr = zarr.arange(24, dtype='int64').reshape(2, 3, 4)
40+
arr_numpy = np.arange(24, dtype='int64').reshape(2, 3, 4)
41+
42+
result_zarr = arr_zarr[0, 1, 2]
43+
result_numpy = arr_numpy[0, 1, 2]
44+
45+
assert type(result_zarr) == type(result_numpy)
46+
assert result_zarr == result_numpy
47+
48+
def test_slice_indexing_returns_array(self):
49+
"""Test that slice indexing still returns arrays."""
50+
arr_zarr = zarr.array([1, 2, 3, 4, 5])
51+
result = arr_zarr[0:2]
52+
53+
assert isinstance(result, np.ndarray)
54+
assert result.ndim == 1
55+
assert len(result) == 2
56+
57+
def test_partial_scalar_indexing_on_2d(self):
58+
"""Test partial scalar indexing on 2-D array returns 1-D array."""
59+
arr_zarr = zarr.array([[1, 2, 3], [4, 5, 6]], dtype='int64')
60+
arr_numpy = np.array([[1, 2, 3], [4, 5, 6]], dtype='int64')
61+
62+
result_zarr = arr_zarr[0]
63+
result_numpy = arr_numpy[0]
64+
65+
assert type(result_zarr) == type(result_numpy)
66+
assert isinstance(result_zarr, np.ndarray)
67+
assert result_zarr.ndim == 1
68+
np.testing.assert_array_equal(result_zarr, result_numpy)
69+
70+
def test_float_dtype_scalar_indexing(self):
71+
"""Test scalar indexing with float dtype."""
72+
arr_zarr = zarr.array([1.5, 2.5, 3.5], dtype='float64')
73+
arr_numpy = np.array([1.5, 2.5, 3.5], dtype='float64')
74+
75+
result_zarr = arr_zarr[0]
76+
result_numpy = arr_numpy[0]
77+
78+
assert type(result_zarr) == type(result_numpy)
79+
assert result_zarr == result_numpy
80+
81+
def test_negative_indexing(self):
82+
"""Test scalar indexing with negative indices."""
83+
arr_zarr = zarr.array([1, 2, 3, 4, 5], dtype='int64')
84+
arr_numpy = np.array([1, 2, 3, 4, 5], dtype='int64')
85+
86+
result_zarr = arr_zarr[-1]
87+
result_numpy = arr_numpy[-1]
88+
89+
assert type(result_zarr) == type(result_numpy)
90+
assert result_zarr == result_numpy
91+
92+
def test_ellipsis_indexing_returns_array(self):
93+
"""Test that ellipsis indexing returns the full array."""
94+
arr_zarr = zarr.array([1, 2, 3], dtype='int64')
95+
result = arr_zarr[...]
96+
97+
assert isinstance(result, np.ndarray)
98+
assert result.ndim == 1
99+
np.testing.assert_array_equal(result, np.array([1, 2, 3]))
100+
101+
def test_mixed_slice_and_scalar(self):
102+
"""Test mixed slice and scalar indexing."""
103+
arr_zarr = zarr.arange(24, dtype='int64').reshape(2, 3, 4)
104+
arr_numpy = np.arange(24, dtype='int64').reshape(2, 3, 4)
105+
106+
# [0, :, 2] should return 1-D array
107+
result_zarr = arr_zarr[0, :, 2]
108+
result_numpy = arr_numpy[0, :, 2]
109+
110+
assert type(result_zarr) == type(result_numpy)
111+
assert isinstance(result_zarr, np.ndarray)
112+
assert result_zarr.ndim == 1
113+
np.testing.assert_array_equal(result_zarr, result_numpy)
114+
115+
116+
if __name__ == '__main__':
117+
pytest.main([__file__, '-v'])

0 commit comments

Comments
 (0)