Skip to content

Commit 12dc9db

Browse files
rwgkdanielfrg
authored andcommitted
fix(bindings): keep benchmark discovery from loading NVRTC
The Python benchmark smoke test imported `bench_launch` during discovery, which compiled kernels before `pyperf` workers started and made the local-CTK Linux lanes fail when NVRTC was not yet visible in that process setup. Discover benchmark IDs without importing GPU modules, defer launch setup until a launch benchmark actually runs, and preserve CUDA-related environment variables for worker processes. Made-with: Cursor
1 parent 5727e60 commit 12dc9db

File tree

3 files changed

+305
-32
lines changed

3 files changed

+305
-32
lines changed

cuda_bindings/benchmarks/benchmarks/bench_launch.py

Lines changed: 51 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,11 @@
55
import ctypes
66
import time
77

8-
from runner.runtime import alloc_persistent, compile_and_load, ensure_context
8+
from runner.runtime import alloc_persistent, assert_drv, compile_and_load
99

1010
from cuda.bindings import driver as cuda
1111

12-
ensure_context()
13-
14-
# Compile kernels
12+
# Compile kernels lazily so benchmark discovery does not need NVRTC.
1513
KERNEL_SOURCE = """\
1614
extern "C" __global__ void empty_kernel() { return; }
1715
extern "C" __global__ void small_kernel(float *f) { *f = 0.0f; }
@@ -33,28 +31,57 @@
3331
{ *F = 0; }
3432
"""
3533

36-
MODULE = compile_and_load(KERNEL_SOURCE)
34+
MODULE = None
35+
EMPTY_KERNEL = None
36+
SMALL_KERNEL = None
37+
KERNEL_16_ARGS = None
38+
STREAM = None
39+
FLOAT_PTR = None
40+
INT_PTRS = None
41+
_VAL_PS = None
42+
PACKED_16 = None
43+
44+
45+
def _ensure_launch_state() -> None:
46+
global MODULE, EMPTY_KERNEL, SMALL_KERNEL, KERNEL_16_ARGS, STREAM
47+
global FLOAT_PTR, INT_PTRS, _VAL_PS, PACKED_16
48+
49+
if EMPTY_KERNEL is not None:
50+
return
51+
52+
module = compile_and_load(KERNEL_SOURCE)
53+
54+
err, empty_kernel = cuda.cuModuleGetFunction(module, b"empty_kernel")
55+
assert_drv(err)
56+
err, small_kernel = cuda.cuModuleGetFunction(module, b"small_kernel")
57+
assert_drv(err)
58+
err, kernel_16_args = cuda.cuModuleGetFunction(module, b"small_kernel_16_args")
59+
assert_drv(err)
3760

38-
# Get kernel handles
39-
_err, EMPTY_KERNEL = cuda.cuModuleGetFunction(MODULE, b"empty_kernel")
40-
_err, SMALL_KERNEL = cuda.cuModuleGetFunction(MODULE, b"small_kernel")
41-
_err, KERNEL_16_ARGS = cuda.cuModuleGetFunction(MODULE, b"small_kernel_16_args")
61+
err, stream = cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_NON_BLOCKING.value)
62+
assert_drv(err)
4263

43-
# Create a non-blocking stream for launches
44-
_err, STREAM = cuda.cuStreamCreate(cuda.CUstream_flags.CU_STREAM_NON_BLOCKING.value)
64+
float_ptr = alloc_persistent(ctypes.sizeof(ctypes.c_float))
65+
int_ptrs = tuple(alloc_persistent(ctypes.sizeof(ctypes.c_int)) for _ in range(16))
4566

46-
# Allocate device memory for kernel arguments
47-
FLOAT_PTR = alloc_persistent(ctypes.sizeof(ctypes.c_float))
48-
INT_PTRS = [alloc_persistent(ctypes.sizeof(ctypes.c_int)) for _ in range(16)]
67+
val_ps = [ctypes.c_void_p(int(ptr)) for ptr in int_ptrs]
68+
packed_16 = (ctypes.c_void_p * 16)()
69+
for index, value_ptr in enumerate(val_ps):
70+
packed_16[index] = ctypes.addressof(value_ptr)
4971

