Skip to content

Commit acfe5ee

Browse files
Dask fefactoring. Code cleanup.
1 parent 97eed20 commit acfe5ee

17 files changed

Lines changed: 3164 additions & 1537 deletions

File tree

insardev/insardev/Batch.py

Lines changed: 273 additions & 215 deletions
Large diffs are not rendered by default.

insardev/insardev/BatchCore.py

Lines changed: 1249 additions & 655 deletions
Large diffs are not rendered by default.

insardev/insardev/Stack.py

Lines changed: 101 additions & 69 deletions
Large diffs are not rendered by default.

insardev/insardev/Stack_ps.py

Lines changed: 10 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,6 @@ def psfunction(self, device='auto', allow_rechunk=True, debug=False):
149149
# Disable automatic rechunking (manual chunk control)
150150
psf = stack.chunk({'y': 2048, 'x': 2048}).psfunction(allow_rechunk=False)
151151
"""
152-
import dask
153152
import dask.array
154153
import numpy as np
155154
import torch
@@ -184,62 +183,11 @@ def psfunction(self, device='auto', allow_rechunk=True, debug=False):
184183
if not isinstance(slc_data.data, dask.array.Array):
185184
slc_data = slc_data.chunk({'y': 512, 'x': 512})
186185

187-
# Save original spatial chunks for restoring output
188-
original_y_chunks = slc_data.chunks[1] if len(slc_data.chunks) > 1 else None
189-
original_x_chunks = slc_data.chunks[2] if len(slc_data.chunks) > 2 else None
190-
191186
if debug:
192187
print(f'DEBUG: psfunction for {key}: shape={slc_data.shape}, chunks={slc_data.chunks}')
193188

194-
# Calculate target chunks for memory efficiency
195-
from .utils_dask import compute_aligned_chunks_3d
196-
dask_chunk_bytes = dask_chunk_mb * 1024 * 1024
197-
element_bytes = slc_data.dtype.itemsize
198-
target_chunks = compute_aligned_chunks_3d(
199-
slc_data.shape, slc_data.chunks, dask_chunk_bytes, element_bytes,
200-
min_chunk=256, keep_first_dim=True
201-
)
202-
203-
# Get current spatial chunk sizes (use first chunk as representative)
204-
orig_y = slc_data.chunks[1][0] if len(slc_data.chunks) > 1 and slc_data.chunks[1] else slc_data.shape[1]
205-
orig_x = slc_data.chunks[2][0] if len(slc_data.chunks) > 2 and slc_data.chunks[2] else slc_data.shape[2]
206-
207-
# Calculate chunk memory sizes
208-
n_dates = slc_data.shape[0]
209-
orig_chunk_mb = n_dates * orig_y * orig_x * element_bytes / (1024 * 1024)
210-
# target_chunks is tuple of tuples: ((n,), (y1, y2, ...), (x1, x2, ...))
211-
target_y = max(target_chunks[1]) if target_chunks[1] else orig_y
212-
target_x = max(target_chunks[2]) if target_chunks[2] else orig_x
213-
target_chunk_mb = n_dates * target_y * target_x * element_bytes / (1024 * 1024)
214-
215-
# Determine if rechunking is needed (only if original is larger than target)
216-
needs_rechunk = (orig_y > target_y) or (orig_x > target_x)
217-
218-
# Print NOTE about chunks (only once for first burst)
219-
if not note_printed:
220-
chunk_size_note = f"dask.config['array.chunk-size']={dask_chunk_mb} MB"
221-
if needs_rechunk:
222-
if allow_rechunk:
223-
print(f'NOTE psfunction: rechunking from ({n_dates}, {orig_y}, {orig_x}) [{orig_chunk_mb:.0f} MB] '
224-
f'to ({n_dates}, {target_chunks[1]}, {target_chunks[2]}) [{target_chunk_mb:.0f} MB] for {chunk_size_note}')
225-
else:
226-
print(f'NOTE psfunction: chunks ({n_dates}, {orig_y}, {orig_x}) [{orig_chunk_mb:.0f} MB] exceed '
227-
f'{chunk_size_note}, recommended: ({n_dates}, {target_chunks[1]}, {target_chunks[2]}) [{target_chunk_mb:.0f} MB]. '
228-
f'Use allow_rechunk=True or .chunk() manually.')
229-
else:
230-
# Chunks already fit - just confirm, no "optimal" claim
231-
print(f'NOTE psfunction: chunks ({n_dates}, {orig_y}, {orig_x}) [{orig_chunk_mb:.0f} MB] '
232-
f'fit {chunk_size_note}')
233-
note_printed = True
234-
235-
# Apply rechunking if needed and allowed (pass exact chunk tuples)
236-
if needs_rechunk and allow_rechunk:
237-
slc_data = slc_data.chunk({'date': -1, 'y': target_chunks[1], 'x': target_chunks[2]})
238-
if debug:
239-
print(f'DEBUG: after rechunk: chunks={slc_data.chunks}')
240-
else:
241-
# Just ensure date is single chunk
242-
slc_data = slc_data.chunk({'date': -1})
189+
# Merge dates dim, keep input spatial chunks as-is.
190+
slc_data = slc_data.chunk({'date': -1})
243191

244192
# Create wrapper that captures device and debug
245193
def make_wrapper(dev, dbg):
@@ -262,24 +210,16 @@ def process_wrapper(slc_chunk):
262210
# Use xr.apply_ufunc with dask='parallelized' for lazy execution
263211
# Core dim is 'date' (reduction), chunked dims are y, x
264212
# Note: input_core_dims moves 'date' to last axis, wrapper transposes back
265-
# Use GPU annotation to prevent MPS command buffer conflicts
266213
# Provide explicit meta to avoid ComplexWarning when dask infers
267214
# output type from complex input (we intentionally convert to real)
268-
with dask.annotate(resources={'gpu': 1} if device != 'cpu' else {}):
269-
psf_da = xr.apply_ufunc(
270-
wrapper,
271-
slc_data,
272-
input_core_dims=[['date']],
273-
output_core_dims=[[]],
274-
dask='parallelized',
275-
dask_gufunc_kwargs={'meta': np.array((), dtype=np.float32)},
276-
)
277-
278-
# Restore original spatial chunks if we rechunked (pass full tuples)
279-
if allow_rechunk and original_y_chunks is not None:
280-
psf_da = psf_da.chunk({'y': original_y_chunks, 'x': original_x_chunks})
281-
if debug:
282-
print(f'DEBUG: restored output chunks: {psf_da.chunks}')
215+
psf_da = xr.apply_ufunc(
216+
wrapper,
217+
slc_data,
218+
input_core_dims=[['date']],
219+
output_core_dims=[[]],
220+
dask='parallelized',
221+
dask_gufunc_kwargs={'meta': np.array((), dtype=np.float32)},
222+
)
283223

284224
# Assign name to match SLC variable
285225
psf_da.name = var_name

insardev/insardev/Stack_stl.py

Lines changed: 61 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -151,41 +151,71 @@ def _stl(self, data, freq='W', periods=52, robust=False):
151151
n_dates_out = len(dt_periodic)
152152
n_dates_in = data.date.size
153153

154-
# Use rechunk2d for uniform chunk sizes based on memory
155-
# Multiplier accounts for STL internal memory (trend, seasonal, resid, weights, etc.)
156-
# Effective memory: 8 * n_dates_in * 16 bytes (complex128) per pixel
157-
from .utils_dask import rechunk2d
158-
mem_per_pixel = 8 * n_dates_in * 16 # complex128 = 16 bytes
159-
optimal = rechunk2d((data.y.size, data.x.size), element_bytes=mem_per_pixel)
160-
chunks_y, chunks_x = optimal['y'], optimal['x']
161-
162-
# Rechunk: all dates together (-1), auto-chunked y,x
163-
first_dim = data.dims[0]
164-
data = data.chunk({first_dim: -1, 'y': chunks_y, 'x': chunks_x})
165-
166-
# Use blockwise to avoid embedding large arrays in the graph
167-
def process_block(data_block):
168-
# data_block: (n_dates, y_chunk, x_chunk)
169-
# transpose to (y, x, n_dates) for vectorized STL
170-
data_transposed = data_block.transpose(1, 2, 0)
154+
# No rechunk on dim 0 — pass per-date delayed lists to kernel.
155+
data_dask = data.data
156+
157+
y_chunks = data_dask.chunks[1]
158+
x_chunks = data_dask.chunks[2]
159+
y_breaks = [0] + list(np.cumsum(y_chunks))
160+
x_breaks = [0] + list(np.cumsum(x_chunks))
161+
162+
def process_chunks(data_chunks):
163+
import math
164+
from .utils_dask import get_dask_chunk_size_mb
165+
chunks = [np.asarray(c) for c in data_chunks]
166+
ny, nx = chunks[0].shape[1], chunks[0].shape[2]
167+
n_dates_in_local = sum(c.shape[0] for c in chunks)
168+
result = np.empty((3, n_dates_out, ny, nx), dtype=np.float32)
171169
vec_stl = np.vectorize(
172170
lambda ts: utils_stl.stl1d(ts, dt, dt_periodic, periods, robust),
173171
signature='(n)->(m),(m),(m)'
174172
)
175-
# result: (3, y, x, n_dates_out) after asarray
176-
block = np.asarray(vec_stl(data_transposed))
177-
del vec_stl, data_transposed
178-
# transpose to (3, n_dates_out, y, x)
179-
return block.transpose(0, 3, 1, 2).astype(np.float32)
180-
181-
data_dask = data.data
182-
models = dask.array.map_blocks(
183-
process_block, data_dask,
184-
dtype=np.float32,
185-
drop_axis=0,
186-
new_axis=[0, 1],
187-
chunks=(3, n_dates_out) + data_dask.chunks[1:],
188-
)
173+
# Calculate sub-tile size from dask chunk budget.
174+
# Per sub-tile memory: input (n_dates_in × sub_pixels × 4) + output (3 × n_dates_out × sub_pixels × 4)
175+
per_pixel_bytes = (n_dates_in_local + 3 * n_dates_out) * 4
176+
budget_bytes = int(get_dask_chunk_size_mb() * 1024 * 1024)
177+
max_sub_pixels = max(256, budget_bytes // max(1, per_pixel_bytes))
178+
sub_side = int(math.sqrt(max_sub_pixels))
179+
sub_h = min(sub_side, ny)
180+
sub_w = min(sub_side, nx)
181+
for ty0 in range(0, ny, sub_h):
182+
ty1 = min(ty0 + sub_h, ny)
183+
for tx0 in range(0, nx, sub_w):
184+
tx1 = min(tx0 + sub_w, nx)
185+
if len(chunks) == 1:
186+
tile = chunks[0][:, ty0:ty1, tx0:tx1]
187+
else:
188+
tile = np.concatenate(
189+
[c[:, ty0:ty1, tx0:tx1] for c in chunks], axis=0
190+
)
191+
# (n_dates, sub_h, sub_w) -> (sub_h, sub_w, n_dates)
192+
tile_t = tile.transpose(1, 2, 0)
193+
del tile
194+
# result: (3, sub_h, sub_w, n_dates_out) after asarray
195+
block = np.asarray(vec_stl(tile_t))
196+
del tile_t
197+
# (3, sub_h, sub_w, n_dates_out) -> (3, n_dates_out, sub_h, sub_w)
198+
result[:, :, ty0:ty1, tx0:tx1] = block.transpose(0, 3, 1, 2)
199+
del block
200+
del vec_stl
201+
return result
202+
203+
blocks_rows = []
204+
for bj in range(len(y_breaks) - 1):
205+
y0, y1 = y_breaks[bj], y_breaks[bj + 1]
206+
blocks_row = []
207+
for bk in range(len(x_breaks) - 1):
208+
x0, x1 = x_breaks[bk], x_breaks[bk + 1]
209+
td_list = data_dask[:, y0:y1, x0:x1].to_delayed().ravel().tolist()
210+
block = dask.array.from_delayed(
211+
dask.delayed(process_chunks)(td_list),
212+
shape=(3, n_dates_out, y1 - y0, x1 - x0),
213+
dtype=np.float32,
214+
)
215+
blocks_row.append(block)
216+
blocks_rows.append(dask.array.concatenate(blocks_row, axis=3))
217+
218+
models = dask.array.concatenate(blocks_rows, axis=2)
189219

190220
coords = {'date': dt_periodic.astype('datetime64[ns]'), 'y': data.y, 'x': data.x}
191221

@@ -194,9 +224,6 @@ def process_block(data_block):
194224
keys_vars = {}
195225
for varidx, varname in enumerate(varnames):
196226
var_data = models[varidx]
197-
# Rechunk to date=1 for efficient per-slice downstream operations (preserve spatial chunks)
198-
if hasattr(var_data, 'rechunk'):
199-
var_data = var_data.rechunk({0: 1})
200227
keys_vars[varname] = xr.DataArray(var_data, coords=coords)
201228
model = xr.Dataset({**keys_vars})
202229
del models

0 commit comments

Comments
 (0)