@@ -151,41 +151,71 @@ def _stl(self, data, freq='W', periods=52, robust=False):
151151 n_dates_out = len (dt_periodic )
152152 n_dates_in = data .date .size
153153
154- # Use rechunk2d for uniform chunk sizes based on memory
155- # Multiplier accounts for STL internal memory (trend, seasonal, resid, weights, etc.)
156- # Effective memory: 8 * n_dates_in * 16 bytes (complex128) per pixel
157- from .utils_dask import rechunk2d
158- mem_per_pixel = 8 * n_dates_in * 16 # complex128 = 16 bytes
159- optimal = rechunk2d ((data .y .size , data .x .size ), element_bytes = mem_per_pixel )
160- chunks_y , chunks_x = optimal ['y' ], optimal ['x' ]
161-
162- # Rechunk: all dates together (-1), auto-chunked y,x
163- first_dim = data .dims [0 ]
164- data = data .chunk ({first_dim : - 1 , 'y' : chunks_y , 'x' : chunks_x })
165-
166- # Use blockwise to avoid embedding large arrays in the graph
167- def process_block (data_block ):
168- # data_block: (n_dates, y_chunk, x_chunk)
169- # transpose to (y, x, n_dates) for vectorized STL
170- data_transposed = data_block .transpose (1 , 2 , 0 )
154+ # No rechunk on dim 0 — pass per-date delayed lists to kernel.
155+ data_dask = data .data
156+
157+ y_chunks = data_dask .chunks [1 ]
158+ x_chunks = data_dask .chunks [2 ]
159+ y_breaks = [0 ] + list (np .cumsum (y_chunks ))
160+ x_breaks = [0 ] + list (np .cumsum (x_chunks ))
161+
162+ def process_chunks (data_chunks ):
163+ import math
164+ from .utils_dask import get_dask_chunk_size_mb
165+ chunks = [np .asarray (c ) for c in data_chunks ]
166+ ny , nx = chunks [0 ].shape [1 ], chunks [0 ].shape [2 ]
167+ n_dates_in_local = sum (c .shape [0 ] for c in chunks )
168+ result = np .empty ((3 , n_dates_out , ny , nx ), dtype = np .float32 )
171169 vec_stl = np .vectorize (
172170 lambda ts : utils_stl .stl1d (ts , dt , dt_periodic , periods , robust ),
173171 signature = '(n)->(m),(m),(m)'
174172 )
175- # result: (3, y, x, n_dates_out) after asarray
176- block = np .asarray (vec_stl (data_transposed ))
177- del vec_stl , data_transposed
178- # transpose to (3, n_dates_out, y, x)
179- return block .transpose (0 , 3 , 1 , 2 ).astype (np .float32 )
180-
181- data_dask = data .data
182- models = dask .array .map_blocks (
183- process_block , data_dask ,
184- dtype = np .float32 ,
185- drop_axis = 0 ,
186- new_axis = [0 , 1 ],
187- chunks = (3 , n_dates_out ) + data_dask .chunks [1 :],
188- )
173+ # Calculate sub-tile size from dask chunk budget.
174+ # Per sub-tile memory: input (n_dates_in × sub_pixels × 4) + output (3 × n_dates_out × sub_pixels × 4)
175+ per_pixel_bytes = (n_dates_in_local + 3 * n_dates_out ) * 4
176+ budget_bytes = int (get_dask_chunk_size_mb () * 1024 * 1024 )
177+ max_sub_pixels = max (256 , budget_bytes // max (1 , per_pixel_bytes ))
178+ sub_side = int (math .sqrt (max_sub_pixels ))
179+ sub_h = min (sub_side , ny )
180+ sub_w = min (sub_side , nx )
181+ for ty0 in range (0 , ny , sub_h ):
182+ ty1 = min (ty0 + sub_h , ny )
183+ for tx0 in range (0 , nx , sub_w ):
184+ tx1 = min (tx0 + sub_w , nx )
185+ if len (chunks ) == 1 :
186+ tile = chunks [0 ][:, ty0 :ty1 , tx0 :tx1 ]
187+ else :
188+ tile = np .concatenate (
189+ [c [:, ty0 :ty1 , tx0 :tx1 ] for c in chunks ], axis = 0
190+ )
191+ # (n_dates, sub_h, sub_w) -> (sub_h, sub_w, n_dates)
192+ tile_t = tile .transpose (1 , 2 , 0 )
193+ del tile
194+ # result: (3, sub_h, sub_w, n_dates_out) after asarray
195+ block = np .asarray (vec_stl (tile_t ))
196+ del tile_t
197+ # (3, sub_h, sub_w, n_dates_out) -> (3, n_dates_out, sub_h, sub_w)
198+ result [:, :, ty0 :ty1 , tx0 :tx1 ] = block .transpose (0 , 3 , 1 , 2 )
199+ del block
200+ del vec_stl
201+ return result
202+
203+ blocks_rows = []
204+ for bj in range (len (y_breaks ) - 1 ):
205+ y0 , y1 = y_breaks [bj ], y_breaks [bj + 1 ]
206+ blocks_row = []
207+ for bk in range (len (x_breaks ) - 1 ):
208+ x0 , x1 = x_breaks [bk ], x_breaks [bk + 1 ]
209+ td_list = data_dask [:, y0 :y1 , x0 :x1 ].to_delayed ().ravel ().tolist ()
210+ block = dask .array .from_delayed (
211+ dask .delayed (process_chunks )(td_list ),
212+ shape = (3 , n_dates_out , y1 - y0 , x1 - x0 ),
213+ dtype = np .float32 ,
214+ )
215+ blocks_row .append (block )
216+ blocks_rows .append (dask .array .concatenate (blocks_row , axis = 3 ))
217+
218+ models = dask .array .concatenate (blocks_rows , axis = 2 )
189219
190220 coords = {'date' : dt_periodic .astype ('datetime64[ns]' ), 'y' : data .y , 'x' : data .x }
191221
@@ -194,9 +224,6 @@ def process_block(data_block):
194224 keys_vars = {}
195225 for varidx , varname in enumerate (varnames ):
196226 var_data = models [varidx ]
197- # Rechunk to date=1 for efficient per-slice downstream operations (preserve spatial chunks)
198- if hasattr (var_data , 'rechunk' ):
199- var_data = var_data .rechunk ({0 : 1 })
200227 keys_vars [varname ] = xr .DataArray (var_data , coords = coords )
201228 model = xr .Dataset ({** keys_vars })
202229 del models
0 commit comments