|
2 | 2 | from warnings import warn |
3 | 3 |
|
4 | 4 | import numpy as np |
| 5 | +import os |
5 | 6 |
|
6 | 7 | from kernel_tuner.backends.backend import GPUBackend |
7 | 8 | from kernel_tuner.observers.nvcuda import CudaRuntimeObserver |
8 | 9 | from kernel_tuner.util import SkippableFailure |
9 | | -from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc |
| 10 | +from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc, find_cuda_home |
10 | 11 |
|
11 | 12 | # embedded in try block to be able to generate documentation |
12 | 13 | # and run tests without cuda-python installed |
@@ -74,9 +75,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None |
74 | 75 | self.current_module = None |
75 | 76 | self.func = None |
76 | 77 | self.compiler_options = compiler_options or [] |
77 | | - self.compiler_options_bytes = [] |
78 | | - for option in self.compiler_options: |
79 | | - self.compiler_options_bytes.append(str(option).encode("UTF-8")) |
80 | 78 |
|
81 | 79 | # create a stream and events |
82 | 80 | err, self.stream = driver.cuStreamCreate(0) |
@@ -154,37 +152,60 @@ def compile(self, kernel_instance): |
154 | 152 | """ |
155 | 153 | kernel_string = kernel_instance.kernel_string |
156 | 154 | kernel_name = kernel_instance.name |
| 155 | + expression_name = str.encode(kernel_name) |
| 156 | + compiler_options = list(self.compiler_options) |
157 | 157 |
|
158 | | - # mimic pycuda behavior to wrap kernel_string in extern "C" if not in kernel_string already |
159 | | - if 'extern "C"' not in kernel_string: |
160 | | - kernel_string = 'extern "C" {\n' + kernel_string + "\n}" |
| 158 | + # Add -std=c++11 |
| 159 | + if not any(opt.startswith(("-std=", "--std=")) for opt in self.compiler_options): |
| 160 | + compiler_options.append("--std=c++11") |
161 | 161 |
|
162 | | - compiler_options = self.compiler_options_bytes |
163 | | - if not any([b"--std=" in opt for opt in compiler_options]): |
164 | | - compiler_options.append(b"--std=c++11") |
165 | | - if not any(["--std=" in opt for opt in self.compiler_options]): |
166 | | - self.compiler_options.append("--std=c++11") |
167 | | - if not any([b"--gpu-architecture=" in opt or b"-arch" in opt for opt in compiler_options]): |
168 | | - compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8")) |
169 | | - if not any(["--gpu-architecture=" in opt or "-arch" in opt for opt in self.compiler_options]): |
170 | | - self.compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}") |
| 162 | + # Add -arch |
| 163 | + if not any(opt.startswith(("-arch", "--arch", "--gpu-architecture=")) for opt in self.compiler_options): |
| 164 | + arch_val = to_valid_nvrtc_gpu_arch_cc(self.cc) |
| 165 | + compiler_options.append(f"--gpu-architecture=compute_{arch_val}") |
| 166 | + |
| 167 | + # Add CUDA home to include path |
| 168 | + cuda_home = find_cuda_home() |
| 169 | + if cuda_home: |
| 170 | + cuda_include = os.path.join(cuda_home, "include") |
| 171 | + compiler_options.append(f"-I{cuda_include}") |
| 172 | + |
| 173 | + # nvrtcCompileProgram requires bytes instead of str |
| 174 | + compiler_options = [str(opt).encode("UTF-8") for opt in compiler_options] |
171 | 175 |
|
172 | 176 | err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], []) |
173 | 177 | try: |
| 178 | + # Add the kernel as an expression. This is necessary for templated kernels to ensure that the |
| 179 | + # compiler actually instantiates the kernel that we want to compile. |
| 180 | + cuda_error_check(err) |
| 181 | + err = nvrtc.nvrtcAddNameExpression(program, expression_name) |
| 182 | + |
| 183 | + # Compile the program |
174 | 184 | cuda_error_check(err) |
175 | 185 | err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options) |
| 186 | + |
| 187 | + # Get the PTX |
176 | 188 | cuda_error_check(err) |
177 | 189 | err, size = nvrtc.nvrtcGetPTXSize(program) |
178 | 190 | cuda_error_check(err) |
179 | 191 | buff = b" " * size |
180 | 192 | err = nvrtc.nvrtcGetPTX(program, buff) |
181 | 193 | cuda_error_check(err) |
| 194 | + |
| 195 | + # Load the module |
182 | 196 | err, self.current_module = driver.cuModuleLoadData(np.char.array(buff)) |
183 | 197 | if err == driver.CUresult.CUDA_ERROR_INVALID_PTX: |
184 | 198 | raise SkippableFailure("uses too much shared data") |
185 | 199 | else: |
186 | 200 | cuda_error_check(err) |
187 | | - err, self.func = driver.cuModuleGetFunction(self.current_module, str.encode(kernel_name)) |
| 201 | + |
| 202 | + # First, get the "lowered" name of the kernel (i.e., the name inside the PTX). |
| 203 | + # After, we can use the lowered name to lookup the kernel in the module. |
| 204 | + err, lowered_name = nvrtc.nvrtcGetLoweredName(program, expression_name) |
| 205 | + cuda_error_check(err) |
| 206 | + err, self.func = driver.cuModuleGetFunction( |
| 207 | + self.current_module, lowered_name |
| 208 | + ) |
188 | 209 | cuda_error_check(err) |
189 | 210 |
|
190 | 211 | # get the number of registers per thread used in this kernel |
|
0 commit comments