Skip to content

Commit e92524a

Browse files
authored
[Multi-GPU Polars] Use current rmm resource in SPMD mode (rapidsai#21842)
Fix `spmd_execution()` MR lifecycle: wrap, set as current, restore `spmd_execution()` wrapped the current device resource in `RmmResourceAdaptor`, but never installed it as the _current_ resource. As a result, libcudf temporary allocations bypassed the adaptor. * Adds a `set_memory_resource(mr)` context manager in `rapidsmpf/utils.py` that saves, sets, and restores the current device resource. * Uses it in `spmd_execution()` so both the RapidsMPF `Context` and libcudf share the same `RmmResourceAdaptor`. Authors: - Mads R. B. Kristensen (https://github.com/madsbk) Approvers: - Tom Augspurger (https://github.com/TomAugspurger) URL: rapidsai#21842
1 parent 2b799cd commit e92524a

6 files changed

Lines changed: 120 additions & 59 deletions

File tree

python/cudf_polars/cudf_polars/experimental/benchmarks/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1864,19 +1864,23 @@ def run_polars_spmd(
18641864
"""Run benchmark queries using SPMD execution via the ``rrun`` launcher."""
18651865
if run_config.collect_traces:
18661866
raise NotImplementedError(
1867-
"--collect-traces is not yet supported with --cluster spmd."
1867+
"--collect-traces is not yet supported with --cluster spmd"
18681868
)
1869+
if run_config.rmm_async:
1870+
raise NotImplementedError("--rmm-async is not supported with --cluster spmd")
18691871
executor_options = get_executor_options(run_config, benchmark=benchmark)
18701872
# "runtime" and "cluster" are reserved — spmd_execution sets them
18711873
executor_options.pop("runtime", None)
18721874
executor_options.pop("cluster", None)
1875+
rmm.mr.set_current_device_resource(
1876+
rmm.mr.CudaAsyncMemoryResource(release_threshold=args.rmm_release_threshold)
1877+
)
18731878
with spmd_execution(
1874-
mr=rmm.mr.CudaAsyncMemoryResource(release_threshold=args.rmm_release_threshold)
1875-
if run_config.rmm_async
1876-
else None,
18771879
executor_options=executor_options,
1878-
parquet_options=parquet_options,
1879-
cuda_stream_policy=run_config.stream_policy,
1880+
engine_options={
1881+
"parquet_options": parquet_options,
1882+
"cuda_stream_policy": run_config.stream_policy,
1883+
},
18801884
) as (comm, ctx, engine):
18811885
from cudf_polars.experimental.rapidsmpf.collectives.common import reserve_op_id
18821886
from cudf_polars.experimental.rapidsmpf.frontend.spmd import (

python/cudf_polars/cudf_polars/experimental/rapidsmpf/frontend/spmd.py

Lines changed: 50 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,10 @@
2828
from cudf_polars.containers import DataFrame
2929
from cudf_polars.dsl.ir import IRExecutionContext
3030
from cudf_polars.experimental.rapidsmpf.core import generate_network
31-
from cudf_polars.experimental.rapidsmpf.utils import empty_table_chunk
31+
from cudf_polars.experimental.rapidsmpf.utils import (
32+
empty_table_chunk,
33+
set_memory_resource,
34+
)
3235
from cudf_polars.experimental.utils import _concat
3336
from cudf_polars.utils.config import SPMDContext
3437

@@ -87,11 +90,11 @@ def evaluate_pipeline_spmd_mode(
8790
"""
8891
if config_options.executor.runtime != "rapidsmpf":
8992
raise RuntimeError("Runtime must be rapidsmpf")
90-
if config_options.executor.spmd is None:
91-
raise RuntimeError("spmd must be set for SPMD mode")
92-
comm = config_options.executor.spmd.comm
93-
context = config_options.executor.spmd.context
94-
py_executor = config_options.executor.spmd.py_executor
93+
if config_options.executor.spmd_context is None:
94+
raise RuntimeError("spmd_context must be set for SPMD mode")
95+
comm = config_options.executor.spmd_context.comm
96+
context = config_options.executor.spmd_context.context
97+
py_executor = config_options.executor.spmd_context.py_executor
9598

9699
ir_context = IRExecutionContext(get_cuda_stream=context.get_stream_from_pool)
97100

@@ -210,10 +213,9 @@ def allgather_polars_dataframe(
210213
@contextmanager
211214
def spmd_execution(
212215
*,
213-
mr: rmm.mr.DeviceMemoryResource | None = None,
214216
rapidsmpf_options: Options | None = None,
215217
executor_options: dict[str, object] | None = None,
216-
**engine_kwargs: Any,
218+
engine_options: dict[str, Any] | None = None,
217219
) -> Iterator[tuple[Communicator, Context, pl.GPUEngine]]:
218220
"""
219221
Context manager that bootstraps a RapidsMPF SPMD context and a matching GPUEngine.
@@ -239,6 +241,27 @@ def spmd_execution(
239241
240242
All resources (communicator, stream pool, thread-pool) are released on exit.
241243
244+
**Memory resource**
245+
246+
``spmd_execution`` captures ``rmm.mr.get_current_device_resource()`` at entry,
247+
wraps it in ``RmmResourceAdaptor`` (so libcudf temporary allocations and the
248+
RapidsMPF ``Context`` share the same resource), sets the wrapped resource as
249+
current, and restores the original on exit.
250+
251+
To use a custom allocator, call ``rmm.mr.set_current_device_resource(your_mr)``
252+
before entering ``spmd_execution()``. Do not pre-wrap it in ``RmmResourceAdaptor``.
253+
254+
.. code-block:: python
255+
256+
import rmm
257+
258+
# Optional: install a pool allocator before entering spmd_execution.
259+
# rmm.mr.set_current_device_resource(
260+
# rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource())
261+
# )
262+
with spmd_execution(...) as (comm, ctx, engine):
263+
...
264+
242265
**DataFrame and LazyFrame semantics**
243266
244267
Because every rank runs an independent Python process, a :class:`~polars.DataFrame`
@@ -283,9 +306,6 @@ def spmd_execution(
283306
284307
Parameters
285308
----------
286-
mr
287-
RMM device memory resource to use. Defaults to
288-
``rmm.mr.CudaAsyncMemoryResource()`` when ``None``.
289309
rapidsmpf_options
290310
RapidsMPF options. Defaults to ``Options(get_environment_variables())``
291311
when ``None``.
@@ -294,20 +314,20 @@ def spmd_execution(
294314
:class:`~polars.lazyframe.engine_config.GPUEngine`. The keys
295315
``"runtime"``, ``"cluster"``, and ``"spmd"`` are reserved and may not
296316
be overridden.
297-
**engine_kwargs
317+
engine_options
298318
Extra keyword arguments forwarded directly to
299319
:class:`~polars.lazyframe.engine_config.GPUEngine`. For example,
300-
pass ``parquet_options={"use_rapidsmpf_native": True}`` to enable
301-
native Parquet reads. The keys ``"memory_resource"`` and
320+
pass ``engine_options={"parquet_options": {"use_rapidsmpf_native": True}}``
321+
to enable native Parquet reads. The keys ``"memory_resource"`` and
302322
``"executor"`` are reserved and may not be overridden.
303323
304324
Yields
305325
------
306-
comm : Communicator
326+
comm
307327
The active RapidsMPF communicator.
308-
ctx : Context
328+
ctx
309329
The active RapidsMPF context.
310-
engine : pl.GPUEngine
330+
engine
311331
A Polars GPU engine wired to ``comm`` and ``ctx``. Pass it to
312332
``LazyFrame.collect(engine=engine)`` on each rank.
313333
@@ -320,8 +340,8 @@ def spmd_execution(
320340
If ``executor_options`` contains any of the reserved keys
321341
``"runtime"``, ``"cluster"``, or ``"spmd"``.
322342
TypeError
323-
If ``engine_kwargs`` contains any of the reserved keys
324-
``"raise_on_fail"``, ``"memory_resource"``, or ``"executor"``.
343+
If ``engine_options`` contains any of the reserved keys
344+
``"memory_resource"`` or ``"executor"``.
325345
326346
Examples
327347
--------
@@ -341,20 +361,20 @@ def spmd_execution(
341361
)
342362

343363
executor_options = executor_options or {}
344-
engine_kwargs = engine_kwargs or {}
364+
engine_options = engine_options or {}
345365

346366
# Check for reserved keys.
347-
if bad := {"runtime", "cluster", "spmd"} & executor_options.keys():
367+
if bad := {"runtime", "cluster", "spmd_context"} & executor_options.keys():
348368
raise TypeError(f"executor_options may not contain reserved keys: {bad}")
349-
if bad := {"memory_resource", "executor"} & engine_kwargs.keys():
350-
raise TypeError(f"engine_kwargs may not contain reserved keys: {bad}")
369+
if bad := {"memory_resource", "executor"} & engine_options.keys():
370+
raise TypeError(f"engine_options may not contain reserved keys: {bad}")
351371

352372
rapidsmpf_options = (
353373
rapidsmpf_options
354374
if rapidsmpf_options is not None
355375
else Options(get_environment_variables())
356376
)
357-
mr = RmmResourceAdaptor(mr if mr is not None else rmm.mr.CudaAsyncMemoryResource())
377+
mr = RmmResourceAdaptor(rmm.mr.get_current_device_resource())
358378
comm = bootstrap.create_ucxx_comm(
359379
progress_thread=ProgressThread(),
360380
type=bootstrap.BackendType.AUTO,
@@ -367,19 +387,22 @@ def spmd_execution(
367387
thread_name_prefix="spmd-executor",
368388
)
369389
try:
370-
with Context.from_options(comm.logger, mr, rapidsmpf_options) as ctx:
390+
with (
391+
set_memory_resource(mr),
392+
Context.from_options(comm.logger, mr, rapidsmpf_options) as ctx,
393+
):
371394
engine = pl.GPUEngine(
372395
memory_resource=ctx.br().device_mr,
373396
executor="streaming",
374397
executor_options={
375398
**executor_options,
376399
"runtime": "rapidsmpf",
377400
"cluster": "spmd",
378-
"spmd": SPMDContext(
401+
"spmd_context": SPMDContext(
379402
comm=comm, context=ctx, py_executor=py_executor
380403
),
381404
},
382-
**engine_kwargs,
405+
**engine_options,
383406
)
384407
yield comm, ctx, engine
385408
finally:

python/cudf_polars/cudf_polars/experimental/rapidsmpf/utils.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,16 @@
55
from __future__ import annotations
66

77
import asyncio
8+
import contextlib
89
import operator
910
import struct
1011
from contextlib import asynccontextmanager
1112
from dataclasses import dataclass
1213
from functools import reduce
1314
from typing import TYPE_CHECKING, Any
1415

16+
import rmm.mr
17+
1518
from cudf_polars.dsl.tracing import LOG_TRACES, Scope
1619

1720
try:
@@ -42,7 +45,7 @@
4245
from cudf_polars.experimental.utils import _concat
4346

4447
if TYPE_CHECKING:
45-
from collections.abc import AsyncIterator, Callable
48+
from collections.abc import AsyncIterator, Callable, Iterator
4649

4750
from rapidsmpf.communicator.communicator import Communicator
4851
from rapidsmpf.streaming.core.channel import Channel
@@ -57,6 +60,27 @@
5760
from cudf_polars.typing import DataType, Schema
5861

5962

63+
@contextlib.contextmanager
64+
def set_memory_resource(mr: rmm.mr.DeviceMemoryResource) -> Iterator[None]:
65+
"""
66+
Context manager that temporarily sets ``mr`` as the current device resource.
67+
68+
On entry, ``mr`` is installed via ``rmm.mr.set_current_device_resource(mr)``.
69+
On exit, the previously active resource is restored unconditionally.
70+
71+
Parameters
72+
----------
73+
mr
74+
The memory resource to activate for the duration of the block.
75+
"""
76+
old = rmm.mr.get_current_device_resource()
77+
rmm.mr.set_current_device_resource(mr)
78+
try:
79+
yield
80+
finally:
81+
rmm.mr.set_current_device_resource(old)
82+
83+
6084
@asynccontextmanager
6185
async def shutdown_on_error(
6286
context: Context,

python/cudf_polars/cudf_polars/utils/config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -872,7 +872,7 @@ class StreamingExecutor:
872872
f"{_env_prefix}__RAPIDSMPF_PY_EXECUTOR_MAX_WORKERS", int, default=None
873873
)
874874
)
875-
spmd: SPMDContext | None = None
875+
spmd_context: SPMDContext | None = None
876876
ray_context: RayContext | None = None
877877

878878
def __post_init__(self) -> None: # noqa: D105

python/cudf_polars/docs/cudf-polars-mp.md

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -356,35 +356,41 @@ The result is guaranteed to be a `pl.DataFrame` containing rows from all ranks i
356356

357357
### Passing options
358358

359-
`mr`, `rapidsmpf_options`, `executor_options`, and `engine_kwargs` accept pass-through
359+
`rapidsmpf_options`, `executor_options`, and `engine_options` accept pass-through
360360
arguments:
361361

362362
```python
363363
import rmm
364364
from rapidsmpf.integrations.cudf_polars import Options
365365

366366
with spmd_execution(
367-
mr=rmm.mr.PoolMemoryResource(rmm.mr.CudaMemoryResource()),
368367
rapidsmpf_options=Options(num_streaming_threads=8),
369368
executor_options={
370369
"max_rows_per_partition": 500_000,
371370
"rapidsmpf_spill": True,
372371
"rapidsmpf_py_executor_max_workers": 2,
373372
},
374-
parquet_options={"use_rapidsmpf_native": True},
373+
engine_options={"parquet_options": {"use_rapidsmpf_native": True}},
375374
) as (comm, ctx, engine):
376375
...
377376
```
378377

379-
`mr` is an `rmm.mr.DeviceMemoryResource` used as the GPU memory resource for the
380-
RapidsMPF `Context`. Defaults to `None` (uses the current device resource).
378+
**Memory resource:** `spmd_execution` captures `rmm.mr.get_current_device_resource()`
379+
at entry, wraps it in `RmmResourceAdaptor` (so libcudf temporary allocations and the
380+
RapidsMPF `Context` share the same resource), sets the wrapped resource as current, and
381+
restores the original resource on exit. To use a custom allocator, call
382+
`rmm.mr.set_current_device_resource(your_mr)` **before** entering `spmd_execution()`.
383+
Do not pre-wrap it in `RmmResourceAdaptor`.
381384

382385
`rapidsmpf_options` is an `Options` object passed to the RapidsMPF `Context`. Defaults
383386
to `None` (uses RapidsMPF defaults).
384387

385388
`executor_options` is forwarded directly to `pl.GPUEngine` as its `executor_options`
386389
argument; user-supplied keys are merged with reserved entries set by `spmd_execution()`.
387-
Any additional keyword arguments to `spmd_execution()` are also forwarded to `pl.GPUEngine`.
390+
391+
`engine_options` is forwarded as keyword arguments to `pl.GPUEngine`. For example,
392+
pass `engine_options={"parquet_options": {"use_rapidsmpf_native": True}}` to enable
393+
native Parquet reads.
388394

389395
Notable `executor_options` keys:
390396

python/cudf_polars/tests/experimental/rapidsmpf/test_spmd.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from __future__ import annotations
66

7-
from typing import Any
8-
97
import pytest
108
from rapidsmpf.bootstrap import is_running_with_rrun
9+
from rapidsmpf.rmm_resource_adaptor import RmmResourceAdaptor
1110

1211
import polars as pl
1312

@@ -35,39 +34,30 @@ def test_yields_context_and_engine() -> None:
3534

3635
def test_reserved_keys() -> None:
3736
"""executor_options rejects reserved keys."""
38-
for key in ("runtime", "cluster", "spmd"):
37+
for key in ("runtime", "cluster", "spmd_context"):
3938
with (
4039
pytest.raises(TypeError, match="reserved"),
4140
spmd_execution(executor_options={key: "anything"}),
4241
):
4342
pass
4443

4544

46-
def test_engine_kwargs_reserved_keys() -> None:
47-
"""engine_kwargs rejects keys that are set explicitly by spmd_execution."""
45+
def test_engine_options_reserved_keys() -> None:
46+
"""engine_options rejects keys that are set explicitly by spmd_execution."""
4847
for key in ("memory_resource", "executor"):
49-
kwargs: dict[str, Any] = {key: "anything"}
5048
with (
5149
pytest.raises(TypeError, match="reserved"),
52-
spmd_execution(**kwargs),
50+
spmd_execution(engine_options={key: "anything"}),
5351
):
5452
pass
5553

5654

57-
def test_engine_kwargs_parquet_options() -> None:
58-
"""engine_kwargs forwards parquet_options to GPUEngine without error."""
59-
with spmd_execution(parquet_options={}) as (comm, ctx, engine):
55+
def test_engine_options_parquet_options() -> None:
56+
"""engine_options forwards parquet_options to GPUEngine without error."""
57+
with spmd_execution(engine_options={"parquet_options": {}}) as (comm, ctx, engine):
6058
assert isinstance(engine, pl.GPUEngine)
6159

6260

63-
def test_custom_mr() -> None:
64-
"""spmd_execution accepts a custom memory resource."""
65-
mr = rmm.mr.CudaMemoryResource()
66-
with spmd_execution(mr=mr) as (comm, ctx, engine):
67-
result = pl.LazyFrame({"a": [1, 2, 3]}).collect(engine=engine)
68-
assert result.shape == (3, 1)
69-
70-
7161
def test_scan() -> None:
7262
"""Each rank scans its own single-row LazyFrame and gets that row back."""
7363
with spmd_execution() as (comm, ctx, engine):
@@ -158,6 +148,20 @@ def test_allgather_polars_dataframe_empty() -> None:
158148
assert result.dtypes == [pl.Int32, pl.Float64]
159149

160150

151+
def test_mr_wrapped_as_current_inside_context() -> None:
152+
"""Inside spmd_execution the current device resource is RmmResourceAdaptor."""
153+
with spmd_execution() as (comm, ctx, engine):
154+
assert isinstance(rmm.mr.get_current_device_resource(), RmmResourceAdaptor)
155+
156+
157+
def test_mr_restored_after_context() -> None:
158+
"""After spmd_execution exits the original device resource is restored."""
159+
original = rmm.mr.get_current_device_resource()
160+
with spmd_execution() as (comm, ctx, engine):
161+
pass
162+
assert rmm.mr.get_current_device_resource() is original
163+
164+
161165
def test_allgather_polars_dataframe_multi_column() -> None:
162166
"""allgather preserves column names, count, and dtypes for multi-column DataFrames."""
163167
with spmd_execution() as (comm, ctx, _):

0 commit comments

Comments
 (0)