Skip to content

Commit a81fd07

Browse files
authored
Improve error message when default pool lacks managed allocation support (#1835)
When ManagedMemoryResource() is called without options on a platform where the default memory pool does not support managed allocations (e.g. WSL2), the error from cuMemGetMemPool is now caught and re-raised as a RuntimeError with actionable guidance. Made-with: Cursor
1 parent fc1ff27 commit a81fd07

3 files changed

Lines changed: 30 additions & 7 deletions

File tree

cuda_core/cuda/core/_memory/_managed_memory_resource.pyx

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ from cuda.core._utils.cuda_utils cimport (
1111
HANDLE_RETURN,
1212
check_or_create_options,
1313
)
14+
from cuda.core._utils.cuda_utils import CUDAError
1415

1516
from dataclasses import dataclass
1617
import threading
@@ -226,12 +227,25 @@ cdef inline _MMR_init(ManagedMemoryResource self, options):
226227
)
227228

228229
if opts is None:
229-
MP_init_current_pool(
230-
self,
231-
loc_type,
232-
loc_id,
233-
cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED,
234-
)
230+
try:
231+
MP_init_current_pool(
232+
self,
233+
loc_type,
234+
loc_id,
235+
cydriver.CUmemAllocationType.CU_MEM_ALLOCATION_TYPE_MANAGED,
236+
)
237+
except CUDAError as e:
238+
if "CUDA_ERROR_NOT_SUPPORTED" in str(e):
239+
from .._device import Device
240+
if not Device().properties.concurrent_managed_access:
241+
raise RuntimeError(
242+
"The default memory pool on this device does not support "
243+
"managed allocations (concurrent managed access is not "
244+
"available). Use "
245+
"ManagedMemoryResource(options=ManagedMemoryResourceOptions(...)) "
246+
"to create a dedicated managed pool."
247+
) from e
248+
raise
235249
else:
236250
MP_init_create_pool(
237251
self,

cuda_core/cuda/core/_memory/_memory_pool.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,9 @@ cdef int MP_init_current_pool(
257257
self._h_pool = create_mempool_handle_ref(pool)
258258
self._mempool_owned = False
259259
ELSE:
260-
raise RuntimeError("not supported")
260+
raise RuntimeError(
261+
"Getting the current memory pool requires CUDA 13.0 or later"
262+
)
261263
return 0
262264

263265

cuda_core/tests/test_managed_memory_warning.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,13 @@ def device_without_concurrent_managed_access(init_cuda):
4444
return device
4545

4646

47+
@requires_cuda_13
48+
def test_default_pool_error_without_concurrent_access(device_without_concurrent_managed_access):
49+
"""ManagedMemoryResource() raises RuntimeError when the default pool doesn't support managed."""
50+
with pytest.raises(RuntimeError, match="does not support managed allocations"):
51+
ManagedMemoryResource()
52+
53+
4754
@requires_cuda_13
4855
def test_warning_emitted(device_without_concurrent_managed_access):
4956
"""ManagedMemoryResource emits a warning when concurrent managed access is unsupported."""

0 commit comments

Comments
 (0)