@@ -27,6 +27,8 @@ from cuda.core.experimental._utils.cuda_utils import (
2727)
2828
2929
30+ # TODO: I prefer to type these as "cdef object" and avoid accessing them from within Python,
31+ # but it seems it is very convenient to expose them for testing purposes...
3032_tls = threading.local()
3133_lock = threading.Lock()
3234cdef bint _is_cuInit = False
@@ -55,7 +57,8 @@ cdef class DeviceProperties:
5557 cdef inline _get_attribute(self , cydriver.CUdevice_attribute attr):
5658 """ Retrieve the attribute value directly from the driver."""
5759 cdef int val
58- HANDLE_RETURN(cydriver.cuDeviceGetAttribute(& val, attr, self ._handle))
60+ with nogil:
61+ HANDLE_RETURN(cydriver.cuDeviceGetAttribute(& val, attr, self ._handle))
5962 return val
6063
6164 cdef _get_cached_attribute(self , attr):
@@ -912,7 +915,8 @@ cdef cydriver.CUcontext _get_primary_context(int dev_id) except?NULL:
912915 primary_ctxs = _tls.primary_ctxs = [0 ] * total
913916 cdef cydriver.CUcontext ctx = < cydriver.CUcontext>< uintptr_t> (primary_ctxs[dev_id])
914917 if ctx == NULL:
915- HANDLE_RETURN(cydriver.cuDevicePrimaryCtxRetain(&ctx , dev_id ))
918+ with nogil:
919+ HANDLE_RETURN(cydriver.cuDevicePrimaryCtxRetain(&ctx , dev_id ))
916920 primary_ctxs[dev_id] = <uintptr_t>(ctx )
917921 return ctx
918922
@@ -948,19 +952,21 @@ class Device:
948952 def __new__(cls , device_id: Optional[int] = None ):
949953 global _is_cuInit
950954 if _is_cuInit is False :
951- with _lock:
955+ with _lock, nogil :
952956 HANDLE_RETURN(cydriver.cuInit(0 ))
953957 _is_cuInit = True
954958
955959 # important: creating a Device instance does not initialize the GPU!
956960 cdef cydriver.CUdevice dev
957961 cdef cydriver.CUcontext ctx
958962 if device_id is None :
959- err = cydriver.cuCtxGetDevice(& dev)
963+ with nogil:
964+ err = cydriver.cuCtxGetDevice(& dev)
960965 if err == cydriver.CUresult.CUDA_SUCCESS:
961966 device_id = int (dev)
962967 elif err == cydriver.CUresult.CUDA_ERROR_INVALID_CONTEXT:
963- HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
968+ with nogil:
969+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(& ctx))
964970 assert < void * > (ctx) == NULL
965971 device_id = 0 # cudart behavior
966972 else :
@@ -973,18 +979,20 @@ class Device:
973979 try :
974980 devices = _tls.devices
975981 except AttributeError :
976- HANDLE_RETURN(cydriver.cuDeviceGetCount(& total))
982+ with nogil:
983+ HANDLE_RETURN(cydriver.cuDeviceGetCount(& total))
977984 devices = _tls.devices = []
978985 for dev_id in range (total):
979986 device = super ().__new__(cls )
980987 device._id = dev_id
981988 # If the device is in TCC mode, or does not support memory pools for some other reason,
982989 # use the SynchronousMemoryResource which does not use memory pools.
983- HANDLE_RETURN(
984- cydriver.cuDeviceGetAttribute(
985- & attr, cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
990+ with nogil:
991+ HANDLE_RETURN(
992+ cydriver.cuDeviceGetAttribute(
993+ & attr, cydriver.CUdevice_attribute.CU_DEVICE_ATTRIBUTE_MEMORY_POOLS_SUPPORTED, dev_id
994+ )
986995 )
987- )
988996 if attr == 1 :
989997 device._mr = DeviceMemoryResource(dev_id)
990998 else :
@@ -1005,16 +1013,18 @@ class Device:
10051013 f" Device {self._id} is not yet initialized, perhaps you forgot to call .set_current() first?"
10061014 )
10071015
1008- def _get_current_context (self , check_consistency = False ) -> driver.CUcontext:
1016+ def _get_current_context (self , bint check_consistency = False ) -> driver.CUcontext:
10091017 cdef cydriver.CUcontext ctx
1010- HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx ))
1011- if ctx == NULL:
1012- raise CUDAError("No context is bound to the calling CPU thread.")
10131018 cdef cydriver.CUdevice dev
1014- if check_consistency:
1015- HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev ))
1016- if <int>(dev ) != self._id:
1017- raise CUDAError("Internal error (current device is not equal to Device.device_id )")
1019+ cdef cydriver.CUdevice this_dev = self ._id
1020+ with nogil:
1021+ HANDLE_RETURN(cydriver.cuCtxGetCurrent(&ctx ))
1022+ if ctx == NULL:
1023+ raise CUDAError("No context is bound to the calling CPU thread.")
1024+ if check_consistency:
1025+ HANDLE_RETURN(cydriver.cuCtxGetDevice(&dev ))
1026+ if dev != this_dev:
1027+ raise CUDAError("Internal error (current device is not equal to Device.device_id )")
10181028 return driver.CUcontext(<uintptr_t>ctx )
10191029
10201030 @property
@@ -1043,10 +1053,12 @@ class Device:
10431053
10441054 """
10451055 cdef cydriver.CUuuid uuid
1046- IF CUDA_CORE_BUILD_MAJOR == "12":
1047- HANDLE_RETURN(cydriver.cuDeviceGetUuid_v2(&uuid , self._id ))
1048- ELSE: # 13.0+
1049- HANDLE_RETURN(cydriver.cuDeviceGetUuid(&uuid , self._id ))
1056+ cdef cydriver.CUdevice this_dev = self ._id
1057+ with nogil:
1058+ IF CUDA_CORE_BUILD_MAJOR == "12":
1059+ HANDLE_RETURN(cydriver.cuDeviceGetUuid_v2(&uuid , this_dev ))
1060+ ELSE: # 13.0+
1061+ HANDLE_RETURN(cydriver.cuDeviceGetUuid(&uuid , this_dev ))
10501062 cdef bytes uuid_b = cpython.PyBytes_FromStringAndSize(uuid.bytes, sizeof(uuid.bytes))
10511063 cdef str uuid_hex = uuid_b.hex()
10521064 # 8-4-4-4-12
@@ -1058,7 +1070,10 @@ class Device:
10581070 # Use 256 characters to be consistent with CUDA Runtime
10591071 cdef int LENGTH = 256
10601072 cdef bytes name = bytes(LENGTH)
1061- HANDLE_RETURN(cydriver.cuDeviceGetName(<char*>name , LENGTH , self._id ))
1073+ cdef char* name_ptr = name
1074+ cdef cydriver.CUdevice this_dev = self ._id
1075+ with nogil:
1076+ HANDLE_RETURN(cydriver.cuDeviceGetName(name_ptr , LENGTH , this_dev ))
10621077 name = name.split(b" \0" )[0 ]
10631078 return name.decode()
10641079
@@ -1161,7 +1176,8 @@ class Device:
11611176 >>> # ... do work on device 0 ...
11621177
11631178 """
1164- cdef cydriver.CUcontext _ctx
1179+ cdef cydriver.CUcontext prev_ctx
1180+ cdef cydriver.CUcontext curr_ctx
11651181 if ctx is not None:
11661182 # TODO: revisit once Context is cythonized
11671183 assert_type(ctx , Context )
@@ -1170,16 +1186,19 @@ class Device:
11701186 "the provided context was created on the device with"
11711187 f" id = {ctx._id}, which is different from the target id = {self ._id}"
11721188 )
1173- # _ctx is the previous context
1174- HANDLE_RETURN(cydriver.cuCtxPopCurrent(&_ctx ))
1175- HANDLE_RETURN(cydriver.cuCtxPushCurrent(<cydriver.CUcontext>(ctx._handle )))
1189+ # prev_ctx is the previous context
1190+ curr_ctx = < cydriver.CUcontext> (ctx._handle)
1191+ with nogil:
1192+ HANDLE_RETURN(cydriver.cuCtxPopCurrent(&prev_ctx ))
1193+ HANDLE_RETURN(cydriver.cuCtxPushCurrent(curr_ctx ))
11761194 self._has_inited = True
1177- if _ctx != NULL:
1178- return Context._from_ctx(<uintptr_t>(_ctx ), self._id )
1195+ if prev_ctx != NULL:
1196+ return Context._from_ctx(<uintptr_t>(prev_ctx ), self._id )
11791197 else:
11801198 # use primary ctx
1181- _ctx = _get_primary_context(self ._id)
1182- HANDLE_RETURN(cydriver.cuCtxSetCurrent(_ctx ))
1199+ curr_ctx = _get_primary_context(self ._id)
1200+ with nogil:
1201+ HANDLE_RETURN(cydriver.cuCtxSetCurrent(curr_ctx ))
11831202 self._has_inited = True
11841203
11851204 def create_context(self , options: ContextOptions = None ) -> Context:
0 commit comments