@@ -398,10 +398,18 @@ def _run_quantile(data, k, module):
398398
399399
400400def _run_dask_cupy_quantile (data , k ):
401- msg = 'Currently percentile calculation has not' \
402- 'been supported for Dask array backed by CuPy.' \
403- 'See issue at https://github.com/dask/dask/issues/6942'
404- raise NotImplementedError (msg )
401+ w = 100.0 / k
402+ p = np .arange (w , 100 + w , w )
403+ if p [- 1 ] > 100.0 :
404+ p [- 1 ] = 100.0
405+
406+ finite_mask = da .isfinite (data )
407+ finite_data = data [finite_mask ].compute ()
408+ # transfer from GPU to CPU
409+ finite_data_np = cupy .asnumpy (finite_data )
410+ q = np .percentile (finite_data_np , p )
411+ q = np .unique (q )
412+ return q
405413
406414
407415def _quantile (agg , k ):
@@ -575,25 +583,28 @@ def _run_jenks(data, n_classes):
575583 return kclass
576584
577585
578- def _run_natural_break (agg , num_sample , k ):
579- data = agg .data
580- num_data = data .size
581- max_data = np .max (data [np .isfinite (data )])
586+ def _compute_natural_break_bins (data_flat_np , num_sample , k , max_data ):
587+ """Shared helper: compute natural break bins from a flat numpy array.
588+
589+ Returns (bins, uvk) where bins is a numpy array of bin edges
590+ and uvk is the number of unique values.
591+ """
592+ num_data = data_flat_np .size
582593
583594 if num_sample is not None and num_sample < num_data :
584595 # randomly select sample from the whole dataset
585596 # create a pseudo random number generator
586- # Note: cupy and nupy generate different random numbers
597+ # Note: cupy and numpy generate different random numbers
587598 # use numpy.random to ensure the same result
588599 generator = np .random .RandomState (1234567890 )
589600 idx = np .linspace (
590- 0 , data . size , data . size , endpoint = False , dtype = np .uint32
601+ 0 , num_data , num_data , endpoint = False , dtype = np .uint32
591602 )
592603 generator .shuffle (idx )
593604 sample_idx = idx [:num_sample ]
594- sample_data = data . flatten () [sample_idx ]
605+ sample_data = data_flat_np [sample_idx ]
595606 else :
596- sample_data = data . flatten ()
607+ sample_data = data_flat_np
597608
598609 # warning if number of total data points to fit the model bigger than 40k
599610 if sample_data .size >= 40000 :
@@ -627,6 +638,45 @@ def _run_natural_break(agg, num_sample, k):
627638 bins = np .array (centroids [1 :])
628639 bins [- 1 ] = max_data
629640
641+ return bins , uvk
642+
643+
644+ def _run_natural_break (agg , num_sample , k ):
645+ data = agg .data
646+ max_data = float (np .max (data [np .isfinite (data )]))
647+ data_flat_np = data .flatten ()
648+
649+ bins , uvk = _compute_natural_break_bins (data_flat_np , num_sample , k , max_data )
650+ out = _bin (agg , bins , np .arange (uvk ))
651+ return out
652+
653+
654+ def _run_cupy_natural_break (agg , num_sample , k ):
655+ data = agg .data
656+ max_data = float (cupy .max (data [cupy .isfinite (data )]).get ())
657+ data_flat_np = cupy .asnumpy (data .ravel ())
658+
659+ bins , uvk = _compute_natural_break_bins (data_flat_np , num_sample , k , max_data )
660+ out = _bin (agg , bins , np .arange (uvk ))
661+ return out
662+
663+
664+ def _run_dask_natural_break (agg , num_sample , k ):
665+ data = agg .data
666+ max_data = float (da .max (data [da .isfinite (data )]).compute ())
667+ data_flat_np = np .asarray (data .ravel ().compute ())
668+
669+ bins , uvk = _compute_natural_break_bins (data_flat_np , num_sample , k , max_data )
670+ out = _bin (agg , bins , np .arange (uvk ))
671+ return out
672+
673+
674+ def _run_dask_cupy_natural_break (agg , num_sample , k ):
675+ data = agg .data
676+ max_data = float (da .max (data [da .isfinite (data )]).compute ().item ())
677+ data_flat_np = cupy .asnumpy (data .ravel ().compute ())
678+
679+ bins , uvk = _compute_natural_break_bins (data_flat_np , num_sample , k , max_data )
630680 out = _bin (agg , bins , np .arange (uvk ))
631681 return out
632682
@@ -644,7 +694,8 @@ def natural_breaks(agg: xr.DataArray,
644694 Parameters
645695 ----------
646696 agg : xarray.DataArray
647- 2D NumPy DataArray of values to be reclassified.
697+ 2D NumPy, CuPy, NumPy-backed Dask, or CuPy-backed Dask array
698+ of values to be reclassified.
648699 num_sample : int, default=20000
649700 Number of sample data points used to fit the model.
650701 Natural Breaks (Jenks) classification is indeed O(n²) complexity,
@@ -715,13 +766,10 @@ def natural_breaks(agg: xr.DataArray,
715766 """
716767
717768 mapper = ArrayTypeFunctionMapping (
718- numpy_func = lambda * args : _run_natural_break (* args ),
719- dask_func = lambda * args : not_implemented_func (
720- * args , messages = 'natural_breaks() does not support dask with numpy backed DataArray.' ), # noqa
721- cupy_func = lambda * args : not_implemented_func (
722- * args , messages = 'natural_breaks() does not support cupy backed DataArray.' ), # noqa
723- dask_cupy_func = lambda * args : not_implemented_func (
724- * args , messages = 'natural_breaks() does not support dask with cupy backed DataArray.' ), # noqa
769+ numpy_func = _run_natural_break ,
770+ dask_func = _run_dask_natural_break ,
771+ cupy_func = _run_cupy_natural_break ,
772+ dask_cupy_func = _run_dask_cupy_natural_break ,
725773 )
726774 out = mapper (agg )(agg , num_sample , k )
727775 return xr .DataArray (out ,
@@ -772,6 +820,27 @@ def _run_equal_interval(agg, k, module):
772820 return out
773821
774822
823+ def _run_dask_cupy_equal_interval (agg , k ):
824+ data = agg .data .ravel ()
825+
826+ # replace inf with nan
827+ data = da .where (data == np .inf , np .nan , data )
828+ data = da .where (data == - np .inf , np .nan , data )
829+
830+ min_data = float (da .nanmin (data ).compute ().item ())
831+ max_data = float (da .nanmax (data ).compute ().item ())
832+
833+ width = (max_data - min_data ) * 1.0 / k
834+ cuts = np .arange (min_data + width , max_data + width , width )
835+ l_cuts = cuts .shape [0 ]
836+ if l_cuts > k :
837+ cuts = cuts [0 :k ]
838+
839+ cuts [- 1 ] = max_data
840+ out = _bin (agg , cuts , np .arange (l_cuts ))
841+ return out
842+
843+
775844def equal_interval (agg : xr .DataArray ,
776845 k : int = 5 ,
777846 name : Optional [str ] = 'equal_interval' ) -> xr .DataArray :
@@ -832,8 +901,7 @@ def equal_interval(agg: xr.DataArray,
832901 numpy_func = lambda * args : _run_equal_interval (* args , module = np ),
833902 dask_func = lambda * args : _run_equal_interval (* args , module = da ),
834903 cupy_func = lambda * args : _run_equal_interval (* args , module = cupy ),
835- dask_cupy_func = lambda * args : not_implemented_func (
836- * args , messages = 'equal_interval() does support dask with cupy backed DataArray.' ), # noqa
904+ dask_cupy_func = _run_dask_cupy_equal_interval
837905 )
838906 out = mapper (agg )(agg , k )
839907 return xr .DataArray (out ,
0 commit comments