11# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
22# SPDX-License-Identifier: Apache-2.0
3- """Sort (ShuffleSorted) logic for the RapidsMPF streaming runtime."""
3+ """Sort logic for the RapidsMPF streaming runtime."""
44
55from __future__ import annotations
66
2424from cudf_polars .experimental .rapidsmpf .collectives .allgather import AllGatherManager
2525from cudf_polars .experimental .rapidsmpf .collectives .shuffle import ShuffleManager
2626from cudf_polars .experimental .rapidsmpf .dispatch import generate_ir_sub_network
27- from cudf_polars .experimental .rapidsmpf .nodes import shutdown_on_error
27+ from cudf_polars .experimental .rapidsmpf .nodes import (
28+ default_node_single ,
29+ shutdown_on_error ,
30+ )
2831from cudf_polars .experimental .rapidsmpf .utils import (
2932 ChannelManager ,
3033 allgather_reduce ,
3134 chunk_to_frame ,
3235 concat_batch ,
3336 empty_table_chunk ,
37+ evaluate_batch ,
38+ evaluate_chunk ,
3439 gather_in_task_group ,
3540 names_to_indices ,
41+ process_children ,
3642 recv_metadata ,
3743 replay_buffered_channel ,
3844 send_metadata ,
3945)
46+ from cudf_polars .experimental .repartition import Repartition
4047from cudf_polars .experimental .sort import (
41- ShuffleSorted ,
4248 _get_final_sort_boundaries ,
49+ _has_simple_zlice ,
4350 _select_local_split_candidates ,
4451 find_sort_splits ,
4552)
@@ -76,6 +83,63 @@ def __iter__(self) -> Generator[Message, None, None]:
7683 yield self ._store .extract (mid = self ._mids .popleft ())
7784
7885
86+ async def _simple_top_or_bottom_k (
87+ context : Context ,
88+ comm : Communicator ,
89+ ch_in : Channel [TableChunk ],
90+ ch_out : Channel [TableChunk ],
91+ ir : Sort ,
92+ ir_context : IRExecutionContext ,
93+ metadata_in : ChannelMetadata ,
94+ collective_ids : list [int ],
95+ tracer : ActorTracer | None ,
96+ ) -> None :
97+ """Sort + simple head/tail slice."""
98+ # TODO: We may need to gate this optimization on the slice size.
99+ await send_metadata (
100+ ch_out ,
101+ context ,
102+ ChannelMetadata (local_count = 1 , partitioning = None , duplicated = True ),
103+ )
104+
105+ chunks : list [TableChunk ] = []
106+ while (msg := await ch_in .recv (context )) is not None :
107+ chunks .append (
108+ await evaluate_chunk (
109+ context ,
110+ TableChunk .from_message (msg , br = context .br ()),
111+ ir ,
112+ ir_context = ir_context ,
113+ )
114+ )
115+ chunk : TableChunk = await evaluate_batch (chunks , context , ir , ir_context = ir_context )
116+ chunks .clear ()
117+
118+ if comm .nranks > 1 and not metadata_in .duplicated :
119+ allgather = AllGatherManager (context , comm , collective_ids .pop ())
120+ with allgather .inserting () as inserter :
121+ inserter .insert (comm .rank , chunk )
122+
123+ stream = ir_context .get_cuda_stream ()
124+ chunk = await evaluate_chunk (
125+ context ,
126+ TableChunk .from_pylibcudf_table (
127+ await allgather .extract_concatenated (stream , ordered = True ),
128+ stream ,
129+ exclusive_view = True ,
130+ br = context .br (),
131+ ),
132+ ir ,
133+ ir_context = ir_context ,
134+ )
135+
136+ if tracer is not None :
137+ tracer .add_chunk (table = chunk .table_view ())
138+ await ch_out .send (context , Message (comm .rank , chunk ))
139+
140+ await ch_out .drain (context )
141+
142+
79143def _boundary_schema (by : list [str ], by_dtypes : list [DataType ]) -> Schema :
80144 """Schema of boundaries table."""
81145 name_gen = unique_names (by )
@@ -94,7 +158,7 @@ async def _compute_sort_boundaries(
94158 comm : Communicator ,
95159 ir_context : IRExecutionContext ,
96160 local_candidates_list : list [TableChunk ],
97- ir : ShuffleSorted ,
161+ ir : Sort ,
98162 by : list [str ],
99163 num_partitions : int ,
100164 allgather_id : int | None ,
@@ -210,10 +274,11 @@ async def _receive_and_buffer_chunks(
210274 context : Context ,
211275 ch_in : Channel [TableChunk ],
212276 chunk_store : ChunkStore ,
213- sort_ir : Sort ,
277+ ir : Sort ,
214278 by : list [str ],
215279 num_partitions : int ,
216280 comm : Communicator ,
281+ ir_context : IRExecutionContext ,
217282) -> list [TableChunk ]:
218283 """Receive input chunks, collect local split candidates, and buffer chunks for later insert."""
219284 await recv_metadata (ch_in , context )
@@ -223,10 +288,14 @@ async def _receive_and_buffer_chunks(
223288 while (msg := await ch_in .recv (context )) is not None :
224289 seq_num = msg .sequence_number
225290 df = chunk_to_frame (
226- TableChunk .from_message (msg , br = context .br ()).make_available_and_spill (
227- context .br (), allow_overbooking = True
291+ # Make sure chunks are pre-sorted
292+ await evaluate_chunk (
293+ context ,
294+ TableChunk .from_message (msg , br = context .br ()),
295+ ir ,
296+ ir_context = ir_context ,
228297 ),
229- sort_ir ,
298+ ir ,
230299 )
231300 local_candidates_list .append (
232301 TableChunk .from_pylibcudf_table (
@@ -238,7 +307,7 @@ async def _receive_and_buffer_chunks(
238307 br = context .br (),
239308 )
240309 )
241- if sort_ir .stable :
310+ if ir .stable :
242311 nrows = df .table .num_rows ()
243312 start = (comm .rank * (1 << 48 )) + local_row_offset
244313 seq_id_col = plc .filling .sequence (
@@ -276,7 +345,7 @@ async def _insert_chunks_into_shuffle(
276345 metadata_in : ChannelMetadata ,
277346 chunk_store : ChunkStore ,
278347 sort_boundaries_df : DataFrame ,
279- ir : ShuffleSorted ,
348+ ir : Sort ,
280349 ir_context : IRExecutionContext ,
281350 by : list [str ],
282351) -> tuple [ShuffleManager , Sort ]:
@@ -286,8 +355,6 @@ async def _insert_chunks_into_shuffle(
286355 by_indices = names_to_indices (tuple (by ), ir .schema )
287356
288357 skip_insert = metadata_in .duplicated and comm .rank != 0
289- local_sort_ir = ir .children [0 ]
290- assert isinstance (local_sort_ir , Sort ), "ShuffleSorted must have a Sort child."
291358
292359 shuffle = ShuffleManager (
293360 context ,
@@ -323,21 +390,21 @@ async def _insert_chunks_into_shuffle(
323390 )
324391 inserter .insert_split (available_chunk , splits )
325392
326- post_sort_ir = local_sort_ir
327- if local_sort_ir .stable :
328- assert local_sort_ir .zlice is None
393+ post_sort_ir = ir
394+ if ir .stable :
395+ assert ir .zlice is None
329396 seq_id_name = next (unique_names (ir .schema .keys ()))
330397 post_sort_ir = Sort (
331- local_sort_ir .schema | {seq_id_name : DataType (pl .UInt64 ())},
398+ ir .schema | {seq_id_name : DataType (pl .UInt64 ())},
332399 (
333- * local_sort_ir .by ,
400+ * ir .by ,
334401 NamedExpr (seq_id_name , Col (DataType (pl .UInt64 ()), seq_id_name )),
335402 ),
336- (* local_sort_ir .order , plc .types .Order .ASCENDING ),
337- (* local_sort_ir .null_order , plc .types .NullOrder .AFTER ),
338- local_sort_ir .stable ,
403+ (* ir .order , plc .types .Order .ASCENDING ),
404+ (* ir .null_order , plc .types .NullOrder .AFTER ),
405+ ir .stable ,
339406 None ,
340- local_sort_ir .children [0 ],
407+ ir .children [0 ],
341408 )
342409
343410 return shuffle , post_sort_ir
@@ -369,19 +436,19 @@ async def _extract_partitions_and_send(
369436 ),
370437 context = ir_context ,
371438 ).table
372- if table .num_columns () > ncols_out :
373- table = plc .Table (table .columns ()[:ncols_out ])
374- if tracer is not None :
375- tracer .add_chunk (table = table )
376- await ch_out .send (
377- context ,
378- Message (
379- partition_id ,
380- TableChunk .from_pylibcudf_table (
381- table , stream , exclusive_view = True , br = context .br ()
439+ if table .num_columns () > ncols_out :
440+ table = plc .Table (table .columns ()[:ncols_out ])
441+ if tracer is not None :
442+ tracer .add_chunk (table = table )
443+ await ch_out .send (
444+ context ,
445+ Message (
446+ partition_id ,
447+ TableChunk .from_pylibcudf_table (
448+ table , stream , exclusive_view = True , br = context .br ()
449+ ),
382450 ),
383- ),
384- )
451+ )
385452
386453 await ch_out .drain (context )
387454
@@ -390,7 +457,7 @@ async def _extract_partitions_and_send(
390457async def sort_actor (
391458 context : Context ,
392459 comm : Communicator ,
393- ir : ShuffleSorted ,
460+ ir : Sort ,
394461 ir_context : IRExecutionContext ,
395462 ch_in : Channel [TableChunk ],
396463 ch_out : Channel [TableChunk ],
@@ -400,14 +467,29 @@ async def sort_actor(
400467 collective_ids : list [int ],
401468) -> None :
402469 """Streaming sort actor."""
403- local_sort_ir = ir .children [0 ]
404- assert isinstance (local_sort_ir , Sort ), "ShuffleSorted must have a Sort child."
405470 ch_replay = context .create_channel ()
406471 async with shutdown_on_error (
407472 context , ch_in , ch_out , ch_replay , trace_ir = ir , ir_context = ir_context
408473 ) as tracer :
409474 metadata_in = await recv_metadata (ch_in , context )
410475
476+ if ir .zlice is not None :
477+ assert _has_simple_zlice (ir .zlice ), (
478+ f"This slice not supported in `sort_actor`: { ir .zlice } ."
479+ )
480+ await _simple_top_or_bottom_k (
481+ context ,
482+ comm ,
483+ ch_in ,
484+ ch_out ,
485+ ir ,
486+ ir_context ,
487+ metadata_in ,
488+ collective_ids ,
489+ tracer ,
490+ )
491+ return
492+
411493 sampled_chunks , num_partitions = await _sample_chunks_for_size_estimate (
412494 context , comm , ch_in , num_partitions , metadata_in , executor , collective_ids
413495 )
@@ -427,10 +509,11 @@ async def sort_actor(
427509 context ,
428510 ch_replay ,
429511 chunk_store ,
430- local_sort_ir ,
512+ ir ,
431513 by ,
432514 num_partitions ,
433515 comm ,
516+ ir_context ,
434517 ),
435518 )
436519
@@ -460,22 +543,45 @@ async def sort_actor(
460543 )
461544
462545 await _extract_partitions_and_send (
463- context , ch_out , shuffle , post_sort_ir , ir_context , ir .schema , tracer = tracer
546+ context ,
547+ ch_out ,
548+ shuffle ,
549+ post_sort_ir ,
550+ ir_context ,
551+ ir .schema ,
552+ tracer = tracer ,
464553 )
465554
466555
467- @generate_ir_sub_network .register (ShuffleSorted )
468- def _shuffle_sorted_network (
469- ir : ShuffleSorted , rec : SubNetGenerator
470- ) -> tuple [dict , dict ]:
556+ @generate_ir_sub_network .register (Sort )
557+ def _sort_rapidsmpf_network (ir : Sort , rec : SubNetGenerator ) -> tuple [dict , dict ]:
558+ """Wire multi-partition ``Sort`` to ``sort_actor``; single-partition uses ``default_node_single``."""
559+ executor = rec .state ["config_options" ].executor
560+ partition_info = rec .state ["partition_info" ]
561+ dynamic = executor .dynamic_planning is not None
562+
563+ if partition_info [ir ].count == 1 and (
564+ not dynamic or isinstance (ir .children [0 ], Repartition )
565+ ):
566+ nodes , channels = process_children (ir , rec )
567+ channels [ir ] = ChannelManager (rec .state ["context" ])
568+ nodes [ir ] = [
569+ default_node_single (
570+ rec .state ["context" ],
571+ ir ,
572+ rec .state ["ir_context" ],
573+ channels [ir ].reserve_input_slot (),
574+ channels [ir .children [0 ]].reserve_output_slot (),
575+ )
576+ ]
577+ return nodes , channels
578+
471579 (child ,) = ir .children
472580 nodes , channels = rec (child )
473581 by = [ne .value .name for ne in ir .by if isinstance (ne .value , Col )]
474582 if len (by ) != len (ir .by ):
475583 raise NotImplementedError ("Sorting columns must be column names." )
476584
477- executor = rec .state ["config_options" ].executor
478- dynamic = executor .dynamic_planning is not None
479585 collective_ids = list (rec .state ["collective_id_map" ][ir ])
480586 expected_id_count = 3 if dynamic else 2
481587 assert len (collective_ids ) == expected_id_count , (
@@ -492,7 +598,7 @@ def _shuffle_sorted_network(
492598 ch_in = channels [child ].reserve_output_slot (),
493599 ch_out = channels [ir ].reserve_input_slot (),
494600 by = by ,
495- num_partitions = rec . state [ " partition_info" ] [ir ].count ,
601+ num_partitions = partition_info [ir ].count ,
496602 executor = executor ,
497603 collective_ids = collective_ids ,
498604 )
0 commit comments