Skip to content

Commit 0904a04

Browse files
rparolinclaude
andcommitted
Update tests for managed memory DLPack device classification
- Fix test_buffer_dunder_dlpack_device_success to expect kDLCUDAManaged for unified memory instead of the old buggy kDLCUDAHost. - Fix test_buffer_dlpack_failure_clean_up error message to match the unified classify_dl_device error. - Add test_managed_buffer_dlpack_roundtrip_device_type to cover the Buffer -> DLPack capsule -> StridedMemoryView end-to-end path. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 22c6583 commit 0904a04

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

cuda_core/tests/test_memory.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,7 @@ def test_buffer_dunder_dlpack():
556556
[
557557
(DummyDeviceMemoryResource, (DLDeviceType.kDLCUDA, 0)),
558558
(DummyHostMemoryResource, (DLDeviceType.kDLCPU, 0)),
559-
(DummyUnifiedMemoryResource, (DLDeviceType.kDLCUDAHost, 0)),
559+
(DummyUnifiedMemoryResource, (DLDeviceType.kDLCUDAManaged, 0)),
560560
(DummyPinnedMemoryResource, (DLDeviceType.kDLCUDAHost, 0)),
561561
],
562562
)
@@ -579,7 +579,7 @@ def test_buffer_dlpack_failure_clean_up():
579579
dummy_mr = NullMemoryResource()
580580
buffer = dummy_mr.allocate(size=1024)
581581
before = sys.getrefcount(buffer)
582-
with pytest.raises(BufferError, match="invalid buffer"):
582+
with pytest.raises(BufferError, match="buffer is neither device-accessible nor host-accessible"):
583583
buffer.__dlpack__()
584584
after = sys.getrefcount(buffer)
585585
# we use the buffer refcount as sentinel for proper clean-up here,
@@ -588,6 +588,23 @@ def test_buffer_dlpack_failure_clean_up():
588588
assert after == before
589589

590590

591+
def test_managed_buffer_dlpack_roundtrip_device_type():
592+
"""Verify that a managed Buffer round-trips through DLPack with kDLCUDAManaged."""
593+
device = Device()
594+
device.set_current()
595+
skip_if_managed_memory_unsupported(device)
596+
mr = DummyUnifiedMemoryResource(device)
597+
buf = mr.allocate(size=1024)
598+
599+
# Buffer-level classification should report managed.
600+
assert buf.__dlpack_device__() == (DLDeviceType.kDLCUDAManaged, 0)
601+
602+
# The end-to-end path: Buffer -> DLPack capsule -> StridedMemoryView
603+
# must preserve kDLCUDAManaged rather than downgrading to kDLCUDAHost.
604+
view = StridedMemoryView.from_any_interface(buf, stream_ptr=-1)
605+
assert view.__dlpack_device__() == (int(DLDeviceType.kDLCUDAManaged), 0)
606+
607+
591608
@pytest.mark.parametrize("use_device_object", [True, False])
592609
def test_device_memory_resource_initialization(use_device_object):
593610
"""Test that DeviceMemoryResource can be initialized successfully.

0 commit comments

Comments
 (0)