Skip to content

Commit 2e2fc1a

Browse files
committed
Improve #789: Remove cyclical dependency between {driver|runtime} and utils
Rather than having bindings.utils._get_handle.pyx depend on driver and runtime and define the getters there, this flips things so driver and runtime register their own handlers.
1 parent b87c787 commit 2e2fc1a

5 files changed

Lines changed: 252 additions & 235 deletions

File tree

cuda_bindings/cuda/bindings/driver.pyx.in

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from libc.limits cimport CHAR_MIN
1414
from libcpp.vector cimport vector
1515
from cpython.buffer cimport PyObject_CheckBuffer, PyObject_GetBuffer, PyBuffer_Release, PyBUF_SIMPLE, PyBUF_ANY_CONTIGUOUS
1616
from cpython.bytes cimport PyBytes_FromStringAndSize
17+
from cuda.bindings import utils
1718
import cuda.bindings.driver
1819
from libcpp.map cimport map
1920

@@ -53948,3 +53949,124 @@ def sizeof(objType):
5394853949
if objType == VdpOutputSurface:
5394953950
return sizeof(cydriver.VdpOutputSurface){{endif}}
5395053951
raise TypeError("Unknown type: " + str(objType))
53952+
53953+
def _add_native_handle_getters() -> None:
53954+
_add_cuda_native_handle_getter = utils._add_cuda_native_handle_getter
53955+
{{if 'CUcontext' in found_types}}
53956+
def CUcontext_getter(CUcontext x): return <uintptr_t><void*><cydriver.CUcontext>(x._pvt_ptr[0])
53957+
_add_cuda_native_handle_getter(CUcontext, CUcontext_getter)
53958+
{{endif}}
53959+
{{if 'CUmodule' in found_types}}
53960+
def CUmodule_getter(CUmodule x): return <uintptr_t><void*><cydriver.CUmodule>(x._pvt_ptr[0])
53961+
_add_cuda_native_handle_getter(CUmodule, CUmodule_getter)
53962+
{{endif}}
53963+
{{if 'CUfunction' in found_types}}
53964+
def CUfunction_getter(CUfunction x): return <uintptr_t><void*><cydriver.CUfunction>(x._pvt_ptr[0])
53965+
_add_cuda_native_handle_getter(CUfunction, CUfunction_getter)
53966+
{{endif}}
53967+
{{if 'CUlibrary' in found_types}}
53968+
def CUlibrary_getter(CUlibrary x): return <uintptr_t><void*><cydriver.CUlibrary>(x._pvt_ptr[0])
53969+
_add_cuda_native_handle_getter(CUlibrary, CUlibrary_getter)
53970+
{{endif}}
53971+
{{if 'CUkernel' in found_types}}
53972+
def CUkernel_getter(CUkernel x): return <uintptr_t><void*><cydriver.CUkernel>(x._pvt_ptr[0])
53973+
_add_cuda_native_handle_getter(CUkernel, CUkernel_getter)
53974+
{{endif}}
53975+
{{if 'CUarray' in found_types}}
53976+
def CUarray_getter(CUarray x): return <uintptr_t><void*><cydriver.CUarray>(x._pvt_ptr[0])
53977+
_add_cuda_native_handle_getter(CUarray, CUarray_getter)
53978+
{{endif}}
53979+
{{if 'CUmipmappedArray' in found_types}}
53980+
def CUmipmappedArray_getter(CUmipmappedArray x): return <uintptr_t><void*><cydriver.CUmipmappedArray>(x._pvt_ptr[0])
53981+
_add_cuda_native_handle_getter(CUmipmappedArray, CUmipmappedArray_getter)
53982+
{{endif}}
53983+
{{if 'CUtexref' in found_types}}
53984+
def CUtexref_getter(CUtexref x): return <uintptr_t><void*><cydriver.CUtexref>(x._pvt_ptr[0])
53985+
_add_cuda_native_handle_getter(CUtexref, CUtexref_getter)
53986+
{{endif}}
53987+
{{if 'CUsurfref' in found_types}}
53988+
def CUsurfref_getter(CUsurfref x): return <uintptr_t><void*><cydriver.CUsurfref>(x._pvt_ptr[0])
53989+
_add_cuda_native_handle_getter(CUsurfref, CUsurfref_getter)
53990+
{{endif}}
53991+
{{if 'CUevent' in found_types}}
53992+
def CUevent_getter(CUevent x): return <uintptr_t><void*><cydriver.CUevent>(x._pvt_ptr[0])
53993+
_add_cuda_native_handle_getter(CUevent, CUevent_getter)
53994+
{{endif}}
53995+
{{if 'CUstream' in found_types}}
53996+
def CUstream_getter(CUstream x): return <uintptr_t><void*><cydriver.CUstream>(x._pvt_ptr[0])
53997+
_add_cuda_native_handle_getter(CUstream, CUstream_getter)
53998+
{{endif}}
53999+
{{if 'CUgraphicsResource' in found_types}}
54000+
def CUgraphicsResource_getter(CUgraphicsResource x): return <uintptr_t><void*><cydriver.CUgraphicsResource>(x._pvt_ptr[0])
54001+
_add_cuda_native_handle_getter(CUgraphicsResource, CUgraphicsResource_getter)
54002+
{{endif}}
54003+
{{if 'CUexternalMemory' in found_types}}
54004+
def CUexternalMemory_getter(CUexternalMemory x): return <uintptr_t><void*><cydriver.CUexternalMemory>(x._pvt_ptr[0])
54005+
_add_cuda_native_handle_getter(CUexternalMemory, CUexternalMemory_getter)
54006+
{{endif}}
54007+
{{if 'CUexternalSemaphore' in found_types}}
54008+
def CUexternalSemaphore_getter(CUexternalSemaphore x): return <uintptr_t><void*><cydriver.CUexternalSemaphore>(x._pvt_ptr[0])
54009+
_add_cuda_native_handle_getter(CUexternalSemaphore, CUexternalSemaphore_getter)
54010+
{{endif}}
54011+
{{if 'CUgraph' in found_types}}
54012+
def CUgraph_getter(CUgraph x): return <uintptr_t><void*><cydriver.CUgraph>(x._pvt_ptr[0])
54013+
_add_cuda_native_handle_getter(CUgraph, CUgraph_getter)
54014+
{{endif}}
54015+
{{if 'CUgraphNode' in found_types}}
54016+
def CUgraphNode_getter(CUgraphNode x): return <uintptr_t><void*><cydriver.CUgraphNode>(x._pvt_ptr[0])
54017+
_add_cuda_native_handle_getter(CUgraphNode, CUgraphNode_getter)
54018+
{{endif}}
54019+
{{if 'CUgraphExec' in found_types}}
54020+
def CUgraphExec_getter(CUgraphExec x): return <uintptr_t><void*><cydriver.CUgraphExec>(x._pvt_ptr[0])
54021+
_add_cuda_native_handle_getter(CUgraphExec, CUgraphExec_getter)
54022+
{{endif}}
54023+
{{if 'CUmemoryPool' in found_types}}
54024+
def CUmemoryPool_getter(CUmemoryPool x): return <uintptr_t><void*><cydriver.CUmemoryPool>(x._pvt_ptr[0])
54025+
_add_cuda_native_handle_getter(CUmemoryPool, CUmemoryPool_getter)
54026+
{{endif}}
54027+
{{if 'CUuserObject' in found_types}}
54028+
def CUuserObject_getter(CUuserObject x): return <uintptr_t><void*><cydriver.CUuserObject>(x._pvt_ptr[0])
54029+
_add_cuda_native_handle_getter(CUuserObject, CUuserObject_getter)
54030+
{{endif}}
54031+
{{if 'CUgraphDeviceNode' in found_types}}
54032+
def CUgraphDeviceNode_getter(CUgraphDeviceNode x): return <uintptr_t><void*><cydriver.CUgraphDeviceNode>(x._pvt_ptr[0])
54033+
_add_cuda_native_handle_getter(CUgraphDeviceNode, CUgraphDeviceNode_getter)
54034+
{{endif}}
54035+
{{if 'CUasyncCallbackHandle' in found_types}}
54036+
def CUasyncCallbackHandle_getter(CUasyncCallbackHandle x): return <uintptr_t><void*><cydriver.CUasyncCallbackHandle>(x._pvt_ptr[0])
54037+
_add_cuda_native_handle_getter(CUasyncCallbackHandle, CUasyncCallbackHandle_getter)
54038+
{{endif}}
54039+
{{if 'CUgreenCtx' in found_types}}
54040+
def CUgreenCtx_getter(CUgreenCtx x): return <uintptr_t><void*><cydriver.CUgreenCtx>(x._pvt_ptr[0])
54041+
_add_cuda_native_handle_getter(CUgreenCtx, CUgreenCtx_getter)
54042+
{{endif}}
54043+
{{if 'CUlinkState' in found_types}}
54044+
def CUlinkState_getter(CUlinkState x): return <uintptr_t><void*><cydriver.CUlinkState>(x._pvt_ptr[0])
54045+
_add_cuda_native_handle_getter(CUlinkState, CUlinkState_getter)
54046+
{{endif}}
54047+
{{if 'CUdevResourceDesc' in found_types}}
54048+
def CUdevResourceDesc_getter(CUdevResourceDesc x): return <uintptr_t><void*><cydriver.CUdevResourceDesc>(x._pvt_ptr[0])
54049+
_add_cuda_native_handle_getter(CUdevResourceDesc, CUdevResourceDesc_getter)
54050+
{{endif}}
54051+
{{if 'CUlogsCallbackHandle' in found_types}}
54052+
def CUlogsCallbackHandle_getter(CUlogsCallbackHandle x): return <uintptr_t><void*><cydriver.CUlogsCallbackHandle>(x._pvt_ptr[0])
54053+
_add_cuda_native_handle_getter(CUlogsCallbackHandle, CUlogsCallbackHandle_getter)
54054+
{{endif}}
54055+
{{if True}}
54056+
def CUeglStreamConnection_getter(CUeglStreamConnection x): return <uintptr_t><void*><cydriver.CUeglStreamConnection>(x._pvt_ptr[0])
54057+
_add_cuda_native_handle_getter(CUeglStreamConnection, CUeglStreamConnection_getter)
54058+
{{endif}}
54059+
{{if True}}
54060+
def EGLImageKHR_getter(EGLImageKHR x): return <uintptr_t><void*><cydriver.EGLImageKHR>(x._pvt_ptr[0])
54061+
_add_cuda_native_handle_getter(EGLImageKHR, EGLImageKHR_getter)
54062+
{{endif}}
54063+
{{if True}}
54064+
def EGLStreamKHR_getter(EGLStreamKHR x): return <uintptr_t><void*><cydriver.EGLStreamKHR>(x._pvt_ptr[0])
54065+
_add_cuda_native_handle_getter(EGLStreamKHR, EGLStreamKHR_getter)
54066+
{{endif}}
54067+
{{if True}}
54068+
def EGLSyncKHR_getter(EGLSyncKHR x): return <uintptr_t><void*><cydriver.EGLSyncKHR>(x._pvt_ptr[0])
54069+
_add_cuda_native_handle_getter(EGLSyncKHR, EGLSyncKHR_getter)
54070+
{{endif}}
54071+
_add_native_handle_getters()
54072+

