|
27 | 27 | _CASE_SLOTS = frozenset(slot for slots in SLOT_ORDERS.values() for slot in slots) |
28 | 28 |
|
29 | 29 |
|
| 30 | +def _extract_nonindex_coords( |
| 31 | + *dataarrays: xr.DataArray, |
| 32 | +) -> tuple[list[xr.DataArray], dict[str, tuple[str, dict]]]: |
| 33 | + """Extract and merge non-index coords, returning cleaned dataarrays and merged mappings. |
| 34 | +
|
| 35 | + Non-index coords (like `component` on `contributor` dim) cause concat conflicts. |
| 36 | + This extracts them, merges the mappings, and returns dataarrays without them. |
| 37 | + """ |
| 38 | + if not dataarrays: |
| 39 | + return [], {} |
| 40 | + |
| 41 | + # Find non-index coords and collect mappings |
| 42 | + merged: dict[str, tuple[str, dict]] = {} |
| 43 | + coords_to_drop: set[str] = set() |
| 44 | + |
| 45 | + for da in dataarrays: |
| 46 | + for name, coord in da.coords.items(): |
| 47 | + if len(coord.dims) != 1: |
| 48 | + continue |
| 49 | + dim = coord.dims[0] |
| 50 | + if dim == name or dim not in da.coords: |
| 51 | + continue |
| 52 | + |
| 53 | + coords_to_drop.add(name) |
| 54 | + if name not in merged: |
| 55 | + merged[name] = (dim, {}) |
| 56 | + elif merged[name][0] != dim: |
| 57 | + warnings.warn( |
| 58 | + f"Coordinate '{name}' appears on different dims: " |
| 59 | + f"'{merged[name][0]}' vs '{dim}'. Dropping this coordinate.", |
| 60 | + stacklevel=4, |
| 61 | + ) |
| 62 | + continue |
| 63 | + |
| 64 | + for dv, cv in zip(da.coords[dim].values, coord.values, strict=False): |
| 65 | + if dv not in merged[name][1]: |
| 66 | + merged[name][1][dv] = cv |
| 67 | + elif merged[name][1][dv] != cv: |
| 68 | + warnings.warn( |
| 69 | + f"Coordinate '{name}' has conflicting values for dim value '{dv}': " |
| 70 | + f"'{merged[name][1][dv]}' vs '{cv}'. Keeping first value.", |
| 71 | + stacklevel=4, |
| 72 | + ) |
| 73 | + |
| 74 | + # Drop these coords from dataarrays |
| 75 | + result = list(dataarrays) |
| 76 | + if coords_to_drop: |
| 77 | + result = [da.drop_vars(coords_to_drop, errors='ignore') for da in result] |
| 78 | + |
| 79 | + return result, merged |
| 80 | + |
| 81 | + |
| 82 | +def _apply_merged_coords(da: xr.DataArray, merged: dict[str, tuple[str, dict]]) -> xr.DataArray: |
| 83 | + """Apply merged coord mappings to concatenated dataarray.""" |
| 84 | + if not merged: |
| 85 | + return da |
| 86 | + |
| 87 | + new_coords = {} |
| 88 | + for name, (dim, mapping) in merged.items(): |
| 89 | + if dim not in da.dims: |
| 90 | + continue |
| 91 | + new_coords[name] = (dim, [mapping.get(dv, dv) for dv in da.coords[dim].values]) |
| 92 | + |
| 93 | + return da.assign_coords(new_coords) |
| 94 | + |
| 95 | + |
30 | 96 | def _apply_slot_defaults(plotly_kwargs: dict, defaults: dict[str, str | None]) -> None: |
31 | 97 | """Apply default slot assignments to plotly kwargs. |
32 | 98 |
|
@@ -254,12 +320,10 @@ def solution(self) -> xr.Dataset: |
254 | 320 | self._require_solutions() |
255 | 321 | datasets = [fs.solution for fs in self._systems] |
256 | 322 | self._warn_mismatched_dimensions(datasets) |
257 | | - self._solution = xr.concat( |
258 | | - [ds.expand_dims(case=[name]) for ds, name in zip(datasets, self._names, strict=True)], |
259 | | - dim='case', |
260 | | - join='outer', |
261 | | - fill_value=float('nan'), |
262 | | - ) |
| 323 | + expanded = [ds.expand_dims(case=[name]) for ds, name in zip(datasets, self._names, strict=True)] |
| 324 | + expanded, merged_coords = _extract_nonindex_coords(*expanded) |
| 325 | + result = xr.concat(expanded, dim='case', join='outer', coords='minimal', fill_value=float('nan')) |
| 326 | + self._solution = _apply_merged_coords(result, merged_coords) |
263 | 327 | return self._solution |
264 | 328 |
|
265 | 329 | @property |
@@ -322,12 +386,10 @@ def inputs(self) -> xr.Dataset: |
322 | 386 | if self._inputs is None: |
323 | 387 | datasets = [fs.to_dataset(include_solution=False) for fs in self._systems] |
324 | 388 | self._warn_mismatched_dimensions(datasets) |
325 | | - self._inputs = xr.concat( |
326 | | - [ds.expand_dims(case=[name]) for ds, name in zip(datasets, self._names, strict=True)], |
327 | | - dim='case', |
328 | | - join='outer', |
329 | | - fill_value=float('nan'), |
330 | | - ) |
| 389 | + expanded = [ds.expand_dims(case=[name]) for ds, name in zip(datasets, self._names, strict=True)] |
| 390 | + expanded, merged_coords = _extract_nonindex_coords(*expanded) |
| 391 | + result = xr.concat(expanded, dim='case', join='outer', coords='minimal', fill_value=float('nan')) |
| 392 | + self._inputs = _apply_merged_coords(result, merged_coords) |
331 | 393 | return self._inputs |
332 | 394 |
|
333 | 395 |
|
@@ -372,7 +434,11 @@ def _concat_property(self, prop_name: str) -> xr.DataArray: |
372 | 434 | continue |
373 | 435 | if not arrays: |
374 | 436 | return xr.DataArray() |
375 | | - return xr.concat(arrays, dim='case', join='outer', fill_value=float('nan'), coords='minimal', compat='override') |
| 437 | + arrays, merged_coords = _extract_nonindex_coords(*arrays) |
| 438 | + result = xr.concat( |
| 439 | + arrays, dim='case', join='outer', fill_value=float('nan'), coords='minimal', compat='override' |
| 440 | + ) |
| 441 | + return _apply_merged_coords(result, merged_coords) |
376 | 442 |
|
377 | 443 | def _merge_dict_property(self, prop_name: str) -> dict[str, str]: |
378 | 444 | """Merge a dict property from all cases (later cases override).""" |
@@ -526,9 +592,11 @@ def _combine_data(self, method_name: str, *args, **kwargs) -> tuple[xr.DataArray |
526 | 592 | if not arrays: |
527 | 593 | return xr.DataArray(dims=[]), '' |
528 | 594 |
|
529 | | - return xr.concat( |
530 | | - arrays, dim='case', join='outer', fill_value=float('nan'), coords='minimal', compat='override' |
531 | | - ), title |
| 595 | + arrays, merged_coords = _extract_nonindex_coords(*arrays) |
| 596 | + combined = xr.concat( |
| 597 | + arrays, dim='case', join='outer', coords='minimal', fill_value=float('nan'), compat='override' |
| 598 | + ) |
| 599 | + return _apply_merged_coords(combined, merged_coords), title |
532 | 600 |
|
533 | 601 | def _finalize(self, da: xr.DataArray, fig, show: bool | None) -> PlotResult: |
534 | 602 | """Handle show and return PlotResult.""" |
|
0 commit comments