@@ -452,7 +452,6 @@ def _convert_color(c):
452452 return tf .Image (f (agg .data ))
453453
454454
455-
456455def _infer_coord_unit_type (coord : xr .DataArray , cellsize : float ) -> str :
457456 """
458457 Heuristic to classify a spatial coordinate axis as:
@@ -504,50 +503,40 @@ def _infer_coord_unit_type(coord: xr.DataArray, cellsize: float) -> str:
504503 return "unknown"
505504
506505
507- def _infer_vertical_unit_type (agg : xr .DataArray ) -> str :
508- """
509- Heuristic to classify the DataArray values as:
510- - 'elevation' (meters/feet etc)
511- - 'angle' (degrees/radians)
512- - 'unknown'
513- """
506+ def _infer_vertical_unit_type (agg ):
514507 units = str (agg .attrs .get ("units" , "" )).lower ()
515508
516- # 1) Explicit units
517- if any (k in units for k in ("degree" , "deg" )):
518- return "angle"
519- if "rad" in units :
509+ # Cheap / reliable first
510+ if any (k in units for k in ("degree" , "deg" )) or "rad" in units :
520511 return "angle"
521512 if units in ("m" , "meter" , "metre" , "meters" , "metres" ,
522513 "km" , "kilometer" , "kilometre" , "kilometers" , "kilometres" ,
523514 "ft" , "foot" , "feet" ):
524515 return "elevation"
525516
526- # 2) Numeric heuristics on data range
527- data = agg .values
528- if not np .issubdtype (data .dtype , np .number ):
517+ # Numeric fallback: sample only (never full compute)
518+ data = agg .data
519+ try :
520+ vmin , vmax = _sample_windows_min_max (data , max_window_elems = 65536 , windows = 5 )
521+ except Exception :
529522 return "unknown"
530523
531- finite = np .isfinite (data )
532- if not np .any (finite ):
524+ if not np .isfinite (vmin ) or not np .isfinite (vmax ):
533525 return "unknown"
534526
535- vmin = float (data [finite ].min ())
536- vmax = float (data [finite ].max ())
537527 span = vmax - vmin
538528
539- # Elevation-like: tens–thousands of units, typical DEM ranges.
529+ # Elevation-ish heuristic
540530 if 10.0 <= span <= 20000.0 and vmin > - 500.0 :
541531 return "elevation"
542532
543- # Angle-like: often 0–360, -180–180, or small (-pi, pi)
544- if - 360.0 <= vmin <= 360.0 and - 360.0 <= vmax <= 360.0 :
545- # If the span is not huge, treat as angle-ish
546- if span <= 720.0 :
547- return "angle"
533+ # Angle-ish heuristic
534+ if - 360.0 <= vmin <= 360.0 and - 360.0 <= vmax <= 360.0 and span <= 720.0 :
535+ return "angle"
548536
549537 return "unknown"
550538
539+
551540def warn_if_unit_mismatch (agg : xr .DataArray ) -> None :
552541 """
553542 Heuristic check for horizontal vs vertical unit mismatch.
@@ -599,3 +588,106 @@ def warn_if_unit_mismatch(agg: xr.DataArray) -> None:
599588 "meter-based coordinates before calling `slope`." ,
600589 UserWarning ,
601590 )
591+
592+
593+ def _to_float_scalar (x ) -> float :
594+ """Convert numpy/cupy scalar or 0-d array to python float safely."""
595+ if cupy is not None :
596+ # cupy.ndarray scalar
597+ if isinstance (x , cupy .ndarray ):
598+ return float (cupy .asnumpy (x ).item ())
599+ # cupy scalar type
600+ if x .__class__ .__module__ .startswith ("cupy" ) and hasattr (x , "item" ):
601+ return float (x .item ())
602+
603+ if hasattr (x , "item" ):
604+ return float (x .item ())
605+ return float (x )
606+
607+
608+ def _sample_windows_min_max (
609+ data ,
610+ * ,
611+ max_window_elems : int = 65536 , # e.g. 256x256
612+ windows : int = 5 , # corners + center default
613+ ) -> tuple [float , float ]:
614+ """
615+ Estimate (nanmin, nanmax) from a small sample of windows.
616+
617+ Works for numpy, cupy, dask+numpy, dask+cupy. Only computes on the sampled
618+ windows, not the full array.
619+ """
620+ # Normalize to last-2D sampling (y,x). For higher dims, sample first index.
621+ if hasattr (data , "ndim" ) and data .ndim >= 3 :
622+ prefix = (0 ,) * (data .ndim - 2 )
623+ else :
624+ prefix = ()
625+
626+ # Determine y/x sizes
627+ shape = data .shape
628+ ny , nx = shape [- 2 ], shape [- 1 ]
629+
630+ if ny == 0 or nx == 0 :
631+ return np .nan , np .nan
632+
633+ # Choose a square-ish window size bounded by array shape
634+ w = int (np .sqrt (max_window_elems ))
635+ w = max (1 , min (w , ny , nx ))
636+
637+ # Define window anchor positions: (top-left), (top-right), (bottom-left), (bottom-right), (center)
638+ anchors = [
639+ (0 , 0 ),
640+ (0 , max (0 , nx - w )),
641+ (max (0 , ny - w ), 0 ),
642+ (max (0 , ny - w ), max (0 , nx - w )),
643+ ]
644+ if windows >= 5 :
645+ anchors .append ((max (0 , ny // 2 - w // 2 ), max (0 , nx // 2 - w // 2 )))
646+
647+ # If windows > 5, sprinkle additional evenly-spaced anchors (optional)
648+ if windows > 5 :
649+ extra = windows - 5
650+ ys = np .linspace (0 , max (0 , ny - w ), extra + 2 , dtype = int )[1 :- 1 ]
651+ xs = np .linspace (0 , max (0 , nx - w ), extra + 2 , dtype = int )[1 :- 1 ]
652+ for y0 , x0 in zip (ys , xs ):
653+ anchors .append ((int (y0 ), int (x0 )))
654+
655+ # Reduce min/max across sampled windows
656+ mins = []
657+ maxs = []
658+
659+ for y0 , x0 in anchors :
660+ sl = prefix + (slice (y0 , y0 + w ), slice (x0 , x0 + w ))
661+ win = data [sl ]
662+
663+ if da is not None and isinstance (win , da .Array ):
664+ # Compute scalars only on this window
665+ mins .append (da .nanmin (win ))
666+ maxs .append (da .nanmax (win ))
667+ elif cupy is not None and isinstance (win , cupy .ndarray ):
668+ mins .append (cupy .nanmin (win ))
669+ maxs .append (cupy .nanmax (win ))
670+ else :
671+ mins .append (np .nanmin (win ))
672+ maxs .append (np .nanmax (win ))
673+
674+ # Finalize: if dask, compute the scalar graph now (still tiny)
675+ if da is not None and any (isinstance (m , da .Array ) for m in mins ):
676+ mn = da .nanmin (da .stack (mins )).compute ()
677+ mx = da .nanmax (da .stack (maxs )).compute ()
678+ return _to_float_scalar (mn ), _to_float_scalar (mx )
679+
680+ # If cupy scalars, convert safely
681+ if cupy is not None and (any (isinstance (m , cupy .ndarray ) for m in mins ) or
682+ any (getattr (m .__class__ , "__module__" , "" ).startswith ("cupy" ) for m in mins )):
683+ mn = mins [0 ]
684+ mx = maxs [0 ]
685+ # reduce on device
686+ for m in mins [1 :]:
687+ mn = cupy .minimum (mn , m )
688+ for m in maxs [1 :]:
689+ mx = cupy .maximum (mx , m )
690+ return _to_float_scalar (mn ), _to_float_scalar (mx )
691+
692+ # numpy scalars
693+ return float (np .nanmin (np .array (mins , dtype = float ))), float (np .nanmax (np .array (maxs , dtype = float )))
0 commit comments