Skip to content

Commit 33f3ed7

Browse files
authored
Use the new make_table_chunks_available_or_wait API from RapidsMPF (rapidsai#21291)
Replacing `opaque_reservation`. Authors: - Mads R. B. Kristensen (https://github.com/madsbk) - Richard (Rick) Zamora (https://github.com/rjzamora) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) - Richard (Rick) Zamora (https://github.com/rjzamora) URL: rapidsai#21291
1 parent dd02e83 commit 33f3ed7

6 files changed

Lines changed: 125 additions & 201 deletions

File tree

python/cudf_polars/cudf_polars/experimental/rapidsmpf/io.py

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,10 @@
99
import math
1010
from typing import TYPE_CHECKING, Any
1111

12+
from rapidsmpf.memory.memory_reservation import opaque_memory_usage
13+
from rapidsmpf.streaming.core.memory_reserve_or_wait import (
14+
reserve_memory,
15+
)
1216
from rapidsmpf.streaming.core.message import Message
1317
from rapidsmpf.streaming.cudf.channel_metadata import ChannelMetadata
1418
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
@@ -37,7 +41,6 @@
3741
)
3842
from cudf_polars.experimental.rapidsmpf.utils import (
3943
ChannelManager,
40-
opaque_reservation,
4144
send_metadata,
4245
)
4346

@@ -345,25 +348,29 @@ async def read_chunk(
345348
tracer
346349
The actor tracer for collecting runtime statistics.
347350
"""
348-
with opaque_reservation(context, estimated_chunk_bytes):
351+
with opaque_memory_usage(
352+
await reserve_memory(
353+
context, size=estimated_chunk_bytes, net_memory_delta=estimated_chunk_bytes
354+
)
355+
):
349356
df = await asyncio.to_thread(
350357
scan.do_evaluate,
351358
*scan._non_child_args,
352359
context=ir_context,
353360
)
354-
if tracer is not None:
355-
tracer.add_chunk(table=df.table)
356-
await ch_out.send(
357-
context,
358-
Message(
359-
seq_num,
360-
TableChunk.from_pylibcudf_table(
361-
df.table,
362-
df.stream,
363-
exclusive_view=True,
364-
),
361+
if tracer is not None:
362+
tracer.add_chunk(table=df.table)
363+
await ch_out.send(
364+
context,
365+
Message(
366+
seq_num,
367+
TableChunk.from_pylibcudf_table(
368+
df.table,
369+
df.stream,
370+
exclusive_view=True,
365371
),
366-
)
372+
),
373+
)
367374

368375

369376
@define_py_node()

python/cudf_polars/cudf_polars/experimental/rapidsmpf/join.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,11 @@
88
from typing import TYPE_CHECKING, Any, Literal
99

1010
from rapidsmpf.memory.buffer import MemoryType
11+
from rapidsmpf.memory.memory_reservation import opaque_memory_usage
12+
from rapidsmpf.streaming.core.memory_reserve_or_wait import (
13+
missing_net_memory_delta,
14+
reserve_memory,
15+
)
1116
from rapidsmpf.streaming.core.message import Message
1217
from rapidsmpf.streaming.cudf.channel_metadata import ChannelMetadata
1318
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
@@ -27,7 +32,6 @@
2732
ChannelManager,
2833
chunk_to_frame,
2934
empty_table_chunk,
30-
opaque_reservation,
3135
process_children,
3236
recv_metadata,
3337
send_metadata,
@@ -191,28 +195,31 @@ async def broadcast_join_node(
191195
return
192196
else:
193197
large_chunk_processed = True
194-
large_chunk = TableChunk.from_message(msg).make_available_and_spill(
195-
context.br(), allow_overbooking=True
198+
large_chunk = await TableChunk.from_message(msg).make_available_or_wait(
199+
context,
200+
net_memory_delta=missing_net_memory_delta,
196201
)
197202
seq_num = msg.sequence_number
198-
del msg
199203

200204
large_df = DataFrame.from_table(
201205
large_chunk.table_view(),
202206
list(large_child.schema.keys()),
203207
list(large_child.schema.values()),
204208
large_chunk.stream,
205209
)
210+
large_chunk_size = large_chunk.data_alloc_size(MemoryType.DEVICE)
211+
del large_chunk # `large_df` keeps `large_chunk` alive.
206212

207213
# Lazily create empty small table if small_dfs is empty
208214
if not small_dfs:
209215
stream = ir_context.get_cuda_stream()
210216
empty_small_chunk = empty_table_chunk(small_child, context, stream)
211217
small_dfs = [chunk_to_frame(empty_small_chunk, small_child)]
212218

213-
large_chunk_size = large_chunk.data_alloc_size(MemoryType.DEVICE)
214219
input_bytes = large_chunk_size + small_size
215-
with opaque_reservation(context, input_bytes):
220+
with opaque_memory_usage(
221+
await reserve_memory(context, size=input_bytes, net_memory_delta=0)
222+
):
216223
df = _concat(
217224
*[
218225
await asyncio.to_thread(
@@ -229,18 +236,18 @@ async def broadcast_join_node(
229236
],
230237
context=ir_context,
231238
)
239+
del large_df
232240

233-
# Send output chunk
234-
await ch_out.send(
235-
context,
236-
Message(
237-
seq_num,
238-
TableChunk.from_pylibcudf_table(
239-
df.table, df.stream, exclusive_view=True
240-
),
241+
# Send output chunk
242+
await ch_out.send(
243+
context,
244+
Message(
245+
seq_num,
246+
TableChunk.from_pylibcudf_table(
247+
df.table, df.stream, exclusive_view=True
241248
),
242-
)
243-
del df, large_df, large_chunk
249+
),
250+
)
244251

245252
del small_dfs, small_chunks
246253
await ch_out.drain(context)

python/cudf_polars/cudf_polars/experimental/rapidsmpf/nodes.py

Lines changed: 52 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
from typing import TYPE_CHECKING, Any, cast
99

1010
from rapidsmpf.memory.buffer import MemoryType
11+
from rapidsmpf.memory.memory_reservation import opaque_memory_usage
1112
from rapidsmpf.streaming.core.message import Message
1213
from rapidsmpf.streaming.core.node import define_py_node
1314
from rapidsmpf.streaming.core.spillable_messages import SpillableMessages
1415
from rapidsmpf.streaming.cudf.channel_metadata import ChannelMetadata
15-
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
16+
from rapidsmpf.streaming.cudf.table_chunk import (
17+
TableChunk,
18+
make_table_chunks_available_or_wait,
19+
)
1620

1721
from cudf_polars.containers import DataFrame
1822
from cudf_polars.dsl.ir import IR, Cache, Empty, Filter, Projection
@@ -23,7 +27,6 @@
2327
ChannelManager,
2428
empty_table_chunk,
2529
make_spill_function,
26-
opaque_reservation,
2730
process_children,
2831
recv_metadata,
2932
remap_partitioning,
@@ -102,18 +105,21 @@ async def default_node_single(
102105
else:
103106
# Make sure we have an empty chunk in case do_evaluate
104107
# always produces rows (e.g. aggregation)
105-
stream = ir_context.get_cuda_stream()
106-
chunk = empty_table_chunk(ir.children[0], context, stream)
108+
chunk = empty_table_chunk(
109+
ir.children[0], context, ir_context.get_cuda_stream()
110+
)
107111
else:
108112
received_any = True
109-
chunk = TableChunk.from_message(msg).make_available_and_spill(
110-
context.br(), allow_overbooking=True
111-
)
113+
chunk = TableChunk.from_message(msg)
112114
seq_num = msg.sequence_number
113-
del msg
114115

115-
input_bytes = chunk.data_alloc_size(MemoryType.DEVICE)
116-
with opaque_reservation(context, input_bytes):
116+
chunk, extra = await make_table_chunks_available_or_wait(
117+
context,
118+
chunk,
119+
reserve_extra=chunk.data_alloc_size(),
120+
net_memory_delta=0,
121+
)
122+
with opaque_memory_usage(extra):
117123
df = await asyncio.to_thread(
118124
ir.do_evaluate,
119125
*ir._non_child_args,
@@ -125,18 +131,18 @@ async def default_node_single(
125131
),
126132
context=ir_context,
127133
)
128-
if tracer is not None:
129-
tracer.add_chunk(table=df.table)
130-
await ch_out.send(
131-
context,
132-
Message(
133-
seq_num,
134-
TableChunk.from_pylibcudf_table(
135-
df.table, chunk.stream, exclusive_view=True
136-
),
134+
if tracer is not None:
135+
tracer.add_chunk(table=df.table)
136+
await ch_out.send(
137+
context,
138+
Message(
139+
seq_num,
140+
TableChunk.from_pylibcudf_table(
141+
df.table, chunk.stream, exclusive_view=True
137142
),
138-
)
139-
del df, chunk
143+
),
144+
)
145+
del df, chunk
140146

141147
await ch_out.drain(context)
142148

@@ -234,10 +240,15 @@ async def default_node_multi(
234240
ready_chunks[ch_idx] = empty_table_chunk(child, context, stream)
235241

236242
# Ensure all table chunks are unspilled and available.
237-
ready_chunks = [
238-
chunk.make_available_and_spill(context.br(), allow_overbooking=True)
239-
for chunk in cast(list[TableChunk], ready_chunks)
240-
]
243+
ready_chunks, extra = await make_table_chunks_available_or_wait(
244+
context,
245+
ready_chunks,
246+
reserve_extra=sum(
247+
chunk.data_alloc_size()
248+
for chunk in cast(list[TableChunk], ready_chunks)
249+
),
250+
net_memory_delta=0,
251+
)
241252
dfs = [
242253
DataFrame.from_table(
243254
chunk.table_view(), # type: ignore[union-attr]
@@ -247,33 +258,29 @@ async def default_node_multi(
247258
)
248259
for chunk, child in zip(ready_chunks, ir.children, strict=True)
249260
]
250-
251-
input_bytes = sum(
252-
chunk.data_alloc_size(MemoryType.DEVICE)
253-
for chunk in cast(list[TableChunk], ready_chunks)
254-
)
255-
with opaque_reservation(context, input_bytes):
261+
with opaque_memory_usage(extra):
256262
df = await asyncio.to_thread(
257263
ir.do_evaluate,
258264
*ir._non_child_args,
259265
*dfs,
260266
context=ir_context,
261267
)
262-
if tracer is not None:
263-
tracer.add_chunk(table=df.table)
264-
await ch_out.send(
265-
context,
266-
Message(
267-
seq_num,
268-
TableChunk.from_pylibcudf_table(
269-
df.table,
270-
df.stream,
271-
exclusive_view=True,
272-
),
268+
del dfs
269+
if tracer is not None:
270+
tracer.add_chunk(table=df.table)
271+
await ch_out.send(
272+
context,
273+
Message(
274+
seq_num,
275+
TableChunk.from_pylibcudf_table(
276+
df.table,
277+
df.stream,
278+
exclusive_view=True,
273279
),
274-
)
275-
seq_num += 1
276-
del df, dfs
280+
),
281+
)
282+
seq_num += 1
283+
del df
277284

278285
# Drain the output channel
279286
del ready_chunks

python/cudf_polars/cudf_polars/experimental/rapidsmpf/repartition.py

Lines changed: 25 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,14 @@
77
import math
88
from typing import TYPE_CHECKING, Any
99

10-
from rapidsmpf.memory.buffer import MemoryType
10+
from rapidsmpf.memory.memory_reservation import opaque_memory_usage
1111
from rapidsmpf.streaming.core.message import Message
1212
from rapidsmpf.streaming.core.node import define_py_node
1313
from rapidsmpf.streaming.cudf.channel_metadata import ChannelMetadata
14-
from rapidsmpf.streaming.cudf.table_chunk import TableChunk
14+
from rapidsmpf.streaming.cudf.table_chunk import (
15+
TableChunk,
16+
make_table_chunks_available_or_wait,
17+
)
1518

1619
from cudf_polars.containers import DataFrame
1720
from cudf_polars.experimental.rapidsmpf.collectives.allgather import AllGatherManager
@@ -20,7 +23,6 @@
2023
from cudf_polars.experimental.rapidsmpf.utils import (
2124
ChannelManager,
2225
empty_table_chunk,
23-
opaque_reservation,
2426
recv_metadata,
2527
send_metadata,
2628
)
@@ -183,18 +185,16 @@ async def concatenate_node(
183185
if msg is None:
184186
done_receiving = True
185187
break
186-
chunks.append(
187-
TableChunk.from_message(msg).make_available_and_spill(
188-
context.br(), allow_overbooking=True
189-
)
190-
)
191-
del msg
188+
chunks.append(TableChunk.from_message(msg))
192189

193190
if chunks:
194-
input_bytes = sum(
195-
chunk.data_alloc_size(MemoryType.DEVICE) for chunk in chunks
191+
chunks, extra = await make_table_chunks_available_or_wait(
192+
context,
193+
chunks,
194+
reserve_extra=sum(chunk.data_alloc_size() for chunk in chunks),
195+
net_memory_delta=0,
196196
)
197-
with opaque_reservation(context, input_bytes):
197+
with opaque_memory_usage(extra):
198198
df = _concat(
199199
*(
200200
DataFrame.from_table(
@@ -207,19 +207,20 @@ async def concatenate_node(
207207
),
208208
context=ir_context,
209209
)
210-
if tracer is not None:
211-
tracer.add_chunk(table=df.table)
212-
await ch_out.send(
213-
context,
214-
Message(
215-
seq_num,
216-
TableChunk.from_pylibcudf_table(
217-
df.table, df.stream, exclusive_view=True
218-
),
210+
del chunks
211+
if tracer is not None:
212+
tracer.add_chunk(table=df.table)
213+
await ch_out.send(
214+
context,
215+
Message(
216+
seq_num,
217+
TableChunk.from_pylibcudf_table(
218+
df.table, df.stream, exclusive_view=True
219219
),
220-
)
221-
seq_num += 1
222-
del df, chunks
220+
),
221+
)
222+
seq_num += 1
223+
del df
223224

224225
# Break if we reached end of stream
225226
if done_receiving:

0 commit comments

Comments
 (0)