Skip to content

Commit 73d2d89

Browse files
cuda.core: convert peer_accessible_by to a live MutableSet view
DeviceMemoryResource.peer_accessible_by previously returned a sorted tuple[int, ...] backed by a Python-level cache, which was prone to divergence from driver state across multiple wrappers around the same memory pool. The setter accepted Device | int and emitted a single batched cuMemPoolSetAccess covering the diff against the cache. This commit replaces the property with a live driver-backed view: - Adds PeerAccessibleBySetProxy in _memory/_peer_access_utils.py, a collections.abc.MutableSet whose reads call cuMemPoolGetAccess and whose writes call cuMemPoolSetAccess. Iteration yields Device objects; add, discard, and __contains__ accept either a Device or a device-ordinal int. The proxy is constructed fresh on every property access, so there is nothing to cache or pickle. - Drops the _peer_accessible_by cache field (and its initializations in __cinit__, _DMR_init, and from_allocation_handle), eliminating the owned/non-owned read split. All pools now share the same code path and always query the driver. - All bulk operations on the proxy (update, |=, &=, -=, ^=, clear, pop) issue exactly one cuMemPoolSetAccess call. Peer-access transitions can take seconds per pool because every existing memory mapping is updated, so coalescing into a single driver call lets the toolkit handle the mappings in parallel. The property setter (mr.peer_accessible_by = [...]) preserves its original single-call behavior via the same shared planner path. - Single-element add validates can_access_peer through plan_peer_access_update, matching the existing setter contract. This is a breaking change captured in the v1.0.0 release notes. Callers comparing against tuples must update to set comparisons (mr.peer_accessible_by == {Device(0)}). Existing tests are migrated; new tests for set-interface conformance are intentionally deferred to a follow-up. Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 64e2e6a commit 73d2d89

7 files changed

Lines changed: 347 additions & 89 deletions

File tree

cuda_core/cuda/core/_memory/_device_memory_resource.pxd

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ from cuda.core._memory._ipc cimport IPCDataForMR
99
cdef class DeviceMemoryResource(_MemPool):
1010
cdef:
1111
int _dev_id
12-
object _peer_accessible_by
1312

1413

1514
cpdef DMR_mempool_get_access(DeviceMemoryResource, int)

cuda_core/cuda/core/_memory/_device_memory_resource.pyx

Lines changed: 107 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,11 @@ import multiprocessing
2525
import platform # no-cython-lint
2626
import uuid
2727

28-
from ._peer_access_utils import plan_peer_access_update
28+
from ._peer_access_utils import (
29+
PeerAccessibleBySetProxy,
30+
_resolve_peer_device_id,
31+
plan_peer_access_update,
32+
)
2933
from cuda.core._utils.cuda_utils import check_multiprocessing_start_method
3034

3135
__all__ = ['DeviceMemoryResource', 'DeviceMemoryResourceOptions']
@@ -131,7 +135,6 @@ cdef class DeviceMemoryResource(_MemPool):
131135

132136
def __cinit__(self, *args, **kwargs):
133137
self._dev_id = cydriver.CU_DEVICE_INVALID
134-
self._peer_accessible_by = None
135138

136139
def __init__(self, device_id: Device | int, options=None):
137140
_DMR_init(self, device_id, options)
@@ -191,7 +194,6 @@ cdef class DeviceMemoryResource(_MemPool):
191194
_ipc.MP_from_allocation_handle(cls, alloc_handle))
192195
from .._device import Device
193196
mr._dev_id = Device(device_id).device_id
194-
mr._peer_accessible_by = ()
195197
return mr
196198

