Skip to content

Commit c9f8c91

Browse files
committed
Updates register function to return registered object. Avoids possible early deregistration.
1 parent 4fb3d47 commit c9f8c91

File tree

1 file changed

+8
-5
lines changed

1 file changed

+8
-5
lines changed

cuda_core/cuda/core/experimental/_memory.pyx

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -666,10 +666,13 @@ class DeviceMemoryResource(MemoryResource):
666666
raise RuntimeError(f"Memory resource {uuid} was not found") from None
667667

668668
def register(self, uuid: uuid.UUID):
669-
if uuid not in _ipc_registry:
670-
assert self._uuid is None or self._uuid == uuid
671-
_ipc_registry[uuid] = self
672-
self._uuid = uuid
669+
existing = _ipc_registry.get(uuid)
670+
if existing is not None:
671+
return existing
672+
assert self._uuid is None or self._uuid == uuid
673+
_ipc_registry[uuid] = self
674+
self._uuid = uuid
675+
return self
673676

674677
def unregister(self):
675678
with contextlib.suppress(KeyError):
@@ -716,7 +719,7 @@ class DeviceMemoryResource(MemoryResource):
716719
raise_if_driver_error(err)
717720
uuid = getattr(alloc_handle, 'uuid', None)
718721
if uuid is not None:
719-
self.register(uuid)
722+
self = self.register(uuid)
720723
return self
721724

722725
def get_allocation_handle(self) -> IPCAllocationHandle:

0 commit comments

Comments
 (0)