Skip to content

Commit d139016

Browse files
authored
Reserve more collective IDs for "dynamic" IR nodes (rapidsai#21343)
- Contributes to rapidsai#20482 (Needed for dynamic GroupBy, Distinct, and Join support) Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Matthew Murray (https://github.com/Matt711) URL: rapidsai#21343
1 parent 33f3ed7 commit d139016

2 files changed

Lines changed: 37 additions & 7 deletions

File tree

  • python/cudf_polars/cudf_polars/experimental/rapidsmpf

python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/common.py

Lines changed: 36 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
from rapidsmpf.shuffler import Shuffler
1111

12+
from cudf_polars.dsl.ir import Distinct, GroupBy
1213
from cudf_polars.dsl.traversal import traversal
1314
from cudf_polars.experimental.join import Join
1415
from cudf_polars.experimental.repartition import Repartition
@@ -18,6 +19,7 @@
1819
from types import TracebackType
1920

2021
from cudf_polars.dsl.ir import IR
22+
from cudf_polars.utils.config import ConfigOptions
2123

2224

2325
# Set of available collective IDs
@@ -50,11 +52,13 @@ class ReserveOpIDs:
5052
----------
5153
ir : IR
5254
The root IR node of the pipeline.
55+
config_options : ConfigOptions, optional
56+
Configuration options (needed for dynamic planning).
5357
5458
Notes
5559
-----
5660
This context manager:
57-
1. Identifies all Shuffle nodes in the IR
61+
1. Identifies all IR nodes that may require collective operations
5862
2. Reserves collective IDs from the vacancy pool
5963
3. Creates a mapping from IR nodes to their reserved IDs
6064
4. Releases all IDs back to the pool on __exit__
@@ -63,12 +67,24 @@ class ReserveOpIDs:
6367
(e.g., for metadata gathering, shuffling multiple sides of a join).
6468
"""
6569

66-
def __init__(self, ir: IR):
70+
def __init__(self, ir: IR, config_options: ConfigOptions | None = None):
71+
self.config_options = config_options
72+
73+
# Check if dynamic planning is enabled
74+
self.dynamic_planning_enabled = (
75+
config_options is not None
76+
and config_options.executor.name == "streaming"
77+
and config_options.executor.dynamic_planning is not None
78+
)
79+
6780
# Find all collective IR nodes.
81+
collective_types: tuple[type, ...] = (Shuffle, Join, Repartition)
82+
if self.dynamic_planning_enabled:
83+
# Include GroupBy and Distinct when dynamic planning is enabled
84+
collective_types = (Shuffle, Join, Repartition, GroupBy, Distinct)
85+
6886
self.collective_nodes: list[IR] = [
69-
node
70-
for node in traversal([ir])
71-
if isinstance(node, (Shuffle, Join, Repartition))
87+
node for node in traversal([ir]) if isinstance(node, collective_types)
7288
]
7389
self.collective_id_map: dict[IR, list[int]] = {}
7490

@@ -85,7 +101,21 @@ def __enter__(self) -> dict[IR, list[int]]:
85101
"""
86102
# Reserve IDs and map nodes to a list of IDs
87103
for node in self.collective_nodes:
88-
self.collective_id_map[node] = [_get_new_collective_id()]
104+
if isinstance(node, (GroupBy, Distinct)) and self.dynamic_planning_enabled:
105+
# GroupBy/Distinct need 2 IDs: one for size allgather, one for shuffle
106+
self.collective_id_map[node] = [
107+
_get_new_collective_id(),
108+
_get_new_collective_id(),
109+
]
110+
elif isinstance(node, Join) and self.dynamic_planning_enabled:
111+
# Join needs 3 IDs: size allgather, left shuffle/bcast, right shuffle/bcast
112+
self.collective_id_map[node] = [
113+
_get_new_collective_id(),
114+
_get_new_collective_id(),
115+
_get_new_collective_id(),
116+
]
117+
else:
118+
self.collective_id_map[node] = [_get_new_collective_id()]
89119

90120
return self.collective_id_map
91121

python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def evaluate_logical_plan(
101101
log_query_plan(ir)
102102

103103
# Reserve shuffle IDs for the entire pipeline execution
104-
with ReserveOpIDs(ir) as collective_id_map:
104+
with ReserveOpIDs(ir, config_options) as collective_id_map:
105105
# Build and execute the streaming pipeline.
106106
# This must be done on all worker processes
107107
# for cluster == "distributed".

0 commit comments

Comments
 (0)