cuda_bindings/cuda/bindings/nvrtc.pyx.in

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from libc.limits cimport CHAR_MIN
1414
from libcpp.vector cimport vector
1515
from cpython.buffer cimport PyObject_CheckBuffer, PyObject_GetBuffer, PyBuffer_Release, PyBUF_SIMPLE, PyBUF_ANY_CONTIGUOUS
1616
from cpython.bytes cimport PyBytes_FromStringAndSize
17+
from cuda.bindings import utils
1718

1819
ctypedef unsigned long long signed_char_ptr
1920
ctypedef unsigned long long unsigned_char_ptr

cuda_bindings/cuda/bindings/runtime.pyx.in

Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ from libc.limits cimport CHAR_MIN
1414
from libcpp.vector cimport vector
1515
from cpython.buffer cimport PyObject_CheckBuffer, PyObject_GetBuffer, PyBuffer_Release, PyBUF_SIMPLE, PyBUF_ANY_CONTIGUOUS
1616
from cpython.bytes cimport PyBytes_FromStringAndSize
17+
from cuda.bindings import utils
1718
import cuda.bindings.driver
1819
from libcpp.map cimport map
1920

@@ -37912,3 +37913,104 @@ def sizeof(objType):
3791237913
if objType == cudaEglStreamConnection:
3791337914
return sizeof(cyruntime.cudaEglStreamConnection){{endif}}
3791437915
raise TypeError("Unknown type: " + str(objType))
37916+
37917+
def _add_native_handle_getters() -> None:
37918+
_add_cuda_native_handle_getter = utils._add_cuda_native_handle_getter
37919+
{{if 'cudaArray_t' in found_types}}
37920+
def cudaArray_t_getter(cudaArray_t x): return <uintptr_t><void*><cyruntime.cudaArray_t>(x._pvt_ptr[0])
37921+
_add_cuda_native_handle_getter(cudaArray_t, cudaArray_t_getter)
37922+
{{endif}}
37923+
{{if 'cudaArray_const_t' in found_types}}
37924+
def cudaArray_const_t_getter(cudaArray_const_t x): return <uintptr_t><void*><cyruntime.cudaArray_const_t>(x._pvt_ptr[0])
37925+
_add_cuda_native_handle_getter(cudaArray_const_t, cudaArray_const_t_getter)
37926+
{{endif}}
37927+
{{if 'cudaMipmappedArray_t' in found_types}}
37928+
def cudaMipmappedArray_t_getter(cudaMipmappedArray_t x): return <uintptr_t><void*><cyruntime.cudaMipmappedArray_t>(x._pvt_ptr[0])
37929+
_add_cuda_native_handle_getter(cudaMipmappedArray_t, cudaMipmappedArray_t_getter)
37930+
{{endif}}
37931+
{{if 'cudaMipmappedArray_const_t' in found_types}}
37932+
def cudaMipmappedArray_const_t_getter(cudaMipmappedArray_const_t x): return <uintptr_t><void*><cyruntime.cudaMipmappedArray_const_t>(x._pvt_ptr[0])
37933+
_add_cuda_native_handle_getter(cudaMipmappedArray_const_t, cudaMipmappedArray_const_t_getter)
37934+
{{endif}}
37935+
{{if 'cudaStream_t' in found_types}}
37936+
def cudaStream_t_getter(cudaStream_t x): return <uintptr_t><void*><cyruntime.cudaStream_t>(x._pvt_ptr[0])
37937+
_add_cuda_native_handle_getter(cudaStream_t, cudaStream_t_getter)
37938+
{{endif}}
37939+
{{if 'cudaEvent_t' in found_types}}
37940+
def cudaEvent_t_getter(cudaEvent_t x): return <uintptr_t><void*><cyruntime.cudaEvent_t>(x._pvt_ptr[0])
37941+
_add_cuda_native_handle_getter(cudaEvent_t, cudaEvent_t_getter)
37942+
{{endif}}
37943+
{{if 'cudaGraphicsResource_t' in found_types}}
37944+
def cudaGraphicsResource_t_getter(cudaGraphicsResource_t x): return <uintptr_t><void*><cyruntime.cudaGraphicsResource_t>(x._pvt_ptr[0])
37945+
_add_cuda_native_handle_getter(cudaGraphicsResource_t, cudaGraphicsResource_t_getter)
37946+
{{endif}}
37947+
{{if 'cudaExternalMemory_t' in found_types}}
37948+
def cudaExternalMemory_t_getter(cudaExternalMemory_t x): return <uintptr_t><void*><cyruntime.cudaExternalMemory_t>(x._pvt_ptr[0])
37949+
_add_cuda_native_handle_getter(cudaExternalMemory_t, cudaExternalMemory_t_getter)
37950+
{{endif}}
37951+
{{if 'cudaExternalSemaphore_t' in found_types}}
37952+
def cudaExternalSemaphore_t_getter(cudaExternalSemaphore_t x): return <uintptr_t><void*><cyruntime.cudaExternalSemaphore_t>(x._pvt_ptr[0])
37953+
_add_cuda_native_handle_getter(cudaExternalSemaphore_t, cudaExternalSemaphore_t_getter)
37954+
{{endif}}
37955+
{{if 'cudaGraph_t' in found_types}}
37956+
def cudaGraph_t_getter(cudaGraph_t x): return <uintptr_t><void*><cyruntime.cudaGraph_t>(x._pvt_ptr[0])
37957+
_add_cuda_native_handle_getter(cudaGraph_t, cudaGraph_t_getter)
37958+
{{endif}}
37959+
{{if 'cudaGraphNode_t' in found_types}}
37960+
def cudaGraphNode_t_getter(cudaGraphNode_t x): return <uintptr_t><void*><cyruntime.cudaGraphNode_t>(x._pvt_ptr[0])
37961+
_add_cuda_native_handle_getter(cudaGraphNode_t, cudaGraphNode_t_getter)
37962+
{{endif}}
37963+
{{if 'cudaUserObject_t' in found_types}}
37964+
def cudaUserObject_t_getter(cudaUserObject_t x): return <uintptr_t><void*><cyruntime.cudaUserObject_t>(x._pvt_ptr[0])
37965+
_add_cuda_native_handle_getter(cudaUserObject_t, cudaUserObject_t_getter)
37966+
{{endif}}
37967+
{{if 'cudaFunction_t' in found_types}}
37968+
def cudaFunction_t_getter(cudaFunction_t x): return <uintptr_t><void*><cyruntime.cudaFunction_t>(x._pvt_ptr[0])
37969+
_add_cuda_native_handle_getter(cudaFunction_t, cudaFunction_t_getter)
37970+
{{endif}}
37971+
{{if 'cudaKernel_t' in found_types}}
37972+
def cudaKernel_t_getter(cudaKernel_t x): return <uintptr_t><void*><cyruntime.cudaKernel_t>(x._pvt_ptr[0])
37973+
_add_cuda_native_handle_getter(cudaKernel_t, cudaKernel_t_getter)
37974+
{{endif}}
37975+
{{if 'cudaLibrary_t' in found_types}}
37976+
def cudaLibrary_t_getter(cudaLibrary_t x): return <uintptr_t><void*><cyruntime.cudaLibrary_t>(x._pvt_ptr[0])
37977+
_add_cuda_native_handle_getter(cudaLibrary_t, cudaLibrary_t_getter)
37978+
{{endif}}
37979+
{{if 'cudaMemPool_t' in found_types}}
37980+
def cudaMemPool_t_getter(cudaMemPool_t x): return <uintptr_t><void*><cyruntime.cudaMemPool_t>(x._pvt_ptr[0])
37981+
_add_cuda_native_handle_getter(cudaMemPool_t, cudaMemPool_t_getter)
37982+
{{endif}}
37983+
{{if 'cudaGraphExec_t' in found_types}}
37984+
def cudaGraphExec_t_getter(cudaGraphExec_t x): return <uintptr_t><void*><cyruntime.cudaGraphExec_t>(x._pvt_ptr[0])
37985+
_add_cuda_native_handle_getter(cudaGraphExec_t, cudaGraphExec_t_getter)
37986+
{{endif}}
37987+
{{if 'cudaGraphDeviceNode_t' in found_types}}
37988+
def cudaGraphDeviceNode_t_getter(cudaGraphDeviceNode_t x): return <uintptr_t><void*><cyruntime.cudaGraphDeviceNode_t>(x._pvt_ptr[0])
37989+
_add_cuda_native_handle_getter(cudaGraphDeviceNode_t, cudaGraphDeviceNode_t_getter)
37990+
{{endif}}
37991+
{{if 'cudaAsyncCallbackHandle_t' in found_types}}
37992+
def cudaAsyncCallbackHandle_t_getter(cudaAsyncCallbackHandle_t x): return <uintptr_t><void*><cyruntime.cudaAsyncCallbackHandle_t>(x._pvt_ptr[0])
37993+
_add_cuda_native_handle_getter(cudaAsyncCallbackHandle_t, cudaAsyncCallbackHandle_t_getter)
37994+
{{endif}}
37995+
{{if 'cudaLogsCallbackHandle' in found_types}}
37996+
def cudaLogsCallbackHandle_getter(cudaLogsCallbackHandle x): return <uintptr_t><void*><cyruntime.cudaLogsCallbackHandle>(x._pvt_ptr[0])
37997+
_add_cuda_native_handle_getter(cudaLogsCallbackHandle, cudaLogsCallbackHandle_getter)
37998+
{{endif}}
37999+
{{if True}}
38000+
def EGLImageKHR_getter(EGLImageKHR x): return <uintptr_t><void*><cyruntime.EGLImageKHR>(x._pvt_ptr[0])
38001+
_add_cuda_native_handle_getter(EGLImageKHR, EGLImageKHR_getter)
38002+
{{endif}}
38003+
{{if True}}
38004+
def EGLStreamKHR_getter(EGLStreamKHR x): return <uintptr_t><void*><cyruntime.EGLStreamKHR>(x._pvt_ptr[0])
38005+
_add_cuda_native_handle_getter(EGLStreamKHR, EGLStreamKHR_getter)
38006+
{{endif}}
38007+
{{if True}}
38008+
def EGLSyncKHR_getter(EGLSyncKHR x): return <uintptr_t><void*><cyruntime.EGLSyncKHR>(x._pvt_ptr[0])
38009+
_add_cuda_native_handle_getter(EGLSyncKHR, EGLSyncKHR_getter)
38010+
{{endif}}
38011+
{{if True}}
38012+
def cudaEglStreamConnection_getter(cudaEglStreamConnection x): return <uintptr_t><void*><cyruntime.cudaEglStreamConnection>(x._pvt_ptr[0])
38013+
_add_cuda_native_handle_getter(cudaEglStreamConnection, cudaEglStreamConnection_getter)
38014+
{{endif}}
38015+
_add_native_handle_getters()
38016+
Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,31 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4-
from ._get_handle import get_cuda_native_handle
4+
from typing import Any, Callable
5+
56
from ._ptx_utils import get_minimal_required_cuda_ver_from_ptx_ver, get_ptx_ver
7+
8+
_handle_getters: dict[type, Callable[[Any], int]] = {}
9+
10+
11+
def _add_cuda_native_handle_getter(t: type, getter: Callable[[Any], int]) -> None:
12+
_handle_getters[t] = getter
13+
14+
15+
def get_cuda_native_handle(obj: Any) -> int:
16+
"""Returns the address of the provided CUDA Python object as a Python int.
17+
18+
Parameters
19+
----------
20+
obj : Any
21+
CUDA Python object
22+
23+
Returns
24+
-------
25+
int : The object address.
26+
"""
27+
obj_type = type(obj)
28+
try:
29+
return _handle_getters[obj_type](obj)
30+
except KeyError:
31+
raise TypeError("Unknown type: " + str(obj_type)) from None

0 commit comments

Comments
 (0)