Skip to content

Commit 732f1e3

Browse files
rparolinclaude
andcommitted
Centralize DLPack device classification into classify_dl_device()
Extract the duplicated device-type mapping logic from Buffer.__dlpack_device__(), setup_dl_tensor_device(), and _smv_get_dl_device() into a single classify_dl_device() function in _dlpack.pyx. All three call sites now delegate to it. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 25b3db7 commit 732f1e3

File tree

3 files changed

+26
-45
lines changed

3 files changed

+26
-45
lines changed

cuda_core/cuda/core/_dlpack.pyx

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -88,21 +88,28 @@ cdef inline int setup_dl_tensor_layout(DLTensor* dl_tensor, object buf) except -
8888
return 0
8989

9090

91+
def classify_dl_device(buf) -> tuple[int, int]:
92+
"""Classify a buffer into a DLPack (device_type, device_id) pair.
93+
94+
``buf`` must expose ``is_device_accessible``, ``is_host_accessible``,
95+
``is_managed``, and ``device_id`` attributes.
96+
"""
97+
cdef bint d = buf.is_device_accessible
98+
cdef bint h = buf.is_host_accessible
99+
if d and not h:
100+
return (_kDLCUDA, buf.device_id)
101+
if d and h:
102+
return (_kDLCUDAManaged if buf.is_managed else _kDLCUDAHost, 0)
103+
if not d and h:
104+
return (_kDLCPU, 0)
105+
raise BufferError("buffer is neither device-accessible nor host-accessible")
106+
107+
91108
cdef inline int setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except -1:
92109
cdef DLDevice* device = &dl_tensor.device
93-
# buf should be a Buffer instance
94-
if buf.is_device_accessible and not buf.is_host_accessible:
95-
device.device_type = _kDLCUDA
96-
device.device_id = buf.device_id
97-
elif buf.is_device_accessible and buf.is_host_accessible:
98-
# Keep in sync with Buffer.__dlpack_device__() and _smv_get_dl_device().
99-
device.device_type = _kDLCUDAManaged if buf.is_managed else _kDLCUDAHost
100-
device.device_id = 0
101-
elif not buf.is_device_accessible and buf.is_host_accessible:
102-
device.device_type = _kDLCPU
103-
device.device_id = 0
104-
else: # not buf.is_device_accessible and not buf.is_host_accessible
105-
raise BufferError("invalid buffer")
110+
dev_type, dev_id = classify_dl_device(buf)
111+
device.device_type = <_DLDeviceType>dev_type
112+
device.device_id = <int32_t>dev_id
106113
return 0
107114

108115

cuda_core/cuda/core/_memory/_buffer.pyx

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ if sys.version_info >= (3, 12):
3434
else:
3535
BufferProtocol = object
3636

37-
from cuda.core._dlpack import DLDeviceType, make_py_capsule
37+
from cuda.core._dlpack import DLDeviceType, classify_dl_device, make_py_capsule
3838
from cuda.core._utils.cuda_utils import driver
3939
from cuda.core._device import Device
4040

@@ -323,18 +323,7 @@ cdef class Buffer:
323323
return capsule
324324

325325
def __dlpack_device__(self) -> tuple[int, int]:
326-
cdef bint d = self.is_device_accessible
327-
cdef bint h = self.is_host_accessible
328-
if d and (not h):
329-
return (DLDeviceType.kDLCUDA, self.device_id)
330-
if d and h:
331-
# Keep in sync with setup_dl_tensor_device() and _smv_get_dl_device().
332-
if self.is_managed:
333-
return (DLDeviceType.kDLCUDAManaged, 0)
334-
return (DLDeviceType.kDLCUDAHost, 0)
335-
if (not d) and h:
336-
return (DLDeviceType.kDLCPU, 0)
337-
raise BufferError("buffer is neither device-accessible nor host-accessible")
326+
return classify_dl_device(self)
338327

339328
def __buffer__(self, flags: int, /) -> memoryview:
340329
# Support for Python-level buffer protocol as per PEP 688.

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 4 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
from ._dlpack cimport *
8+
from ._dlpack import classify_dl_device
89
from libc.stdint cimport intptr_t
910
from cuda.core._layout cimport _StridedLayout, get_strides_ptr
1011
from cuda.core._stream import Stream
@@ -590,8 +591,6 @@ cdef inline int _smv_get_dl_device(
590591
cdef _DLDeviceType device_type
591592
cdef int32_t device_id
592593
cdef object buf
593-
cdef bint d
594-
cdef bint h
595594
if view.dl_tensor != NULL:
596595
device_type = view.dl_tensor.device.device_type
597596
if device_type == _kDLCUDA:
@@ -601,23 +600,9 @@ cdef inline int _smv_get_dl_device(
601600
device_id = 0
602601
elif view.is_device_accessible:
603602
buf = view.get_buffer()
604-
d = buf.is_device_accessible
605-
h = buf.is_host_accessible
606-
if d and (not h):
607-
device_type = _kDLCUDA
608-
device_id = buf.device_id
609-
elif d and h:
610-
# Keep in sync with Buffer.__dlpack_device__() and setup_dl_tensor_device().
611-
if buf.is_managed:
612-
device_type = _kDLCUDAManaged
613-
else:
614-
device_type = _kDLCUDAHost
615-
device_id = 0
616-
elif (not d) and h:
617-
device_type = _kDLCPU
618-
device_id = 0
619-
else:
620-
raise BufferError("buffer is neither device-accessible nor host-accessible")
603+
dev_type, dev_id = classify_dl_device(buf)
604+
device_type = <_DLDeviceType>dev_type
605+
device_id = <int32_t>dev_id
621606
else:
622607
device_type = _kDLCPU
623608
device_id = 0

0 commit comments

Comments
 (0)