@@ -966,6 +966,24 @@ def __init__(self, mapping: dict[str, xr.Dataset] | Stack | None = None, wrap: b
966966 def wrap (data ):
967967 return np .mod (data + np .pi , 2 * np .pi ) - np .pi
968968
969+ def trend1d (self , * args , ** kwargs ):
970+ raise TypeError (
971+ "trend1d() does not support wrapped phase (BatchWrap). "
972+ "Use BatchComplex for complex phase fitting, or unwrap first for real polynomial fitting."
973+ )
974+
975+ def trend2d (self , * args , ** kwargs ):
976+ raise TypeError (
977+ "trend2d() does not support wrapped phase (BatchWrap). "
978+ "Use BatchComplex for complex phase fitting, or unwrap first for real polynomial fitting."
979+ )
980+
981+ def trend1d_pairs (self , * args , ** kwargs ):
982+ raise TypeError (
983+ "trend1d_pairs() does not support wrapped phase (BatchWrap). "
984+ "Use BatchComplex for complex phase fitting, or unwrap first."
985+ )
986+
969987 def __add__ (self , other : Batch ):
970988 keys = self .keys ()
971989 return type (self )({k : (self [k ] + other [k ] if k in other else self [k ]) for k in keys })
@@ -1775,6 +1793,24 @@ class Batches(tuple):
17751793 def __new__ (cls , batches = ()):
17761794 return super ().__new__ (cls , batches )
17771795
1796+ @staticmethod
1797+ def _preserve_nonspatial (source , target ):
1798+ """Copy non-spatial variables (e.g. BPR) from source to target batch."""
1799+ import dask .array as da
1800+ for key in source :
1801+ src_ds = source [key ]
1802+ tgt_ds = target [key ]
1803+ extra = {}
1804+ for v in src_ds .data_vars :
1805+ if v not in tgt_ds .data_vars :
1806+ var = src_ds [v ]
1807+ if not isinstance (var .data , da .Array ):
1808+ var = var .chunk ()
1809+ extra [v ] = var
1810+ if extra :
1811+ target [key ] = tgt_ds .assign (extra )
1812+ return target
1813+
17781814 def snapshot (self , store : str | None = None , storage_options : dict [str , str ] | None = None ,
17791815 caption : str | None = None , allow_rechunk : bool = False ,
17801816 n_jobs : int = 1 , debug : bool = False ):
@@ -2334,11 +2370,14 @@ def unwrap1d(self, device='auto', debug=False, **kwargs):
23342370 # Delegate to BatchWrap.unwrap1d
23352371 return phase .unwrap1d (weight = weight , device = device , debug = debug , ** kwargs )
23362372
2337- def trend2d (self , transform , degree = 1 , device = 'auto' , debug = False ):
2373+ def detrend2d (self , transform , degree = 1 , device = 'auto' , debug = False ):
23382374 """
2339- Compute 2D polynomial trend (ramp) from phase .
2375+ Detrend 2D polynomial trend and return Batches with detrended data .
23402376
2341- Expects Batches with [Batch/BatchWrap (phase), BatchUnit (weight, optional)].
2377+ Expects Batches from interferogram(): [BatchComplex (intf), BatchUnit (corr)].
2378+
2379+ For complex input (BatchComplex): detrended = intf * trend.conj()
2380+ For real input (Batch): detrended = phase - trend
23422381
23432382 Parameters
23442383 ----------
@@ -2353,44 +2392,53 @@ def trend2d(self, transform, degree=1, device='auto', debug=False):
23532392
23542393 Returns
23552394 -------
2356- Batch
2357- Trend surface .
2395+ Batches
2396+ Batches with [detrended_phase, weight] preserving original types .
23582397
23592398 Examples
23602399 --------
2361- >>> phase, corr = stack.pairs(baseline.tolist()).phasediff(wavelength=30).angle()
2362- >>> trend = Batches([phase, corr]).trend2d(stack.transform())
2363- >>> detrended = phase - trend
2400+ >>> # Complex interferogram detrending (chained)
2401+ >>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend2d(transform)
23642402 """
23652403 if len (self ) < 1 :
2366- raise ValueError ("trend2d () requires Batches with at least 1 element: [phase]" )
2404+ raise ValueError ("detrend2d () requires Batches with at least 1 element: [phase]" )
23672405
23682406 phase = self [0 ]
23692407 weight = self [1 ] if len (self ) >= 2 and isinstance (self [1 ], BatchUnit ) else None
23702408
2371- if not isinstance (phase , (Batch , BatchWrap , BatchComplex )):
2372- raise TypeError (f"First element must be Batch, BatchWrap, or BatchComplex, got { type (phase ).__name__ } " )
2409+ if not isinstance (phase , (Batch , BatchComplex )):
2410+ raise TypeError (f"First element must be Batch or BatchComplex, got { type (phase ).__name__ } " )
23732411
2374- # Delegate to BatchCore.trend2d
2375- return phase .trend2d (transform , weight = weight , degree = degree , device = device , debug = debug )
2412+ trend = phase .trend2d (transform , weight = weight , degree = degree , device = device , debug = debug )
23762413
2377- def detrend2d (self , transform , degree = 1 , device = 'auto' , debug = False ):
2414+ if isinstance (phase , BatchComplex ):
2415+ detrended = phase * trend .conj ()
2416+ else :
2417+ detrended = phase - trend
2418+
2419+ # Preserve non-spatial variables (e.g. BPR) that may be dropped by arithmetic
2420+ detrended = Batches ._preserve_nonspatial (phase , detrended )
2421+
2422+ # Rebuild Batches preserving all original elements except first
2423+ elements = [detrended ] + list (self [1 :])
2424+ return Batches (elements )
2425+
2426+ def detrend1d (self , baseline = 'BPR' , degree = 1 , device = 'auto' , debug = False ):
23782427 """
2379- Detrend 2D polynomial trend and return Batches with detrended data .
2428+ Detrend 1D polynomial trend along perpendicular baseline and return Batches .
23802429
2381- Expects Batches from interferogram(): [BatchComplex (intf), BatchUnit (corr)]
2382- or from angle(): [BatchWrap (phase), BatchUnit (corr)] .
2430+ Removes DEM residual phase proportional to perpendicular baseline at each pixel.
2431+ Mirrors detrend2d() pattern: auto-detects input type .
23832432
2384- For complex input: detrended = intf * trend.conj()
2385- For wrapped input: detrended = wrap(phase - trend)
2386- For real input: detrended = phase - trend
2433+ For complex input (BatchComplex): unit-circle fitting, detrend multiplicatively (phase * trend.conj())
2434+ For real input (Batch): standard polynomial, subtract
23872435
23882436 Parameters
23892437 ----------
2390- transform : Batch
2391- Coordinate transform from stack.transform() containing 'azi' and 'rng' .
2438+ baseline : str
2439+ Variable name to regress against (default 'BPR' for perpendicular baseline) .
23922440 degree : int
2393- Polynomial degree (1=plane , 2=quadratic). Default 1.
2441+ Polynomial degree (1=linear , 2=quadratic). Default 1.
23942442 device : str
23952443 PyTorch device: 'auto', 'cuda', 'mps', 'cpu'.
23962444 debug : bool
@@ -2403,27 +2451,31 @@ def detrend2d(self, transform, degree=1, device='auto', debug=False):
24032451
24042452 Examples
24052453 --------
2406- >>> # Complex interferogram detrending (chained )
2407- >>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend2d(transform )
2408- >>> # Wrapped phase detrending
2409- >>> phase, corr = stack.pairs(baseline).phasediff (wavelength=30).angle().detrend2d(transform )
2454+ >>> # Complex interferogram: multiplicative detrend (like detrend2d )
2455+ >>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend1d( )
2456+ >>> # Full pipeline: detrend2d (spatial) then detrend1d (baseline)
2457+ >>> result = stack.pairs(bl).interferogram (wavelength=30).detrend2d(transform).detrend1d( )
24102458 """
24112459 if len (self ) < 1 :
2412- raise ValueError ("detrend2d () requires Batches with at least 1 element: [phase]" )
2460+ raise ValueError ("detrend1d () requires Batches with at least 1 element: [phase]" )
24132461
24142462 phase = self [0 ]
24152463 weight = self [1 ] if len (self ) >= 2 and isinstance (self [1 ], BatchUnit ) else None
24162464
2417- if not isinstance (phase , (Batch , BatchWrap , BatchComplex )):
2418- raise TypeError (f"First element must be Batch, BatchWrap, or BatchComplex, got { type (phase ).__name__ } " )
2465+ if not isinstance (phase , (Batch , BatchComplex )):
2466+ raise TypeError (f"First element must be Batch or BatchComplex, got { type (phase ).__name__ } " )
24192467
2420- trend = phase .trend2d (transform , weight = weight , degree = degree , device = device , debug = debug )
2468+ trend = phase .trend1d (weight = weight , baseline = baseline , degree = degree ,
2469+ device = device , debug = debug )
24212470
24222471 if isinstance (phase , BatchComplex ):
24232472 detrended = phase * trend .conj ()
24242473 else :
24252474 detrended = phase - trend
24262475
2476+ # Preserve non-spatial variables (e.g. BPR) that may be dropped by arithmetic
2477+ detrended = Batches ._preserve_nonspatial (phase , detrended )
2478+
24272479 # Rebuild Batches preserving all original elements except first
24282480 elements = [detrended ] + list (self [1 :])
24292481 return Batches (elements )
@@ -2465,93 +2517,76 @@ def lstsq(self, device='auto', cumsum=True, debug=False):
24652517 # Delegate to Batch.lstsq
24662518 return data .lstsq (weight = weight , device = device , cumsum = cumsum , debug = debug )
24672519
2468- def regression1d_baseline (self , baseline = 'BPR' , degree = 1 , wrap = False , iterations = 1 ,
2469- device = 'auto' , debug = False ):
2520+ def regression1d_baseline (self , * args , ** kwargs ):
2521+ raise NotImplementedError ("Batches.regression1d_baseline() is removed. Use Batches.detrend1d() or Batch.trend1d() instead." )
2522+
2523+ def detrend1d_pairs (self , degree = 0 , days = 90 , count = None ,
2524+ device = 'auto' , debug = False ):
24702525 """
2471- Fit 1D polynomial trend along perpendicular baseline at each (y, x) pixel.
2526+ Detrend 1D polynomial trend along temporal pairs and return Batches.
2527+
2528+ Removes per-date temporal trend from pair data. Mirrors detrend1d()/detrend2d()
2529+ pattern: auto-detects input type.
24722530
2473- Expects Batches with [Batch/BatchWrap (phase), BatchUnit (weight, optional)].
2531+ For complex input (BatchComplex): unit-circle fitting, detrend multiplicatively (phase * trend.conj())
2532+ For real input (Batch): standard polynomial, subtract (phase - trend)
24742533
24752534 Parameters
24762535 ----------
2477- baseline : str
2478- Variable name to regress against (default 'BPR' for perpendicular baseline).
24792536 degree : int
2480- Polynomial degree (1=linear, 2=quadratic). Default 1.
2481- wrap : bool
2482- If True, use circular (sin/cos) fitting for wrapped phase.
2483- iterations : int
2484- Number of fitting iterations (default 1).
2537+ Polynomial degree (0=mean, 1=linear). Default 0.
2538+ days : int
2539+ Maximum time interval (±days symmetric window) to include.
2540+ Default 90. Controls temporal high-pass: only pairs within ±days
2541+ contribute to the per-date atmospheric estimate.
2542+ count : int or None
2543+ Maximum number of pairs per date to use. Default None (all pairs).
24852544 device : str
24862545 PyTorch device: 'auto', 'cuda', 'mps', 'cpu'.
24872546 debug : bool
24882547 Print diagnostic information.
24892548
24902549 Returns
24912550 -------
2492- Batch
2493- Fitted trend values, same shape as input data .
2551+ Batches
2552+ Batches with [detrended_phase, weight] preserving original types .
24942553
24952554 Examples
24962555 --------
2497- >>> trend = Batches([intf, corr]).regression1d_baseline()
2498- >>> detrended = intf - trend
2556+ >>> # Complex interferogram detrending (chained)
2557+ >>> intf, corr = stack.pairs(baseline).interferogram(wavelength=30).detrend1d_pairs()
2558+ >>> # Real detrending
2559+ >>> result = Batches([unwrapped, corr]).detrend1d_pairs(degree=0, days=100)
24992560 """
25002561 if len (self ) < 1 :
2501- raise ValueError ("regression1d_baseline () requires Batches with at least 1 element: [data ]" )
2562+ raise ValueError ("detrend1d_pairs () requires Batches with at least 1 element: [phase ]" )
25022563
2503- data = self [0 ]
2564+ phase = self [0 ]
25042565 weight = self [1 ] if len (self ) >= 2 and isinstance (self [1 ], BatchUnit ) else None
25052566
2506- # Delegate to BatchCore.regression1d_baseline
2507- return data .regression1d_baseline (weight = weight , baseline = baseline , degree = degree ,
2508- wrap = wrap , iterations = iterations ,
2509- device = device , debug = debug )
2510-
2511- def regression1d_pairs (self , degree = 0 , days = None , count = None , wrap = False , iterations = 1 ,
2512- device = 'auto' , debug = False ):
2513- """
2514- Fit 1D polynomial trend along temporal pairs for each date.
2567+ if not isinstance (phase , (Batch , BatchComplex )):
2568+ raise TypeError (f"First element must be Batch or BatchComplex, got { type (phase ).__name__ } " )
25152569
2516- Expects Batches with [Batch/BatchWrap (phase), BatchUnit (weight, optional)].
2570+ trend = phase .trend1d_pairs (weight = weight , degree = degree , days = days , count = count ,
2571+ device = device , debug = debug )
25172572
2518- Parameters
2519- ----------
2520- degree : int
2521- Polynomial degree (0=mean, 1=linear). Default 0.
2522- days : int or None
2523- Maximum time interval (in days) to include. Default None (all pairs).
2524- count : int or None
2525- Maximum number of pairs per date to use. Default None (all pairs).
2526- wrap : bool
2527- If True, use circular (sin/cos) fitting for wrapped phase.
2528- iterations : int
2529- Number of fitting iterations (default 1).
2530- device : str
2531- PyTorch device: 'auto', 'cuda', 'mps', 'cpu'.
2532- debug : bool
2533- Print diagnostic information.
2573+ if isinstance (phase , BatchComplex ):
2574+ detrended = phase * trend .conj ()
2575+ else :
2576+ detrended = phase - trend
25342577
2535- Returns
2536- -------
2537- Batch
2538- Fitted trend values, same shape as input data.
2578+ # Preserve non-spatial variables (e.g. BPR) that may be dropped by arithmetic
2579+ detrended = Batches ._preserve_nonspatial (phase , detrended )
25392580
2540- Examples
2541- --------
2542- >>> trend = Batches([intf, corr]).regression1d_pairs(degree=0, days=100)
2543- >>> detrended = intf - trend
2544- """
2545- if len (self ) < 1 :
2546- raise ValueError ("regression1d_pairs() requires Batches with at least 1 element: [data]" )
2581+ # Rebuild Batches preserving all original elements except first
2582+ elements = [detrended ] + list (self [1 :])
2583+ return Batches (elements )
25472584
2548- data = self [ 0 ]
2549- weight = self [ 1 ] if len ( self ) >= 2 and isinstance ( self [ 1 ], BatchUnit ) else None
2585+ def regression1d_pairs ( self , * args , ** kwargs ):
2586+ raise NotImplementedError ( "Batches.regression1d_pairs() is removed. Use Batches.detrend1d_pairs() instead." )
25502587
2551- # Delegate to BatchCore.regression1d_pairs
2552- return data .regression1d_pairs (weight = weight , degree = degree , days = days , count = count ,
2553- wrap = wrap , iterations = iterations ,
2554- device = device , debug = debug )
2588+ def trend1d_pairs (self , * args , ** kwargs ):
2589+ raise NotImplementedError ("Batches.trend1d_pairs() is removed. Use Batches.detrend1d_pairs() instead." )
25552590
25562591 def stl (self , freq = 'W' , periods = 52 , robust = False ):
25572592 """
0 commit comments