Skip to content

Commit 34a8752

Browse files
authored
Use the same channel for TableChunk and Metadata messages (rapidsai#21182)
Now that rapidsai/rapidsmpf#811 was merged in RapidsMPF, we can drop the confusing `ChannelPair` design in cudf-polars. **NOTE**: After this and rapidsai/rapidsmpf#819 are both in, we can also drop the cudf-polasr `Metadata` definition in favor of the "standard" `ChannelMetadata` defined in RapidsMPF. Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Matthew Murray (https://github.com/Matt711) - Tom Augspurger (https://github.com/TomAugspurger) URL: rapidsai#21182
1 parent c1113e4 commit 34a8752

8 files changed

Lines changed: 213 additions & 255 deletions

File tree

python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/shuffle.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,18 +24,20 @@
2424
ChannelManager,
2525
HashPartitioned,
2626
Metadata,
27+
recv_metadata,
28+
send_metadata,
2729
)
2830
from cudf_polars.experimental.shuffle import Shuffle
2931

3032
if TYPE_CHECKING:
33+
from rapidsmpf.streaming.core.channel import Channel
3134
from rapidsmpf.streaming.core.context import Context
3235

3336
import pylibcudf as plc
3437
from rmm.pylibrmm.stream import Stream
3538

3639
from cudf_polars.dsl.ir import IR, IRExecutionContext
3740
from cudf_polars.experimental.rapidsmpf.core import SubNetGenerator
38-
from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
3941

4042

