Skip to content

Commit 14b21c0

Browse files
Merge branch 'master' into hip-template-kernels
2 parents 897ec55 + 3cfcb9f commit 14b21c0

15 files changed

Lines changed: 847 additions & 50 deletions

doc/requirements.txt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ domdf-python-tools==3.10.0 ; python_version >= "3.9" and python_version < "3.15"
1616
exceptiongroup==1.2.2 ; python_version >= "3.9" and python_version < "3.11"
1717
executing==2.2.0 ; python_version >= "3.9" and python_version < "3.15"
1818
fastjsonschema==2.21.1 ; python_version >= "3.9" and python_version < "3.15"
19-
idna==3.10 ; python_version >= "3.9" and python_version < "3.15"
19+
idna==3.15 ; python_version >= "3.9" and python_version < "3.15"
2020
imagesize==1.4.1 ; python_version >= "3.9" and python_version < "3.15"
2121
importlib-metadata==8.6.1 ; python_version >= "3.9" and python_version < "3.10"
2222
iniconfig==2.0.0 ; python_version >= "3.9" and python_version < "3.15"
@@ -31,10 +31,10 @@ jupyter-core==5.7.2 ; python_version >= "3.9" and python_version < "3.15"
3131
jupyterlab-pygments==0.3.0 ; python_version >= "3.9" and python_version < "3.15"
3232
markupsafe==2.1.5 ; python_version >= "3.9" and python_version < "3.15"
3333
matplotlib-inline==0.1.7 ; python_version >= "3.9" and python_version < "3.15"
34-
mistune==3.1.2 ; python_version >= "3.9" and python_version < "3.15"
34+
mistune==3.2.1 ; python_version >= "3.9" and python_version < "3.15"
3535
natsort==8.4.0 ; python_version >= "3.9" and python_version < "3.15"
3636
nbclient==0.10.2 ; python_version >= "3.9" and python_version < "3.15"
37-
nbconvert==7.17.0 ; python_version >= "3.9" and python_version < "3.15"
37+
nbconvert==7.17.1 ; python_version >= "3.9" and python_version < "3.15"
3838
nbformat==5.10.4 ; python_version >= "3.9" and python_version < "3.15"
3939
nbsphinx==0.9.7 ; python_version >= "3.9" and python_version < "3.15"
4040
numpy==1.26.4 ; python_version >= "3.9" and python_version < "3.15"
@@ -49,15 +49,15 @@ prompt-toolkit==3.0.50 ; python_version >= "3.9" and python_version < "3.15"
4949
ptyprocess==0.7.0 ; python_version >= "3.9" and python_version < "3.15" and sys_platform != "win32"
5050
pure-eval==0.2.3 ; python_version >= "3.9" and python_version < "3.15"
5151
pycparser==2.22 ; python_version >= "3.9" and python_version < "3.15" and implementation_name == "pypy"
52-
pygments==2.19.1 ; python_version >= "3.9" and python_version < "3.15"
53-
pytest==8.3.5 ; python_version >= "3.9" and python_version < "3.15"
52+
pygments==2.20.0 ; python_version >= "3.9" and python_version < "3.15"
53+
pytest==9.0.3 ; python_version >= "3.9" and python_version < "3.15"
5454
python-constraint2==2.1.0 ; python_version >= "3.9" and python_version < "3.15"
5555
python-dateutil==2.9.0.post0 ; python_version >= "3.9" and python_version < "3.15"
5656
pytz==2025.1 ; python_version >= "3.9" and python_version < "3.15"
5757
pywin32==308 ; sys_platform == "win32" and platform_python_implementation != "PyPy" and python_version >= "3.9" and python_version < "3.15"
5858
pyzmq==26.2.1 ; python_version >= "3.9" and python_version < "3.15"
5959
referencing==0.36.2 ; python_version >= "3.9" and python_version < "3.15"
60-
requests==2.32.4 ; python_version >= "3.9" and python_version < "3.15"
60+
requests==2.33.0 ; python_version >= "3.9" and python_version < "3.15"
6161
rpds-py==0.23.1 ; python_version >= "3.9" and python_version < "3.15"
6262
scikit-learn==1.6.1 ; python_version >= "3.9" and python_version < "3.15"
6363
scipy==1.13.1 ; python_version >= "3.9" and python_version < "3.15"
@@ -78,11 +78,11 @@ stack-data==0.6.3 ; python_version >= "3.9" and python_version < "3.15"
7878
threadpoolctl==3.5.0 ; python_version >= "3.9" and python_version < "3.15"
7979
tinycss2==1.4.0 ; python_version >= "3.9" and python_version < "3.15"
8080
tomli==2.2.1 ; python_version >= "3.9" and python_version < "3.15"
81-
tornado==6.5.1 ; python_version >= "3.9" and python_version < "3.15"
81+
tornado==6.5.5 ; python_version >= "3.9" and python_version < "3.15"
8282
traitlets==5.14.3 ; python_version >= "3.9" and python_version < "3.15"
8383
typing-extensions==4.12.2 ; python_version >= "3.9" and python_version < "3.15"
8484
tzdata==2025.1 ; python_version >= "3.9" and python_version < "3.15"
85-
urllib3==2.6.3 ; python_version >= "3.9" and python_version < "3.15"
85+
urllib3==2.7.0 ; python_version >= "3.9" and python_version < "3.15"
8686
wcwidth==0.2.13 ; python_version >= "3.9" and python_version < "3.15"
8787
webencodings==0.5.1 ; python_version >= "3.9" and python_version < "3.15"
8888
xmltodict==0.14.2 ; python_version >= "3.9" and python_version < "3.15"

