Skip to content

Commit 97eed20

Browse files
Refactor detrending functions
1 parent 954be39 commit 97eed20

6 files changed

Lines changed: 1047 additions & 1154 deletions

File tree

insardev/insardev/Batch.py

Lines changed: 126 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -966,6 +966,24 @@ def __init__(self, mapping: dict[str, xr.Dataset] | Stack | None = None, wrap: b
966966
def wrap(data):
967967
return np.mod(data + np.pi, 2 * np.pi) - np.pi
968968

969+
def trend1d(self, *args, **kwargs):
970+
raise TypeError(
971+
"trend1d() does not support wrapped phase (BatchWrap). "
972+
"Use BatchComplex for complex phase fitting, or unwrap first for real polynomial fitting."
973+
)
974+
975+
def trend2d(self, *args, **kwargs):
976+
raise TypeError(
977+
"trend2d() does not support wrapped phase (BatchWrap). "
978+
"Use BatchComplex for complex phase fitting, or unwrap first for real polynomial fitting."
979+
)
980+
981+
def trend1d_pairs(self, *args, **kwargs):
982+
raise TypeError(
983+
"trend1d_pairs() does not support wrapped phase (BatchWrap). "
984+
"Use BatchComplex for complex phase fitting, or unwrap first."
985+
)
986+
969987
def __add__(self, other: Batch):
970988
keys = self.keys()
971989
return type(self)({k: (self[k] + other[k] if k in other else self[k]) for k in keys})
@@ -1775,6 +1793,24 @@ class Batches(tuple):
17751793
def __new__(cls, batches=()):
17761794
return super().__new__(cls, batches)
17771795

1796+
@staticmethod
1797+
def _preserve_nonspatial(source, target):
1798+
"""Copy non-spatial variables (e.g. BPR) from source to target batch."""
1799+
import dask.array as da
1800+
for key in source:
1801+
src_ds = source[key]
1802+
tgt_ds = target[key]
1803+
extra = {}
1804+
for v in src_ds.data_vars:
1805+
if v not in tgt_ds.data_vars:
1806+
var = src_ds[v]
1807+
if not isinstance(var.data, da.Array):
1808+
var = var.chunk()
1809+
extra[v] = var
1810+
if extra:
1811+
target[key] = tgt_ds.assign(extra)
1812+
return target
1813+
17781814
def snapshot(self, store: str | None = None, storage_options: dict[str, str] | None = None,
17791815
caption: str | None = None, allow_rechunk: bool = False,
17801816
n_jobs: int = 1, debug: bool = False):
@@ -2334,11 +2370,14 @@ def unwrap1d(self, device='auto', debug=False, **kwargs):
23342370
# Delegate to BatchWrap.unwrap1d
23352371
return phase.unwrap1d(weight=weight, device=device, debug=debug, **kwargs)
23362372

2337-
def trend2d(self, transform, degree=1, device='auto', debug=False):
2373+
def detrend2d(self, transform, degree=1, device='auto', debug=False):
23382374
"""
2339-
Compute 2D polynomial trend (ramp) from phase.
2375+
Detrend 2D polynomial trend and return Batches with detrended data.
23402376
2341-
Expects Batches with [Batch/BatchWrap (phase), BatchUnit (weight, optional)].
2377+
Expects Batches from interferogram(): [BatchComplex (intf), BatchUnit (corr)].
2378+
2379+
For complex input (BatchComplex): detrended = intf * trend.conj()
2380+
For real input (Batch): detrended = phase - trend
23422381
23432382
Parameters
23442383
----------
@@ -2353,44 +2392,53 @@ def trend2d(self, transform, degree=1, device='auto', debug=False):
23532392
23542393
Returns
23552394
-------
2356-
Batch
2357-
Trend surface.
2395+
Batches
2396+
Batches with [detrended_phase, weight] preserving original types.
23582397
23592398
Examples
23602399
--------
2361-
>>> phase, corr = stack.pairs(baseline.tolist()).phasediff(wavelength=30).angle()
2362-
>>> trend = Batches([phase, corr]).trend2d(stack.transform())
2363-
>>> detrended = phase - trend
2400+
>>> # Complex interferogram detrending (chained)
2401+
>>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend2d(transform)
23642402
"""
23652403
if len(self) < 1:
2366-
raise ValueError("trend2d() requires Batches with at least 1 element: [phase]")
2404+
raise ValueError("detrend2d() requires Batches with at least 1 element: [phase]")
23672405

23682406
phase = self[0]
23692407
weight = self[1] if len(self) >= 2 and isinstance(self[1], BatchUnit) else None
23702408

2371-
if not isinstance(phase, (Batch, BatchWrap, BatchComplex)):
2372-
raise TypeError(f"First element must be Batch, BatchWrap, or BatchComplex, got {type(phase).__name__}")
2409+
if not isinstance(phase, (Batch, BatchComplex)):
2410+
raise TypeError(f"First element must be Batch or BatchComplex, got {type(phase).__name__}")
23732411

2374-
# Delegate to BatchCore.trend2d
2375-
return phase.trend2d(transform, weight=weight, degree=degree, device=device, debug=debug)
2412+
trend = phase.trend2d(transform, weight=weight, degree=degree, device=device, debug=debug)
23762413

2377-
def detrend2d(self, transform, degree=1, device='auto', debug=False):
2414+
if isinstance(phase, BatchComplex):
2415+
detrended = phase * trend.conj()
2416+
else:
2417+
detrended = phase - trend
2418+
2419+
# Preserve non-spatial variables (e.g. BPR) that may be dropped by arithmetic
2420+
detrended = Batches._preserve_nonspatial(phase, detrended)
2421+
2422+
# Rebuild Batches preserving all original elements except first
2423+
elements = [detrended] + list(self[1:])
2424+
return Batches(elements)
2425+
2426+
def detrend1d(self, baseline='BPR', degree=1, device='auto', debug=False):
23782427
"""
2379-
Detrend 2D polynomial trend and return Batches with detrended data.
2428+
Detrend 1D polynomial trend along perpendicular baseline and return Batches.
23802429
2381-
Expects Batches from interferogram(): [BatchComplex (intf), BatchUnit (corr)]
2382-
or from angle(): [BatchWrap (phase), BatchUnit (corr)].
2430+
Removes DEM residual phase proportional to perpendicular baseline at each pixel.
2431+
Mirrors detrend2d() pattern: auto-detects input type.
23832432
2384-
For complex input: detrended = intf * trend.conj()
2385-
For wrapped input: detrended = wrap(phase - trend)
2386-
For real input: detrended = phase - trend
2433+
For complex input (BatchComplex): unit-circle fitting, detrend multiplicatively (phase * trend.conj())
2434+
For real input (Batch): standard polynomial, subtract
23872435
23882436
Parameters
23892437
----------
2390-
transform : Batch
2391-
Coordinate transform from stack.transform() containing 'azi' and 'rng'.
2438+
baseline : str
2439+
Variable name to regress against (default 'BPR' for perpendicular baseline).
23922440
degree : int
2393-
Polynomial degree (1=plane, 2=quadratic). Default 1.
2441+
Polynomial degree (1=linear, 2=quadratic). Default 1.
23942442
device : str
23952443
PyTorch device: 'auto', 'cuda', 'mps', 'cpu'.
23962444
debug : bool
@@ -2403,27 +2451,31 @@ def detrend2d(self, transform, degree=1, device='auto', debug=False):
24032451
24042452
Examples
24052453
--------
2406-
>>> # Complex interferogram detrending (chained)
2407-
>>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend2d(transform)
2408-
>>> # Wrapped phase detrending
2409-
>>> phase, corr = stack.pairs(baseline).phasediff(wavelength=30).angle().detrend2d(transform)
2454+
>>> # Complex interferogram: multiplicative detrend (like detrend2d)
2455+
>>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend1d()
2456+
>>> # Full pipeline: detrend2d (spatial) then detrend1d (baseline)
2457+
>>> result = stack.pairs(bl).interferogram(wavelength=30).detrend2d(transform).detrend1d()
24102458
"""
24112459
if len(self) < 1:
2412-
raise ValueError("detrend2d() requires Batches with at least 1 element: [phase]")
2460+
raise ValueError("detrend1d() requires Batches with at least 1 element: [phase]")
24132461

24142462
phase = self[0]
24152463
weight = self[1] if len(self) >= 2 and isinstance(self[1], BatchUnit) else None
24162464

2417-
if not isinstance(phase, (Batch, BatchWrap, BatchComplex)):
2418-
raise TypeError(f"First element must be Batch, BatchWrap, or BatchComplex, got {type(phase).__name__}")
2465+
if not isinstance(phase, (Batch, BatchComplex)):
2466+
raise TypeError(f"First element must be Batch or BatchComplex, got {type(phase).__name__}")
24192467

2420-
trend = phase.trend2d(transform, weight=weight, degree=degree, device=device, debug=debug)
2468+
trend = phase.trend1d(weight=weight, baseline=baseline, degree=degree,
2469+
device=device, debug=debug)
24212470

24222471
if isinstance(phase, BatchComplex):
24232472
detrended = phase * trend.conj()
24242473
else:
24252474
detrended = phase - trend
24262475

2476+
# Preserve non-spatial variables (e.g. BPR) that may be dropped by arithmetic
2477+
detrended = Batches._preserve_nonspatial(phase, detrended)
2478+
24272479
# Rebuild Batches preserving all original elements except first
24282480
elements = [detrended] + list(self[1:])
24292481
return Batches(elements)
@@ -2465,93 +2517,76 @@ def lstsq(self, device='auto', cumsum=True, debug=False):
24652517
# Delegate to Batch.lstsq
24662518
return data.lstsq(weight=weight, device=device, cumsum=cumsum, debug=debug)
24672519

2468-
def regression1d_baseline(self, baseline='BPR', degree=1, wrap=False, iterations=1,
2469-
device='auto', debug=False):
2520+
def regression1d_baseline(self, *args, **kwargs):
2521+
raise NotImplementedError("Batches.regression1d_baseline() is removed. Use Batches.detrend1d() or Batch.trend1d() instead.")
2522+
2523+
def detrend1d_pairs(self, degree=0, days=90, count=None,
2524+
device='auto', debug=False):
24702525
"""
2471-
Fit 1D polynomial trend along perpendicular baseline at each (y, x) pixel.
2526+
Detrend 1D polynomial trend along temporal pairs and return Batches.
2527+
2528+
Removes per-date temporal trend from pair data. Mirrors detrend1d()/detrend2d()
2529+
pattern: auto-detects input type.
24722530
2473-
Expects Batches with [Batch/BatchWrap (phase), BatchUnit (weight, optional)].
2531+
For complex input (BatchComplex): unit-circle fitting, detrend multiplicatively (phase * trend.conj())
2532+
For real input (Batch): standard polynomial, subtract (phase - trend)
24742533
24752534
Parameters
24762535
----------
2477-
baseline : str
2478-
Variable name to regress against (default 'BPR' for perpendicular baseline).
24792536
degree : int
2480-
Polynomial degree (1=linear, 2=quadratic). Default 1.
2481-
wrap : bool
2482-
If True, use circular (sin/cos) fitting for wrapped phase.
2483-
iterations : int
2484-
Number of fitting iterations (default 1).
2537+
Polynomial degree (0=mean, 1=linear). Default 0.
2538+
days : int
2539+
Maximum time interval (±days symmetric window) to include.
2540+
Default 90. Controls temporal high-pass: only pairs within ±days
2541+
contribute to the per-date atmospheric estimate.
2542+
count : int or None
2543+
Maximum number of pairs per date to use. Default None (all pairs).
24852544
device : str
24862545
PyTorch device: 'auto', 'cuda', 'mps', 'cpu'.
24872546
debug : bool
24882547
Print diagnostic information.
24892548
24902549
Returns
24912550
-------
2492-
Batch
2493-
Fitted trend values, same shape as input data.
2551+
Batches
2552+
Batches with [detrended_phase, weight] preserving original types.
24942553
24952554
Examples
24962555
--------
2497-
>>> trend = Batches([intf, corr]).regression1d_baseline()
2498-
>>> detrended = intf - trend
2556+
>>> # Complex interferogram detrending (chained)
2557+
>>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend1d_pairs()
2558+
>>> # Real detrending
2559+
>>> result = Batches([unwrapped, corr]).detrend1d_pairs(degree=0, days=100)
24992560
"""
25002561
if len(self) < 1:
2501-
raise ValueError("regression1d_baseline() requires Batches with at least 1 element: [data]")
2562+
raise ValueError("detrend1d_pairs() requires Batches with at least 1 element: [phase]")
25022563

2503-
data = self[0]
2564+
phase = self[0]
25042565
weight = self[1] if len(self) >= 2 and isinstance(self[1], BatchUnit) else None
25052566

2506-
# Delegate to BatchCore.regression1d_baseline
2507-
return data.regression1d_baseline(weight=weight, baseline=baseline, degree=degree,
2508-
wrap=wrap, iterations=iterations,
2509-
device=device, debug=debug)
2510-
2511-
def regression1d_pairs(self, degree=0, days=None, count=None, wrap=False, iterations=1,
2512-
device='auto', debug=False):
2513-
"""
2514-
Fit 1D polynomial trend along temporal pairs for each date.
2567+
if not isinstance(phase, (Batch, BatchComplex)):
2568+
raise TypeError(f"First element must be Batch or BatchComplex, got {type(phase).__name__}")
25152569

2516-
Expects Batches with [Batch/BatchWrap (phase), BatchUnit (weight, optional)].
2570+
trend = phase.trend1d_pairs(weight=weight, degree=degree, days=days, count=count,
2571+
device=device, debug=debug)
25172572

2518-
Parameters
2519-
----------
2520-
degree : int
2521-
Polynomial degree (0=mean, 1=linear). Default 0.
2522-
days : int or None
2523-
Maximum time interval (in days) to include. Default None (all pairs).
2524-
count : int or None
2525-
Maximum number of pairs per date to use. Default None (all pairs).
2526-
wrap : bool
2527-
If True, use circular (sin/cos) fitting for wrapped phase.
2528-
iterations : int
2529-
Number of fitting iterations (default 1).
2530-
device : str
2531-
PyTorch device: 'auto', 'cuda', 'mps', 'cpu'.
2532-
debug : bool
2533-
Print diagnostic information.
2573+
if isinstance(phase, BatchComplex):
2574+
detrended = phase * trend.conj()
2575+
else:
2576+
detrended = phase - trend
25342577

2535-
Returns
2536-
-------
2537-
Batch
2538-
Fitted trend values, same shape as input data.
2578+
# Preserve non-spatial variables (e.g. BPR) that may be dropped by arithmetic
2579+
detrended = Batches._preserve_nonspatial(phase, detrended)
25392580

2540-
Examples
2541-
--------
2542-
>>> trend = Batches([intf, corr]).regression1d_pairs(degree=0, days=100)
2543-
>>> detrended = intf - trend
2544-
"""
2545-
if len(self) < 1:
2546-
raise ValueError("regression1d_pairs() requires Batches with at least 1 element: [data]")
2581+
# Rebuild Batches preserving all original elements except first
2582+
elements = [detrended] + list(self[1:])
2583+
return Batches(elements)
25472584

2548-
data = self[0]
2549-
weight = self[1] if len(self) >= 2 and isinstance(self[1], BatchUnit) else None
2585+
def regression1d_pairs(self, *args, **kwargs):
2586+
raise NotImplementedError("Batches.regression1d_pairs() is removed. Use Batches.detrend1d_pairs() instead.")
25502587

2551-
# Delegate to BatchCore.regression1d_pairs
2552-
return data.regression1d_pairs(weight=weight, degree=degree, days=days, count=count,
2553-
wrap=wrap, iterations=iterations,
2554-
device=device, debug=debug)
2588+
def trend1d_pairs(self, *args, **kwargs):
2589+
raise NotImplementedError("Batches.trend1d_pairs() is removed. Use Batches.detrend1d_pairs() instead.")
25552590

25562591
def stl(self, freq='W', periods=52, robust=False):
25572592
"""

0 commit comments

Comments
 (0)