|
2 | 2 | from collections.abc import Callable |
3 | 3 | from functools import wraps |
4 | 4 | from inspect import Parameter, signature |
5 | | -from typing import Any, TypeVar |
| 5 | +from typing import TYPE_CHECKING, Any, TypeVar |
| 6 | + |
| 7 | +import numpy as np |
| 8 | +from packaging.version import Version |
6 | 9 |
|
7 | 10 | from zarr.errors import ZarrFutureWarning |
8 | 11 |
|
| 12 | +if TYPE_CHECKING: |
| 13 | + from numpy.typing import NDArray |
| 14 | + |
9 | 15 | T = TypeVar("T") |
10 | 16 |
|
11 | 17 | # Based off https://github.com/scikit-learn/scikit-learn/blob/e87b32a81c70abed8f2e97483758eb64df8255e9/sklearn/utils/validation.py#L63 |
@@ -68,3 +74,37 @@ def inner_f(*args: Any, **kwargs: Any) -> T: |
68 | 74 | return _inner_deprecate_positional_args(func) |
69 | 75 |
|
70 | 76 | return _inner_deprecate_positional_args # type: ignore[return-value] |
| 77 | + |
| 78 | + |
| 79 | +def _reshape_view(arr: "NDArray[Any]", shape: tuple[int, ...]) -> "NDArray[Any]": |
| 80 | + """Reshape an array without copying data. |
| 81 | +
|
| 82 | + This function provides compatibility across NumPy versions for reshaping arrays |
| 83 | + as views. On NumPy >= 2.1, it uses ``reshape(copy=False)`` which explicitly |
| 84 | + fails if a view cannot be created. On older versions, it uses direct shape |
| 85 | + assignment which has the same behavior but is deprecated in 2.5+. |
| 86 | +
|
| 87 | + Parameters |
| 88 | + ---------- |
| 89 | + arr : NDArray |
| 90 | + The array to reshape. |
| 91 | + shape : tuple of int |
| 92 | + The new shape. |
| 93 | +
|
| 94 | + Returns |
| 95 | + ------- |
| 96 | + NDArray |
| 97 | + A reshaped view of the array. |
| 98 | +
|
| 99 | + Raises |
| 100 | + ------ |
| 101 | + AttributeError |
| 102 | + If a view cannot be created (the array is not contiguous) on NumPy < 2.1. |
| 103 | + ValueError |
| 104 | + If a view cannot be created (the array is not contiguous) on NumPy >= 2.1. |
| 105 | + """ |
| 106 | + if Version(np.__version__) >= Version("2.1"): |
| 107 | + return arr.reshape(shape, copy=False) # type: ignore[call-overload, no-any-return] |
| 108 | + else: |
| 109 | + arr.shape = shape |
| 110 | + return arr |
0 commit comments