Skip to content

Commit ccbb894

Browse files
committed
fixed cupy and dask cases for the unit mismatch heuristic
1 parent ab30e48 commit ccbb894

File tree

1 file changed

+117
-25
lines changed

1 file changed

+117
-25
lines changed

xrspatial/utils.py

Lines changed: 117 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,6 @@ def _convert_color(c):
452452
return tf.Image(f(agg.data))
453453

454454

455-
456455
def _infer_coord_unit_type(coord: xr.DataArray, cellsize: float) -> str:
457456
"""
458457
Heuristic to classify a spatial coordinate axis as:
@@ -504,50 +503,40 @@ def _infer_coord_unit_type(coord: xr.DataArray, cellsize: float) -> str:
504503
return "unknown"
505504

506505

507-
def _infer_vertical_unit_type(agg: xr.DataArray) -> str:
508-
"""
509-
Heuristic to classify the DataArray values as:
510-
- 'elevation' (meters/feet etc)
511-
- 'angle' (degrees/radians)
512-
- 'unknown'
513-
"""
506+
def _infer_vertical_unit_type(agg):
514507
units = str(agg.attrs.get("units", "")).lower()
515508

516-
# 1) Explicit units
517-
if any(k in units for k in ("degree", "deg")):
518-
return "angle"
519-
if "rad" in units:
509+
# Cheap / reliable first
510+
if any(k in units for k in ("degree", "deg")) or "rad" in units:
520511
return "angle"
521512
if units in ("m", "meter", "metre", "meters", "metres",
522513
"km", "kilometer", "kilometre", "kilometers", "kilometres",
523514
"ft", "foot", "feet"):
524515
return "elevation"
525516

526-
# 2) Numeric heuristics on data range
527-
data = agg.values
528-
if not np.issubdtype(data.dtype, np.number):
517+
# Numeric fallback: sample only (never full compute)
518+
data = agg.data
519+
try:
520+
vmin, vmax = _sample_windows_min_max(data, max_window_elems=65536, windows=5)
521+
except Exception:
529522
return "unknown"
530523

531-
finite = np.isfinite(data)
532-
if not np.any(finite):
524+
if not np.isfinite(vmin) or not np.isfinite(vmax):
533525
return "unknown"
534526

535-
vmin = float(data[finite].min())
536-
vmax = float(data[finite].max())
537527
span = vmax - vmin
538528

539-
# Elevation-like: tens–thousands of units, typical DEM ranges.
529+
# Elevation-ish heuristic
540530
if 10.0 <= span <= 20000.0 and vmin > -500.0:
541531
return "elevation"
542532

543-
# Angle-like: often 0–360, -180–180, or small (-pi, pi)
544-
if -360.0 <= vmin <= 360.0 and -360.0 <= vmax <= 360.0:
545-
# If the span is not huge, treat as angle-ish
546-
if span <= 720.0:
547-
return "angle"
533+
# Angle-ish heuristic
534+
if -360.0 <= vmin <= 360.0 and -360.0 <= vmax <= 360.0 and span <= 720.0:
535+
return "angle"
548536

549537
return "unknown"
550538

539+
551540
def warn_if_unit_mismatch(agg: xr.DataArray) -> None:
552541
"""
553542
Heuristic check for horizontal vs vertical unit mismatch.
@@ -599,3 +588,106 @@ def warn_if_unit_mismatch(agg: xr.DataArray) -> None:
599588
"meter-based coordinates before calling `slope`.",
600589
UserWarning,
601590
)
591+
592+
593+
def _to_float_scalar(x) -> float:
594+
"""Convert numpy/cupy scalar or 0-d array to python float safely."""
595+
if cupy is not None:
596+
# cupy.ndarray scalar
597+
if isinstance(x, cupy.ndarray):
598+
return float(cupy.asnumpy(x).item())
599+
# cupy scalar type
600+
if x.__class__.__module__.startswith("cupy") and hasattr(x, "item"):
601+
return float(x.item())
602+
603+
if hasattr(x, "item"):
604+
return float(x.item())
605+
return float(x)
606+
607+
608+
def _sample_windows_min_max(
609+
data,
610+
*,
611+
max_window_elems: int = 65536, # e.g. 256x256
612+
windows: int = 5, # corners + center default
613+
) -> tuple[float, float]:
614+
"""
615+
Estimate (nanmin, nanmax) from a small sample of windows.
616+
617+
Works for numpy, cupy, dask+numpy, dask+cupy. Only computes on the sampled
618+
windows, not the full array.
619+
"""
620+
# Normalize to last-2D sampling (y,x). For higher dims, sample first index.
621+
if hasattr(data, "ndim") and data.ndim >= 3:
622+
prefix = (0,) * (data.ndim - 2)
623+
else:
624+
prefix = ()
625+
626+
# Determine y/x sizes
627+
shape = data.shape
628+
ny, nx = shape[-2], shape[-1]
629+
630+
if ny == 0 or nx == 0:
631+
return np.nan, np.nan
632+
633+
# Choose a square-ish window size bounded by array shape
634+
w = int(np.sqrt(max_window_elems))
635+
w = max(1, min(w, ny, nx))
636+
637+
# Define window anchor positions: (top-left), (top-right), (bottom-left), (bottom-right), (center)
638+
anchors = [
639+
(0, 0),
640+
(0, max(0, nx - w)),
641+
(max(0, ny - w), 0),
642+
(max(0, ny - w), max(0, nx - w)),
643+
]
644+
if windows >= 5:
645+
anchors.append((max(0, ny // 2 - w // 2), max(0, nx // 2 - w // 2)))
646+
647+
# If windows > 5, sprinkle additional evenly-spaced anchors (optional)
648+
if windows > 5:
649+
extra = windows - 5
650+
ys = np.linspace(0, max(0, ny - w), extra + 2, dtype=int)[1:-1]
651+
xs = np.linspace(0, max(0, nx - w), extra + 2, dtype=int)[1:-1]
652+
for y0, x0 in zip(ys, xs):
653+
anchors.append((int(y0), int(x0)))
654+
655+
# Reduce min/max across sampled windows
656+
mins = []
657+
maxs = []
658+
659+
for y0, x0 in anchors:
660+
sl = prefix + (slice(y0, y0 + w), slice(x0, x0 + w))
661+
win = data[sl]
662+
663+
if da is not None and isinstance(win, da.Array):
664+
# Compute scalars only on this window
665+
mins.append(da.nanmin(win))
666+
maxs.append(da.nanmax(win))
667+
elif cupy is not None and isinstance(win, cupy.ndarray):
668+
mins.append(cupy.nanmin(win))
669+
maxs.append(cupy.nanmax(win))
670+
else:
671+
mins.append(np.nanmin(win))
672+
maxs.append(np.nanmax(win))
673+
674+
# Finalize: if dask, compute the scalar graph now (still tiny)
675+
if da is not None and any(isinstance(m, da.Array) for m in mins):
676+
mn = da.nanmin(da.stack(mins)).compute()
677+
mx = da.nanmax(da.stack(maxs)).compute()
678+
return _to_float_scalar(mn), _to_float_scalar(mx)
679+
680+
# If cupy scalars, convert safely
681+
if cupy is not None and (any(isinstance(m, cupy.ndarray) for m in mins) or
682+
any(getattr(m.__class__, "__module__", "").startswith("cupy") for m in mins)):
683+
mn = mins[0]
684+
mx = maxs[0]
685+
# reduce on device
686+
for m in mins[1:]:
687+
mn = cupy.minimum(mn, m)
688+
for m in maxs[1:]:
689+
mx = cupy.maximum(mx, m)
690+
return _to_float_scalar(mn), _to_float_scalar(mx)
691+
692+
# numpy scalars
693+
return float(np.nanmin(np.array(mins, dtype=float))), float(np.nanmax(np.array(maxs, dtype=float)))

0 commit comments

Comments
 (0)