3535 XarrayDataArrayEncoding ,
3636)
3737from eopf_geozarr .s2_optimization .common import DISTRIBUTED_AVAILABLE
38+ from eopf_geozarr .s2_optimization .s2_band_mapping import BAND_INFO
3839
3940from .s2_resampling import determine_variable_type , downsample_variable
4041
@@ -66,6 +67,101 @@ def get_grid_spacing(ds: xr.DataArray, coords: tuple[Hashable, ...]) -> tuple[fl
6667 return tuple (np .abs (ds .coords [coord ][0 ].data - ds .coords [coord ][1 ].data ) for coord in coords )
6768
6869
70+ def _coarsen_variable (var_name : str , var_data : xr .DataArray , factor : int ) -> xr .DataArray :
71+ """Coarsen a single variable using type-aware resampling.
72+
73+ Dispatches to the appropriate coarsen reduction (mean, max, subsample)
74+ based on `determine_variable_type`. Preserves encoding and dtype.
75+ """
76+ var_type = determine_variable_type (var_name , var_data )
77+ coarsened = var_data .coarsen ({"x" : factor , "y" : factor }, boundary = "trim" )
78+ if var_type in ("reflectance" , "probability" ):
79+ result = coarsened .mean ()
80+ elif var_type == "classification" :
81+ result = coarsened .reduce (subsample_2 )
82+ elif var_type == "quality_mask" :
83+ result = coarsened .max ()
84+ else :
85+ raise ValueError (f"Unknown variable type { var_type } " )
86+
87+ result .encoding = var_data .encoding
88+ return result .astype (var_data .dtype )
89+
90+
91+ def inject_missing_bands (
92+ dataset : xr .Dataset ,
93+ dt_input : xr .DataTree ,
94+ target_resolution : int ,
95+ * ,
96+ bands : set [str ] | None = None ,
97+ ) -> xr .Dataset :
98+ """Inject bands whose native resolution is finer than `target_resolution`.
99+
100+ For each spectral band defined in `BAND_INFO` whose native resolution is
101+ finer than `target_resolution`, this function checks whether the band is
102+ already present in `dataset`. If not, it looks for the band in the
103+ appropriate source group (e.g. `/measurements/reflectance/r10m`),
104+ downsamples it to the target grid using the type-aware resampling from
105+ `determine_variable_type`, and merges it into `dataset`.
106+
107+ Args:
108+ dataset: The target-resolution dataset (e.g. the r20m or r60m
109+ reflectance group).
110+ dt_input: The full input DataTree (used to locate finer-resolution
111+ source bands).
112+ target_resolution: Target resolution in metres (e.g. 20 or 60).
113+ bands: If provided, only inject these band names. If `None`
114+ (default), inject every eligible band from `BAND_INFO`.
115+
116+ Returns:
117+ `dataset` with any missing finer-resolution bands added.
118+ """
119+ for band_name , info in BAND_INFO .items ():
120+ if bands is not None and band_name not in bands :
121+ continue
122+ native_res = info .native_resolution # type: ignore[attr-defined]
123+ if native_res >= target_resolution :
124+ continue
125+ if band_name in dataset .data_vars :
126+ continue
127+
128+ source_path = f"/measurements/reflectance/r{ native_res } m"
129+ if source_path not in dt_input .groups :
130+ continue
131+
132+ source_ds = dt_input [source_path ].to_dataset ()
133+ if band_name not in source_ds .data_vars :
134+ continue
135+
136+ band_src = source_ds [band_name ]
137+ factor = target_resolution // native_res
138+ band_ds = _coarsen_variable (band_name , band_src , factor )
139+
140+ # Replace coordinates with the target dataset's coordinates so that
141+ # xarray.Dataset.assign does not try to align on mismatched values.
142+ band_ds = xr .DataArray (
143+ band_ds .values ,
144+ dims = band_ds .dims ,
145+ coords = {d : dataset .coords [d ] for d in band_ds .dims if d in dataset .coords },
146+ attrs = band_ds .attrs ,
147+ name = band_name ,
148+ )
149+
150+ # Preserve source encoding so downstream encoding logic can inspect it
151+ band_ds .encoding = band_src .encoding .copy ()
152+
153+ dataset = dataset .assign ({band_name : band_ds })
154+ log .info (
155+ "Injected downsampled band from finer resolution" ,
156+ band = band_name ,
157+ source = f"r{ native_res } m" ,
158+ target = f"r{ target_resolution } m" ,
159+ shape = band_ds .shape ,
160+ )
161+
162+ return dataset
163+
164+
69165def create_multiscale_from_datatree (
70166 dt_input : xr .DataTree ,
71167 * ,
@@ -121,6 +217,22 @@ def create_multiscale_from_datatree(
121217 )
122218
123219 if is_measurement_group :
220+ # Inject bands whose native resolution is finer than this group's
221+ # (e.g. b08 native at 10m into r20m/r60m) so they propagate through
222+ # the full overview chain (r120m … r720m).
223+ if group_path .startswith ("/measurements/reflectance/" ):
224+ try :
225+ group_resolution = int (group_name [1 :- 1 ])
226+ except ValueError :
227+ group_resolution = 0
228+ if group_resolution > 10 :
229+ dataset = inject_missing_bands (
230+ dataset ,
231+ dt_input ,
232+ group_resolution ,
233+ bands = {"b08" },
234+ )
235+
124236 # Measurement groups: apply custom encoding
125237 encoding = create_measurements_encoding (
126238 dataset ,
@@ -732,24 +844,7 @@ def create_downsampled_resolution_group(source_dataset: xr.Dataset, factor: int)
732844 for var_name , var_data in source_dataset .data_vars .items ():
733845 if var_data .ndim < 2 :
734846 continue
735- var_typ = determine_variable_type (var_name , var_data )
736- if var_typ == "quality_mask" :
737- lazy_downsampled = var_data .coarsen ({"x" : factor , "y" : factor }, boundary = "trim" ).max ()
738- elif var_typ == "reflectance" :
739- lazy_downsampled = var_data .coarsen ({"x" : factor , "y" : factor }, boundary = "trim" ).mean ()
740- elif var_typ == "classification" :
741- lazy_downsampled = var_data .coarsen ({"x" : factor , "y" : factor }, boundary = "trim" ).reduce (
742- subsample_2
743- )
744- elif var_typ == "probability" :
745- lazy_downsampled = var_data .coarsen ({"x" : factor , "y" : factor }, boundary = "trim" ).mean ()
746- else :
747- raise ValueError (f"Unknown variable type { var_typ } " )
748-
749- # preserve encoding
750- lazy_downsampled .encoding = var_data .encoding
751- # Ensure that dtype is preserved
752- lazy_vars [var_name ] = lazy_downsampled .astype (var_data .dtype )
847+ lazy_vars [var_name ] = _coarsen_variable (var_name , var_data , factor )
753848
754849 if not lazy_vars :
755850 return xr .Dataset ()
@@ -760,8 +855,8 @@ def create_downsampled_resolution_group(source_dataset: xr.Dataset, factor: int)
760855
761856def subsample_2 (a : xr .DataArray , axis : tuple [int , ...] | None = None ) -> xr .DataArray :
762857 if axis is None :
763- return a [((slice ( None , None , 2 ) ,) * a .ndim )]
764- indexer = [slice ( None , None , 2 ) if i in axis else slice (None ) for i in range (a .ndim )]
858+ return a [((0 ,) * a .ndim )]
859+ indexer = [0 if i in axis else slice (None ) for i in range (a .ndim )]
765860 return a [tuple (indexer )]
766861
767862
0 commit comments