doc/requirements_test.txt

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,9 @@ ptyprocess==0.7.0 ; python_version >= "3.10" and python_version < "4" and (os_na
311311
pure-eval==0.2.3 ; python_version >= "3.10" and python_version < "4" \
312312
--hash=sha256:1db8e35b67b3d218d818ae653e27f06c3aa420901fa7b081ca98cbedc874e0d0 \
313313
--hash=sha256:5f4e983f40564c576c7c8635ae88db5956bb2229d7e9237d03b3c0b0190eaf42
314-
pygments==2.19.1 ; python_version >= "3.10" and python_version < "4" \
315-
--hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \
316-
--hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c
314+
pygments==2.20.0 ; python_version >= "3.10" and python_version < "4" \
315+
--hash=sha256:6757cd03768053ff99f3039c1a36d6c0aa0b263438fcab17520b30a303a82b5f \
316+
--hash=sha256:81a9e26dd42fd28a23a2d169d86d7ac03b46e2f8b59ed4698fb4785f946d0176
317317
pyproject-hooks==1.2.0 ; python_version >= "3.10" and python_version < "4" \
318318
--hash=sha256:1e859bd5c40fae9448642dd871adf459e5e2084186e8d2c2a79a824c970da1f8 \
319319
--hash=sha256:9e5c6bfa8dcc30091c74b0cf803c81fdd29d94f01992a7707bc97babb1141913
@@ -323,9 +323,9 @@ pytest-cov==5.0.0 ; python_version >= "3.10" and python_version < "4" \
323323
pytest-timeout==2.3.1 ; python_version >= "3.10" and python_version < "4" \
324324
--hash=sha256:12397729125c6ecbdaca01035b9e5239d4db97352320af155b3f5de1ba5165d9 \
325325
--hash=sha256:68188cb703edfc6a18fad98dc25a3c61e9f24d644b0b70f33af545219fc7813e
326-
pytest==8.3.5 ; python_version >= "3.10" and python_version < "4" \
327-
--hash=sha256:c69214aa47deac29fad6c2a4f590b9c4a9fdb16a403176fe154b79c0b4d4d820 \
328-
--hash=sha256:f4efe70cc14e511565ac476b57c279e12a855b11f48f212af1080ef2263d3845
326+
pytest==9.0.3 ; python_version >= "3.10" and python_version < "4" \
327+
--hash=sha256:2c5efc453d45394fdd706ade797c0a81091eccd1d6e4bccfcd476e2b8e0ab5d9 \
328+
--hash=sha256:b86ada508af81d19edeb213c681b1d48246c1a91d304c6c81a427674c17eb91c
329329
python-constraint2==2.2.2 ; python_version >= "3.10" and python_version < "4" \
330330
--hash=sha256:02dcdf6d6f2d403b6304dddb242ef1b3db791600c7b8f8cd895dc3f87509bc6e \
331331
--hash=sha256:0951ff7ee0d549037ed078ecf828f33003730531a7231f9773c3674553362efa \

kernel_tuner/accuracy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def _find_bfloat16_if_available():
9696
+ "please install either the package `ml_dtypes`, `jax`, or `tensorflow`"
9797
)
9898

