Skip to content

Commit 19c50f8

Browse files
committed
Change <void*><uintptr_t>win32api.GetProcAddress back to intptr_t. Changing load_nvidia_dynamic_library() to also use to-intptr_t conversion, for compatibility with win32api.GetProcAddress. Document that CDLL behaves differently (it uses to-uintptr_t`).
1 parent f14d76b commit 19c50f8

3 files changed

Lines changed: 66 additions & 64 deletions

File tree

cuda_bindings/cuda/bindings/_internal/nvjitlink_windows.pyx

Lines changed: 33 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# This code was automatically generated across versions from 12.0.1 to 12.6.2. Do not modify it directly.
66

7-
from libc.stdint cimport uintptr_t
7+
from libc.stdint cimport intptr_t
88

99
from .utils import FunctionNotFoundError, NotSupportedError
1010

@@ -41,7 +41,7 @@ cdef void* __nvJitLinkVersion = NULL
4141

4242

4343
cdef void* load_library(int driver_ver) except* with gil:
44-
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink")
44+
cdef intptr_t handle = path_finder.load_nvidia_dynamic_library("nvJitLink")
4545
return <void*>handle
4646

4747

@@ -51,7 +51,7 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
5151
return 0
5252

5353
cdef int err, driver_ver
54-
cdef uintptr_t handle
54+
cdef intptr_t handle
5555
with gil:
5656
# Load driver to check version
5757
try:
@@ -60,98 +60,98 @@ cdef int _check_or_init_nvjitlink() except -1 nogil:
6060
raise NotSupportedError(f'CUDA driver is not found ({e})')
6161
global __cuDriverGetVersion
6262
if __cuDriverGetVersion == NULL:
63-
__cuDriverGetVersion = <void*><uintptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
63+
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
6464
if __cuDriverGetVersion == NULL:
6565
raise RuntimeError('something went wrong')
6666
err = (<int (*)(int*) nogil>__cuDriverGetVersion)(&driver_ver)
6767
if err != 0:
6868
raise RuntimeError('something went wrong')
6969

7070
# Load library
71-
handle = <uintptr_t>load_library(driver_ver)
71+
handle = <intptr_t>load_library(driver_ver)
7272

7373
# Load function
7474
global __nvJitLinkCreate
7575
try:
76-
__nvJitLinkCreate = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkCreate')
76+
__nvJitLinkCreate = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkCreate')
7777
except:
7878
pass
7979

8080
global __nvJitLinkDestroy
8181
try:
82-
__nvJitLinkDestroy = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkDestroy')
82+
__nvJitLinkDestroy = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkDestroy')
8383
except:
8484
pass
8585

8686
global __nvJitLinkAddData
8787
try:
88-
__nvJitLinkAddData = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkAddData')
88+
__nvJitLinkAddData = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkAddData')
8989
except:
9090
pass
9191

9292
global __nvJitLinkAddFile
9393
try:
94-
__nvJitLinkAddFile = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkAddFile')
94+
__nvJitLinkAddFile = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkAddFile')
9595
except:
9696
pass
9797

9898
global __nvJitLinkComplete
9999
try:
100-
__nvJitLinkComplete = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkComplete')
100+
__nvJitLinkComplete = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkComplete')
101101
except:
102102
pass
103103

104104
global __nvJitLinkGetLinkedCubinSize
105105
try:
106-
__nvJitLinkGetLinkedCubinSize = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedCubinSize')
106+
__nvJitLinkGetLinkedCubinSize = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedCubinSize')
107107
except:
108108
pass
109109

110110
global __nvJitLinkGetLinkedCubin
111111
try:
112-
__nvJitLinkGetLinkedCubin = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedCubin')
112+
__nvJitLinkGetLinkedCubin = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedCubin')
113113
except:
114114
pass
115115

116116
global __nvJitLinkGetLinkedPtxSize
117117
try:
118-
__nvJitLinkGetLinkedPtxSize = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedPtxSize')
118+
__nvJitLinkGetLinkedPtxSize = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedPtxSize')
119119
except:
120120
pass
121121

122122
global __nvJitLinkGetLinkedPtx
123123
try:
124-
__nvJitLinkGetLinkedPtx = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedPtx')
124+
__nvJitLinkGetLinkedPtx = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetLinkedPtx')
125125
except:
126126
pass
127127

128128
global __nvJitLinkGetErrorLogSize
129129
try:
130-
__nvJitLinkGetErrorLogSize = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetErrorLogSize')
130+
__nvJitLinkGetErrorLogSize = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetErrorLogSize')
131131
except:
132132
pass
133133

134134
global __nvJitLinkGetErrorLog
135135
try:
136-
__nvJitLinkGetErrorLog = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetErrorLog')
136+
__nvJitLinkGetErrorLog = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetErrorLog')
137137
except:
138138
pass
139139

140140
global __nvJitLinkGetInfoLogSize
141141
try:
142-
__nvJitLinkGetInfoLogSize = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetInfoLogSize')
142+
__nvJitLinkGetInfoLogSize = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetInfoLogSize')
143143
except:
144144
pass
145145

146146
global __nvJitLinkGetInfoLog
147147
try:
148-
__nvJitLinkGetInfoLog = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetInfoLog')
148+
__nvJitLinkGetInfoLog = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkGetInfoLog')
149149
except:
150150
pass
151151

152152
global __nvJitLinkVersion
153153
try:
154-
__nvJitLinkVersion = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvJitLinkVersion')
154+
__nvJitLinkVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvJitLinkVersion')
155155
except:
156156
pass
157157

@@ -171,46 +171,46 @@ cpdef dict _inspect_function_pointers():
171171
cdef dict data = {}
172172

173173
global __nvJitLinkCreate
174-
data["__nvJitLinkCreate"] = <uintptr_t>__nvJitLinkCreate
174+
data["__nvJitLinkCreate"] = <intptr_t>__nvJitLinkCreate
175175

176176
global __nvJitLinkDestroy
177-
data["__nvJitLinkDestroy"] = <uintptr_t>__nvJitLinkDestroy
177+
data["__nvJitLinkDestroy"] = <intptr_t>__nvJitLinkDestroy
178178

179179
global __nvJitLinkAddData
180-
data["__nvJitLinkAddData"] = <uintptr_t>__nvJitLinkAddData
180+
data["__nvJitLinkAddData"] = <intptr_t>__nvJitLinkAddData
181181

182182
global __nvJitLinkAddFile
183-
data["__nvJitLinkAddFile"] = <uintptr_t>__nvJitLinkAddFile
183+
data["__nvJitLinkAddFile"] = <intptr_t>__nvJitLinkAddFile
184184

185185
global __nvJitLinkComplete
186-
data["__nvJitLinkComplete"] = <uintptr_t>__nvJitLinkComplete
186+
data["__nvJitLinkComplete"] = <intptr_t>__nvJitLinkComplete
187187

188188
global __nvJitLinkGetLinkedCubinSize
189-
data["__nvJitLinkGetLinkedCubinSize"] = <uintptr_t>__nvJitLinkGetLinkedCubinSize
189+
data["__nvJitLinkGetLinkedCubinSize"] = <intptr_t>__nvJitLinkGetLinkedCubinSize
190190

191191
global __nvJitLinkGetLinkedCubin
192-
data["__nvJitLinkGetLinkedCubin"] = <uintptr_t>__nvJitLinkGetLinkedCubin
192+
data["__nvJitLinkGetLinkedCubin"] = <intptr_t>__nvJitLinkGetLinkedCubin
193193

194194
global __nvJitLinkGetLinkedPtxSize
195-
data["__nvJitLinkGetLinkedPtxSize"] = <uintptr_t>__nvJitLinkGetLinkedPtxSize
195+
data["__nvJitLinkGetLinkedPtxSize"] = <intptr_t>__nvJitLinkGetLinkedPtxSize
196196

197197
global __nvJitLinkGetLinkedPtx
198-
data["__nvJitLinkGetLinkedPtx"] = <uintptr_t>__nvJitLinkGetLinkedPtx
198+
data["__nvJitLinkGetLinkedPtx"] = <intptr_t>__nvJitLinkGetLinkedPtx
199199

200200
global __nvJitLinkGetErrorLogSize
201-
data["__nvJitLinkGetErrorLogSize"] = <uintptr_t>__nvJitLinkGetErrorLogSize
201+
data["__nvJitLinkGetErrorLogSize"] = <intptr_t>__nvJitLinkGetErrorLogSize
202202

203203
global __nvJitLinkGetErrorLog
204-
data["__nvJitLinkGetErrorLog"] = <uintptr_t>__nvJitLinkGetErrorLog
204+
data["__nvJitLinkGetErrorLog"] = <intptr_t>__nvJitLinkGetErrorLog
205205

206206
global __nvJitLinkGetInfoLogSize
207-
data["__nvJitLinkGetInfoLogSize"] = <uintptr_t>__nvJitLinkGetInfoLogSize
207+
data["__nvJitLinkGetInfoLogSize"] = <intptr_t>__nvJitLinkGetInfoLogSize
208208

209209
global __nvJitLinkGetInfoLog
210-
data["__nvJitLinkGetInfoLog"] = <uintptr_t>__nvJitLinkGetInfoLog
210+
data["__nvJitLinkGetInfoLog"] = <intptr_t>__nvJitLinkGetInfoLog
211211

212212
global __nvJitLinkVersion
213-
data["__nvJitLinkVersion"] = <uintptr_t>__nvJitLinkVersion
213+
data["__nvJitLinkVersion"] = <intptr_t>__nvJitLinkVersion
214214

215215
func_ptrs = data
216216
return data

cuda_bindings/cuda/bindings/_internal/nvvm_windows.pyx

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#
55
# This code was automatically generated across versions from 11.0.3 to 12.8.0. Do not modify it directly.
66

7-
from libc.stdint cimport uintptr_t
7+
from libc.stdint cimport intptr_t
88

99
from .utils import FunctionNotFoundError, NotSupportedError
1010

@@ -39,7 +39,7 @@ cdef void* __nvvmGetProgramLog = NULL
3939

4040

4141
cdef void* load_library(int driver_ver) except* with gil:
42-
cdef uintptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm")
42+
cdef intptr_t handle = path_finder.load_nvidia_dynamic_library("nvvm")
4343
return <void*>handle
4444

4545

@@ -49,7 +49,7 @@ cdef int _check_or_init_nvvm() except -1 nogil:
4949
return 0
5050

5151
cdef int err, driver_ver
52-
cdef uintptr_t handle
52+
cdef intptr_t handle
5353
with gil:
5454
# Load driver to check version
5555
try:
@@ -58,86 +58,86 @@ cdef int _check_or_init_nvvm() except -1 nogil:
5858
raise NotSupportedError(f'CUDA driver is not found ({e})')
5959
global __cuDriverGetVersion
6060
if __cuDriverGetVersion == NULL:
61-
__cuDriverGetVersion = <void*><uintptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
61+
__cuDriverGetVersion = <void*><intptr_t>win32api.GetProcAddress(nvcuda_handle, 'cuDriverGetVersion')
6262
if __cuDriverGetVersion == NULL:
6363
raise RuntimeError('something went wrong')
6464
err = (<int (*)(int*) nogil>__cuDriverGetVersion)(&driver_ver)
6565
if err != 0:
6666
raise RuntimeError('something went wrong')
6767

6868
# Load library
69-
handle = <uintptr_t>load_library(driver_ver)
69+
handle = <intptr_t>load_library(driver_ver)
7070

7171
# Load function
7272
global __nvvmVersion
7373
try:
74-
__nvvmVersion = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmVersion')
74+
__nvvmVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmVersion')
7575
except:
7676
pass
7777

7878
global __nvvmIRVersion
7979
try:
80-
__nvvmIRVersion = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmIRVersion')
80+
__nvvmIRVersion = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmIRVersion')
8181
except:
8282
pass
8383

8484
global __nvvmCreateProgram
8585
try:
86-
__nvvmCreateProgram = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmCreateProgram')
86+
__nvvmCreateProgram = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmCreateProgram')
8787
except:
8888
pass
8989

9090
global __nvvmDestroyProgram
9191
try:
92-
__nvvmDestroyProgram = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmDestroyProgram')
92+
__nvvmDestroyProgram = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmDestroyProgram')
9393
except:
9494
pass
9595

9696
global __nvvmAddModuleToProgram
9797
try:
98-
__nvvmAddModuleToProgram = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmAddModuleToProgram')
98+
__nvvmAddModuleToProgram = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmAddModuleToProgram')
9999
except:
100100
pass
101101

102102
global __nvvmLazyAddModuleToProgram
103103
try:
104-
__nvvmLazyAddModuleToProgram = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmLazyAddModuleToProgram')
104+
__nvvmLazyAddModuleToProgram = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmLazyAddModuleToProgram')
105105
except:
106106
pass
107107

108108
global __nvvmCompileProgram
109109
try:
110-
__nvvmCompileProgram = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmCompileProgram')
110+
__nvvmCompileProgram = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmCompileProgram')
111111
except:
112112
pass
113113

114114
global __nvvmVerifyProgram
115115
try:
116-
__nvvmVerifyProgram = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmVerifyProgram')
116+
__nvvmVerifyProgram = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmVerifyProgram')
117117
except:
118118
pass
119119

120120
global __nvvmGetCompiledResultSize
121121
try:
122-
__nvvmGetCompiledResultSize = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmGetCompiledResultSize')
122+
__nvvmGetCompiledResultSize = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmGetCompiledResultSize')
123123
except:
124124
pass
125125

126126
global __nvvmGetCompiledResult
127127
try:
128-
__nvvmGetCompiledResult = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmGetCompiledResult')
128+
__nvvmGetCompiledResult = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmGetCompiledResult')
129129
except:
130130
pass
131131

132132
global __nvvmGetProgramLogSize
133133
try:
134-
__nvvmGetProgramLogSize = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmGetProgramLogSize')
134+
__nvvmGetProgramLogSize = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmGetProgramLogSize')
135135
except:
136136
pass
137137

138138
global __nvvmGetProgramLog
139139
try:
140-
__nvvmGetProgramLog = <void*><uintptr_t>win32api.GetProcAddress(handle, 'nvvmGetProgramLog')
140+
__nvvmGetProgramLog = <void*><intptr_t>win32api.GetProcAddress(handle, 'nvvmGetProgramLog')
141141
except:
142142
pass
143143

@@ -157,40 +157,40 @@ cpdef dict _inspect_function_pointers():
157157
cdef dict data = {}
158158

159159
global __nvvmVersion
160-
data["__nvvmVersion"] = <uintptr_t>__nvvmVersion
160+
data["__nvvmVersion"] = <intptr_t>__nvvmVersion
161161

162162
global __nvvmIRVersion
163-
data["__nvvmIRVersion"] = <uintptr_t>__nvvmIRVersion
163+
data["__nvvmIRVersion"] = <intptr_t>__nvvmIRVersion
164164

165165
global __nvvmCreateProgram
166-
data["__nvvmCreateProgram"] = <uintptr_t>__nvvmCreateProgram
166+
data["__nvvmCreateProgram"] = <intptr_t>__nvvmCreateProgram
167167

168168
global __nvvmDestroyProgram
169-
data["__nvvmDestroyProgram"] = <uintptr_t>__nvvmDestroyProgram
169+
data["__nvvmDestroyProgram"] = <intptr_t>__nvvmDestroyProgram
170170

171171
global __nvvmAddModuleToProgram
172-
data["__nvvmAddModuleToProgram"] = <uintptr_t>__nvvmAddModuleToProgram
172+
data["__nvvmAddModuleToProgram"] = <intptr_t>__nvvmAddModuleToProgram
173173

174174
global __nvvmLazyAddModuleToProgram
175-
data["__nvvmLazyAddModuleToProgram"] = <uintptr_t>__nvvmLazyAddModuleToProgram
175+
data["__nvvmLazyAddModuleToProgram"] = <intptr_t>__nvvmLazyAddModuleToProgram
176176

177177
global __nvvmCompileProgram
178-
data["__nvvmCompileProgram"] = <uintptr_t>__nvvmCompileProgram
178+
data["__nvvmCompileProgram"] = <intptr_t>__nvvmCompileProgram
179179

180180
global __nvvmVerifyProgram
181-
data["__nvvmVerifyProgram"] = <uintptr_t>__nvvmVerifyProgram
181+
data["__nvvmVerifyProgram"] = <intptr_t>__nvvmVerifyProgram
182182

183183
global __nvvmGetCompiledResultSize
184-
data["__nvvmGetCompiledResultSize"] = <uintptr_t>__nvvmGetCompiledResultSize
184+
data["__nvvmGetCompiledResultSize"] = <intptr_t>__nvvmGetCompiledResultSize
185185

186186
global __nvvmGetCompiledResult
187-
data["__nvvmGetCompiledResult"] = <uintptr_t>__nvvmGetCompiledResult
187+
data["__nvvmGetCompiledResult"] = <intptr_t>__nvvmGetCompiledResult
188188

189189
global __nvvmGetProgramLogSize
190-
data["__nvvmGetProgramLogSize"] = <uintptr_t>__nvvmGetProgramLogSize
190+
data["__nvvmGetProgramLogSize"] = <intptr_t>__nvvmGetProgramLogSize
191191

192192
global __nvvmGetProgramLog
193-
data["__nvvmGetProgramLog"] = <uintptr_t>__nvvmGetProgramLog
193+
data["__nvvmGetProgramLog"] = <intptr_t>__nvvmGetProgramLog
194194

195195
func_ptrs = data
196196
return data

cuda_bindings/cuda/bindings/_path_finder_utils/load_nvidia_dynamic_library.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ def load_nvidia_dynamic_library(name: str) -> int:
1616
raise ctypes.WinError(ctypes.get_last_error())
1717
except Exception as e:
1818
raise RuntimeError(f"Failed to load DLL at {dl_path}: {e}") from e
19-
return ctypes.c_size_t(handle).value # Ensures unsigned result
19+
# Use `cdef void* ptr = <void*><intptr_t>` in cython to convert back to void*
20+
return handle # C signed int, matches win32api.GetProcAddress
2021
else:
2122
try:
2223
handle = ctypes.CDLL(dl_path, mode=os.RTLD_NOW | os.RTLD_GLOBAL)
23-
return handle._handle # Raw void* as unsigned int
2424
except OSError as e:
2525
raise RuntimeError(f"Failed to dlopen {dl_path}: {e}") from e
26+
# Use `cdef void* ptr = <void*><uintptr_t>` in cython to convert back to void*
27+
return handle._handle # C unsigned int

0 commit comments

Comments
 (0)