4143
class ShuffleManager:
@@ -125,8 +127,8 @@ async def shuffle_node(
125127
context: Context,
126128
ir: Shuffle,
127129
ir_context: IRExecutionContext,
128-
ch_in: ChannelPair,
129-
ch_out: ChannelPair,
130+
ch_in: Channel[TableChunk],
131+
ch_out: Channel[TableChunk],
130132
columns_to_hash: tuple[int, ...],
131133
num_partitions: int,
132134
collective_id: int,
@@ -147,21 +149,19 @@ async def shuffle_node(
147149
ir_context
148150
The execution context for the IR node.
149151
ch_in
150-
Input ChannelPair with metadata and data channels.
152+
Input Channel[TableChunk] with metadata and data channels.
151153
ch_out
152-
Output ChannelPair with metadata and data channels.
154+
Output Channel[TableChunk] with metadata and data channels.
153155
columns_to_hash
154156
Tuple of column indices to use for hashing.
155157
num_partitions
156158
Number of partitions to shuffle into.
157159
collective_id
158160
The collective ID.
159161
"""
160-
async with shutdown_on_error(
161-
context, ch_in.metadata, ch_in.data, ch_out.metadata, ch_out.data
162-
):
162+
async with shutdown_on_error(context, ch_in, ch_out):
163163
# Receive and send updated metadata.
164-
_ = await ch_in.recv_metadata(context)
164+
_ = await recv_metadata(ch_in, context)
165165
column_names = list(ir.schema.keys())
166166
partitioned_on = tuple(column_names[i] for i in columns_to_hash)
167167
output_metadata = Metadata(
@@ -173,15 +173,15 @@ async def shuffle_node(
173173
count=num_partitions,
174174
),
175175
)
176-
await ch_out.send_metadata(context, output_metadata)
176+
await send_metadata(ch_out, context, output_metadata)
177177

178178
# Create ShuffleManager instance
179179
shuffle = ShuffleManager(
180180
context, num_partitions, columns_to_hash, collective_id
181181
)
182182

183183
# Process input chunks
184-
while (msg := await ch_in.data.recv(context)) is not None:
184+
while (msg := await ch_in.recv(context)) is not None:
185185
# Extract TableChunk from message and insert into shuffler
186186
shuffle.insert_chunk(
187187
TableChunk.from_message(msg).make_available_and_spill(
@@ -202,7 +202,7 @@ async def shuffle_node(
202202
context.comm().nranks,
203203
):
204204
# Extract and send the output chunk
205-
await ch_out.data.send(
205+
await ch_out.send(
206206
context,
207207
Message(
208208
partition_id,
@@ -214,7 +214,7 @@ async def shuffle_node(
214214
),
215215
)
216216

217-
await ch_out.data.drain(context)
217+
await ch_out.drain(context)
218218

219219

220220
@generate_ir_sub_network.register(Shuffle)

python/cudf_polars/cudf_polars/experimental/rapidsmpf/core.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,8 +477,8 @@ def generate_network(
477477
nodes_dict, channels = mapper(ir)
478478
ch_out = channels[ir].reserve_output_slot()
479479

480-
# Add node to drain metadata channel before pull_from_channel
481-
# (since pull_from_channel doesn't accept a ChannelPair)
480+
# Add node to drain metadata before pull_from_channel
481+
# (since pull_from_channel doesn't handle metadata messages)
482482
ch_final_data: Channel[TableChunk] = context.create_channel()
483483
drain_node = metadata_drain_node(
484484
context,

python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
ChannelManager,
4242
Metadata,
4343
opaque_reservation,
44+
send_metadata,
4445
)
4546

4647
if TYPE_CHECKING:
@@ -53,7 +54,6 @@
5354
from cudf_polars.experimental.base import ColumnStat, StatsCollector
5455
from cudf_polars.experimental.rapidsmpf.core import SubNetGenerator
5556
from cudf_polars.experimental.rapidsmpf.dispatch import LowerIRTransformer
56-
from cudf_polars.experimental.rapidsmpf.utils import ChannelPair
5757
from cudf_polars.utils.config import ParquetOptions
5858

5959

@@ -139,7 +139,7 @@ async def dataframescan_node(
139139
context: Context,
140140
ir: DataFrameScan,
141141
ir_context: IRExecutionContext,
142-
ch_out: ChannelPair,
142+
ch_out: Channel[TableChunk],
143143
*,
144144
num_producers: int,
145145
rows_per_partition: int,
@@ -157,7 +157,7 @@ async def dataframescan_node(
157157
ir_context
158158
The execution context for the IR node.
159159
ch_out
160-
The output ChannelPair.
160+
The output Channel[TableChunk].
161161
num_producers
162162
The number of producers to use for the DataFrameScan node.
163163
rows_per_partition
@@ -166,7 +166,7 @@ async def dataframescan_node(
166166
Estimated size of each chunk in bytes. Used for memory reservation
167167
with block spilling to avoid thrashing.
168168
"""
169-
async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
169+
async with shutdown_on_error(context, ch_out):
170170
# Find local partition count.
171171
nrows = ir.df.shape()[0]
172172
global_count = math.ceil(nrows / rows_per_partition) if nrows > 0 else 0
@@ -180,7 +180,8 @@ async def dataframescan_node(
180180
local_offset = local_count * context.comm().rank
181181

182182
# Send basic metadata
183-
await ch_out.send_metadata(
183+
await send_metadata(
184+
ch_out,
184185
context,
185186
Metadata(local_count=local_count, global_count=global_count),
186187
)
@@ -201,7 +202,7 @@ async def dataframescan_node(
201202

202203
# If there are no slices, drain the channel and return
203204
if len(ir_slices) == 0:
204-
await ch_out.data.drain(context)
205+
await ch_out.drain(context)
205206
return
206207

207208
# If there is only one ir_slices or one producer, we can
@@ -212,16 +213,16 @@ async def dataframescan_node(
212213
context,
213214
ir_slice,
214215
seq_num,
215-
ch_out.data,
216+
ch_out,
216217
ir_context,
217218
estimated_chunk_bytes,
218219
)
219-
await ch_out.data.drain(context)
220+
await ch_out.drain(context)
220221
return
221222

222223
# Use Lineariser to ensure ordered delivery
223224
num_producers = min(num_producers, len(ir_slices))
224-
lineariser = Lineariser(context, ch_out.data, num_producers)
225+
lineariser = Lineariser(context, ch_out, num_producers)
225226

226227
# Assign tasks to producers using round-robin
227228
producer_tasks: list[list[tuple[int, DataFrameScan]]] = [
@@ -365,7 +366,7 @@ async def scan_node(
365366
context: Context,
366367
ir: Scan,
367368
ir_context: IRExecutionContext,
368-
ch_out: ChannelPair,
369+
ch_out: Channel[TableChunk],
369370
*,
370371
num_producers: int,
371372
plan: IOPartitionPlan,
@@ -384,7 +385,7 @@ async def scan_node(
384385
ir_context
385386
The execution context for the IR node.
386387
ch_out
387-
The output ChannelPair.
388+
The output Channel[TableChunk].
388389
num_producers
389390
The number of producers to use for the scan node.
390391
plan
@@ -395,7 +396,7 @@ async def scan_node(
395396
Estimated size of each chunk in bytes. Used for memory reservation
396397
with block spilling to avoid thrashing.
397398
"""
398-
async with shutdown_on_error(context, ch_out.metadata, ch_out.data):
399+
async with shutdown_on_error(context, ch_out):
399400
# Build a list of local Scan operations
400401
scans: list[Scan | SplitScan] = []
401402
if plan.flavor == IOPartitionFlavor.SPLIT_FILES:
@@ -464,14 +465,15 @@ async def scan_node(
464465
)
465466

466467
# Send basic metadata
467-
await ch_out.send_metadata(
468+
await send_metadata(
469+
ch_out,
468470
context,
469471
Metadata(local_count=len(scans), global_count=count),
470472
)
471473

472474
# If there is nothing to scan, drain the channel and return
473475
if len(scans) == 0:
474-
await ch_out.data.drain(context)
476+
await ch_out.drain(context)
475477
return
476478

477479
# If there is only one scan or one producer, we can
@@ -482,16 +484,16 @@ async def scan_node(
482484
context,
483485
scan,
484486
seq_num,
485-
ch_out.data,
487+
ch_out,
486488
ir_context,
487489
estimated_chunk_bytes,
488490
)
489-
await ch_out.data.drain(context)
491+
await ch_out.drain(context)
490492
return
491493

492494
# Use Lineariser to ensure ordered delivery
493495
num_producers = min(num_producers, len(scans))
494-
lineariser = Lineariser(context, ch_out.data, num_producers)
496+
lineariser = Lineariser(context, ch_out, num_producers)
495497

496498
# Assign tasks to producers using round-robin
497499
producer_tasks: list[list[tuple[int, Scan | SplitScan]]] = [
@@ -524,7 +526,7 @@ def make_rapidsmpf_read_parquet_node(
524526
context: Context,
525527
ir: Scan,
526528
num_producers: int,
527-
ch_out: ChannelPair,
529+
ch_out: Channel[TableChunk],
528530
stats: StatsCollector,
529531
partition_info: PartitionInfo,
530532
) -> Any | None:
@@ -540,7 +542,7 @@ def make_rapidsmpf_read_parquet_node(
540542
num_producers
541543
The number of producers to use for the scan node.
542544
ch_out
543-
The output ChannelPair.
545+
The output Channel[TableChunk].
544546
stats
545547
The statistics collector.
546548
partition_info
@@ -611,7 +613,7 @@ def make_rapidsmpf_read_parquet_node(
611613
try:
612614
return read_parquet(
613615
context,
614-
ch_out.data,
616+
ch_out,
615617
num_producers,
616618
parquet_reader_options,
617619
num_rows_per_chunk,
@@ -652,7 +654,8 @@ def _(
652654
)
653655

654656
# Use rapidsmpf native read_parquet node if possible
655-
ch_pair = channels[ir].reserve_input_slot()
657+
ch_in: Channel[TableChunk] | None = None
658+
ch_out = channels[ir].reserve_input_slot()
656659
nodes: dict[IR, list[Any]] = {}
657660
native_node: Any = None
658661
if (
@@ -665,21 +668,24 @@ def _(
665668
and ir.skip_rows == 0
666669
and not distributed_split_files
667670
):
671+
# Create new channel to so ch_out can be used to add metadata
672+
ch_in = rec.state["context"].create_channel()
668673
native_node = make_rapidsmpf_read_parquet_node(
669674
rec.state["context"],
670675
ir,
671676
num_producers,
672-
ch_pair,
677+
ch_in,
673678
rec.state["stats"],
674679
partition_info,
675680
)
676681

677-
if native_node is not None:
682+
if native_node is not None and ch_in is not None:
678683
# Need metadata node, because the native read_parquet
679684
# node does not send metadata.
680685
metadata_node = metadata_feeder_node(
681686
rec.state["context"],
682-
ch_pair,
687+
ch_in,
688+
ch_out,
683689
Metadata(
684690
# partition_info.count is the estimated "global" count.
685691
# Just estimate the local count as well.
@@ -689,7 +695,7 @@ def _(
689695
global_count=partition_info.count,
690696
),
691697
)
692-
nodes[ir] = [metadata_node, native_node]
698+
nodes[ir] = [native_node, metadata_node]
693699
else:
694700
# Fall back to scan_node (predicate not convertible, or other constraint)
695701
parquet_options = dataclasses.replace(parquet_options, chunked=False)
@@ -699,7 +705,7 @@ def _(
699705
rec.state["context"],
700706
ir,
701707
rec.state["ir_context"],
702-
ch_pair,
708+
ch_out,
703709
num_producers=num_producers,
704710
plan=plan,
705711
parquet_options=parquet_options,

0 commit comments

Comments
 (0)