33# SPDX-License-Identifier: Apache-2.0
44
55import argparse
6+ import ast
67import importlib .util
7- import inspect
8+ import os
89import sys
910from collections .abc import Callable
1011from pathlib import Path
1516PROJECT_ROOT = Path (__file__ ).resolve ().parent .parent
1617BENCH_DIR = PROJECT_ROOT / "benchmarks"
1718DEFAULT_OUTPUT = PROJECT_ROOT / "results-python.json"
19+ PYPERF_INHERITED_ENV_VARS = (
20+ "CUDA_HOME" ,
21+ "CUDA_PATH" ,
22+ "CUDA_VISIBLE_DEVICES" ,
23+ "LD_LIBRARY_PATH" ,
24+ "NVIDIA_VISIBLE_DEVICES" ,
25+ )
26+ _MODULE_CACHE : dict [Path , ModuleType ] = {}
1827
1928
2029def load_module (module_path : Path ) -> ModuleType :
30+ module_path = module_path .resolve ()
31+ cached_module = _MODULE_CACHE .get (module_path )
32+ if cached_module is not None :
33+ return cached_module
34+
2135 module_name = f"cuda_bindings_bench_{ module_path .stem } "
2236 spec = importlib .util .spec_from_file_location (module_name , module_path )
2337 if spec is None or spec .loader is None :
2438 raise RuntimeError (f"Failed to load benchmark module: { module_path } " )
2539 module = importlib .util .module_from_spec (spec )
2640 spec .loader .exec_module (module )
41+ _MODULE_CACHE [module_path ] = module
2742 return module
2843
2944
@@ -33,6 +48,29 @@ def benchmark_id(module_name: str, function_name: str) -> str:
3348 return f"{ module_suffix } .{ suffix } "
3449
3550
51+ def _discover_module_functions (module_path : Path ) -> list [str ]:
52+ tree = ast .parse (module_path .read_text (encoding = "utf-8" ), filename = str (module_path ))
53+ return [
54+ node .name
55+ for node in tree .body
56+ if isinstance (node , (ast .FunctionDef , ast .AsyncFunctionDef )) and node .name .startswith ("bench_" )
57+ ]
58+
59+
60+ def _lazy_benchmark (module_path : Path , function_name : str ) -> Callable [[int ], float ]:
61+ loaded_function : Callable [[int ], float ] | None = None
62+
63+ def run (loops : int ) -> float :
64+ nonlocal loaded_function
65+ if loaded_function is None :
66+ module = load_module (module_path )
67+ loaded_function = getattr (module , function_name )
68+ return loaded_function (loops )
69+
70+ run .__name__ = function_name
71+ return run
72+
73+
3674def discover_benchmarks () -> dict [str , Callable [[int ], float ]]:
3775 """Discover bench_ functions.
3876
@@ -42,24 +80,19 @@ def discover_benchmarks() -> dict[str, Callable[[int], float]]:
4280 """
4381 registry : dict [str , Callable [[int ], float ]] = {}
4482 for module_path in sorted (BENCH_DIR .glob ("bench_*.py" )):
45- module = load_module (module_path )
4683 module_name = module_path .stem
47- for function_name , function in inspect .getmembers (module , inspect .isfunction ):
48- if not function_name .startswith ("bench_" ):
49- continue
50- if function .__module__ != module .__name__ :
51- continue
84+ for function_name in _discover_module_functions (module_path ):
5285 bench_id = benchmark_id (module_name , function_name )
5386 if bench_id in registry :
5487 raise ValueError (f"Duplicate benchmark ID discovered: { bench_id } " )
55- registry [bench_id ] = function
88+ registry [bench_id ] = _lazy_benchmark ( module_path , function_name )
5689 return registry
5790
5891
5992def strip_pyperf_output_args (argv : list [str ]) -> list [str ]:
6093 cleaned : list [str ] = []
6194 skip_next = False
62- for i , arg in enumerate ( argv ) :
95+ for arg in argv :
6396 if skip_next :
6497 skip_next = False
6598 continue
@@ -72,6 +105,48 @@ def strip_pyperf_output_args(argv: list[str]) -> list[str]:
72105 return cleaned
73106
74107
108+ def _split_env_vars (arg_value : str ) -> list [str ]:
109+ return [env_var for env_var in arg_value .split ("," ) if env_var ]
110+
111+
112+ def ensure_pyperf_worker_env (argv : list [str ]) -> list [str ]:
113+ if "--copy-env" in argv :
114+ return list (argv )
115+
116+ inherited_env : list [str ] = []
117+ cleaned : list [str ] = []
118+ skip_next = False
119+ for arg in argv :
120+ if skip_next :
121+ inherited_env .extend (_split_env_vars (arg ))
122+ skip_next = False
123+ continue
124+ if arg == "--inherit-environ" :
125+ skip_next = True
126+ continue
127+ if arg .startswith ("--inherit-environ=" ):
128+ inherited_env .extend (_split_env_vars (arg .partition ("=" )[2 ]))
129+ continue
130+ cleaned .append (arg )
131+
132+ if skip_next :
133+ raise ValueError ("Missing value for --inherit-environ" )
134+
135+ for env_var in PYPERF_INHERITED_ENV_VARS :
136+ if env_var in os .environ :
137+ inherited_env .append (env_var )
138+
139+ deduped_env : list [str ] = []
140+ for env_var in inherited_env :
141+ if env_var not in deduped_env :
142+ deduped_env .append (env_var )
143+
144+ if deduped_env :
145+ cleaned .extend (["--inherit-environ" , "," .join (deduped_env )])
146+
147+ return cleaned
148+
149+
75150def parse_args (argv : list [str ]) -> tuple [argparse .Namespace , list [str ]]:
76151 parser = argparse .ArgumentParser (add_help = False )
77152 parser .add_argument (
@@ -118,12 +193,13 @@ def main() -> None:
118193 else :
119194 benchmark_ids = sorted (registry )
120195
121- # Strip any --output args to avoid conflicts with our output handling
196+ # Strip any --output args to avoid conflicts with our output handling.
122197 output_path = parsed .output .resolve ()
123198 remaining_argv = strip_pyperf_output_args (remaining_argv )
199+ remaining_argv = ensure_pyperf_worker_env (remaining_argv )
124200 is_worker = "--worker" in remaining_argv
125201
126- # Delete the file so this run starts fresh
202+ # Delete the file so this run starts fresh.
127203 if not is_worker :
128204 output_path .unlink (missing_ok = True )
129205
0 commit comments