@@ -393,7 +393,8 @@ def reclassify(agg: xr.DataArray,
393393 attrs = agg .attrs )
394394
395395
396- def _run_quantile (data , k , module ):
396+ def _run_quantile (data , num_sample , k , module ):
397+ # num_sample ignored for in-memory backends
397398 w = 100.0 / k
398399 p = module .arange (w , 100 + w , w )
399400
@@ -405,27 +406,47 @@ def _run_quantile(data, k, module):
405406 return q
406407
407408
408- def _run_dask_cupy_quantile (data , k ):
409- # Convert dask+cupy chunks to numpy one at a time via map_blocks,
410- # then use dask's streaming approximate percentile (no full materialization).
409+ def _run_dask_quantile (data , num_sample , k ):
410+ # Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
411+ # unknown dask chunk sizes. Use sampling via indexed access to avoid
412+ # materialising the full array (#884).
413+ w = 100.0 / k
414+ p = np .arange (w , 100 + w , w )
415+ if p [- 1 ] > 100.0 :
416+ p [- 1 ] = 100.0
417+ clean = da .where (da .isinf (data ), np .nan , data )
418+ num_data = data .size
419+ if num_sample is None or num_sample >= num_data :
420+ num_sample = num_data
421+ sample_idx = _generate_sample_indices (num_data , num_sample )
422+ values = np .asarray (clean .ravel ()[sample_idx ].compute ())
423+ values = values [np .isfinite (values )]
424+ q = np .percentile (values , p )
425+ q = np .unique (q )
426+ return q
427+
428+
429+ def _run_dask_cupy_quantile (data , num_sample , k ):
430+ # Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
411431 data_cpu = data .map_blocks (cupy .asnumpy , dtype = data .dtype , meta = np .array (()))
412- return _run_quantile (data_cpu , k , da )
432+ return _run_dask_quantile (data_cpu , num_sample , k )
413433
414434
415- def _quantile (agg , k ):
435+ def _quantile (agg , num_sample , k ):
416436 mapper = ArrayTypeFunctionMapping (
417437 numpy_func = lambda * args : _run_quantile (* args , module = np ),
418- dask_func = lambda * args : _run_quantile ( * args , module = da ) ,
438+ dask_func = _run_dask_quantile ,
419439 cupy_func = lambda * args : _run_quantile (* args , module = cupy ),
420440 dask_cupy_func = _run_dask_cupy_quantile
421441 )
422- out = mapper (agg )(agg .data , k )
442+ out = mapper (agg )(agg .data , num_sample , k )
423443 return out
424444
425445
426446@supports_dataset
427447def quantile (agg : xr .DataArray ,
428448 k : int = 4 ,
449+ num_sample : Optional [int ] = 20_000 ,
429450 name : Optional [str ] = 'quantile' ) -> xr .DataArray :
430451 """
431452 Reclassifies data for array `agg` into new values based on quantile
@@ -438,6 +459,12 @@ def quantile(agg: xr.DataArray,
438459 of values to be reclassified.
439460 k : int, default=4
440461 Number of quantiles to be produced.
462+ num_sample : int or None, default=20000
463+ Number of sample data points used to compute percentile
464+ breakpoints. For dask-backed arrays the sample is drawn
465+ lazily to avoid materialising the entire array into RAM.
466+ ``None`` means use all data (safe for numpy/cupy,
467+ automatically capped for dask).
441468 name : str, default='quantile'
442469 Name of the output aggregate array.
443470
@@ -489,7 +516,7 @@ def quantile(agg: xr.DataArray,
489516 res: (10.0, 10.0)
490517 """
491518
492- q = _quantile (agg , k )
519+ q = _quantile (agg , num_sample , k )
493520 k_q = q .shape [0 ]
494521 if k_q < k :
495522 print ("Quantile Warning: Not enough unique values"
@@ -695,16 +722,11 @@ def _run_dask_natural_break(agg, num_sample, k):
695722 max_data = float (da .nanmax (da .where (da .isinf (data ), np .nan , data )).compute ())
696723
697724 num_data = data .size
698- if num_sample is not None and num_sample < num_data :
699- # Sample lazily from dask array; only materialize the sample
700- sample_idx = _generate_sample_indices (num_data , num_sample )
701- sample_data_np = np .asarray (data .ravel ()[sample_idx ].compute ())
702- bins , uvk = _compute_natural_break_bins (
703- sample_data_np , None , k , max_data )
704- else :
705- data_flat_np = np .asarray (data .ravel ().compute ())
706- bins , uvk = _compute_natural_break_bins (
707- data_flat_np , None , k , max_data )
725+ if num_sample is None or num_sample >= num_data :
726+ num_sample = num_data # cap: still uses indexed access, never .compute() all
727+ sample_idx = _generate_sample_indices (num_data , num_sample )
728+ sample_data_np = np .asarray (data .ravel ()[sample_idx ].compute ())
729+ bins , uvk = _compute_natural_break_bins (sample_data_np , None , k , max_data )
708730
709731 out = _bin (agg , bins , np .arange (uvk ))
710732 return out
@@ -717,17 +739,12 @@ def _run_dask_cupy_natural_break(agg, num_sample, k):
717739 max_data = float (da .nanmax (da .where (da .isinf (data ), np .nan , data )).compute ().item ())
718740
719741 num_data = data .size
720- if num_sample is not None and num_sample < num_data :
721- # Sample lazily from dask array; only materialize the sample
722- sample_idx = _generate_sample_indices (num_data , num_sample )
723- sample_data = data .ravel ()[sample_idx ].compute ()
724- sample_data_np = cupy .asnumpy (sample_data )
725- bins , uvk = _compute_natural_break_bins (
726- sample_data_np , None , k , max_data )
727- else :
728- data_flat_np = cupy .asnumpy (data .ravel ().compute ())
729- bins , uvk = _compute_natural_break_bins (
730- data_flat_np , None , k , max_data )
742+ if num_sample is None or num_sample >= num_data :
743+ num_sample = num_data # cap: still uses indexed access, never .compute() all
744+ sample_idx = _generate_sample_indices (num_data , num_sample )
745+ sample_data = data .ravel ()[sample_idx ].compute ()
746+ sample_data_np = cupy .asnumpy (sample_data )
747+ bins , uvk = _compute_natural_break_bins (sample_data_np , None , k , max_data )
731748
732749 out = _bin (agg , bins , np .arange (uvk ))
733750 return out
@@ -1109,20 +1126,39 @@ def head_tail_breaks(agg: xr.DataArray,
11091126 attrs = agg .attrs )
11101127
11111128
1112- def _run_percentiles (data , pct , module ):
1129+ def _run_percentiles (data , num_sample , pct , module ):
1130+ # num_sample ignored for in-memory backends
11131131 q = module .percentile (data [module .isfinite (data )], pct )
11141132 q = module .unique (q )
11151133 return q
11161134
11171135
1118- def _run_dask_cupy_percentiles (data , pct ):
1136+ def _run_dask_percentiles (data , num_sample , pct ):
1137+ # Avoid boolean fancy indexing (data[da.isfinite(data)]) which creates
1138+ # unknown dask chunk sizes. Use sampling via indexed access to avoid
1139+ # materialising the full array (#884).
1140+ clean = da .where (da .isinf (data ), np .nan , data )
1141+ num_data = data .size
1142+ if num_sample is None or num_sample >= num_data :
1143+ num_sample = num_data
1144+ sample_idx = _generate_sample_indices (num_data , num_sample )
1145+ values = np .asarray (clean .ravel ()[sample_idx ].compute ())
1146+ values = values [np .isfinite (values )]
1147+ q = np .percentile (values , pct )
1148+ q = np .unique (q )
1149+ return q
1150+
1151+
1152+ def _run_dask_cupy_percentiles (data , num_sample , pct ):
1153+ # Convert dask+cupy chunks to numpy, then same safe path as dask (#884).
11191154 data_cpu = data .map_blocks (cupy .asnumpy , dtype = data .dtype , meta = np .array (()))
1120- return _run_percentiles (data_cpu , pct , da )
1155+ return _run_dask_percentiles (data_cpu , num_sample , pct )
11211156
11221157
11231158@supports_dataset
11241159def percentiles (agg : xr .DataArray ,
11251160 pct : Optional [List ] = None ,
1161+ num_sample : Optional [int ] = 20_000 ,
11261162 name : Optional [str ] = 'percentiles' ) -> xr .DataArray :
11271163 """
11281164 Classify data based on percentile breakpoints.
@@ -1134,6 +1170,12 @@ def percentiles(agg: xr.DataArray,
11341170 of values to be classified.
11351171 pct : list of float, default=[1, 10, 50, 90, 99]
11361172 Percentile values to use as breakpoints.
1173+ num_sample : int or None, default=20000
1174+ Number of sample data points used to compute percentile
1175+ breakpoints. For dask-backed arrays the sample is drawn
1176+ lazily to avoid materialising the entire array into RAM.
1177+ ``None`` means use all data (safe for numpy/cupy,
1178+ automatically capped for dask).
11371179 name : str, default='percentiles'
11381180 Name of output aggregate array.
11391181
@@ -1154,11 +1196,11 @@ def percentiles(agg: xr.DataArray,
11541196
11551197 mapper = ArrayTypeFunctionMapping (
11561198 numpy_func = lambda * args : _run_percentiles (* args , module = np ),
1157- dask_func = lambda * args : _run_percentiles ( * args , module = da ) ,
1199+ dask_func = _run_dask_percentiles ,
11581200 cupy_func = lambda * args : _run_percentiles (* args , module = cupy ),
11591201 dask_cupy_func = _run_dask_cupy_percentiles ,
11601202 )
1161- q = mapper (agg )(agg .data , pct )
1203+ q = mapper (agg )(agg .data , num_sample , pct )
11621204
11631205 # Materialize bin edges to numpy
11641206 if hasattr (q , 'compute' ):
@@ -1204,7 +1246,8 @@ def _compute_maximum_break_bins(values_np, k):
12041246 return bins
12051247
12061248
1207- def _run_maximum_breaks (agg , k , module ):
1249+ def _run_maximum_breaks (agg , num_sample , k , module ):
1250+ # num_sample ignored for in-memory backends
12081251 data = agg .data
12091252 if module == cupy :
12101253 finite = data [cupy .isfinite (data )]
@@ -1218,21 +1261,29 @@ def _run_maximum_breaks(agg, k, module):
12181261 return out
12191262
12201263
1221- def _run_dask_maximum_breaks (agg , k ):
1264+ def _run_dask_maximum_breaks (agg , num_sample , k ):
12221265 data = agg .data
12231266 data_clean = da .where (da .isinf (data ), np .nan , data )
1224- values_np = np .asarray (data_clean .ravel ().compute ())
1267+ num_data = data .size
1268+ if num_sample is None or num_sample >= num_data :
1269+ num_sample = num_data
1270+ sample_idx = _generate_sample_indices (num_data , num_sample )
1271+ values_np = np .asarray (data_clean .ravel ()[sample_idx ].compute ())
12251272 values_np = values_np [np .isfinite (values_np )]
12261273 bins = _compute_maximum_break_bins (values_np , k )
12271274 out = _bin (agg , bins , np .arange (len (bins )))
12281275 return out
12291276
12301277
1231- def _run_dask_cupy_maximum_breaks (agg , k ):
1278+ def _run_dask_cupy_maximum_breaks (agg , num_sample , k ):
12321279 data = agg .data
12331280 data_clean = da .where (da .isinf (data ), np .nan , data )
12341281 data_cpu = data_clean .map_blocks (cupy .asnumpy , dtype = data .dtype , meta = np .array (()))
1235- values_np = np .asarray (data_cpu .ravel ().compute ())
1282+ num_data = data .size
1283+ if num_sample is None or num_sample >= num_data :
1284+ num_sample = num_data
1285+ sample_idx = _generate_sample_indices (num_data , num_sample )
1286+ values_np = np .asarray (data_cpu .ravel ()[sample_idx ].compute ())
12361287 values_np = values_np [np .isfinite (values_np )]
12371288 bins = _compute_maximum_break_bins (values_np , k )
12381289 out = _bin (agg , bins , np .arange (len (bins )))
@@ -1242,6 +1293,7 @@ def _run_dask_cupy_maximum_breaks(agg, k):
12421293@supports_dataset
12431294def maximum_breaks (agg : xr .DataArray ,
12441295 k : int = 5 ,
1296+ num_sample : Optional [int ] = 20_000 ,
12451297 name : Optional [str ] = 'maximum_breaks' ) -> xr .DataArray :
12461298 """
12471299 Classify data using the Maximum Breaks algorithm.
@@ -1256,6 +1308,11 @@ def maximum_breaks(agg: xr.DataArray,
12561308 of values to be classified.
12571309 k : int, default=5
12581310 Number of classes to be produced.
1311+ num_sample : int or None, default=20000
1312+ Number of sample data points used to fit the model.
1313+ For dask-backed arrays the sample is drawn lazily to avoid
1314+ materialising the entire array into RAM. ``None`` means use
1315+ all data (safe for numpy/cupy, automatically capped for dask).
12591316 name : str, default='maximum_breaks'
12601317 Name of output aggregate array.
12611318
@@ -1277,7 +1334,7 @@ def maximum_breaks(agg: xr.DataArray,
12771334 cupy_func = lambda * args : _run_maximum_breaks (* args , module = cupy ),
12781335 dask_cupy_func = _run_dask_cupy_maximum_breaks ,
12791336 )
1280- out = mapper (agg )(agg , k )
1337+ out = mapper (agg )(agg , num_sample , k )
12811338 return xr .DataArray (out ,
12821339 name = name ,
12831340 dims = agg .dims ,
0 commit comments