1818]
1919
2020
21+ def _is_managed_ptr (ptr ) -> bool :
22+ try :
23+ attr = driver .CUpointer_attribute .CU_POINTER_ATTRIBUTE_IS_MANAGED
24+ return bool (handle_return (driver .cuPointerGetAttribute (attr , ptr )))
25+ except Exception :
26+ return False
27+
28+
29+ def _sync_for_host_managed_access (buffer ) -> None :
30+ if not _is_managed_ptr (buffer .handle ):
31+ return
32+ device = getattr (buffer .memory_resource , "device" , None )
33+ if device is None :
34+ try :
35+ device = Device (int (buffer .device_id ))
36+ except Exception :
37+ return
38+ try :
39+ if not device .properties .concurrent_managed_access :
40+ device .sync ()
41+ except AttributeError :
42+ return
43+
44+
2145class DummyUnifiedMemoryResource (MemoryResource ):
2246 def __init__ (self , device ):
2347 self .device = device
@@ -112,6 +136,7 @@ def verify_buffer(self, buffer, seed=None, value=None):
112136 ptr_expected = self ._ptr (pattern_buffer )
113137 scratch_buffer .copy_from (buffer , stream = self .stream )
114138 self .sync_target .sync ()
139+ _sync_for_host_managed_access (scratch_buffer )
115140 assert libc .memcmp (ptr_test , ptr_expected , self .size ) == 0
116141
117142 @staticmethod
@@ -132,6 +157,7 @@ def _get_pattern_buffer(self, seed, value):
132157 else :
133158 pattern_buffer = DummyUnifiedMemoryResource (self .device ).allocate (self .size )
134159 ptr = self ._ptr (pattern_buffer )
160+ _sync_for_host_managed_access (pattern_buffer )
135161 for i in range (self .size ):
136162 ptr [i ] = (seed + i ) & 0xFF
137163 self .pattern_buffers [key ] = pattern_buffer
@@ -148,11 +174,14 @@ def make_scratch_buffer(device, value, nbytes):
148174def set_buffer (buffer , value ):
149175 assert 0 <= int (value ) < 256
150176 ptr = ctypes .cast (int (buffer .handle ), ctypes .POINTER (ctypes .c_byte ))
177+ _sync_for_host_managed_access (buffer )
151178 ctypes .memset (ptr , value & 0xFF , buffer .size )
152179
153180
154181def compare_equal_buffers (buffer1 , buffer2 ):
155182 """Compare the contents of two host-accessible buffers for bitwise equality."""
183+ _sync_for_host_managed_access (buffer1 )
184+ _sync_for_host_managed_access (buffer2 )
156185 if buffer1 .size != buffer2 .size :
157186 return False
158187 ptr1 = ctypes .cast (int (buffer1 .handle ), ctypes .POINTER (ctypes .c_byte ))
0 commit comments