Skip to content

Commit a94c7d1

Browse files
authored
Revise Sort lowering and sort_actor assumptions (rapidsai#22315)
- Closes rapidsai#22050 - **Updates `Sort` lowering**: - The "rapidsmpf" runtime no longer uses `Sort(ShuffleSorted(Sort(...)))` - We leave most `Sort` nodes alone. - We map `Sort` directly to `sort_actor` - The `sort_actor` performs it's own top/bottom-k optimization and performs chunk-wise sorting itself. **Motivation**: The `Sort(ShuffleSorted(Sort(...)))` pattern will be too messy when we start utilizing `OrderScheme` metadata (after rapidsai#22291). The current assumption that the child IR node is always a `Sort` is also not true/safe. **Note**: This PR technically adds code, but it will allow us to *remove* more code once "tasks" is removed (everything `ShuffleSorted` related). Authors: - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Matthew Roeschke (https://github.com/mroeschke) URL: rapidsai#22315
1 parent 21b2920 commit a94c7d1

9 files changed

Lines changed: 218 additions & 105 deletions

File tree

python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/common.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,12 @@
1010

1111
from rapidsmpf.shuffler import Shuffler
1212

13-
from cudf_polars.dsl.ir import Distinct, GroupBy
13+
from cudf_polars.dsl.ir import Distinct, GroupBy, Sort
1414
from cudf_polars.dsl.traversal import traversal
1515
from cudf_polars.experimental.io import StreamingSink
1616
from cudf_polars.experimental.join import Join
1717
from cudf_polars.experimental.repartition import Repartition
1818
from cudf_polars.experimental.shuffle import Shuffle
19-
from cudf_polars.experimental.sort import ShuffleSorted
2019

2120
if TYPE_CHECKING:
2221
from collections.abc import Iterator
@@ -91,15 +90,15 @@ def __init__(
9190
Join,
9291
Repartition,
9392
StreamingSink,
94-
ShuffleSorted,
93+
Sort,
9594
)
9695
if self.dynamic_planning_enabled:
9796
collective_types = (
9897
Shuffle,
9998
Join,
10099
Repartition,
101100
StreamingSink,
102-
ShuffleSorted,
101+
Sort,
103102
GroupBy,
104103
Distinct,
105104
)
@@ -137,7 +136,7 @@ def __enter__(self) -> dict[IR, list[int]]:
137136
_get_new_collective_id(),
138137
_get_new_collective_id(),
139138
]
140-
elif isinstance(node, ShuffleSorted):
139+
elif isinstance(node, Sort):
141140
if self.dynamic_planning_enabled:
142141
# 3 IDs: size-estimate allgather, boundary allgather, shuffle
143142
self.collective_id_map[node] = [

python/cudf_polars/cudf_polars/experimental/rapidsmpf/collectives/sort.py

Lines changed: 151 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025-2026, NVIDIA CORPORATION & AFFILIATES.
22
# SPDX-License-Identifier: Apache-2.0
3-
"""Sort (ShuffleSorted) logic for the RapidsMPF streaming runtime."""
3+
"""Sort logic for the RapidsMPF streaming runtime."""
44

55
from __future__ import annotations
66

@@ -24,22 +24,29 @@
2424
from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
2525
from cudf_polars.experimental.rapidsmpf.collectives.shuffle import ShuffleManager
2626
from cudf_polars.experimental.rapidsmpf.dispatch import generate_ir_sub_network
27-
from cudf_polars.experimental.rapidsmpf.nodes import shutdown_on_error
27+
from cudf_polars.experimental.rapidsmpf.nodes import (
28+
default_node_single,
29+
shutdown_on_error,
30+
)
2831
from cudf_polars.experimental.rapidsmpf.utils import (
2932
ChannelManager,
3033
allgather_reduce,
3134
chunk_to_frame,
3235
concat_batch,
3336
empty_table_chunk,
37+
evaluate_batch,
38+
evaluate_chunk,
3439
gather_in_task_group,
3540
names_to_indices,
41+
process_children,
3642
recv_metadata,
3743
replay_buffered_channel,
3844
send_metadata,
3945
)
46+
from cudf_polars.experimental.repartition import Repartition
4047
from cudf_polars.experimental.sort import (
41-
ShuffleSorted,
4248
_get_final_sort_boundaries,
49+
_has_simple_zlice,
4350
_select_local_split_candidates,
4451
find_sort_splits,
4552
)
@@ -76,6 +83,63 @@ def __iter__(self) -> Generator[Message, None, None]:
7683
yield self._store.extract(mid=self._mids.popleft())
7784

7885

86+
async def _simple_top_or_bottom_k(
87+
context: Context,
88+
comm: Communicator,
89+
ch_in: Channel[TableChunk],
90+
ch_out: Channel[TableChunk],
91+
ir: Sort,
92+
ir_context: IRExecutionContext,
93+
metadata_in: ChannelMetadata,
94+
collective_ids: list[int],
95+
tracer: ActorTracer | None,
96+
) -> None:
97+
"""Sort + simple head/tail slice."""
98+
# TODO: We may need to gate this optimization on the slice size.
99+
await send_metadata(
100+
ch_out,
101+
context,
102+
ChannelMetadata(local_count=1, partitioning=None, duplicated=True),
103+
)
104+
105+
chunks: list[TableChunk] = []
106+
while (msg := await ch_in.recv(context)) is not None:
107+
chunks.append(
108+
await evaluate_chunk(
109+
context,
110+
TableChunk.from_message(msg, br=context.br()),
111+
ir,
112+
ir_context=ir_context,
113+
)
114+
)
115+
chunk: TableChunk = await evaluate_batch(chunks, context, ir, ir_context=ir_context)
116+
chunks.clear()
117+
118+
if comm.nranks > 1 and not metadata_in.duplicated:
119+
allgather = AllGatherManager(context, comm, collective_ids.pop())
120+
with allgather.inserting() as inserter:
121+
inserter.insert(comm.rank, chunk)
122+
123+
stream = ir_context.get_cuda_stream()
124+
chunk = await evaluate_chunk(
125+
context,
126+
TableChunk.from_pylibcudf_table(
127+
await allgather.extract_concatenated(stream, ordered=True),
128+
stream,
129+
exclusive_view=True,
130+
br=context.br(),
131+
),
132+
ir,
133+
ir_context=ir_context,
134+
)
135+
136+
if tracer is not None:
137+
tracer.add_chunk(table=chunk.table_view())
138+
await ch_out.send(context, Message(comm.rank, chunk))
139+
140+
await ch_out.drain(context)
141+
142+
79143
def _boundary_schema(by: list[str], by_dtypes: list[DataType]) -> Schema:
80144
"""Schema of boundaries table."""
81145
name_gen = unique_names(by)
@@ -94,7 +158,7 @@ async def _compute_sort_boundaries(
94158
comm: Communicator,
95159
ir_context: IRExecutionContext,
96160
local_candidates_list: list[TableChunk],
97-
ir: ShuffleSorted,
161+
ir: Sort,
98162
by: list[str],
99163
num_partitions: int,
100164
allgather_id: int | None,
@@ -210,10 +274,11 @@ async def _receive_and_buffer_chunks(
210274
context: Context,
211275
ch_in: Channel[TableChunk],
212276
chunk_store: ChunkStore,
213-
sort_ir: Sort,
277+
ir: Sort,
214278
by: list[str],
215279
num_partitions: int,
216280
comm: Communicator,
281+
ir_context: IRExecutionContext,
217282
) -> list[TableChunk]:
218283
"""Receive input chunks, collect local split candidates, and buffer chunks for later insert."""
219284
await recv_metadata(ch_in, context)
@@ -223,10 +288,14 @@ async def _receive_and_buffer_chunks(
223288
while (msg := await ch_in.recv(context)) is not None:
224289
seq_num = msg.sequence_number
225290
df = chunk_to_frame(
226-
TableChunk.from_message(msg, br=context.br()).make_available_and_spill(
227-
context.br(), allow_overbooking=True
291+
# Make sure chunks are pre-sorted
292+
await evaluate_chunk(
293+
context,
294+
TableChunk.from_message(msg, br=context.br()),
295+
ir,
296+
ir_context=ir_context,
228297
),
229-
sort_ir,
298+
ir,
230299
)
231300
local_candidates_list.append(
232301
TableChunk.from_pylibcudf_table(
@@ -238,7 +307,7 @@ async def _receive_and_buffer_chunks(
238307
br=context.br(),
239308
)
240309
)
241-
if sort_ir.stable:
310+
if ir.stable:
242311
nrows = df.table.num_rows()
243312
start = (comm.rank * (1 << 48)) + local_row_offset
244313
seq_id_col = plc.filling.sequence(
@@ -276,7 +345,7 @@ async def _insert_chunks_into_shuffle(
276345
metadata_in: ChannelMetadata,
277346
chunk_store: ChunkStore,
278347
sort_boundaries_df: DataFrame,
279-
ir: ShuffleSorted,
348+
ir: Sort,
280349
ir_context: IRExecutionContext,
281350
by: list[str],
282351
) -> tuple[ShuffleManager, Sort]:
@@ -286,8 +355,6 @@ async def _insert_chunks_into_shuffle(
286355
by_indices = names_to_indices(tuple(by), ir.schema)
287356

288357
skip_insert = metadata_in.duplicated and comm.rank != 0
289-
local_sort_ir = ir.children[0]
290-
assert isinstance(local_sort_ir, Sort), "ShuffleSorted must have a Sort child."
291358

292359
shuffle = ShuffleManager(
293360
context,
@@ -323,21 +390,21 @@ async def _insert_chunks_into_shuffle(
323390
)
324391
inserter.insert_split(available_chunk, splits)
325392

326-
post_sort_ir = local_sort_ir
327-
if local_sort_ir.stable:
328-
assert local_sort_ir.zlice is None
393+
post_sort_ir = ir
394+
if ir.stable:
395+
assert ir.zlice is None
329396
seq_id_name = next(unique_names(ir.schema.keys()))
330397
post_sort_ir = Sort(
331-
local_sort_ir.schema | {seq_id_name: DataType(pl.UInt64())},
398+
ir.schema | {seq_id_name: DataType(pl.UInt64())},
332399
(
333-
*local_sort_ir.by,
400+
*ir.by,
334401
NamedExpr(seq_id_name, Col(DataType(pl.UInt64()), seq_id_name)),
335402
),
336-
(*local_sort_ir.order, plc.types.Order.ASCENDING),
337-
(*local_sort_ir.null_order, plc.types.NullOrder.AFTER),
338-
local_sort_ir.stable,
403+
(*ir.order, plc.types.Order.ASCENDING),
404+
(*ir.null_order, plc.types.NullOrder.AFTER),
405+
ir.stable,
339406
None,
340-
local_sort_ir.children[0],
407+
ir.children[0],
341408
)
342409

343410
return shuffle, post_sort_ir
@@ -369,19 +436,19 @@ async def _extract_partitions_and_send(
369436
),
370437
context=ir_context,
371438
).table
372-
if table.num_columns() > ncols_out:
373-
table = plc.Table(table.columns()[:ncols_out])
374-
if tracer is not None:
375-
tracer.add_chunk(table=table)
376-
await ch_out.send(
377-
context,
378-
Message(
379-
partition_id,
380-
TableChunk.from_pylibcudf_table(
381-
table, stream, exclusive_view=True, br=context.br()
439+
if table.num_columns() > ncols_out:
440+
table = plc.Table(table.columns()[:ncols_out])
441+
if tracer is not None:
442+
tracer.add_chunk(table=table)
443+
await ch_out.send(
444+
context,
445+
Message(
446+
partition_id,
447+
TableChunk.from_pylibcudf_table(
448+
table, stream, exclusive_view=True, br=context.br()
449+
),
382450
),
383-
),
384-
)
451+
)
385452

386453
await ch_out.drain(context)
387454

@@ -390,7 +457,7 @@ async def _extract_partitions_and_send(
390457
async def sort_actor(
391458
context: Context,
392459
comm: Communicator,
393-
ir: ShuffleSorted,
460+
ir: Sort,
394461
ir_context: IRExecutionContext,
395462
ch_in: Channel[TableChunk],
396463
ch_out: Channel[TableChunk],
@@ -400,14 +467,29 @@ async def sort_actor(
400467
collective_ids: list[int],
401468
) -> None:
402469
"""Streaming sort actor."""
403-
local_sort_ir = ir.children[0]
404-
assert isinstance(local_sort_ir, Sort), "ShuffleSorted must have a Sort child."
405470
ch_replay = context.create_channel()
406471
async with shutdown_on_error(
407472
context, ch_in, ch_out, ch_replay, trace_ir=ir, ir_context=ir_context
408473
) as tracer:
409474
metadata_in = await recv_metadata(ch_in, context)
410475

476+
if ir.zlice is not None:
477+
assert _has_simple_zlice(ir.zlice), (
478+
f"This slice not supported in `sort_actor`: {ir.zlice}."
479+
)
480+
await _simple_top_or_bottom_k(
481+
context,
482+
comm,
483+
ch_in,
484+
ch_out,
485+
ir,
486+
ir_context,
487+
metadata_in,
488+
collective_ids,
489+
tracer,
490+
)
491+
return
492+
411493
sampled_chunks, num_partitions = await _sample_chunks_for_size_estimate(
412494
context, comm, ch_in, num_partitions, metadata_in, executor, collective_ids
413495
)
@@ -427,10 +509,11 @@ async def sort_actor(
427509
context,
428510
ch_replay,
429511
chunk_store,
430-
local_sort_ir,
512+
ir,
431513
by,
432514
num_partitions,
433515
comm,
516+
ir_context,
434517
),
435518
)
436519

@@ -460,22 +543,45 @@ async def sort_actor(
460543
)
461544

462545
await _extract_partitions_and_send(
463-
context, ch_out, shuffle, post_sort_ir, ir_context, ir.schema, tracer=tracer
546+
context,
547+
ch_out,
548+
shuffle,
549+
post_sort_ir,
550+
ir_context,
551+
ir.schema,
552+
tracer=tracer,
464553
)
465554

466555

467-
@generate_ir_sub_network.register(ShuffleSorted)
468-
def _shuffle_sorted_network(
469-
ir: ShuffleSorted, rec: SubNetGenerator
470-
) -> tuple[dict, dict]:
556+
@generate_ir_sub_network.register(Sort)
557+
def _sort_rapidsmpf_network(ir: Sort, rec: SubNetGenerator) -> tuple[dict, dict]:
558+
"""Wire multi-partition ``Sort`` to ``sort_actor``; single-partition uses ``default_node_single``."""
559+
executor = rec.state["config_options"].executor
560+
partition_info = rec.state["partition_info"]
561+
dynamic = executor.dynamic_planning is not None
562+
563+
if partition_info[ir].count == 1 and (
564+
not dynamic or isinstance(ir.children[0], Repartition)
565+
):
566+
nodes, channels = process_children(ir, rec)
567+
channels[ir] = ChannelManager(rec.state["context"])
568+
nodes[ir] = [
569+
default_node_single(
570+
rec.state["context"],
571+
ir,
572+
rec.state["ir_context"],
573+
channels[ir].reserve_input_slot(),
574+
channels[ir.children[0]].reserve_output_slot(),
575+
)
576+
]
577+
return nodes, channels
578+
471579
(child,) = ir.children
472580
nodes, channels = rec(child)
473581
by = [ne.value.name for ne in ir.by if isinstance(ne.value, Col)]
474582
if len(by) != len(ir.by):
475583
raise NotImplementedError("Sorting columns must be column names.")
476584

477-
executor = rec.state["config_options"].executor
478-
dynamic = executor.dynamic_planning is not None
479585
collective_ids = list(rec.state["collective_id_map"][ir])
480586
expected_id_count = 3 if dynamic else 2
481587
assert len(collective_ids) == expected_id_count, (
@@ -492,7 +598,7 @@ def _shuffle_sorted_network(
492598
ch_in=channels[child].reserve_output_slot(),
493599
ch_out=channels[ir].reserve_input_slot(),
494600
by=by,
495-
num_partitions=rec.state["partition_info"][ir].count,
601+
num_partitions=partition_info[ir].count,
496602
executor=executor,
497603
collective_ids=collective_ids,
498604
)

0 commit comments

Comments
 (0)