50-
# Pre-pack ctypes params for the pre-packed benchmark
51-
_val_ps = [ctypes.c_void_p(int(p)) for p in INT_PTRS]
52-
PACKED_16 = (ctypes.c_void_p * 16)()
53-
for _i in range(16):
54-
PACKED_16[_i] = ctypes.addressof(_val_ps[_i])
72+
MODULE = module
73+
EMPTY_KERNEL = empty_kernel
74+
SMALL_KERNEL = small_kernel
75+
KERNEL_16_ARGS = kernel_16_args
76+
STREAM = stream
77+
FLOAT_PTR = float_ptr
78+
INT_PTRS = int_ptrs
79+
_VAL_PS = val_ps
80+
PACKED_16 = packed_16
5581

5682

5783
def bench_launch_empty_kernel(loops: int) -> float:
84+
_ensure_launch_state()
5885
_cuLaunchKernel = cuda.cuLaunchKernel
5986
_kernel = EMPTY_KERNEL
6087
_stream = STREAM
@@ -66,6 +93,7 @@ def bench_launch_empty_kernel(loops: int) -> float:
6693

6794

6895
def bench_launch_small_kernel(loops: int) -> float:
96+
_ensure_launch_state()
6997
_cuLaunchKernel = cuda.cuLaunchKernel
7098
_kernel = SMALL_KERNEL
7199
_stream = STREAM
@@ -79,11 +107,12 @@ def bench_launch_small_kernel(loops: int) -> float:
79107

80108

81109
def bench_launch_16_args(loops: int) -> float:
110+
_ensure_launch_state()
82111
_cuLaunchKernel = cuda.cuLaunchKernel
83112
_kernel = KERNEL_16_ARGS
84113
_stream = STREAM
85-
_args = tuple(INT_PTRS)
86-
_arg_types = tuple([None] * 16)
114+
_args = INT_PTRS
115+
_arg_types = (None,) * 16
87116

88117
t0 = time.perf_counter()
89118
for _ in range(loops):
@@ -92,6 +121,7 @@ def bench_launch_16_args(loops: int) -> float:
92121

93122

94123
def bench_launch_16_args_pre_packed(loops: int) -> float:
124+
_ensure_launch_state()
95125
_cuLaunchKernel = cuda.cuLaunchKernel
96126
_kernel = KERNEL_16_ARGS
97127
_stream = STREAM

cuda_bindings/benchmarks/runner/main.py

Lines changed: 87 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
# SPDX-License-Identifier: Apache-2.0
44

55
import argparse
6+
import ast
67
import importlib.util
7-
import inspect
8+
import os
89
import sys
910
from collections.abc import Callable
1011
from pathlib import Path
@@ -15,15 +16,29 @@
1516
PROJECT_ROOT = Path(__file__).resolve().parent.parent
1617
BENCH_DIR = PROJECT_ROOT / "benchmarks"
1718
DEFAULT_OUTPUT = PROJECT_ROOT / "results-python.json"
19+
PYPERF_INHERITED_ENV_VARS = (
20+
"CUDA_HOME",
21+
"CUDA_PATH",
22+
"CUDA_VISIBLE_DEVICES",
23+
"LD_LIBRARY_PATH",
24+
"NVIDIA_VISIBLE_DEVICES",
25+
)
26+
_MODULE_CACHE: dict[Path, ModuleType] = {}
1827

1928

2029
def load_module(module_path: Path) -> ModuleType:
30+
module_path = module_path.resolve()
31+
cached_module = _MODULE_CACHE.get(module_path)
32+
if cached_module is not None:
33+
return cached_module
34+
2135
module_name = f"cuda_bindings_bench_{module_path.stem}"
2236
spec = importlib.util.spec_from_file_location(module_name, module_path)
2337
if spec is None or spec.loader is None:
2438
raise RuntimeError(f"Failed to load benchmark module: {module_path}")
2539
module = importlib.util.module_from_spec(spec)
2640
spec.loader.exec_module(module)
41+
_MODULE_CACHE[module_path] = module
2742
return module
2843

2944

@@ -33,6 +48,29 @@ def benchmark_id(module_name: str, function_name: str) -> str:
3348
return f"{module_suffix}.{suffix}"
3449

3550

