Skip to content

Commit e54cb5b

Browse files
committed
Moves AllocationHandle serialization to a registration with multiprocessing, since it depends on DupFd.
1 parent 5dda196 commit e54cb5b

File tree

2 files changed

+16
-14
lines changed

2 files changed

+16
-14
lines changed

cuda_core/cuda/core/experimental/_device.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1337,15 +1337,15 @@ def create_graph_builder(self) -> GraphBuilder:
13371337
return GraphBuilder._init(stream=self.create_stream(), is_stream_owner=True)
13381338

13391339

1340+
def _reduce_device(device):
1341+
return _reconstruct_device, (device.device_id,)
1342+
1343+
13401344
def _reconstruct_device(device_id):
13411345
device = Device(device_id)
13421346
if not device._has_inited:
13431347
device.set_current()
13441348
return device
13451349

13461350

1347-
def _reduce_device(device):
1348-
return _reconstruct_device, (device.device_id,)
1349-
1350-
13511351
multiprocessing.reduction.register(Device, _reduce_device)

cuda_core/cuda/core/experimental/_memory.pyx

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ import abc
1616
import array
1717
import contextlib
1818
import cython
19+
import multiprocessing
1920
import os
2021
import platform
2122
import sys
@@ -421,16 +422,6 @@ cdef class IPCAllocationHandle:
421422
"""Close the handle."""
422423
self.close()
423424

424-
def __reduce__(self):
425-
import multiprocessing
426-
multiprocessing.context.assert_spawning(self)
427-
df = multiprocessing.reduction.DupFd(self.handle)
428-
return self._reconstruct, (df, self._uuid)
429-
430-
@classmethod
431-
def _reconstruct(cls, df, uuid):
432-
return cls._init(df.detach(), uuid)
433-
434425
def __int__(self) -> int:
435426
if self._handle < 0:
436427
raise ValueError(
@@ -447,6 +438,17 @@ cdef class IPCAllocationHandle:
447438
return self._uuid
448439

449440

441+
def _reduce_allocation_handle(alloc_handle):
442+
df = multiprocessing.reduction.DupFd(alloc_handle.handle)
443+
return _reconstruct_allocation_handle, (type(alloc_handle), df, alloc_handle.uuid)
444+
445+
def _reconstruct_allocation_handle(cls, df, uuid):
446+
return cls._init(df.detach(), uuid)
447+
448+
449+
multiprocessing.reduction.register(IPCAllocationHandle, _reduce_allocation_handle)
450+
451+
450452
@dataclass
451453
cdef class DeviceMemoryResourceOptions:
452454
"""Customizable :obj:`~_memory.DeviceMemoryResource` options.

0 commit comments

Comments
 (0)