Skip to content

Commit 25b3db7

Browse files
rparolinclaude
andcommitted
Fix managed memory DLPack device type on buffer-side export paths
Update setup_dl_tensor_device() and Buffer.__dlpack_device__() to emit kDLCUDAManaged for managed memory, closing the gap where the Buffer -> DLPack capsule -> StridedMemoryView path still misclassified managed buffers as kDLCUDAHost. Add cross-reference comments to keep the three classification sites aligned. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 0650f4b commit 25b3db7

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

cuda_core/cuda/core/_dlpack.pyx

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,8 @@ cdef inline int setup_dl_tensor_device(DLTensor* dl_tensor, object buf) except -
9595
device.device_type = _kDLCUDA
9696
device.device_id = buf.device_id
9797
elif buf.is_device_accessible and buf.is_host_accessible:
98-
device.device_type = _kDLCUDAHost
98+
# Keep in sync with Buffer.__dlpack_device__() and _smv_get_dl_device().
99+
device.device_type = _kDLCUDAManaged if buf.is_managed else _kDLCUDAHost
99100
device.device_id = 0
100101
elif not buf.is_device_accessible and buf.is_host_accessible:
101102
device.device_type = _kDLCPU

cuda_core/cuda/core/_memory/_buffer.pyx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -328,7 +328,9 @@ cdef class Buffer:
328328
if d and (not h):
329329
return (DLDeviceType.kDLCUDA, self.device_id)
330330
if d and h:
331-
# TODO: this can also be kDLCUDAManaged, we need more fine-grained checks
331+
# Keep in sync with setup_dl_tensor_device() and _smv_get_dl_device().
332+
if self.is_managed:
333+
return (DLDeviceType.kDLCUDAManaged, 0)
332334
return (DLDeviceType.kDLCUDAHost, 0)
333335
if (not d) and h:
334336
return (DLDeviceType.kDLCPU, 0)

cuda_core/cuda/core/_memoryview.pyx

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -607,6 +607,7 @@ cdef inline int _smv_get_dl_device(
607607
device_type = _kDLCUDA
608608
device_id = buf.device_id
609609
elif d and h:
610+
# Keep in sync with Buffer.__dlpack_device__() and setup_dl_tensor_device().
610611
if buf.is_managed:
611612
device_type = _kDLCUDAManaged
612613
else:

0 commit comments

Comments
 (0)