@@ -1079,19 +1079,31 @@ def _run_head_tail_breaks(agg, module):
10791079
10801080def _run_dask_head_tail_breaks (agg ):
10811081 data = agg .data
1082- data_clean = da .where (da .isinf (data ), np .nan , data )
1082+ # Persist once so the iterative loop does not re-read from storage on
1083+ # every scan. Fuse the three reductions per iteration into a single
1084+ # dask.compute() to cut graph traversals from 3N+1 to N+1.
1085+ data_clean = da .where (da .isinf (data ), np .nan , data ).persist ()
10831086 bins = []
10841087 mask = da .isfinite (data_clean )
1088+ total_count_lazy = mask .sum ()
1089+ total_count = int (total_count_lazy .compute ())
1090+ if total_count == 0 :
1091+ max_v = float (da .nanmax (data_clean ).compute ())
1092+ return _bin (agg , np .array ([max_v ]), np .arange (1 ))
1093+
1094+ current_total = total_count
10851095 while True :
10861096 current = da .where (mask , data_clean , np .nan )
1087- mean_v = float (da .nanmean (current ).compute ())
1097+ new_mask = mask & (data_clean > da .nanmean (current ))
1098+ # Fuse mean and head-count into one graph evaluation.
1099+ mean_v , head_count = dask .compute (da .nanmean (current ), new_mask .sum ())
1100+ mean_v = float (mean_v )
1101+ head_count = int (head_count )
10881102 bins .append (mean_v )
1089- new_mask = mask & (data_clean > mean_v )
1090- head_count = int (new_mask .sum ().compute ())
1091- total_count = int (mask .sum ().compute ())
1092- if head_count == 0 or head_count / total_count > 0.40 :
1103+ if head_count == 0 or head_count / current_total > 0.40 :
10931104 break
10941105 mask = new_mask
1106+ current_total = head_count
10951107 max_v = float (da .nanmax (data_clean ).compute ())
10961108 bins .append (max_v )
10971109 bins = np .array (bins )
@@ -1366,6 +1378,22 @@ def maximum_breaks(agg: xr.DataArray,
13661378 attrs = agg .attrs )
13671379
13681380
1381+ _BOX_PLOT_DEFAULT_SAMPLE = 200_000
1382+
1383+
1384+ def _box_plot_bins_from_sample (finite_np , hinge , max_v ):
1385+ q1 = float (np .percentile (finite_np , 25 ))
1386+ q2 = float (np .percentile (finite_np , 50 ))
1387+ q3 = float (np .percentile (finite_np , 75 ))
1388+ iqr = q3 - q1
1389+ raw_bins = [q1 - hinge * iqr , q1 , q2 , q3 , q3 + hinge * iqr , max_v ]
1390+ bins = np .sort (np .unique (raw_bins ))
1391+ bins = bins [bins <= max_v ]
1392+ if bins [- 1 ] < max_v :
1393+ bins = np .append (bins , max_v )
1394+ return bins
1395+
1396+
13691397def _run_box_plot (agg , hinge , module ):
13701398 data = agg .data
13711399 data_clean = module .where (module .isinf (data ), np .nan , data )
@@ -1376,51 +1404,63 @@ def _run_box_plot(agg, hinge, module):
13761404 q2 = float (cupy .percentile (finite_data , 50 ).get ())
13771405 q3 = float (cupy .percentile (finite_data , 75 ).get ())
13781406 max_v = float (cupy .nanmax (finite_data ).get ())
1379- elif module == da :
1380- q1_l = da .percentile (finite_data , 25 )
1381- q2_l = da .percentile (finite_data , 50 )
1382- q3_l = da .percentile (finite_data , 75 )
1383- max_l = da .nanmax (data_clean )
1384- q1 , q2 , q3 , max_v = dask .compute (q1_l , q2_l , q3_l , max_l )
1385- q1 , q2 , q3 , max_v = q1 .item (), q2 .item (), q3 .item (), max_v .item ()
1386- else :
1387- q1 = float (np .percentile (finite_data , 25 ))
1388- q2 = float (np .percentile (finite_data , 50 ))
1389- q3 = float (np .percentile (finite_data , 75 ))
1390- max_v = float (np .nanmax (finite_data ))
1407+ iqr = q3 - q1
1408+ raw_bins = [q1 - hinge * iqr , q1 , q2 , q3 , q3 + hinge * iqr , max_v ]
1409+ bins = np .sort (np .unique (raw_bins ))
1410+ bins = bins [bins <= max_v ]
1411+ if bins [- 1 ] < max_v :
1412+ bins = np .append (bins , max_v )
1413+ else : # numpy
1414+ finite_np = np .asarray (finite_data )
1415+ max_v = float (np .nanmax (finite_np ))
1416+ bins = _box_plot_bins_from_sample (finite_np , hinge , max_v )
13911417
1392- iqr = q3 - q1
1393- raw_bins = [q1 - hinge * iqr , q1 , q2 , q3 , q3 + hinge * iqr , max_v ]
1394- bins = np .sort (np .unique (raw_bins ))
1395- # Remove bins above max (they'd create empty classes)
1396- bins = bins [bins <= max_v ]
1397- if bins [- 1 ] < max_v :
1398- bins = np .append (bins , max_v )
13991418 out = _bin (agg , bins , np .arange (len (bins )))
14001419 return out
14011420
14021421
1422+ def _run_dask_box_plot (agg , hinge ):
1423+ """Dask+numpy box_plot.
1424+
1425+ Avoids boolean fancy indexing on a dask array (which produces unknown
1426+ chunk sizes and forces a chunk-size compute pass). Samples the data
1427+ via the same seeded index sampler used by natural_breaks, then
1428+ computes percentiles on the finite portion of the sample in numpy.
1429+ """
1430+ data = agg .data
1431+ data_clean = da .where (da .isinf (data ), np .nan , data )
1432+ num_data = data_clean .size
1433+ num_sample = min (_BOX_PLOT_DEFAULT_SAMPLE , num_data )
1434+ sample_idx = _generate_sample_indices (num_data , num_sample )
1435+
1436+ sample_lazy = data_clean .ravel ()[sample_idx ]
1437+ max_lazy = da .nanmax (data_clean )
1438+ sample_np , max_v = dask .compute (sample_lazy , max_lazy )
1439+ sample_np = np .asarray (sample_np )
1440+ finite_np = sample_np [np .isfinite (sample_np )]
1441+ max_v = float (max_v )
1442+
1443+ bins = _box_plot_bins_from_sample (finite_np , hinge , max_v )
1444+ return _bin (agg , bins , np .arange (len (bins )))
1445+
1446+
14031447def _run_dask_cupy_box_plot (agg , hinge ):
1448+ """Dask+cupy box_plot: sample on-device, transfer only the sample."""
14041449 data = agg .data
1405- data_cpu = data .map_blocks (cupy .asnumpy , dtype = data .dtype , meta = np .array (()))
1406- data_clean = da .where (da .isinf (data_cpu ), np .nan , data_cpu )
1407- finite_data = data_clean [da .isfinite (data_clean )]
1450+ data_clean = da .where (da .isinf (data ), np .nan , data )
1451+ num_data = data_clean .size
1452+ num_sample = min (_BOX_PLOT_DEFAULT_SAMPLE , num_data )
1453+ sample_idx = _generate_sample_indices (num_data , num_sample )
14081454
1409- q1_l = da . percentile ( finite_data , 25 )
1410- q2_l = da .percentile ( finite_data , 50 )
1411- q3_l = da . percentile ( finite_data , 75 )
1412- max_l = da . nanmax ( data_clean )
1413- q1 , q2 , q3 , max_v = dask . compute ( q1_l , q2_l , q3_l , max_l )
1414- q1 , q2 , q3 , max_v = q1 . item (), q2 .item (), q3 . item (), max_v . item ( )
1455+ sample_lazy = data_clean . ravel ()[ sample_idx ]
1456+ max_lazy = da .nanmax ( data_clean )
1457+ sample_cp , max_v = dask . compute ( sample_lazy , max_lazy )
1458+ sample_np = cupy . asnumpy ( sample_cp )
1459+ finite_np = sample_np [ np . isfinite ( sample_np )]
1460+ max_v = float ( cupy . asnumpy ( max_v ) .item ()) if hasattr ( max_v , 'get' ) else float ( max_v )
14151461
1416- iqr = q3 - q1
1417- raw_bins = [q1 - hinge * iqr , q1 , q2 , q3 , q3 + hinge * iqr , max_v ]
1418- bins = np .sort (np .unique (raw_bins ))
1419- bins = bins [bins <= max_v ]
1420- if bins [- 1 ] < max_v :
1421- bins = np .append (bins , max_v )
1422- out = _bin (agg , bins , np .arange (len (bins )))
1423- return out
1462+ bins = _box_plot_bins_from_sample (finite_np , hinge , max_v )
1463+ return _bin (agg , bins , np .arange (len (bins )))
14241464
14251465
14261466@supports_dataset
@@ -1459,7 +1499,7 @@ def box_plot(agg: xr.DataArray,
14591499
14601500 mapper = ArrayTypeFunctionMapping (
14611501 numpy_func = lambda * args : _run_box_plot (* args , module = np ),
1462- dask_func = lambda * args : _run_box_plot ( * args , module = da ) ,
1502+ dask_func = _run_dask_box_plot ,
14631503 cupy_func = lambda * args : _run_box_plot (* args , module = cupy ),
14641504 dask_cupy_func = _run_dask_cupy_box_plot ,
14651505 )
0 commit comments