@@ -77,14 +77,12 @@ def _apply_velocity_block(data_block, times_years, min_valid, device):
7777 Returns
7878 -------
7979 np.ndarray
80- 3D array (2, chunk_y, chunk_x) where [0] is velocity, [1] is intercept
80+ 3D array (2, chunk_y, chunk_x) where [0] is velocity, [1] is RMSE
8181 """
8282 import numpy as np
83- # Convert times_years back to numpy array (dask serialization may convert to list)
8483 times_years = np .asarray (times_years , dtype = np .float32 )
85- vel , intercept = Batch ._velocity_torch (data_block , times_years , min_valid = min_valid , device = device )
86- # Stack velocity and intercept along a new first dimension
87- return np .stack ([vel , intercept ], axis = 0 ).astype (np .float32 )
84+ vel , rmse = Batch ._velocity_torch (data_block , times_years , min_valid = min_valid , device = device )
85+ return np .stack ([vel , rmse ], axis = 0 ).astype (np .float32 )
8886
8987
9088class Batch (BatchCore ):
@@ -423,58 +421,63 @@ def _velocity_torch(data, times_years, min_valid=3, device='auto', debug=False):
423421 except Exception :
424422 coeffs [px ] = float ('nan' )
425423
426- intercept = coeffs [:, 0 ] # c0
427424 velocity = coeffs [:, 1 ] # c1 = velocity (unbiased)
428425
426+ # Compute RMSE: residuals = y - A @ coeffs
427+ # A @ coeffs per pixel: sum over basis functions
428+ predicted = torch .zeros_like (y_filled )
429+ for i , c in enumerate (cols ):
430+ predicted += c * coeffs [:, i ].unsqueeze (0 ) # (n_times, n_pixels)
431+ residuals = (y_filled - predicted ) * w # zero out NaN positions
432+ rmse = torch .sqrt ((residuals ** 2 ).sum (dim = 0 ) / valid_count .clamp (min = 1 ))
433+
429434 # Mask pixels with insufficient valid points
430435 valid_mask = valid_count >= min_valid
431436 velocity = torch .where (valid_mask , velocity , torch .tensor (float ('nan' ), device = dev ))
432- intercept = torch .where (valid_mask , intercept , torch .tensor (float ('nan' ), device = dev ))
437+ rmse = torch .where (valid_mask , rmse , torch .tensor (float ('nan' ), device = dev ))
433438
434439 # Reshape back
435440 vel_np = velocity .cpu ().numpy ()
436- int_np = intercept .cpu ().numpy ()
441+ rmse_np = rmse .cpu ().numpy ()
437442 if len (original_shape ) == 3 :
438443 vel_np = vel_np .reshape (original_shape [1 ], original_shape [2 ])
439- int_np = int_np .reshape (original_shape [1 ], original_shape [2 ])
444+ rmse_np = rmse_np .reshape (original_shape [1 ], original_shape [2 ])
440445
441446 # Cleanup GPU memory
442447 if dev .type == 'mps' :
443448 torch .mps .empty_cache ()
444449 elif dev .type == 'cuda' :
445450 torch .cuda .empty_cache ()
446451
447- return vel_np , int_np
452+ return vel_np , rmse_np
448453
449- def velocity (self , min_valid = 5 , device = 'auto' , debug = False ) -> tuple [ "Batch" , "Batch" ] :
454+ def velocity (self , min_valid = 5 , device = 'auto' , debug = False ) -> "Batches" :
450455 """
451- Compute velocity (linear trend) and intercept from time series.
456+ Compute velocity and RMSE from time series.
452457
453- Calculates the slope per year and intercept for each pixel using linear
454- regression on the 'date' dimension. Uses PyTorch for GPU acceleration.
458+ Harmonic regression (linear + seasonal) per pixel on the 'date' dimension.
459+ Uses PyTorch for GPU acceleration.
455460
456461 Parameters
457462 ----------
458463 min_valid : int, optional
459- Minimum number of valid (non-NaN) data points required to compute
460- velocity. Pixels with fewer valid points will be set to NaN.
461- Default is 3.
464+ Minimum number of valid (non-NaN) data points required.
465+ Default is 5.
462466 device : str, optional
463467 PyTorch device: 'auto' (default), 'cuda', 'mps', or 'cpu'.
464- 'auto' uses GPU if Dask client has resources={'gpu': 1}.
465468 debug : bool, optional
466469 Print debug information. Default False.
467470
468471 Returns
469472 -------
470- tuple [Batch, Batch]
471- (velocity, intercept) - velocity is slope per year, intercept is
472- the y-value at t=0 (first date) . Both are lazy Batch objects.
473+ Batches [Batch, Batch]
474+ (velocity, rmse) — velocity is slope per year, RMSE is residual
475+ root-mean-square error . Both are lazy Batch objects.
473476
474477 Examples
475478 --------
476- >>> displacement = stack.lstsq(detrend, corr )
477- >>> velocity, intercept = displacement .velocity()
479+ >>> velocity, rmse = displacement.velocity( )
480+ >>> vel, rmse = detrend0 .velocity().displacement_los(transform).compute ()
478481 """
479482 import dask
480483 import dask .array as da
@@ -497,10 +500,10 @@ def velocity(self, min_valid=5, device='auto', debug=False) -> tuple["Batch", "B
497500 crs = self .crs
498501
499502 vel_results = {}
500- int_results = {}
503+ rmse_results = {}
501504 for key , ds in self .items ():
502505 vel_vars = {}
503- int_vars = {}
506+ rmse_vars = {}
504507 # Filter for spatial variables (with y, x dims) - excludes converted attributes
505508 for var in [v for v in ds .data_vars
506509 if 'y' in ds [v ].dims and 'x' in ds [v ].dims ]:
@@ -542,33 +545,28 @@ def process_block(data_block):
542545 chunks = (2 ,) + data_dask .chunks [1 :],
543546 )
544547
545- # Unpack velocity (index 0) and intercept (index 1)
546548 vel_da = xr .DataArray (
547- result_dask [0 ],
548- dims = ['y' , 'x' ],
549+ result_dask [0 ], dims = ['y' , 'x' ],
549550 coords = {'y' : data_arr .y , 'x' : data_arr .x }
550551 )
551- int_da = xr .DataArray (
552- result_dask [1 ],
553- dims = ['y' , 'x' ],
552+ rmse_da = xr .DataArray (
553+ result_dask [1 ], dims = ['y' , 'x' ],
554554 coords = {'y' : data_arr .y , 'x' : data_arr .x }
555555 )
556-
557556 vel_vars [var ] = vel_da
558- int_vars [var ] = int_da
557+ rmse_vars [var ] = rmse_da
559558
560559 vel_ds = xr .Dataset (vel_vars )
561560 vel_ds .attrs = ds .attrs
562- int_ds = xr .Dataset (int_vars )
563- int_ds .attrs = ds .attrs
564- # Preserve CRS
561+ rmse_ds = xr .Dataset (rmse_vars )
562+ rmse_ds .attrs = ds .attrs
565563 if crs is not None :
566564 vel_ds = vel_ds .rio .write_crs (crs )
567- int_ds = int_ds .rio .write_crs (crs )
565+ rmse_ds = rmse_ds .rio .write_crs (crs )
568566 vel_results [key ] = vel_ds
569- int_results [key ] = int_ds
567+ rmse_results [key ] = rmse_ds
570568
571- return Batch (vel_results ), Batch (int_results )
569+ return Batches (( Batch (vel_results ), Batch (rmse_results )) )
572570
573571 def incidence (self ) -> "Batch" :
574572 """Compute incidence angle from azi, rng, ele, and radar geometry parameters.
@@ -1752,7 +1750,7 @@ def _velocity_block(data_block, weight_block=None,
17521750 vel_results [burst_id ] = vel_ds
17531751 rmse_results [burst_id ] = rmse_ds
17541752
1755- return Batch (vel_results ), Batch (rmse_results )
1753+ return Batches (( Batch (vel_results ), Batch (rmse_results )) )
17561754
17571755 def backscatter (self , * args , ** kwargs ):
17581756 """
@@ -1781,6 +1779,43 @@ def conj(self, **kwargs):
17811779 """intfs.iexp().conj() for np.exp(-1j * intfs)"""
17821780 return self .map_da (lambda da : xr .ufuncs .conj (da ), ** kwargs )
17831781
1782+ def pairs (self , pairs ):
1783+ """Select date pairs from per-date data, returning ref and rep stacks.
1784+
1785+ Parameters
1786+ ----------
1787+ pairs : array-like (n_pairs, 2)
1788+ Pairs as [[ref_date, rep_date], ...]. Dates as datetime64 or indices.
1789+
1790+ Returns
1791+ -------
1792+ tuple (ref, rep)
1793+ Two BatchComplex with 'pair' dimension instead of 'date'.
1794+ """
1795+ import numpy as np
1796+ pairs = np .asarray (pairs )
1797+ ref_dates = pairs [:, 0 ]
1798+ rep_dates = pairs [:, 1 ]
1799+
1800+ # Map dates to integer indices (match by day to handle precision differences)
1801+ key0 = list (self .keys ())[0 ]
1802+ date_coords = self [key0 ].coords ['date' ].values
1803+ # Truncate to day precision for matching
1804+ date_days = np .array (date_coords , dtype = 'datetime64[D]' )
1805+ date_to_idx = {d : i for i , d in enumerate (date_days )}
1806+ ref_idx = [date_to_idx [np .datetime64 (d , 'D' )] for d in ref_dates ]
1807+ rep_idx = [date_to_idx [np .datetime64 (d , 'D' )] for d in rep_dates ]
1808+
1809+ # Select, rename date→pair, and assign pair coords matching the caller
1810+ n_pairs = len (ref_idx )
1811+ pair_coords = np .arange (n_pairs )
1812+ screen_ref = self .isel (date = ref_idx ).rename (date = 'pair' ).map (
1813+ lambda ds : ds .assign_coords (pair = pair_coords ))
1814+ screen_rep = self .isel (date = rep_idx ).rename (date = 'pair' ).map (
1815+ lambda ds : ds .assign_coords (pair = pair_coords ))
1816+
1817+ return screen_ref , screen_rep
1818+
17841819 def lstsq_baseline (self , weight = None , baseline = 'BPR' , stride = 1 , debug = False ):
17851820 """
17861821 Decompose per-pair complex trend into network-consistent per-date model.
@@ -1824,13 +1859,21 @@ def lstsq_baseline(self, weight=None, baseline='BPR', stride=1, debug=False):
18241859 has_bpr = baseline is not None and baseline in ds .coords
18251860 bpr_values = ds .coords [baseline ].values .astype (np .float64 ) if has_bpr else None
18261861
1862+ # Compute unique dates (same logic as lstsq_baseline_array)
1863+ ns_per_day = 86400 * 1e9
1864+ ref_days = ref_values .astype (np .float64 ) / ns_per_day
1865+ rep_days = rep_values .astype (np .float64 ) / ns_per_day
1866+ unique_days = np .unique (np .concatenate ([ref_days , rep_days ]))
1867+ n_dates = len (unique_days )
1868+ # Map unique_days back to datetime64
1869+ date_coords = (unique_days * ns_per_day ).astype ('datetime64[ns]' )
1870+
18271871 weight_ds = weight [key ] if weight is not None else None
18281872
18291873 result_ds = {}
18301874 for pol in pols :
18311875 data_da = ds [pol ]
18321876 data_dask = data_da .data
1833- n_pairs = data_dask .shape [0 ]
18341877
18351878 weight_da = weight_ds [pol ] if weight_ds is not None else None
18361879
@@ -1846,26 +1889,31 @@ def _block(data_block, weight_block=None,
18461889 if weight_da is not None :
18471890 weight_dask = weight_da .data
18481891 result_dask = da .blockwise (
1849- _block , 'pyx ' ,
1892+ _block , 'dyx ' ,
18501893 data_dask , 'pyx' ,
18511894 weight_dask , 'pyx' ,
1852- new_axes = {'p ' : n_pairs },
1895+ new_axes = {'d ' : n_dates },
18531896 concatenate = True ,
18541897 dtype = np .complex64 ,
18551898 meta = np .empty ((0 , 0 , 0 ), dtype = np .complex64 ),
18561899 )
18571900 else :
18581901 result_dask = da .blockwise (
1859- _block , 'pyx ' ,
1902+ _block , 'dyx ' ,
18601903 data_dask , 'pyx' ,
1861- new_axes = {'p ' : n_pairs },
1904+ new_axes = {'d ' : n_dates },
18621905 concatenate = True ,
18631906 dtype = np .complex64 ,
18641907 meta = np .empty ((0 , 0 , 0 ), dtype = np .complex64 ),
18651908 )
18661909
1910+ # Per-date output with date coordinates
18671911 result_ds [pol ] = xr .DataArray (
1868- result_dask , dims = data_da .dims , coords = data_da .coords
1912+ result_dask ,
1913+ dims = ['date' , 'y' , 'x' ],
1914+ coords = {'date' : date_coords ,
1915+ 'y' : data_da .coords ['y' ],
1916+ 'x' : data_da .coords ['x' ]}
18691917 )
18701918
18711919 result [key ] = xr .Dataset (result_ds , attrs = ds .attrs )
@@ -2100,6 +2148,23 @@ def goldstein(self, corr: BatchUnit, window: int | dict[str, int] = 32, threshol
21002148 return type (self )(result )
21012149
21022150
2151+ def _subtract_date_from_pair (first , second ):
2152+ """Subtract per-date atmospheric screens from per-pair data.
2153+
2154+ Uses BatchComplex.pairs() to select ref/rep screens,
2155+ then: result = data * conj(screen_ref) * screen_rep
2156+ """
2157+ import numpy as np
2158+
2159+ key0 = list (first .keys ())[0 ]
2160+ ref_dates = first [key0 ].coords ['ref' ].values
2161+ rep_dates = first [key0 ].coords ['rep' ].values
2162+ pairs = np .column_stack ([ref_dates , rep_dates ])
2163+
2164+ screen_ref , screen_rep = second .pairs (pairs )
2165+ return first * screen_ref .conj () * screen_rep
2166+
2167+
21032168class Batches (tuple ):
21042169 """
21052170 A tuple-like container for multiple Batch objects that allows chained operations.
@@ -2164,7 +2229,7 @@ def correlation(self) -> 'BatchUnit | None':
21642229
21652230 def snapshot (self , store : str | None = None , storage_options : dict [str , str ] | None = None ,
21662231 caption : str | None = None ,
2167- n_chunks : int = 1 , debug : bool = False , ** kwargs ):
2232+ debug : bool = False , ** kwargs ):
21682233 """Save or open a Batches snapshot.
21692234
21702235 When called on a Batches with data, saves all batches to Zarr store.
@@ -2178,8 +2243,6 @@ def snapshot(self, store: str | None = None, storage_options: dict[str, str] | N
21782243 Storage options for cloud stores.
21792244 caption : str, optional
21802245 Progress bar caption.
2181- n_chunks : int
2182- Spatial chunks per worker per batch. Default 4.
21832246 debug : bool
21842247 Print debug information.
21852248
@@ -2200,11 +2263,11 @@ def snapshot(self, store: str | None = None, storage_options: dict[str, str] | N
22002263 if len (self ) == 0 :
22012264 result = utils_io .snapshot (store = store , storage_options = storage_options ,
22022265 caption = caption or 'Opening...' ,
2203- n_chunks = n_chunks , debug = debug )
2266+ debug = debug )
22042267 else :
22052268 result = utils_io .snapshot (* self , store = store , storage_options = storage_options ,
22062269 caption = caption or 'Snapshotting...' ,
2207- n_chunks = n_chunks , debug = debug , wrapper = Batches )
2270+ debug = debug , wrapper = Batches )
22082271
22092272 if isinstance (result , Batches ):
22102273 return result
@@ -3070,7 +3133,20 @@ def subtract(self):
30703133 first = self [0 ]
30713134 second = self [second_idx ]
30723135
3073- if isinstance (first , BatchComplex ):
3136+ # Check if second is per-date (from lstsq_baseline) and first is per-pair
3137+ is_date_to_pair = False
3138+ for key in first .keys ():
3139+ first_ds = first [key ]
3140+ second_ds = second [key ]
3141+ first_pol = [v for v in first_ds .data_vars if 'y' in first_ds [v ].dims ][0 ]
3142+ second_pol = [v for v in second_ds .data_vars if 'y' in second_ds [v ].dims ][0 ]
3143+ if 'pair' in first_ds [first_pol ].dims and 'date' in second_ds [second_pol ].dims :
3144+ is_date_to_pair = True
3145+ break
3146+
3147+ if is_date_to_pair :
3148+ result = _subtract_date_from_pair (first , second )
3149+ elif isinstance (first , BatchComplex ):
30743150 result = first * second .conj ()
30753151 else :
30763152 result = first - second
@@ -3325,11 +3401,9 @@ def velocity(self, max_refine=3, **kwargs):
33253401 weight = self [1 ] if len (self ) >= 2 and isinstance (self [1 ], BatchUnit ) else None
33263402
33273403 if isinstance (phase , BatchComplex ):
3328- vel , rmse = phase .velocity (weight = weight , max_refine = max_refine , ** kwargs )
3329- return Batches ((vel , rmse ))
3404+ return phase .velocity (weight = weight , max_refine = max_refine , ** kwargs )
33303405 else :
3331- vel , intercept = phase .velocity (** kwargs )
3332- return Batches ((vel , intercept ))
3406+ return phase .velocity (** kwargs )
33333407
33343408 def rmse (self , solution ):
33353409 """RMSE of phase vs solution, using correlation weight if present.
0 commit comments