197199
@property
@@ -213,34 +215,60 @@ cdef class DeviceMemoryResource(_MemPool):
213215
@property
214216
def peer_accessible_by(self):
215217
"""
216-
Get or set the devices that can access allocations from this memory
217-
pool. Access can be modified at any time and affects all allocations
218+
Live driver-backed set view of the devices that can access allocations
218219
from this memory pool.
219220
220-
Returns a tuple of sorted device IDs that currently have peer access to
221-
allocations from this memory pool.
221+
Returns a :class:`PeerAccessibleBySetProxy` (a
222+
:class:`collections.abc.MutableSet`) whose reads call
223+
``cuMemPoolGetAccess`` and whose writes call ``cuMemPoolSetAccess``.
224+
Iteration yields :class:`Device` objects; ``add``, ``discard``, and
225+
``__contains__`` accept either a :class:`Device` or a device-ordinal
226+
``int``. There is no in-memory cache, so the view always reflects the
227+
current driver state and stays consistent across multiple wrappers
228+
around the same pool.
222229
223-
When setting, accepts a sequence of :obj:`~_device.Device` objects or device IDs.
224-
Setting to an empty sequence revokes all peer access.
230+
When setting, accepts an iterable of :obj:`~_device.Device` objects or
231+
device IDs. Setting replaces the full set in a single batched driver call.
225232
226-
For non-owned pools (the default or current device pool), the state
227-
is always queried from the driver to reflect changes made by other
228-
wrappers or direct driver calls.
233+
Bulk operations (``update``, ``|=``, ``&=``, ``-=``, ``^=``, ``clear``,
234+
and the property setter) each issue exactly one ``cuMemPoolSetAccess``
235+
call so the toolkit can update existing memory mappings in parallel.
229236
230237
Examples
231238
--------
232239
>>> dmr = DeviceMemoryResource(0)
233-
>>> dmr.peer_accessible_by = [1] # Grant access to device 1
234-
>>> assert dmr.peer_accessible_by == (1,)
235-
>>> dmr.peer_accessible_by = [] # Revoke access
240+
>>> dmr.peer_accessible_by.add(1) # grant access to device 1
241+
>>> assert dmr.peer_accessible_by == {Device(1)}
242+
>>> dmr.peer_accessible_by |= {Device(2)} # batched grant via |=
243+
>>> dmr.peer_accessible_by = [] # revoke all in one call
236244
"""
237-
if not self._mempool_owned:
238-
_DMR_query_peer_access(self)
239-
return self._peer_accessible_by
245+
return PeerAccessibleBySetProxy(self)
240246

241247
@peer_accessible_by.setter
242248
def peer_accessible_by(self, devices):
243-
_DMR_set_peer_accessible_by(self, devices)
249+
_DMR_replace_peer_accessible_by(self, devices)
250+
251+
def _query_peer_access_ids(self):
252+
"""Return the current peer device IDs as a sorted tuple of ints.
253+
254+
Always queries the driver via ``cuMemPoolGetAccess`` for every visible
255+
device. Used by :class:`PeerAccessibleBySetProxy` for ``__iter__`` and
256+
``__len__``.
257+
"""
258+
return _DMR_query_peer_access_ids(self)
259+
260+
def _peer_access_includes(self, int dev_id) -> bool:
261+
"""Return True if peer access from ``dev_id`` is currently granted."""
262+
return _DMR_peer_access_includes(self, dev_id)
263+
264+
def _apply_peer_access_diff(self, to_add, to_remove):
265+
"""Issue a single ``cuMemPoolSetAccess`` for the given add/remove deltas.
266+
267+
``to_add`` and ``to_remove`` are iterables of device-ordinal ints.
268+
Both must already be filtered (no owner, no overlap, no duplicates).
269+
Used by :class:`PeerAccessibleBySetProxy` for batched writes.
270+
"""
271+
_DMR_apply_peer_access_diff(self, tuple(to_add), tuple(to_remove))
244272

245273
@property
246274
def is_device_accessible(self) -> bool:
@@ -253,8 +281,8 @@ cdef class DeviceMemoryResource(_MemPool):
253281
return False
254282

255283

256-
cdef inline _DMR_query_peer_access(DeviceMemoryResource self):
257-
"""Query the driver for the actual peer access state of this pool."""
284+
cdef inline tuple _DMR_query_peer_access_ids(DeviceMemoryResource self):
285+
"""Return the current peer device IDs as a sorted tuple of ints."""
258286
cdef int total
259287
cdef cydriver.CUmemAccess_flags flags
260288
cdef cydriver.CUmemLocation location
@@ -273,59 +301,74 @@ cdef inline _DMR_query_peer_access(DeviceMemoryResource self):
273301
if flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE:
274302
peers.append(dev_id)
275303

276-
self._peer_accessible_by = tuple(sorted(peers))
304+
return tuple(sorted(peers))
277305

278306

279-
cdef inline _DMR_set_peer_accessible_by(DeviceMemoryResource self, devices):
280-
from .._device import Device
307+
cdef inline bint _DMR_peer_access_includes(DeviceMemoryResource self, int dev_id):
308+
"""Return True if peer access from ``dev_id`` is currently granted."""
309+
cdef cydriver.CUmemAccess_flags flags
310+
cdef cydriver.CUmemLocation location
281311

