@@ -695,16 +695,11 @@ def _run_dask_natural_break(agg, num_sample, k):
695695 max_data = float (da .nanmax (da .where (da .isinf (data ), np .nan , data )).compute ())
696696
697697 num_data = data .size
698- if num_sample is not None and num_sample < num_data :
699- # Sample lazily from dask array; only materialize the sample
700- sample_idx = _generate_sample_indices (num_data , num_sample )
701- sample_data_np = np .asarray (data .ravel ()[sample_idx ].compute ())
702- bins , uvk = _compute_natural_break_bins (
703- sample_data_np , None , k , max_data )
704- else :
705- data_flat_np = np .asarray (data .ravel ().compute ())
706- bins , uvk = _compute_natural_break_bins (
707- data_flat_np , None , k , max_data )
698+ if num_sample is None or num_sample >= num_data :
699+ num_sample = num_data # cap: still uses indexed access, never .compute() all
700+ sample_idx = _generate_sample_indices (num_data , num_sample )
701+ sample_data_np = np .asarray (data .ravel ()[sample_idx ].compute ())
702+ bins , uvk = _compute_natural_break_bins (sample_data_np , None , k , max_data )
708703
709704 out = _bin (agg , bins , np .arange (uvk ))
710705 return out
@@ -717,17 +712,12 @@ def _run_dask_cupy_natural_break(agg, num_sample, k):
717712 max_data = float (da .nanmax (da .where (da .isinf (data ), np .nan , data )).compute ().item ())
718713
719714 num_data = data .size
720- if num_sample is not None and num_sample < num_data :
721- # Sample lazily from dask array; only materialize the sample
722- sample_idx = _generate_sample_indices (num_data , num_sample )
723- sample_data = data .ravel ()[sample_idx ].compute ()
724- sample_data_np = cupy .asnumpy (sample_data )
725- bins , uvk = _compute_natural_break_bins (
726- sample_data_np , None , k , max_data )
727- else :
728- data_flat_np = cupy .asnumpy (data .ravel ().compute ())
729- bins , uvk = _compute_natural_break_bins (
730- data_flat_np , None , k , max_data )
715+ if num_sample is None or num_sample >= num_data :
716+ num_sample = num_data # cap: still uses indexed access, never .compute() all
717+ sample_idx = _generate_sample_indices (num_data , num_sample )
718+ sample_data = data .ravel ()[sample_idx ].compute ()
719+ sample_data_np = cupy .asnumpy (sample_data )
720+ bins , uvk = _compute_natural_break_bins (sample_data_np , None , k , max_data )
731721
732722 out = _bin (agg , bins , np .arange (uvk ))
733723 return out
@@ -1204,7 +1194,8 @@ def _compute_maximum_break_bins(values_np, k):
12041194 return bins
12051195
12061196
1207- def _run_maximum_breaks (agg , k , module ):
1197+ def _run_maximum_breaks (agg , num_sample , k , module ):
1198+ # num_sample ignored for in-memory backends
12081199 data = agg .data
12091200 if module == cupy :
12101201 finite = data [cupy .isfinite (data )]
@@ -1218,21 +1209,29 @@ def _run_maximum_breaks(agg, k, module):
12181209 return out
12191210
12201211
1221- def _run_dask_maximum_breaks (agg , k ):
1212+ def _run_dask_maximum_breaks (agg , num_sample , k ):
12221213 data = agg .data
12231214 data_clean = da .where (da .isinf (data ), np .nan , data )
1224- values_np = np .asarray (data_clean .ravel ().compute ())
1215+ num_data = data .size
1216+ if num_sample is None or num_sample >= num_data :
1217+ num_sample = num_data
1218+ sample_idx = _generate_sample_indices (num_data , num_sample )
1219+ values_np = np .asarray (data_clean .ravel ()[sample_idx ].compute ())
12251220 values_np = values_np [np .isfinite (values_np )]
12261221 bins = _compute_maximum_break_bins (values_np , k )
12271222 out = _bin (agg , bins , np .arange (len (bins )))
12281223 return out
12291224
12301225
1231- def _run_dask_cupy_maximum_breaks (agg , k ):
1226+ def _run_dask_cupy_maximum_breaks (agg , num_sample , k ):
12321227 data = agg .data
12331228 data_clean = da .where (da .isinf (data ), np .nan , data )
12341229 data_cpu = data_clean .map_blocks (cupy .asnumpy , dtype = data .dtype , meta = np .array (()))
1235- values_np = np .asarray (data_cpu .ravel ().compute ())
1230+ num_data = data .size
1231+ if num_sample is None or num_sample >= num_data :
1232+ num_sample = num_data
1233+ sample_idx = _generate_sample_indices (num_data , num_sample )
1234+ values_np = np .asarray (data_cpu .ravel ()[sample_idx ].compute ())
12361235 values_np = values_np [np .isfinite (values_np )]
12371236 bins = _compute_maximum_break_bins (values_np , k )
12381237 out = _bin (agg , bins , np .arange (len (bins )))
@@ -1242,6 +1241,7 @@ def _run_dask_cupy_maximum_breaks(agg, k):
12421241@supports_dataset
12431242def maximum_breaks (agg : xr .DataArray ,
12441243 k : int = 5 ,
1244+ num_sample : Optional [int ] = 20_000 ,
12451245 name : Optional [str ] = 'maximum_breaks' ) -> xr .DataArray :
12461246 """
12471247 Classify data using the Maximum Breaks algorithm.
@@ -1256,6 +1256,11 @@ def maximum_breaks(agg: xr.DataArray,
12561256 of values to be classified.
12571257 k : int, default=5
12581258 Number of classes to be produced.
1259+ num_sample : int or None, default=20000
1260+ Number of sample data points used to fit the model.
1261+ For dask-backed arrays the sample is drawn lazily to avoid
1262+ materialising the entire array into RAM. ``None`` means use
1263+ all data (safe for numpy/cupy, automatically capped for dask).
12591264 name : str, default='maximum_breaks'
12601265 Name of output aggregate array.
12611266
@@ -1277,7 +1282,7 @@ def maximum_breaks(agg: xr.DataArray,
12771282 cupy_func = lambda * args : _run_maximum_breaks (* args , module = cupy ),
12781283 dask_cupy_func = _run_dask_cupy_maximum_breaks ,
12791284 )
1280- out = mapper (agg )(agg , k )
1285+ out = mapper (agg )(agg , num_sample , k )
12811286 return xr .DataArray (out ,
12821287 name = name ,
12831288 dims = agg .dims ,
0 commit comments