Skip to content

Commit d3be043

Browse files
authored
test(bindings): nvjitlink_session for reliable teardown (#1852)
Made-with: Cursor
1 parent ffde926 commit d3be043

File tree

1 file changed

+55
-50
lines changed

1 file changed

+55
-50
lines changed

cuda_bindings/tests/test_nvjitlink.py

Lines changed: 55 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,24 @@
11
# SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
22
# SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE
33

4+
from contextlib import contextmanager
5+
46
import pytest
57

68
from cuda.bindings import nvjitlink, nvrtc
79

10+
11+
@contextmanager
12+
def nvjitlink_session(num_options, options):
13+
"""Create an nvJitLink handle and always destroy it (including on test failure)."""
14+
handle = nvjitlink.create(num_options, options)
15+
try:
16+
yield handle
17+
finally:
18+
if handle != 0:
19+
nvjitlink.destroy(handle)
20+
21+
822
# Establish a handful of compatible architectures and PTX versions to test with
923
ARCHITECTURES = ["sm_75", "sm_80", "sm_90", "sm_100"]
1024
PTX_VERSIONS = ["6.4", "7.0", "8.5", "8.8"]
@@ -95,87 +109,78 @@ def test_invalid_arch_error():
95109

96110
@pytest.mark.parametrize("option", ARCHITECTURES)
97111
def test_create_and_destroy(option):
98-
handle = nvjitlink.create(1, [f"-arch={option}"])
99-
assert handle != 0
100-
nvjitlink.destroy(handle)
112+
with nvjitlink_session(1, [f"-arch={option}"]) as handle:
113+
assert handle != 0
101114

102115

103116
def test_create_and_destroy_bytes_options():
104-
handle = nvjitlink.create(1, [b"-arch=sm_80"])
105-
assert handle != 0
106-
nvjitlink.destroy(handle)
117+
with nvjitlink_session(1, [b"-arch=sm_80"]) as handle:
118+
assert handle != 0
107119

108120

109121
@pytest.mark.parametrize("option", ARCHITECTURES)
110122
def test_complete_empty(option):
111-
handle = nvjitlink.create(1, [f"-arch={option}"])
112-
nvjitlink.complete(handle)
113-
nvjitlink.destroy(handle)
123+
with nvjitlink_session(1, [f"-arch={option}"]) as handle:
124+
nvjitlink.complete(handle)
114125

115126

116127
@arch_ptx_parametrized
117128
def test_add_data(arch, ptx_bytes):
118-
handle = nvjitlink.create(1, [f"-arch={arch}"])
119-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
120-
nvjitlink.complete(handle)
121-
nvjitlink.destroy(handle)
129+
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
130+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
131+
nvjitlink.complete(handle)
122132

123133

124134
@arch_ptx_parametrized
125135
def test_add_file(arch, ptx_bytes, tmp_path):
126-
handle = nvjitlink.create(1, [f"-arch={arch}"])
127-
file_path = tmp_path / "test_file.cubin"
128-
file_path.write_bytes(ptx_bytes)
129-
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
130-
nvjitlink.complete(handle)
131-
nvjitlink.destroy(handle)
136+
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
137+
file_path = tmp_path / "test_file.cubin"
138+
file_path.write_bytes(ptx_bytes)
139+
nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path))
140+
nvjitlink.complete(handle)
132141

133142

134143
@pytest.mark.parametrize("arch", ARCHITECTURES)
135144
def test_get_error_log(arch):
136-
handle = nvjitlink.create(1, [f"-arch={arch}"])
137-
nvjitlink.complete(handle)
138-
log_size = nvjitlink.get_error_log_size(handle)
139-
log = bytearray(log_size)
140-
nvjitlink.get_error_log(handle, log)
141-
assert len(log) == log_size
142-
nvjitlink.destroy(handle)
145+
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
146+
nvjitlink.complete(handle)
147+
log_size = nvjitlink.get_error_log_size(handle)
148+
log = bytearray(log_size)
149+
nvjitlink.get_error_log(handle, log)
150+
assert len(log) == log_size
143151

144152

145153
@arch_ptx_parametrized
146154
def test_get_info_log(arch, ptx_bytes):
147-
handle = nvjitlink.create(1, [f"-arch={arch}"])
148-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
149-
nvjitlink.complete(handle)
150-
log_size = nvjitlink.get_info_log_size(handle)
151-
log = bytearray(log_size)
152-
nvjitlink.get_info_log(handle, log)
153-
assert len(log) == log_size
154-
nvjitlink.destroy(handle)
155+
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
156+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
157+
nvjitlink.complete(handle)
158+
log_size = nvjitlink.get_info_log_size(handle)
159+
log = bytearray(log_size)
160+
nvjitlink.get_info_log(handle, log)
161+
assert len(log) == log_size
155162

156163

157164
@arch_ptx_parametrized
158165
def test_get_linked_cubin(arch, ptx_bytes):
159-
handle = nvjitlink.create(1, [f"-arch={arch}"])
160-
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
161-
nvjitlink.complete(handle)
162-
cubin_size = nvjitlink.get_linked_cubin_size(handle)
163-
cubin = bytearray(cubin_size)
164-
nvjitlink.get_linked_cubin(handle, cubin)
165-
assert len(cubin) == cubin_size
166-
nvjitlink.destroy(handle)
166+
with nvjitlink_session(1, [f"-arch={arch}"]) as handle:
167+
nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data")
168+
nvjitlink.complete(handle)
169+
cubin_size = nvjitlink.get_linked_cubin_size(handle)
170+
cubin = bytearray(cubin_size)
171+
nvjitlink.get_linked_cubin(handle, cubin)
172+
assert len(cubin) == cubin_size
167173

168174

169175
@pytest.mark.parametrize("arch", ARCHITECTURES)
170176
def test_get_linked_ptx(arch, get_dummy_ltoir):
171-
handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"])
172-
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
173-
nvjitlink.complete(handle)
174-
ptx_size = nvjitlink.get_linked_ptx_size(handle)
175-
ptx = bytearray(ptx_size)
176-
nvjitlink.get_linked_ptx(handle, ptx)
177-
assert len(ptx) == ptx_size
178-
nvjitlink.destroy(handle)
177+
with nvjitlink_session(3, [f"-arch={arch}", "-lto", "-ptx"]) as handle:
178+
nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data")
179+
nvjitlink.complete(handle)
180+
ptx_size = nvjitlink.get_linked_ptx_size(handle)
181+
ptx = bytearray(ptx_size)
182+
nvjitlink.get_linked_ptx(handle, ptx)
183+
assert len(ptx) == ptx_size
179184

180185

181186
def test_package_version():

0 commit comments

Comments
 (0)