99
1010from rapidsmpf .shuffler import Shuffler
1111
12+ from cudf_polars .dsl .ir import Distinct , GroupBy
1213from cudf_polars .dsl .traversal import traversal
1314from cudf_polars .experimental .join import Join
1415from cudf_polars .experimental .repartition import Repartition
1819 from types import TracebackType
1920
2021 from cudf_polars .dsl .ir import IR
22+ from cudf_polars .utils .config import ConfigOptions
2123
2224
2325# Set of available collective IDs
@@ -50,11 +52,13 @@ class ReserveOpIDs:
5052 ----------
5153 ir : IR
5254 The root IR node of the pipeline.
55+ config_options : ConfigOptions, optional
56+ Configuration options (needed for dynamic planning).
5357
5458 Notes
5559 -----
5660 This context manager:
57- 1. Identifies all Shuffle nodes in the IR
61+ 1. Identifies all IR nodes that may require collective operations
5862 2. Reserves collective IDs from the vacancy pool
5963 3. Creates a mapping from IR nodes to their reserved IDs
6064 4. Releases all IDs back to the pool on __exit__
@@ -63,12 +67,24 @@ class ReserveOpIDs:
6367 (e.g., for metadata gathering, shuffling multiple sides of a join).
6468 """
6569
66- def __init__ (self , ir : IR ):
70+ def __init__ (self , ir : IR , config_options : ConfigOptions | None = None ):
71+ self .config_options = config_options
72+
73+ # Check if dynamic planning is enabled
74+ self .dynamic_planning_enabled = (
75+ config_options is not None
76+ and config_options .executor .name == "streaming"
77+ and config_options .executor .dynamic_planning is not None
78+ )
79+
6780 # Find all collective IR nodes.
81+ collective_types : tuple [type , ...] = (Shuffle , Join , Repartition )
82+ if self .dynamic_planning_enabled :
83+ # Include GroupBy and Distinct when dynamic planning is enabled
84+ collective_types = (Shuffle , Join , Repartition , GroupBy , Distinct )
85+
6886 self .collective_nodes : list [IR ] = [
69- node
70- for node in traversal ([ir ])
71- if isinstance (node , (Shuffle , Join , Repartition ))
87+ node for node in traversal ([ir ]) if isinstance (node , collective_types )
7288 ]
7389 self .collective_id_map : dict [IR , list [int ]] = {}
7490
@@ -85,7 +101,21 @@ def __enter__(self) -> dict[IR, list[int]]:
85101 """
86102 # Reserve IDs and map nodes to a list of IDs
87103 for node in self .collective_nodes :
88- self .collective_id_map [node ] = [_get_new_collective_id ()]
104+ if isinstance (node , (GroupBy , Distinct )) and self .dynamic_planning_enabled :
105+ # GroupBy/Distinct need 2 IDs: one for size allgather, one for shuffle
106+ self .collective_id_map [node ] = [
107+ _get_new_collective_id (),
108+ _get_new_collective_id (),
109+ ]
110+ elif isinstance (node , Join ) and self .dynamic_planning_enabled :
111+ # Join needs 3 IDs: size allgather, left shuffle/bcast, right shuffle/bcast
112+ self .collective_id_map [node ] = [
113+ _get_new_collective_id (),
114+ _get_new_collective_id (),
115+ _get_new_collective_id (),
116+ ]
117+ else :
118+ self .collective_id_map [node ] = [_get_new_collective_id ()]
89119
90120 return self .collective_id_map
91121
0 commit comments