Skip to content

Commit b611a87

Browse files
committed
Sync device before host access to managed buffers
Guard host-side memset/memcmp in test helpers on CMA=0 by syncing the device before touching managed allocations. Made-with: Cursor
1 parent 3ed5217 commit b611a87

1 file changed

Lines changed: 29 additions & 0 deletions

File tree

cuda_core/tests/helpers/buffers.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,30 @@
1818
]
1919

2020

21+
def _is_managed_ptr(ptr) -> bool:
22+
try:
23+
attr = driver.CUpointer_attribute.CU_POINTER_ATTRIBUTE_IS_MANAGED
24+
return bool(handle_return(driver.cuPointerGetAttribute(attr, ptr)))
25+
except Exception:
26+
return False
27+
28+
29+
def _sync_for_host_managed_access(buffer) -> None:
30+
if not _is_managed_ptr(buffer.handle):
31+
return
32+
device = getattr(buffer.memory_resource, "device", None)
33+
if device is None:
34+
try:
35+
device = Device(int(buffer.device_id))
36+
except Exception:
37+
return
38+
try:
39+
if not device.properties.concurrent_managed_access:
40+
device.sync()
41+
except AttributeError:
42+
return
43+
44+
2145
class DummyUnifiedMemoryResource(MemoryResource):
2246
def __init__(self, device):
2347
self.device = device
@@ -112,6 +136,7 @@ def verify_buffer(self, buffer, seed=None, value=None):
112136
ptr_expected = self._ptr(pattern_buffer)
113137
scratch_buffer.copy_from(buffer, stream=self.stream)
114138
self.sync_target.sync()
139+
_sync_for_host_managed_access(scratch_buffer)
115140
assert libc.memcmp(ptr_test, ptr_expected, self.size) == 0
116141

117142
@staticmethod
@@ -132,6 +157,7 @@ def _get_pattern_buffer(self, seed, value):
132157
else:
133158
pattern_buffer = DummyUnifiedMemoryResource(self.device).allocate(self.size)
134159
ptr = self._ptr(pattern_buffer)
160+
_sync_for_host_managed_access(pattern_buffer)
135161
for i in range(self.size):
136162
ptr[i] = (seed + i) & 0xFF
137163
self.pattern_buffers[key] = pattern_buffer
@@ -148,11 +174,14 @@ def make_scratch_buffer(device, value, nbytes):
148174
def set_buffer(buffer, value):
149175
assert 0 <= int(value) < 256
150176
ptr = ctypes.cast(int(buffer.handle), ctypes.POINTER(ctypes.c_byte))
177+
_sync_for_host_managed_access(buffer)
151178
ctypes.memset(ptr, value & 0xFF, buffer.size)
152179

153180

154181
def compare_equal_buffers(buffer1, buffer2):
155182
"""Compare the contents of two host-accessible buffers for bitwise equality."""
183+
_sync_for_host_managed_access(buffer1)
184+
_sync_for_host_managed_access(buffer2)
156185
if buffer1.size != buffer2.size:
157186
return False
158187
ptr1 = ctypes.cast(int(buffer1.handle), ctypes.POINTER(ctypes.c_byte))

0 commit comments

Comments
 (0)