4141 ChannelManager ,
4242 Metadata ,
4343 opaque_reservation ,
44+ send_metadata ,
4445)
4546
4647if TYPE_CHECKING :
5354 from cudf_polars .experimental .base import ColumnStat , StatsCollector
5455 from cudf_polars .experimental .rapidsmpf .core import SubNetGenerator
5556 from cudf_polars .experimental .rapidsmpf .dispatch import LowerIRTransformer
56- from cudf_polars .experimental .rapidsmpf .utils import ChannelPair
5757 from cudf_polars .utils .config import ParquetOptions
5858
5959
@@ -139,7 +139,7 @@ async def dataframescan_node(
139139 context : Context ,
140140 ir : DataFrameScan ,
141141 ir_context : IRExecutionContext ,
142- ch_out : ChannelPair ,
142+ ch_out : Channel [ TableChunk ] ,
143143 * ,
144144 num_producers : int ,
145145 rows_per_partition : int ,
@@ -157,7 +157,7 @@ async def dataframescan_node(
157157 ir_context
158158 The execution context for the IR node.
159159 ch_out
160- The output ChannelPair .
160+ The output Channel[TableChunk] .
161161 num_producers
162162 The number of producers to use for the DataFrameScan node.
163163 rows_per_partition
@@ -166,7 +166,7 @@ async def dataframescan_node(
166166 Estimated size of each chunk in bytes. Used for memory reservation
167167 with block spilling to avoid thrashing.
168168 """
169- async with shutdown_on_error (context , ch_out . metadata , ch_out . data ):
169+ async with shutdown_on_error (context , ch_out ):
170170 # Find local partition count.
171171 nrows = ir .df .shape ()[0 ]
172172 global_count = math .ceil (nrows / rows_per_partition ) if nrows > 0 else 0
@@ -180,7 +180,8 @@ async def dataframescan_node(
180180 local_offset = local_count * context .comm ().rank
181181
182182 # Send basic metadata
183- await ch_out .send_metadata (
183+ await send_metadata (
184+ ch_out ,
184185 context ,
185186 Metadata (local_count = local_count , global_count = global_count ),
186187 )
@@ -201,7 +202,7 @@ async def dataframescan_node(
201202
202203 # If there are no slices, drain the channel and return
203204 if len (ir_slices ) == 0 :
204- await ch_out .data . drain (context )
205+ await ch_out .drain (context )
205206 return
206207
207208 # If there is only one ir_slices or one producer, we can
@@ -212,16 +213,16 @@ async def dataframescan_node(
212213 context ,
213214 ir_slice ,
214215 seq_num ,
215- ch_out . data ,
216+ ch_out ,
216217 ir_context ,
217218 estimated_chunk_bytes ,
218219 )
219- await ch_out .data . drain (context )
220+ await ch_out .drain (context )
220221 return
221222
222223 # Use Lineariser to ensure ordered delivery
223224 num_producers = min (num_producers , len (ir_slices ))
224- lineariser = Lineariser (context , ch_out . data , num_producers )
225+ lineariser = Lineariser (context , ch_out , num_producers )
225226
226227 # Assign tasks to producers using round-robin
227228 producer_tasks : list [list [tuple [int , DataFrameScan ]]] = [
@@ -365,7 +366,7 @@ async def scan_node(
365366 context : Context ,
366367 ir : Scan ,
367368 ir_context : IRExecutionContext ,
368- ch_out : ChannelPair ,
369+ ch_out : Channel [ TableChunk ] ,
369370 * ,
370371 num_producers : int ,
371372 plan : IOPartitionPlan ,
@@ -384,7 +385,7 @@ async def scan_node(
384385 ir_context
385386 The execution context for the IR node.
386387 ch_out
387- The output ChannelPair .
388+ The output Channel[TableChunk] .
388389 num_producers
389390 The number of producers to use for the scan node.
390391 plan
@@ -395,7 +396,7 @@ async def scan_node(
395396 Estimated size of each chunk in bytes. Used for memory reservation
396397 with block spilling to avoid thrashing.
397398 """
398- async with shutdown_on_error (context , ch_out . metadata , ch_out . data ):
399+ async with shutdown_on_error (context , ch_out ):
399400 # Build a list of local Scan operations
400401 scans : list [Scan | SplitScan ] = []
401402 if plan .flavor == IOPartitionFlavor .SPLIT_FILES :
@@ -464,14 +465,15 @@ async def scan_node(
464465 )
465466
466467 # Send basic metadata
467- await ch_out .send_metadata (
468+ await send_metadata (
469+ ch_out ,
468470 context ,
469471 Metadata (local_count = len (scans ), global_count = count ),
470472 )
471473
472474 # If there is nothing to scan, drain the channel and return
473475 if len (scans ) == 0 :
474- await ch_out .data . drain (context )
476+ await ch_out .drain (context )
475477 return
476478
477479 # If there is only one scan or one producer, we can
@@ -482,16 +484,16 @@ async def scan_node(
482484 context ,
483485 scan ,
484486 seq_num ,
485- ch_out . data ,
487+ ch_out ,
486488 ir_context ,
487489 estimated_chunk_bytes ,
488490 )
489- await ch_out .data . drain (context )
491+ await ch_out .drain (context )
490492 return
491493
492494 # Use Lineariser to ensure ordered delivery
493495 num_producers = min (num_producers , len (scans ))
494- lineariser = Lineariser (context , ch_out . data , num_producers )
496+ lineariser = Lineariser (context , ch_out , num_producers )
495497
496498 # Assign tasks to producers using round-robin
497499 producer_tasks : list [list [tuple [int , Scan | SplitScan ]]] = [
@@ -524,7 +526,7 @@ def make_rapidsmpf_read_parquet_node(
524526 context : Context ,
525527 ir : Scan ,
526528 num_producers : int ,
527- ch_out : ChannelPair ,
529+ ch_out : Channel [ TableChunk ] ,
528530 stats : StatsCollector ,
529531 partition_info : PartitionInfo ,
530532) -> Any | None :
@@ -540,7 +542,7 @@ def make_rapidsmpf_read_parquet_node(
540542 num_producers
541543 The number of producers to use for the scan node.
542544 ch_out
543- The output ChannelPair .
545+ The output Channel[TableChunk] .
544546 stats
545547 The statistics collector.
546548 partition_info
@@ -611,7 +613,7 @@ def make_rapidsmpf_read_parquet_node(
611613 try :
612614 return read_parquet (
613615 context ,
614- ch_out . data ,
616+ ch_out ,
615617 num_producers ,
616618 parquet_reader_options ,
617619 num_rows_per_chunk ,
@@ -652,7 +654,8 @@ def _(
652654 )
653655
654656 # Use rapidsmpf native read_parquet node if possible
655- ch_pair = channels [ir ].reserve_input_slot ()
657+ ch_in : Channel [TableChunk ] | None = None
658+ ch_out = channels [ir ].reserve_input_slot ()
656659 nodes : dict [IR , list [Any ]] = {}
657660 native_node : Any = None
658661 if (
@@ -665,21 +668,24 @@ def _(
665668 and ir .skip_rows == 0
666669 and not distributed_split_files
667670 ):
671+ # Create new channel to so ch_out can be used to add metadata
672+ ch_in = rec .state ["context" ].create_channel ()
668673 native_node = make_rapidsmpf_read_parquet_node (
669674 rec .state ["context" ],
670675 ir ,
671676 num_producers ,
672- ch_pair ,
677+ ch_in ,
673678 rec .state ["stats" ],
674679 partition_info ,
675680 )
676681
677- if native_node is not None :
682+ if native_node is not None and ch_in is not None :
678683 # Need metadata node, because the native read_parquet
679684 # node does not send metadata.
680685 metadata_node = metadata_feeder_node (
681686 rec .state ["context" ],
682- ch_pair ,
687+ ch_in ,
688+ ch_out ,
683689 Metadata (
684690 # partition_info.count is the estimated "global" count.
685691 # Just estimate the local count as well.
@@ -689,7 +695,7 @@ def _(
689695 global_count = partition_info .count ,
690696 ),
691697 )
692- nodes [ir ] = [metadata_node , native_node ]
698+ nodes [ir ] = [native_node , metadata_node ]
693699 else :
694700 # Fall back to scan_node (predicate not convertible, or other constraint)
695701 parquet_options = dataclasses .replace (parquet_options , chunked = False )
@@ -699,7 +705,7 @@ def _(
699705 rec .state ["context" ],
700706 ir ,
701707 rec .state ["ir_context" ],
702- ch_pair ,
708+ ch_out ,
703709 num_producers = num_producers ,
704710 plan = plan ,
705711 parquet_options = parquet_options ,
0 commit comments