|
1 | 1 | # SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
2 | 2 | # SPDX-License-Identifier: LicenseRef-NVIDIA-SOFTWARE-LICENSE |
3 | 3 |
|
| 4 | +from contextlib import contextmanager |
| 5 | + |
4 | 6 | import pytest |
5 | 7 |
|
6 | 8 | from cuda.bindings import nvjitlink, nvrtc |
7 | 9 |
|
| 10 | + |
| 11 | +@contextmanager |
| 12 | +def nvjitlink_session(num_options, options): |
| 13 | + """Create an nvJitLink handle and always destroy it (including on test failure).""" |
| 14 | + handle = nvjitlink.create(num_options, options) |
| 15 | + try: |
| 16 | + yield handle |
| 17 | + finally: |
| 18 | + if handle != 0: |
| 19 | + nvjitlink.destroy(handle) |
| 20 | + |
| 21 | + |
8 | 22 | # Establish a handful of compatible architectures and PTX versions to test with |
9 | 23 | ARCHITECTURES = ["sm_75", "sm_80", "sm_90", "sm_100"] |
10 | 24 | PTX_VERSIONS = ["6.4", "7.0", "8.5", "8.8"] |
@@ -95,87 +109,78 @@ def test_invalid_arch_error(): |
95 | 109 |
|
96 | 110 | @pytest.mark.parametrize("option", ARCHITECTURES) |
97 | 111 | def test_create_and_destroy(option): |
98 | | - handle = nvjitlink.create(1, [f"-arch={option}"]) |
99 | | - assert handle != 0 |
100 | | - nvjitlink.destroy(handle) |
| 112 | + with nvjitlink_session(1, [f"-arch={option}"]) as handle: |
| 113 | + assert handle != 0 |
101 | 114 |
|
102 | 115 |
|
103 | 116 | def test_create_and_destroy_bytes_options(): |
104 | | - handle = nvjitlink.create(1, [b"-arch=sm_80"]) |
105 | | - assert handle != 0 |
106 | | - nvjitlink.destroy(handle) |
| 117 | + with nvjitlink_session(1, [b"-arch=sm_80"]) as handle: |
| 118 | + assert handle != 0 |
107 | 119 |
|
108 | 120 |
|
109 | 121 | @pytest.mark.parametrize("option", ARCHITECTURES) |
110 | 122 | def test_complete_empty(option): |
111 | | - handle = nvjitlink.create(1, [f"-arch={option}"]) |
112 | | - nvjitlink.complete(handle) |
113 | | - nvjitlink.destroy(handle) |
| 123 | + with nvjitlink_session(1, [f"-arch={option}"]) as handle: |
| 124 | + nvjitlink.complete(handle) |
114 | 125 |
|
115 | 126 |
|
116 | 127 | @arch_ptx_parametrized |
117 | 128 | def test_add_data(arch, ptx_bytes): |
118 | | - handle = nvjitlink.create(1, [f"-arch={arch}"]) |
119 | | - nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") |
120 | | - nvjitlink.complete(handle) |
121 | | - nvjitlink.destroy(handle) |
| 129 | + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: |
| 130 | + nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") |
| 131 | + nvjitlink.complete(handle) |
122 | 132 |
|
123 | 133 |
|
124 | 134 | @arch_ptx_parametrized |
125 | 135 | def test_add_file(arch, ptx_bytes, tmp_path): |
126 | | - handle = nvjitlink.create(1, [f"-arch={arch}"]) |
127 | | - file_path = tmp_path / "test_file.cubin" |
128 | | - file_path.write_bytes(ptx_bytes) |
129 | | - nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path)) |
130 | | - nvjitlink.complete(handle) |
131 | | - nvjitlink.destroy(handle) |
| 136 | + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: |
| 137 | + file_path = tmp_path / "test_file.cubin" |
| 138 | + file_path.write_bytes(ptx_bytes) |
| 139 | + nvjitlink.add_file(handle, nvjitlink.InputType.ANY, str(file_path)) |
| 140 | + nvjitlink.complete(handle) |
132 | 141 |
|
133 | 142 |
|
134 | 143 | @pytest.mark.parametrize("arch", ARCHITECTURES) |
135 | 144 | def test_get_error_log(arch): |
136 | | - handle = nvjitlink.create(1, [f"-arch={arch}"]) |
137 | | - nvjitlink.complete(handle) |
138 | | - log_size = nvjitlink.get_error_log_size(handle) |
139 | | - log = bytearray(log_size) |
140 | | - nvjitlink.get_error_log(handle, log) |
141 | | - assert len(log) == log_size |
142 | | - nvjitlink.destroy(handle) |
| 145 | + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: |
| 146 | + nvjitlink.complete(handle) |
| 147 | + log_size = nvjitlink.get_error_log_size(handle) |
| 148 | + log = bytearray(log_size) |
| 149 | + nvjitlink.get_error_log(handle, log) |
| 150 | + assert len(log) == log_size |
143 | 151 |
|
144 | 152 |
|
145 | 153 | @arch_ptx_parametrized |
146 | 154 | def test_get_info_log(arch, ptx_bytes): |
147 | | - handle = nvjitlink.create(1, [f"-arch={arch}"]) |
148 | | - nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") |
149 | | - nvjitlink.complete(handle) |
150 | | - log_size = nvjitlink.get_info_log_size(handle) |
151 | | - log = bytearray(log_size) |
152 | | - nvjitlink.get_info_log(handle, log) |
153 | | - assert len(log) == log_size |
154 | | - nvjitlink.destroy(handle) |
| 155 | + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: |
| 156 | + nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") |
| 157 | + nvjitlink.complete(handle) |
| 158 | + log_size = nvjitlink.get_info_log_size(handle) |
| 159 | + log = bytearray(log_size) |
| 160 | + nvjitlink.get_info_log(handle, log) |
| 161 | + assert len(log) == log_size |
155 | 162 |
|
156 | 163 |
|
157 | 164 | @arch_ptx_parametrized |
158 | 165 | def test_get_linked_cubin(arch, ptx_bytes): |
159 | | - handle = nvjitlink.create(1, [f"-arch={arch}"]) |
160 | | - nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") |
161 | | - nvjitlink.complete(handle) |
162 | | - cubin_size = nvjitlink.get_linked_cubin_size(handle) |
163 | | - cubin = bytearray(cubin_size) |
164 | | - nvjitlink.get_linked_cubin(handle, cubin) |
165 | | - assert len(cubin) == cubin_size |
166 | | - nvjitlink.destroy(handle) |
| 166 | + with nvjitlink_session(1, [f"-arch={arch}"]) as handle: |
| 167 | + nvjitlink.add_data(handle, nvjitlink.InputType.ANY, ptx_bytes, len(ptx_bytes), "test_data") |
| 168 | + nvjitlink.complete(handle) |
| 169 | + cubin_size = nvjitlink.get_linked_cubin_size(handle) |
| 170 | + cubin = bytearray(cubin_size) |
| 171 | + nvjitlink.get_linked_cubin(handle, cubin) |
| 172 | + assert len(cubin) == cubin_size |
167 | 173 |
|
168 | 174 |
|
169 | 175 | @pytest.mark.parametrize("arch", ARCHITECTURES) |
170 | 176 | def test_get_linked_ptx(arch, get_dummy_ltoir): |
171 | | - handle = nvjitlink.create(3, [f"-arch={arch}", "-lto", "-ptx"]) |
172 | | - nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data") |
173 | | - nvjitlink.complete(handle) |
174 | | - ptx_size = nvjitlink.get_linked_ptx_size(handle) |
175 | | - ptx = bytearray(ptx_size) |
176 | | - nvjitlink.get_linked_ptx(handle, ptx) |
177 | | - assert len(ptx) == ptx_size |
178 | | - nvjitlink.destroy(handle) |
| 177 | + with nvjitlink_session(3, [f"-arch={arch}", "-lto", "-ptx"]) as handle: |
| 178 | + nvjitlink.add_data(handle, nvjitlink.InputType.LTOIR, get_dummy_ltoir, len(get_dummy_ltoir), "test_data") |
| 179 | + nvjitlink.complete(handle) |
| 180 | + ptx_size = nvjitlink.get_linked_ptx_size(handle) |
| 181 | + ptx = bytearray(ptx_size) |
| 182 | + nvjitlink.get_linked_ptx(handle, ptx) |
| 183 | + assert len(ptx) == ptx_size |
179 | 184 |
|
180 | 185 |
|
181 | 186 | def test_package_version(): |
|
0 commit comments