Skip to content

Commit 0bd9436

Browse files
Unify and clenup code, make Python 3.14 ready
1 parent a3d4b22 commit 0bd9436

6 files changed

Lines changed: 335 additions & 193 deletions

File tree

insardev/insardev/Batch.py

Lines changed: 130 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -77,14 +77,12 @@ def _apply_velocity_block(data_block, times_years, min_valid, device):
7777
Returns
7878
-------
7979
np.ndarray
80-
3D array (2, chunk_y, chunk_x) where [0] is velocity, [1] is intercept
80+
3D array (2, chunk_y, chunk_x) where [0] is velocity, [1] is RMSE
8181
"""
8282
import numpy as np
83-
# Convert times_years back to numpy array (dask serialization may convert to list)
8483
times_years = np.asarray(times_years, dtype=np.float32)
85-
vel, intercept = Batch._velocity_torch(data_block, times_years, min_valid=min_valid, device=device)
86-
# Stack velocity and intercept along a new first dimension
87-
return np.stack([vel, intercept], axis=0).astype(np.float32)
84+
vel, rmse = Batch._velocity_torch(data_block, times_years, min_valid=min_valid, device=device)
85+
return np.stack([vel, rmse], axis=0).astype(np.float32)
8886

8987

9088
class Batch(BatchCore):
@@ -423,58 +421,63 @@ def _velocity_torch(data, times_years, min_valid=3, device='auto', debug=False):
423421
except Exception:
424422
coeffs[px] = float('nan')
425423

426-
intercept = coeffs[:, 0] # c0
427424
velocity = coeffs[:, 1] # c1 = velocity (unbiased)
428425

426+
# Compute RMSE: residuals = y - A @ coeffs
427+
# A @ coeffs per pixel: sum over basis functions
428+
predicted = torch.zeros_like(y_filled)
429+
for i, c in enumerate(cols):
430+
predicted += c * coeffs[:, i].unsqueeze(0) # (n_times, n_pixels)
431+
residuals = (y_filled - predicted) * w # zero out NaN positions
432+
rmse = torch.sqrt((residuals ** 2).sum(dim=0) / valid_count.clamp(min=1))
433+
429434
# Mask pixels with insufficient valid points
430435
valid_mask = valid_count >= min_valid
431436
velocity = torch.where(valid_mask, velocity, torch.tensor(float('nan'), device=dev))
432-
intercept = torch.where(valid_mask, intercept, torch.tensor(float('nan'), device=dev))
437+
rmse = torch.where(valid_mask, rmse, torch.tensor(float('nan'), device=dev))
433438

434439
# Reshape back
435440
vel_np = velocity.cpu().numpy()
436-
int_np = intercept.cpu().numpy()
441+
rmse_np = rmse.cpu().numpy()
437442
if len(original_shape) == 3:
438443
vel_np = vel_np.reshape(original_shape[1], original_shape[2])
439-
int_np = int_np.reshape(original_shape[1], original_shape[2])
444+
rmse_np = rmse_np.reshape(original_shape[1], original_shape[2])
440445

441446
# Cleanup GPU memory
442447
if dev.type == 'mps':
443448
torch.mps.empty_cache()
444449
elif dev.type == 'cuda':
445450
torch.cuda.empty_cache()
446451

447-
return vel_np, int_np
452+
return vel_np, rmse_np
448453

449-
def velocity(self, min_valid=5, device='auto', debug=False) -> tuple["Batch", "Batch"]:
454+
def velocity(self, min_valid=5, device='auto', debug=False) -> "Batches":
450455
"""
451-
Compute velocity (linear trend) and intercept from time series.
456+
Compute velocity and RMSE from time series.
452457
453-
Calculates the slope per year and intercept for each pixel using linear
454-
regression on the 'date' dimension. Uses PyTorch for GPU acceleration.
458+
Harmonic regression (linear + seasonal) per pixel on the 'date' dimension.
459+
Uses PyTorch for GPU acceleration.
455460
456461
Parameters
457462
----------
458463
min_valid : int, optional
459-
Minimum number of valid (non-NaN) data points required to compute
460-
velocity. Pixels with fewer valid points will be set to NaN.
461-
Default is 3.
464+
Minimum number of valid (non-NaN) data points required.
465+
Default is 5.
462466
device : str, optional
463467
PyTorch device: 'auto' (default), 'cuda', 'mps', or 'cpu'.
464-
'auto' uses GPU if Dask client has resources={'gpu': 1}.
465468
debug : bool, optional
466469
Print debug information. Default False.
467470
468471
Returns
469472
-------
470-
tuple[Batch, Batch]
471-
(velocity, intercept) - velocity is slope per year, intercept is
472-
the y-value at t=0 (first date). Both are lazy Batch objects.
473+
Batches[Batch, Batch]
474+
(velocity, rmse) — velocity is slope per year, RMSE is residual
475+
root-mean-square error. Both are lazy Batch objects.
473476
474477
Examples
475478
--------
476-
>>> displacement = stack.lstsq(detrend, corr)
477-
>>> velocity, intercept = displacement.velocity()
479+
>>> velocity, rmse = displacement.velocity()
480+
>>> vel, rmse = detrend0.velocity().displacement_los(transform).compute()
478481
"""
479482
import dask
480483
import dask.array as da
@@ -497,10 +500,10 @@ def velocity(self, min_valid=5, device='auto', debug=False) -> tuple["Batch", "B
497500
crs = self.crs
498501

499502
vel_results = {}
500-
int_results = {}
503+
rmse_results = {}
501504
for key, ds in self.items():
502505
vel_vars = {}
503-
int_vars = {}
506+
rmse_vars = {}
504507
# Filter for spatial variables (with y, x dims) - excludes converted attributes
505508
for var in [v for v in ds.data_vars
506509
if 'y' in ds[v].dims and 'x' in ds[v].dims]:
@@ -542,33 +545,28 @@ def process_block(data_block):
542545
chunks=(2,) + data_dask.chunks[1:],
543546
)
544547

545-
# Unpack velocity (index 0) and intercept (index 1)
546548
vel_da = xr.DataArray(
547-
result_dask[0],
548-
dims=['y', 'x'],
549+
result_dask[0], dims=['y', 'x'],
549550
coords={'y': data_arr.y, 'x': data_arr.x}
550551
)
551-
int_da = xr.DataArray(
552-
result_dask[1],
553-
dims=['y', 'x'],
552+
rmse_da = xr.DataArray(
553+
result_dask[1], dims=['y', 'x'],
554554
coords={'y': data_arr.y, 'x': data_arr.x}
555555
)
556-
557556
vel_vars[var] = vel_da
558-
int_vars[var] = int_da
557+
rmse_vars[var] = rmse_da
559558

560559
vel_ds = xr.Dataset(vel_vars)
561560
vel_ds.attrs = ds.attrs
562-
int_ds = xr.Dataset(int_vars)
563-
int_ds.attrs = ds.attrs
564-
# Preserve CRS
561+
rmse_ds = xr.Dataset(rmse_vars)
562+
rmse_ds.attrs = ds.attrs
565563
if crs is not None:
566564
vel_ds = vel_ds.rio.write_crs(crs)
567-
int_ds = int_ds.rio.write_crs(crs)
565+
rmse_ds = rmse_ds.rio.write_crs(crs)
568566
vel_results[key] = vel_ds
569-
int_results[key] = int_ds
567+
rmse_results[key] = rmse_ds
570568

571-
return Batch(vel_results), Batch(int_results)
569+
return Batches((Batch(vel_results), Batch(rmse_results)))
572570

573571
def incidence(self) -> "Batch":
574572
"""Compute incidence angle from azi, rng, ele, and radar geometry parameters.
@@ -1752,7 +1750,7 @@ def _velocity_block(data_block, weight_block=None,
17521750
vel_results[burst_id] = vel_ds
17531751
rmse_results[burst_id] = rmse_ds
17541752

1755-
return Batch(vel_results), Batch(rmse_results)
1753+
return Batches((Batch(vel_results), Batch(rmse_results)))
17561754

17571755
def backscatter(self, *args, **kwargs):
17581756
"""
@@ -1781,6 +1779,43 @@ def conj(self, **kwargs):
17811779
"""intfs.iexp().conj() for np.exp(-1j * intfs)"""
17821780
return self.map_da(lambda da: xr.ufuncs.conj(da), **kwargs)
17831781

1782+
def pairs(self, pairs):
1783+
"""Select date pairs from per-date data, returning ref and rep stacks.
1784+
1785+
Parameters
1786+
----------
1787+
pairs : array-like (n_pairs, 2)
1788+
Pairs as [[ref_date, rep_date], ...]. Dates as datetime64 or indices.
1789+
1790+
Returns
1791+
-------
1792+
tuple (ref, rep)
1793+
Two BatchComplex with 'pair' dimension instead of 'date'.
1794+
"""
1795+
import numpy as np
1796+
pairs = np.asarray(pairs)
1797+
ref_dates = pairs[:, 0]
1798+
rep_dates = pairs[:, 1]
1799+
1800+
# Map dates to integer indices (match by day to handle precision differences)
1801+
key0 = list(self.keys())[0]
1802+
date_coords = self[key0].coords['date'].values
1803+
# Truncate to day precision for matching
1804+
date_days = np.array(date_coords, dtype='datetime64[D]')
1805+
date_to_idx = {d: i for i, d in enumerate(date_days)}
1806+
ref_idx = [date_to_idx[np.datetime64(d, 'D')] for d in ref_dates]
1807+
rep_idx = [date_to_idx[np.datetime64(d, 'D')] for d in rep_dates]
1808+
1809+
# Select, rename date→pair, and assign pair coords matching the caller
1810+
n_pairs = len(ref_idx)
1811+
pair_coords = np.arange(n_pairs)
1812+
screen_ref = self.isel(date=ref_idx).rename(date='pair').map(
1813+
lambda ds: ds.assign_coords(pair=pair_coords))
1814+
screen_rep = self.isel(date=rep_idx).rename(date='pair').map(
1815+
lambda ds: ds.assign_coords(pair=pair_coords))
1816+
1817+
return screen_ref, screen_rep
1818+
17841819
def lstsq_baseline(self, weight=None, baseline='BPR', stride=1, debug=False):
17851820
"""
17861821
Decompose per-pair complex trend into network-consistent per-date model.
@@ -1824,13 +1859,21 @@ def lstsq_baseline(self, weight=None, baseline='BPR', stride=1, debug=False):
18241859
has_bpr = baseline is not None and baseline in ds.coords
18251860
bpr_values = ds.coords[baseline].values.astype(np.float64) if has_bpr else None
18261861

1862+
# Compute unique dates (same logic as lstsq_baseline_array)
1863+
ns_per_day = 86400 * 1e9
1864+
ref_days = ref_values.astype(np.float64) / ns_per_day
1865+
rep_days = rep_values.astype(np.float64) / ns_per_day
1866+
unique_days = np.unique(np.concatenate([ref_days, rep_days]))
1867+
n_dates = len(unique_days)
1868+
# Map unique_days back to datetime64
1869+
date_coords = (unique_days * ns_per_day).astype('datetime64[ns]')
1870+
18271871
weight_ds = weight[key] if weight is not None else None
18281872

18291873
result_ds = {}
18301874
for pol in pols:
18311875
data_da = ds[pol]
18321876
data_dask = data_da.data
1833-
n_pairs = data_dask.shape[0]
18341877

18351878
weight_da = weight_ds[pol] if weight_ds is not None else None
18361879

@@ -1846,26 +1889,31 @@ def _block(data_block, weight_block=None,
18461889
if weight_da is not None:
18471890
weight_dask = weight_da.data
18481891
result_dask = da.blockwise(
1849-
_block, 'pyx',
1892+
_block, 'dyx',
18501893
data_dask, 'pyx',
18511894
weight_dask, 'pyx',
1852-
new_axes={'p': n_pairs},
1895+
new_axes={'d': n_dates},
18531896
concatenate=True,
18541897
dtype=np.complex64,
18551898
meta=np.empty((0, 0, 0), dtype=np.complex64),
18561899
)
18571900
else:
18581901
result_dask = da.blockwise(
1859-
_block, 'pyx',
1902+
_block, 'dyx',
18601903
data_dask, 'pyx',
1861-
new_axes={'p': n_pairs},
1904+
new_axes={'d': n_dates},
18621905
concatenate=True,
18631906
dtype=np.complex64,
18641907
meta=np.empty((0, 0, 0), dtype=np.complex64),
18651908
)
18661909

1910+
# Per-date output with date coordinates
18671911
result_ds[pol] = xr.DataArray(
1868-
result_dask, dims=data_da.dims, coords=data_da.coords
1912+
result_dask,
1913+
dims=['date', 'y', 'x'],
1914+
coords={'date': date_coords,
1915+
'y': data_da.coords['y'],
1916+
'x': data_da.coords['x']}
18691917
)
18701918

18711919
result[key] = xr.Dataset(result_ds, attrs=ds.attrs)
@@ -2100,6 +2148,23 @@ def goldstein(self, corr: BatchUnit, window: int | dict[str, int] = 32, threshol
21002148
return type(self)(result)
21012149

21022150

2151+
def _subtract_date_from_pair(first, second):
2152+
"""Subtract per-date atmospheric screens from per-pair data.
2153+
2154+
Uses BatchComplex.pairs() to select ref/rep screens,
2155+
then: result = data * conj(screen_ref) * screen_rep
2156+
"""
2157+
import numpy as np
2158+
2159+
key0 = list(first.keys())[0]
2160+
ref_dates = first[key0].coords['ref'].values
2161+
rep_dates = first[key0].coords['rep'].values
2162+
pairs = np.column_stack([ref_dates, rep_dates])
2163+
2164+
screen_ref, screen_rep = second.pairs(pairs)
2165+
return first * screen_ref.conj() * screen_rep
2166+
2167+
21032168
class Batches(tuple):
21042169
"""
21052170
A tuple-like container for multiple Batch objects that allows chained operations.
@@ -2164,7 +2229,7 @@ def correlation(self) -> 'BatchUnit | None':
21642229

21652230
def snapshot(self, store: str | None = None, storage_options: dict[str, str] | None = None,
21662231
caption: str | None = None,
2167-
n_chunks: int = 1, debug: bool = False, **kwargs):
2232+
debug: bool = False, **kwargs):
21682233
"""Save or open a Batches snapshot.
21692234
21702235
When called on a Batches with data, saves all batches to Zarr store.
@@ -2178,8 +2243,6 @@ def snapshot(self, store: str | None = None, storage_options: dict[str, str] | N
21782243
Storage options for cloud stores.
21792244
caption : str, optional
21802245
Progress bar caption.
2181-
n_chunks : int
2182-
Spatial chunks per worker per batch. Default 4.
21832246
debug : bool
21842247
Print debug information.
21852248
@@ -2200,11 +2263,11 @@ def snapshot(self, store: str | None = None, storage_options: dict[str, str] | N
22002263
if len(self) == 0:
22012264
result = utils_io.snapshot(store=store, storage_options=storage_options,
22022265
caption=caption or 'Opening...',
2203-
n_chunks=n_chunks, debug=debug)
2266+
debug=debug)
22042267
else:
22052268
result = utils_io.snapshot(*self, store=store, storage_options=storage_options,
22062269
caption=caption or 'Snapshotting...',
2207-
n_chunks=n_chunks, debug=debug, wrapper=Batches)
2270+
debug=debug, wrapper=Batches)
22082271

22092272
if isinstance(result, Batches):
22102273
return result
@@ -3070,7 +3133,20 @@ def subtract(self):
30703133
first = self[0]
30713134
second = self[second_idx]
30723135

3073-
if isinstance(first, BatchComplex):
3136+
# Check if second is per-date (from lstsq_baseline) and first is per-pair
3137+
is_date_to_pair = False
3138+
for key in first.keys():
3139+
first_ds = first[key]
3140+
second_ds = second[key]
3141+
first_pol = [v for v in first_ds.data_vars if 'y' in first_ds[v].dims][0]
3142+
second_pol = [v for v in second_ds.data_vars if 'y' in second_ds[v].dims][0]
3143+
if 'pair' in first_ds[first_pol].dims and 'date' in second_ds[second_pol].dims:
3144+
is_date_to_pair = True
3145+
break
3146+
3147+
if is_date_to_pair:
3148+
result = _subtract_date_from_pair(first, second)
3149+
elif isinstance(first, BatchComplex):
30743150
result = first * second.conj()
30753151
else:
30763152
result = first - second
@@ -3325,11 +3401,9 @@ def velocity(self, max_refine=3, **kwargs):
33253401
weight = self[1] if len(self) >= 2 and isinstance(self[1], BatchUnit) else None
33263402

33273403
if isinstance(phase, BatchComplex):
3328-
vel, rmse = phase.velocity(weight=weight, max_refine=max_refine, **kwargs)
3329-
return Batches((vel, rmse))
3404+
return phase.velocity(weight=weight, max_refine=max_refine, **kwargs)
33303405
else:
3331-
vel, intercept = phase.velocity(**kwargs)
3332-
return Batches((vel, intercept))
3406+
return phase.velocity(**kwargs)
33333407

33343408
def rmse(self, solution):
33353409
"""RMSE of phase vs solution, using correlation weight if present.

insardev/insardev/BatchCore.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3438,9 +3438,6 @@ def chunk1d(self, budget=None, chunks=-1, p2p=False):
34383438
var_chunks = {'y': optimal['y'], 'x': optimal['x']}
34393439
rechunked = arr.chunk(var_chunks)
34403440
if hasattr(rechunked.data, 'dask'):
3441-
# Full fusion only when graph has enough layers for linear chains.
3442-
# Rechunk-only graphs (≤3 layers) have no fusible chains —
3443-
# skip expensive ensure_dict + fuse_linear (22s on 1M keys, 0% reduction).
34443441
n_layers = len(rechunked.data.__dask_graph__().layers)
34453442
with dask.config.set({'optimization.fuse.active': n_layers > 3}):
34463443
(rechunked.data,) = dask.optimize(rechunked.data)
@@ -3753,8 +3750,8 @@ def downsample(self, new_spacing: tuple[float, float] | float | int, debug: bool
37533750
return result
37543751

37553752
def save(self, store: str, storage_options: dict[str, str] | None = None,
3756-
caption: str | None = 'Saving...', n_chunks: int = 1, debug=False):
3757-
return utils_io.save(self, store=store, storage_options=storage_options, caption=caption, n_chunks=n_chunks, debug=debug)
3753+
caption: str | None = 'Saving...', debug=False):
3754+
return utils_io.save(self, store=store, storage_options=storage_options, caption=caption, debug=debug)
37583755

37593756
def open(self, store: str, storage_options: dict[str, str] | None = None, n_jobs: int = -1, debug=False):
37603757
data = utils_io.open(store=store, storage_options=storage_options, n_jobs=n_jobs, debug=debug)
@@ -3763,11 +3760,11 @@ def open(self, store: str, storage_options: dict[str, str] | None = None, n_jobs
37633760
return data
37643761

37653762
def snapshot(self, store: str | None = None, storage_options: dict[str, str] | None = None,
3766-
caption: str | None = 'Snapshotting...', n_chunks: int = 1,
3763+
caption: str | None = 'Snapshotting...',
37673764
debug=False, **kwargs):
37683765
# Only save if this batch has data; otherwise just open existing store
37693766
if len(self) > 0:
3770-
utils_io.save(self, store=store, storage_options=storage_options, caption=caption, n_chunks=n_chunks,
3767+
utils_io.save(self, store=store, storage_options=storage_options, caption=caption,
37713768
debug=debug)
37723769
return utils_io.open(store=store, storage_options=storage_options,
37733770
n_jobs=-1, debug=debug)

0 commit comments

Comments
 (0)