@@ -20,7 +20,7 @@ def __init__(self) -> None:
2020 self .project_root = None
2121 self .benchmark_timings = []
2222
23- def setup (self , trace_path :str , project_root :str ) -> None :
23+ def setup (self , trace_path : str , project_root : str ) -> None :
2424 try :
2525 # Open connection
2626 self .project_root = project_root
@@ -35,7 +35,7 @@ def setup(self, trace_path:str, project_root:str) -> None:
3535 "benchmark_time_ns INTEGER)"
3636 )
3737 self ._connection .commit ()
38- self .close () # Reopen only at the end of pytest session
38+ self .close () # Reopen only at the end of pytest session
3939 except Exception as e :
4040 print (f"Database setup error: { e } " )
4141 if self ._connection :
@@ -55,14 +55,15 @@ def write_benchmark_timings(self) -> None:
5555 # Insert data into the benchmark_timings table
5656 cur .executemany (
5757 "INSERT INTO benchmark_timings (benchmark_module_path, benchmark_function_name, benchmark_line_number, benchmark_time_ns) VALUES (?, ?, ?, ?)" ,
58- self .benchmark_timings
58+ self .benchmark_timings ,
5959 )
6060 self ._connection .commit ()
61- self .benchmark_timings = [] # Clear the benchmark timings list
61+ self .benchmark_timings = [] # Clear the benchmark timings list
6262 except Exception as e :
6363 print (f"Error writing to benchmark timings database: { e } " )
6464 self ._connection .rollback ()
6565 raise
66+
6667 def close (self ) -> None :
6768 if self ._connection :
6869 self ._connection .close ()
@@ -196,12 +197,7 @@ def pytest_sessionfinish(self, session, exitstatus):
196197
197198 @staticmethod
198199 def pytest_addoption (parser ):
199- parser .addoption (
200- "--codeflash-trace" ,
201- action = "store_true" ,
202- default = False ,
203- help = "Enable CodeFlash tracing"
204- )
200+ parser .addoption ("--codeflash-trace" , action = "store_true" , default = False , help = "Enable CodeFlash tracing" )
205201
206202 @staticmethod
207203 def pytest_plugin_registered (plugin , manager ):
@@ -213,9 +209,9 @@ def pytest_plugin_registered(plugin, manager):
213209 def pytest_configure (config ):
214210 """Register the benchmark marker."""
215211 config .addinivalue_line (
216- "markers" ,
217- "benchmark: mark test as a benchmark that should be run with codeflash tracing"
212+ "markers" , "benchmark: mark test as a benchmark that should be run with codeflash tracing"
218213 )
214+
219215 @staticmethod
220216 def pytest_collection_modifyitems (config , items ):
221217 # Skip tests that don't have the benchmark fixture
@@ -248,16 +244,19 @@ def __call__(self, func, *args, **kwargs):
248244 if args or kwargs :
249245 # Used as benchmark(func, *args, **kwargs)
250246 return self ._run_benchmark (func , * args , ** kwargs )
247+
251248 # Used as @benchmark decorator
252249 def wrapped_func (* args , ** kwargs ):
253250 return func (* args , ** kwargs )
251+
254252 result = self ._run_benchmark (func )
255253 return wrapped_func
256254
257255 def _run_benchmark (self , func , * args , ** kwargs ):
258256 """Actual benchmark implementation."""
259- benchmark_module_path = module_name_from_file_path (Path (str (self .request .node .fspath )),
260- Path (codeflash_benchmark_plugin .project_root ))
257+ benchmark_module_path = module_name_from_file_path (
258+ Path (str (self .request .node .fspath )), Path (codeflash_benchmark_plugin .project_root )
259+ )
261260 benchmark_function_name = self .request .node .name
262261 line_number = int (str (sys ._getframe (2 ).f_lineno )) # 2 frames up in the call stack
263262 # Set env vars
@@ -278,7 +277,8 @@ def _run_benchmark(self, func, *args, **kwargs):
278277 codeflash_trace .function_call_count = 0
279278 # Add to the benchmark timings buffer
280279 codeflash_benchmark_plugin .benchmark_timings .append (
281- (benchmark_module_path , benchmark_function_name , line_number , end - start ))
280+ (benchmark_module_path , benchmark_function_name , line_number , end - start )
281+ )
282282
283283 return result
284284
@@ -290,4 +290,5 @@ def benchmark(request):
290290
291291 return CodeFlashBenchmarkPlugin .Benchmark (request )
292292
293+
293294codeflash_benchmark_plugin = CodeFlashBenchmarkPlugin ()
0 commit comments