Skip to content

Commit a06318e

Browse files
authored
allow numpy ints in shapelike (#3706)
* allow numpy ints in shapelike * changelog * type checker
1 parent 3e7d24d commit a06318e

File tree

3 files changed

+22
-9
lines changed

3 files changed

+22
-9
lines changed

changes/3706.misc.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Allow NumPy ints as input when declaring a shape.

src/zarr/core/common.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
overload,
2222
)
2323

24+
import numpy as np
2425
from typing_extensions import ReadOnly
2526

2627
from zarr.core.config import config as zarr_config
@@ -37,7 +38,7 @@
3738
ZMETADATA_V2_JSON = ".zmetadata"
3839

3940
BytesLike = bytes | bytearray | memoryview
40-
ShapeLike = Iterable[int] | int
41+
ShapeLike = Iterable[int | np.integer[Any]] | int | np.integer[Any]
4142
# For backwards compatibility
4243
ChunkCoords = tuple[int, ...]
4344
ZarrFormat = Literal[2, 3]
@@ -185,23 +186,28 @@ def parse_named_configuration(
185186

186187

187188
def parse_shapelike(data: ShapeLike) -> tuple[int, ...]:
188-
if isinstance(data, int):
189+
"""
190+
Parse a shape-like input into an explicit shape.
191+
"""
192+
if isinstance(data, int | np.integer):
189193
if data < 0:
190194
raise ValueError(f"Expected a non-negative integer. Got {data} instead")
191-
return (data,)
195+
return (int(data),)
192196
try:
193197
data_tuple = tuple(data)
194198
except TypeError as e:
195199
msg = f"Expected an integer or an iterable of integers. Got {data} instead."
196200
raise TypeError(msg) from e
197201

198-
if not all(isinstance(v, int) for v in data_tuple):
202+
if not all(isinstance(v, int | np.integer) for v in data_tuple):
199203
msg = f"Expected an iterable of integers. Got {data} instead."
200204
raise TypeError(msg)
201205
if not all(v > -1 for v in data_tuple):
202206
msg = f"Expected all values to be non-negative. Got {data} instead."
203207
raise ValueError(msg)
204-
return data_tuple
208+
209+
# cast NumPy scalars to plain python ints
210+
return tuple(int(x) for x in data_tuple)
205211

206212

207213
def parse_fill_value(data: Any) -> Any:

tests/test_common.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
from collections.abc import Iterable
34
from typing import TYPE_CHECKING, get_args
45

56
import numpy as np
@@ -15,7 +16,6 @@
1516
from zarr.core.config import parse_indexing_order
1617

1718
if TYPE_CHECKING:
18-
from collections.abc import Iterable
1919
from typing import Any, Literal
2020

2121

@@ -115,9 +115,15 @@ def test_parse_shapelike_invalid_iterable_values(data: Any) -> None:
115115
parse_shapelike(data)
116116

117117

118-
@pytest.mark.parametrize("data", [range(10), [0, 1, 2, 3], (3, 4, 5), ()])
119-
def test_parse_shapelike_valid(data: Iterable[int]) -> None:
120-
assert parse_shapelike(data) == tuple(data)
118+
@pytest.mark.parametrize(
119+
"data", [range(10), [0, 1, 2, np.uint64(3)], (3, 4, 5), (), 1, np.uint8(1)]
120+
)
121+
def test_parse_shapelike_valid(data: Iterable[int] | int) -> None:
122+
if isinstance(data, Iterable):
123+
expected = tuple(data)
124+
else:
125+
expected = (data,)
126+
assert parse_shapelike(data) == expected
121127

122128

123129
# todo: more dtypes

0 commit comments

Comments
 (0)