Skip to content

Commit 1c05c1a

Browse files
authored
Set copy=False in reshape operation (#3649)
* set copy=False in reshape operation * add compat reshape with conditional * fix docstring * fix mypy * remove tuple unpacking which makes it more readable
1 parent 7166a8d commit 1c05c1a

File tree

2 files changed

+44
-3
lines changed

2 files changed

+44
-3
lines changed

src/zarr/_compat.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,16 @@
22
from collections.abc import Callable
33
from functools import wraps
44
from inspect import Parameter, signature
5-
from typing import Any, TypeVar
5+
from typing import TYPE_CHECKING, Any, TypeVar
6+
7+
import numpy as np
8+
from packaging.version import Version
69

710
from zarr.errors import ZarrFutureWarning
811

12+
if TYPE_CHECKING:
13+
from numpy.typing import NDArray
14+
915
T = TypeVar("T")
1016

1117
# Based off https://github.com/scikit-learn/scikit-learn/blob/e87b32a81c70abed8f2e97483758eb64df8255e9/sklearn/utils/validation.py#L63
@@ -68,3 +74,37 @@ def inner_f(*args: Any, **kwargs: Any) -> T:
6874
return _inner_deprecate_positional_args(func)
6975

7076
return _inner_deprecate_positional_args # type: ignore[return-value]
77+
78+
79+
def _reshape_view(arr: "NDArray[Any]", shape: tuple[int, ...]) -> "NDArray[Any]":
80+
"""Reshape an array without copying data.
81+
82+
This function provides compatibility across NumPy versions for reshaping arrays
83+
as views. On NumPy >= 2.1, it uses ``reshape(copy=False)`` which explicitly
84+
fails if a view cannot be created. On older versions, it uses direct shape
85+
assignment which has the same behavior but is deprecated in 2.5+.
86+
87+
Parameters
88+
----------
89+
arr : NDArray
90+
The array to reshape.
91+
shape : tuple of int
92+
The new shape.
93+
94+
Returns
95+
-------
96+
NDArray
97+
A reshaped view of the array.
98+
99+
Raises
100+
------
101+
AttributeError
102+
If a view cannot be created (the array is not contiguous) on NumPy < 2.1.
103+
ValueError
104+
If a view cannot be created (the array is not contiguous) on NumPy >= 2.1.
105+
"""
106+
if Version(np.__version__) >= Version("2.1"):
107+
return arr.reshape(shape, copy=False) # type: ignore[call-overload, no-any-return]
108+
else:
109+
arr.shape = shape
110+
return arr

src/zarr/codecs/vlen_utf8.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import numpy as np
77
from numcodecs.vlen import VLenBytes, VLenUTF8
88

9+
from zarr._compat import _reshape_view
910
from zarr.abc.codec import ArrayBytesCodec
1011
from zarr.core.buffer import Buffer, NDBuffer
1112
from zarr.core.common import JSON, parse_named_configuration
@@ -50,7 +51,7 @@ async def _decode_single(
5051
raw_bytes = chunk_bytes.as_array_like()
5152
decoded = _vlen_utf8_codec.decode(raw_bytes)
5253
assert decoded.dtype == np.object_
53-
decoded = decoded.reshape(chunk_spec.shape)
54+
decoded = _reshape_view(decoded, chunk_spec.shape)
5455
as_string_dtype = decoded.astype(chunk_spec.dtype.to_native_dtype(), copy=False)
5556
return chunk_spec.prototype.nd_buffer.from_numpy_array(as_string_dtype)
5657

@@ -95,7 +96,7 @@ async def _decode_single(
9596
raw_bytes = chunk_bytes.as_array_like()
9697
decoded = _vlen_bytes_codec.decode(raw_bytes)
9798
assert decoded.dtype == np.object_
98-
decoded = decoded.reshape(chunk_spec.shape)
99+
decoded = _reshape_view(decoded, chunk_spec.shape)
99100
return chunk_spec.prototype.nd_buffer.from_numpy_array(decoded)
100101

101102
async def _encode_single(

0 commit comments

Comments
 (0)