1+ from __future__ import annotations
2+
13import functools
24import sys
35from abc import ABC , abstractmethod
1012from xarray .core import utils
1113from xarray .core .parallelcompat import ChunkManagerEntrypoint
1214from xarray .core .pycompat import is_chunked_array , is_duck_dask_array
13- from xarray .core .types import T_Chunks , T_NormalizedChunks
14-
15- T_ChunkedArray = TypeVar ("T_ChunkedArray" )
1615
17- CHUNK_MANAGERS : dict [str , type ["ChunkManagerEntrypoint" ]] = {}
1816
1917if TYPE_CHECKING :
20- from xarray .core .types import CubedArray , ZarrArray
18+ from xarray .core .types import T_Chunks , T_NormalizedChunks
19+ from cubed import Array as CubedArray
2120
2221
2322class CubedManager (ChunkManagerEntrypoint ["CubedArray" ]):
2423 array_cls : type ["CubedArray" ]
2524
26- def __init__ (self ):
25+ def __init__ (self ) -> None :
2726 from cubed import Array
2827
2928 self .array_cls = Array
@@ -33,15 +32,21 @@ def chunks(self, data: "CubedArray") -> T_NormalizedChunks:
3332
3433 def normalize_chunks (
3534 self ,
36- chunks : T_Chunks ,
37- shape : Union [ tuple [int ], None ] = None ,
38- limit : Union [ int , None ] = None ,
39- dtype : Union [ np .dtype , None ] = None ,
40- previous_chunks : T_NormalizedChunks = None ,
41- ) -> tuple [ tuple [ int , ...], ...] :
35+ chunks : T_Chunks | T_NormalizedChunks ,
36+ shape : tuple [int , ...] | None = None ,
37+ limit : int | None = None ,
38+ dtype : np .dtype | None = None ,
39+ previous_chunks : T_NormalizedChunks | None = None ,
40+ ) -> T_NormalizedChunks :
4241 from cubed .vendor .dask .array .core import normalize_chunks
4342
44- return normalize_chunks (chunks , shape = shape , limit = limit , dtype = dtype , previous_chunks = previous_chunks )
43+ return normalize_chunks (
44+ chunks ,
45+ shape = shape ,
46+ limit = limit ,
47+ dtype = dtype ,
48+ previous_chunks = previous_chunks ,
49+ )
4550
4651 def from_array (self , data : np .ndarray , chunks , ** kwargs ) -> "CubedArray" :
4752 from cubed import from_array
@@ -58,10 +63,7 @@ def from_array(self, data: np.ndarray, chunks, **kwargs) -> "CubedArray":
5863 spec = spec ,
5964 )
6065
61- def rechunk (self , data : "CubedArray" , chunks , ** kwargs ) -> "CubedArray" :
62- return data .rechunk (chunks , ** kwargs )
63-
64- def compute (self , * data : "CubedArray" , ** kwargs ) -> np .ndarray :
66+ def compute (self , * data : "CubedArray" , ** kwargs ) -> tuple [np .ndarray , ...]:
6567 from cubed import compute
6668
6769 return compute (* data , ** kwargs )
@@ -74,14 +76,14 @@ def array_api(self) -> Any:
7476
7577 def reduction (
7678 self ,
77- arr : T_ChunkedArray ,
79+ arr : "CubedArray" ,
7880 func : Callable ,
79- combine_func : Optional [ Callable ] = None ,
80- aggregate_func : Optional [ Callable ] = None ,
81- axis : Optional [ Union [ int , Sequence [int ]]] = None ,
82- dtype : Optional [ np .dtype ] = None ,
81+ combine_func : Callable | None = None ,
82+ aggregate_func : Callable | None = None ,
83+ axis : int | Sequence [int ] | None = None ,
84+ dtype : np .dtype | None = None ,
8385 keepdims : bool = False ,
84- ) -> T_ChunkedArray :
86+ ) -> "CubedArray" :
8587 from cubed .core .ops import reduction
8688
8789 return reduction (
@@ -96,16 +98,21 @@ def reduction(
9698
9799 def map_blocks (
98100 self ,
99- func ,
100- * args ,
101- dtype = None ,
102- chunks = None ,
103- drop_axis = [] ,
104- new_axis = None ,
101+ func : Callable ,
102+ * args : Any ,
103+ dtype : np . typing . DTypeLike | None = None ,
104+ chunks : tuple [ int , ...] | None = None ,
105+ drop_axis : int | Sequence [ int ] | None = None ,
106+ new_axis : int | Sequence [ int ] | None = None ,
105107 ** kwargs ,
106108 ):
107109 from cubed .core .ops import map_blocks
108110
111+ if drop_axis is None :
112+ # TODO should fix this upstream in cubed to match dask
113+ # see https://github.com/pydata/xarray/pull/7019#discussion_r1196729489
114+ drop_axis = []
115+
109116 return map_blocks (
110117 func ,
111118 * args ,
@@ -118,14 +125,14 @@ def map_blocks(
118125
119126 def blockwise (
120127 self ,
121- func ,
122- out_ind ,
128+ func : Callable ,
129+ out_ind : Iterable ,
123130 * args : Any ,
124131 # can't type this as mypy assumes args are all same type, but blockwise args alternate types
125- dtype = None ,
126- adjust_chunks = None ,
127- new_axes = None ,
128- align_arrays = True ,
132+ dtype : np . dtype | None = None ,
133+ adjust_chunks : dict [ Any , Callable ] | None = None ,
134+ new_axes : dict [ Any , int ] | None = None ,
135+ align_arrays : bool = True ,
129136 target_store = None ,
130137 ** kwargs ,
131138 ):
@@ -147,16 +154,16 @@ def blockwise(
147154
148155 def apply_gufunc (
149156 self ,
150- func ,
151- signature ,
152- * args ,
153- axes = None ,
154- axis = None ,
155- keepdims = False ,
156- output_dtypes = None ,
157- output_sizes = None ,
158- vectorize = None ,
159- allow_rechunk = False ,
157+ func : Callable ,
158+ signature : str ,
159+ * args : Any ,
160+ axes : Sequence [ tuple [ int , ...]] | None = None ,
161+ axis : int | None = None ,
162+ keepdims : bool = False ,
163+ output_dtypes : Sequence [ np . typing . DTypeLike ] | None = None ,
164+ output_sizes : dict [ str , int ] | None = None ,
165+ vectorize : bool | None = None ,
166+ allow_rechunk : bool = False ,
160167 ** kwargs ,
161168 ):
162169 if allow_rechunk :
@@ -181,17 +188,19 @@ def apply_gufunc(
181188 )
182189
183190 def unify_chunks (
184- self , * args , ** kwargs
185- ) -> tuple [dict [str , T_Chunks ], list ["CubedArray" ]]:
191+ self ,
192+ * args : Any , # can't type this as mypy assumes args are all same type, but dask unify_chunks args alternate types
193+ ** kwargs ,
194+ ) -> tuple [dict [str , T_NormalizedChunks ], list ["CubedArray" ]]:
186195 from cubed .core import unify_chunks
187196
188197 return unify_chunks (* args , ** kwargs )
189198
190199 def store (
191200 self ,
192201 sources : Union ["CubedArray" , Sequence ["CubedArray" ]],
193- targets : Union [ "ZarrArray" , Sequence [ "ZarrArray" ]] ,
194- ** kwargs : dict [ str , Any ] ,
202+ targets : Any ,
203+ ** kwargs ,
195204 ):
196205 """Used when writing to any backend."""
197206 from cubed .core .ops import store
0 commit comments