@@ -15,8 +15,10 @@ class cupy(object):
1515 ndarray = False
1616
1717try :
18+ import dask
1819 import dask .array as da
1920except ImportError :
21+ dask = None
2022 da = None
2123
2224import numba as nb
@@ -232,9 +234,8 @@ def _run_gpu_bin(data, bins, new_values, out):
232234
233235
234236def _run_cupy_bin (data , bins , new_values ):
235- # replace inf by nan to avoid classify these values as we want to treat them as outliers
236- data = cupy .where (data == cupy .inf , cupy .nan , data )
237- data = cupy .where (data == - cupy .inf , cupy .nan , data )
237+ # replace ±inf with nan in a single pass to avoid classifying outliers
238+ data = cupy .where (cupy .isinf (data ), cupy .nan , data )
238239
239240 bins_cupy = cupy .asarray (bins )
240241 new_values_cupy = cupy .asarray (new_values )
@@ -679,7 +680,9 @@ def _generate_sample_indices(num_data, num_sample, seed=1234567890):
679680
680681def _run_dask_natural_break (agg , num_sample , k ):
681682 data = agg .data
682- max_data = float (da .max (data [da .isfinite (data )]).compute ())
683+ # Avoid boolean fancy indexing which flattens dask arrays and
684+ # produces chunks of unknown size; use element-wise where instead
685+ max_data = float (da .nanmax (da .where (da .isinf (data ), np .nan , data )).compute ())
683686
684687 num_data = data .size
685688 if num_sample is not None and num_sample < num_data :
@@ -699,7 +702,9 @@ def _run_dask_natural_break(agg, num_sample, k):
699702
700703def _run_dask_cupy_natural_break (agg , num_sample , k ):
701704 data = agg .data
702- max_data = float (da .max (data [da .isfinite (data )]).compute ().item ())
705+ # Avoid boolean fancy indexing which flattens dask arrays and
706+ # produces chunks of unknown size; use element-wise where instead
707+ max_data = float (da .nanmax (da .where (da .isinf (data ), np .nan , data )).compute ().item ())
703708
704709 num_data = data .size
705710 if num_sample is not None and num_sample < num_data :
@@ -817,57 +822,28 @@ def natural_breaks(agg: xr.DataArray,
817822
818823
819824def _run_equal_interval (agg , k , module ):
820- data = agg .data .ravel ()
821- if module == cupy :
822- nan = cupy .nan
823- inf = cupy .inf
824- else :
825- nan = np .nan
826- inf = np .inf
825+ data = agg .data
827826
828- data = module . where ( data == inf , nan , data )
829- data = module .where (data == - inf , nan , data )
827+ # Replace ± inf with nan in a single pass (no ravel needed )
828+ data_clean = module .where (module . isinf ( data ), np . nan , data )
830829
831- max_data = module .nanmax ( data )
832- min_data = module .nanmin ( data )
830+ min_lazy = module .nanmin ( data_clean )
831+ max_lazy = module .nanmax ( data_clean )
833832
834833 if module == cupy :
835- min_data = min_data .get ()
836- max_data = max_data .get ()
837-
838- if module == da :
839- min_data = min_data .compute ()
840- max_data = max_data .compute ()
841-
842- width = (max_data - min_data ) * 1.0 / k
843- cuts = module .arange (min_data + width , max_data + width , width )
844- l_cuts = cuts .shape [0 ]
845- if l_cuts > k :
846- # handle overshooting
847- cuts = cuts [0 :k ]
848-
849- if module == da :
850- # work around to assign cuts[-1] = max_data
851- bins = da .concatenate ([cuts [:k - 1 ], [max_data ]])
852- out = _bin (agg , bins , np .arange (l_cuts ))
834+ min_data = float (min_lazy .get ())
835+ max_data = float (max_lazy .get ())
836+ elif module == da :
837+ # Compute both in a single pass over the data
838+ min_data , max_data = dask .compute (min_lazy , max_lazy )
839+ min_data = float (min_data )
840+ max_data = float (max_data )
853841 else :
854- cuts [- 1 ] = max_data
855- out = _bin (agg , cuts , np .arange (l_cuts ))
856-
857- return out
858-
859-
860- def _run_dask_cupy_equal_interval (agg , k ):
861- data = agg .data .ravel ()
862-
863- # replace inf with nan
864- data = da .where (data == np .inf , np .nan , data )
865- data = da .where (data == - np .inf , np .nan , data )
866-
867- min_data = float (da .nanmin (data ).compute ().item ())
868- max_data = float (da .nanmax (data ).compute ().item ())
842+ min_data = float (min_lazy )
843+ max_data = float (max_lazy )
869844
870- width = (max_data - min_data ) * 1.0 / k
845+ width = (max_data - min_data ) / k
846+ # Build cuts as numpy — only k elements, no need for dask/cupy overhead
871847 cuts = np .arange (min_data + width , max_data + width , width )
872848 l_cuts = cuts .shape [0 ]
873849 if l_cuts > k :
@@ -938,7 +914,7 @@ def equal_interval(agg: xr.DataArray,
938914 numpy_func = lambda * args : _run_equal_interval (* args , module = np ),
939915 dask_func = lambda * args : _run_equal_interval (* args , module = da ),
940916 cupy_func = lambda * args : _run_equal_interval (* args , module = cupy ),
941- dask_cupy_func = _run_dask_cupy_equal_interval
917+ dask_cupy_func = lambda * args : _run_equal_interval ( * args , module = da )
942918 )
943919 out = mapper (agg )(agg , k )
944920 return xr .DataArray (out ,
0 commit comments