282-
this_dev = Device(self._dev_id)
283-
cdef object resolve_device_id = lambda dev: Device(dev).device_id
284-
cdef object plan
285-
cdef tuple target_ids
286-
cdef tuple to_add
287-
cdef tuple to_rm
288-
if not self._mempool_owned:
289-
_DMR_query_peer_access(self)
290-
plan = plan_peer_access_update(
291-
owner_device_id=self._dev_id,
292-
current_peer_ids=self._peer_accessible_by,
293-
requested_devices=devices,
294-
resolve_device_id=resolve_device_id,
295-
can_access_peer=this_dev.can_access_peer,
296-
)
297-
target_ids = plan.target_ids
298-
to_add = plan.to_add
299-
to_rm = plan.to_remove
300-
cdef size_t count = len(to_add) + len(to_rm)
312+
location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
313+
location.id = dev_id
314+
with nogil:
315+
HANDLE_RETURN(cydriver.cuMemPoolGetAccess(&flags, as_cu(self._h_pool), &location))
316+
return flags == cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
317+
318+
319+
cdef inline _DMR_apply_peer_access_diff(
320+
DeviceMemoryResource self, tuple to_add, tuple to_remove
321+
):
322+
"""Issue one ``cuMemPoolSetAccess`` for the given add/remove deltas."""
323+
cdef size_t count = len(to_add) + len(to_remove)
301324
cdef cydriver.CUmemAccessDesc* access_desc = NULL
302325
cdef size_t i = 0
303326

304-
if count > 0:
305-
access_desc = <cydriver.CUmemAccessDesc*>PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc))
306-
if access_desc == NULL:
307-
raise MemoryError("Failed to allocate memory for access descriptors")
327+
if count == 0:
328+
return
329+
330+
access_desc = <cydriver.CUmemAccessDesc*>PyMem_Malloc(count * sizeof(cydriver.CUmemAccessDesc))
331+
if access_desc == NULL:
332+
raise MemoryError("Failed to allocate memory for access descriptors")
333+
334+
try:
335+
for dev_id in to_add:
336+
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
337+
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
338+
access_desc[i].location.id = dev_id
339+
i += 1
340+
for dev_id in to_remove:
341+
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE
342+
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
343+
access_desc[i].location.id = dev_id
344+
i += 1
345+
346+
with nogil:
347+
HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(self._h_pool), access_desc, count))
348+
finally:
349+
if access_desc != NULL:
350+
PyMem_Free(access_desc)
308351

309-
try:
310-
for dev_id in to_add:
311-
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_READWRITE
312-
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
313-
access_desc[i].location.id = dev_id
314-
i += 1
315352

316-
for dev_id in to_rm:
317-
access_desc[i].flags = cydriver.CUmemAccess_flags.CU_MEM_ACCESS_FLAGS_PROT_NONE
318-
access_desc[i].location.type = cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE
319-
access_desc[i].location.id = dev_id
320-
i += 1
353+
cdef inline _DMR_replace_peer_accessible_by(DeviceMemoryResource self, devices):
354+
"""Replace the full peer-access set in a single batched driver call.
321355
322-
with nogil:
323-
HANDLE_RETURN(cydriver.cuMemPoolSetAccess(as_cu(self._h_pool), access_desc, count))
324-
finally:
325-
if access_desc != NULL:
326-
PyMem_Free(access_desc)
356+
Backs the ``mr.peer_accessible_by = [...]`` setter. Uses the same planner
357+
as the proxy's bulk ops; the only difference is that adds and removes are
358+
derived from the symmetric difference between current driver state and the
359+
requested target set.
360+
"""
361+
from .._device import Device
327362

328-
self._peer_accessible_by = tuple(target_ids)
363+
this_dev = Device(self._dev_id)
364+
plan = plan_peer_access_update(
365+
owner_device_id=self._dev_id,
366+
current_peer_ids=_DMR_query_peer_access_ids(self),
367+
requested_devices=devices,
368+
resolve_device_id=_resolve_peer_device_id,
369+
can_access_peer=this_dev.can_access_peer,
370+
)
371+
_DMR_apply_peer_access_diff(self, plan.to_add, plan.to_remove)
329372

330373

331374
cdef inline _DMR_init(DeviceMemoryResource self, device_id, options):
@@ -351,7 +394,6 @@ cdef inline _DMR_init(DeviceMemoryResource self, device_id, options):
351394
self._mempool_owned = False
352395
MP_raise_release_threshold(self)
353396
else:
354-
self._peer_accessible_by = ()
355397
MP_init_create_pool(
356398
self,
357399
cydriver.CUmemLocationType.CU_MEM_LOCATION_TYPE_DEVICE,

0 commit comments

Comments
 (0)