99-
return None
99+
return dtype
100100

101101

102102
def _to_float_dtype(x: str) -> np.dtype:

kernel_tuner/backends/cupy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
# and run tests without cupy installed
1111
try:
1212
import cupy as cp
13+
from cupyx import get_runtime_info
1314
except ImportError:
1415
cp = None
1516

@@ -68,7 +69,7 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
6869

6970
# collect environment information
7071
env = dict()
71-
cupy_info = str(cp._cupyx.get_runtime_info()).split("\n")[:-1]
72+
cupy_info = str(get_runtime_info()).split("\n")[:-1]
7273
info_dict = {
7374
s.split(":")[0].strip(): s.split(":")[1].strip() for s in cupy_info
7475
}

kernel_tuner/backends/nvcuda.py

Lines changed: 38 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,12 @@
22
from warnings import warn
33

44
import numpy as np
5+
import os
56

67
from kernel_tuner.backends.backend import GPUBackend
78
from kernel_tuner.observers.nvcuda import CudaRuntimeObserver
89
from kernel_tuner.util import SkippableFailure
9-
from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc
10+
from kernel_tuner.utils.nvcuda import cuda_error_check, to_valid_nvrtc_gpu_arch_cc, find_cuda_home
1011

1112
# embedded in try block to be able to generate documentation
1213
# and run tests without cuda-python installed
@@ -74,9 +75,6 @@ def __init__(self, device=0, iterations=7, compiler_options=None, observers=None
7475
self.current_module = None
7576
self.func = None
7677
self.compiler_options = compiler_options or []
77-
self.compiler_options_bytes = []
78-
for option in self.compiler_options:
79-
self.compiler_options_bytes.append(str(option).encode("UTF-8"))
8078

8179
# create a stream and events
8280
err, self.stream = driver.cuStreamCreate(0)
@@ -154,37 +152,60 @@ def compile(self, kernel_instance):
154152
"""
155153
kernel_string = kernel_instance.kernel_string
156154
kernel_name = kernel_instance.name
155+
expression_name = str.encode(kernel_name)
156+
compiler_options = list(self.compiler_options)
157157

158-
# mimic pycuda behavior to wrap kernel_string in extern "C" if not in kernel_string already
159-
if 'extern "C"' not in kernel_string:
160-
kernel_string = 'extern "C" {\n' + kernel_string + "\n}"
158+
# Add -std=c++11
159+
if not any(opt.startswith(("-std=", "--std=")) for opt in self.compiler_options):
160+
compiler_options.append("--std=c++11")
161161

162-
compiler_options = self.compiler_options_bytes
163-
if not any([b"--std=" in opt for opt in compiler_options]):
164-
compiler_options.append(b"--std=c++11")
165-
if not any(["--std=" in opt for opt in self.compiler_options]):
166-
self.compiler_options.append("--std=c++11")
167-
if not any([b"--gpu-architecture=" in opt or b"-arch" in opt for opt in compiler_options]):
168-
compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}".encode("UTF-8"))
169-
if not any(["--gpu-architecture=" in opt or "-arch" in opt for opt in self.compiler_options]):
170-
self.compiler_options.append(f"--gpu-architecture=compute_{to_valid_nvrtc_gpu_arch_cc(self.cc)}")
162+
# Add -arch
163+
if not any(opt.startswith(("-arch", "--arch", "--gpu-architecture=")) for opt in self.compiler_options):
164+
arch_val = to_valid_nvrtc_gpu_arch_cc(self.cc)
165+
compiler_options.append(f"--gpu-architecture=compute_{arch_val}")
166+
167+
# Add CUDA home to include path
168+
cuda_home = find_cuda_home()
169+
if cuda_home:
170+
cuda_include = os.path.join(cuda_home, "include")
171+
compiler_options.append(f"-I{cuda_include}")
172+
173+
# nvrtcCompileProgram requires bytes instead of str
174+
compiler_options = [str(opt).encode("UTF-8") for opt in compiler_options]
171175

172176
err, program = nvrtc.nvrtcCreateProgram(str.encode(kernel_string), b"CUDAProgram", 0, [], [])
173177
try:
178+
# Add the kernel as an expression. This is necessary for templated kernels to ensure that the
179+
# compiler actually instantiates the kernel that we want to compile.
180+
cuda_error_check(err)
181+
err = nvrtc.nvrtcAddNameExpression(program, expression_name)
182+
183+
# Compile the program
174184
cuda_error_check(err)
175185
err = nvrtc.nvrtcCompileProgram(program, len(compiler_options), compiler_options)
186+
187+
# Get the PTX
176188
cuda_error_check(err)
177189
err, size = nvrtc.nvrtcGetPTXSize(program)
178190
cuda_error_check(err)
179191
buff = b" " * size
180192
err = nvrtc.nvrtcGetPTX(program, buff)
181193
cuda_error_check(err)
194+
195+
# Load the module
182196
err, self.current_module = driver.cuModuleLoadData(np.char.array(buff))
183197
if err == driver.CUresult.CUDA_ERROR_INVALID_PTX:
184198
raise SkippableFailure("uses too much shared data")
185199
else:
186200
cuda_error_check(err)
187-
err, self.func = driver.cuModuleGetFunction(self.current_module, str.encode(kernel_name))
201+
202+
# First, get the "lowered" name of the kernel (i.e., the name inside the PTX).
203+
# After, we can use the lowered name to lookup the kernel in the module.
204+
err, lowered_name = nvrtc.nvrtcGetLoweredName(program, expression_name)
205+
cuda_error_check(err)
206+
err, self.func = driver.cuModuleGetFunction(
207+
self.current_module, lowered_name
208+
)
188209
cuda_error_check(err)
189210

190211
# get the number of registers per thread used in this kernel

kernel_tuner/interface.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,9 @@
6565
pyatf_strategies,
6666
random_sample,
6767
simulated_annealing,
68-
skopt
68+
skopt,
69+
gen_hybrid_vndx,
70+
gen_adaptive_tabu_greywolf,
6971
)
7072
from kernel_tuner.strategies.wrapper import OptAlgWrapper
7173

@@ -87,6 +89,8 @@
8789
"firefly_algorithm": firefly_algorithm,
8890
"bayes_opt": bayes_opt,
8991
"pyatf_strategies": pyatf_strategies,
92+
"hybrid_vndx": gen_hybrid_vndx,
93+
"adaptive_tabu_greywolf": gen_adaptive_tabu_greywolf,
9094
}
9195

9296

@@ -397,6 +401,8 @@ def __deepcopy__(self, _):
397401
* "random_sample" takes a random sample of the search space
398402
* "simulated_annealing" simulated annealing strategy
399403
* "skopt" uses the minimization methods from `skopt`
404+
* "HybridVNDX" a hybrid variable neighborhood descent strategy
405+
* "AdaptiveTabuGreyWolf" an adaptive tabu-guided grey wolf optimization strategy
400406
401407
Strategy-specific parameters and options are explained under strategy_options.
402408

0 commit comments

Comments
 (0)