51+
def _discover_module_functions(module_path: Path) -> list[str]:
52+
tree = ast.parse(module_path.read_text(encoding="utf-8"), filename=str(module_path))
53+
return [
54+
node.name
55+
for node in tree.body
56+
if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)) and node.name.startswith("bench_")
57+
]
58+
59+
60+
def _lazy_benchmark(module_path: Path, function_name: str) -> Callable[[int], float]:
61+
loaded_function: Callable[[int], float] | None = None
62+
63+
def run(loops: int) -> float:
64+
nonlocal loaded_function
65+
if loaded_function is None:
66+
module = load_module(module_path)
67+
loaded_function = getattr(module, function_name)
68+
return loaded_function(loops)
69+
70+
run.__name__ = function_name
71+
return run
72+
73+
3674
def discover_benchmarks() -> dict[str, Callable[[int], float]]:
3775
"""Discover bench_ functions.
3876
@@ -42,24 +80,19 @@ def discover_benchmarks() -> dict[str, Callable[[int], float]]:
4280
"""
4381
registry: dict[str, Callable[[int], float]] = {}
4482
for module_path in sorted(BENCH_DIR.glob("bench_*.py")):
45-
module = load_module(module_path)
4683
module_name = module_path.stem
47-
for function_name, function in inspect.getmembers(module, inspect.isfunction):
48-
if not function_name.startswith("bench_"):
49-
continue
50-
if function.__module__ != module.__name__:
51-
continue
84+
for function_name in _discover_module_functions(module_path):
5285
bench_id = benchmark_id(module_name, function_name)
5386
if bench_id in registry:
5487
raise ValueError(f"Duplicate benchmark ID discovered: {bench_id}")
55-
registry[bench_id] = function
88+
registry[bench_id] = _lazy_benchmark(module_path, function_name)
5689
return registry
5790

5891

5992
def strip_pyperf_output_args(argv: list[str]) -> list[str]:
6093
cleaned: list[str] = []
6194
skip_next = False
62-
for i, arg in enumerate(argv):
95+
for arg in argv:
6396
if skip_next:
6497
skip_next = False
6598
continue
@@ -72,6 +105,48 @@ def strip_pyperf_output_args(argv: list[str]) -> list[str]:
72105
return cleaned
73106

74107

108+
def _split_env_vars(arg_value: str) -> list[str]:
109+
return [env_var for env_var in arg_value.split(",") if env_var]
110+
111+
112+
def ensure_pyperf_worker_env(argv: list[str]) -> list[str]:
113+
if "--copy-env" in argv:
114+
return list(argv)
115+
116+
inherited_env: list[str] = []
117+
cleaned: list[str] = []
118+
skip_next = False
119+
for arg in argv:
120+
if skip_next:
121+
inherited_env.extend(_split_env_vars(arg))
122+
skip_next = False
123+
continue
124+
if arg == "--inherit-environ":
125+
skip_next = True
126+
continue
127+
if arg.startswith("--inherit-environ="):
128+
inherited_env.extend(_split_env_vars(arg.partition("=")[2]))
129+
continue
130+
cleaned.append(arg)
131+
132+
if skip_next:
133+
raise ValueError("Missing value for --inherit-environ")
134+
135+
for env_var in PYPERF_INHERITED_ENV_VARS:
136+
if env_var in os.environ:
137+
inherited_env.append(env_var)
138+
139+
deduped_env: list[str] = []
140+
for env_var in inherited_env:
141+
if env_var not in deduped_env:
142+
deduped_env.append(env_var)
143+
144+
if deduped_env:
145+
cleaned.extend(["--inherit-environ", ",".join(deduped_env)])
146+
147+
return cleaned
148+
149+
75150
def parse_args(argv: list[str]) -> tuple[argparse.Namespace, list[str]]:
76151
parser = argparse.ArgumentParser(add_help=False)
77152
parser.add_argument(
@@ -118,12 +193,13 @@ def main() -> None:
118193
else:
119194
benchmark_ids = sorted(registry)
120195

121-
# Strip any --output args to avoid conflicts with our output handling
196+
# Strip any --output args to avoid conflicts with our output handling.
122197
output_path = parsed.output.resolve()
123198
remaining_argv = strip_pyperf_output_args(remaining_argv)
199+
remaining_argv = ensure_pyperf_worker_env(remaining_argv)
124200
is_worker = "--worker" in remaining_argv
125201

126-
# Delete the file so this run starts fresh
202+
# Delete the file so this run starts fresh.
127203
if not is_worker:
128204
output_path.unlink(missing_ok=True)
129205

0 commit comments

Comments
 (0)