Skip to content

Commit ea0cab7

Browse files
committed
freeze the class
1 parent 08c6a1a commit ea0cab7

3 files changed

Lines changed: 14 additions & 26 deletions

File tree

codeflash/benchmarking/plugin/plugin.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,8 @@
1212
from codeflash.benchmarking.codeflash_trace import codeflash_trace
1313
from codeflash.code_utils.code_utils import module_name_from_file_path
1414

15-
# from codeflash.models.models import BenchmarkKey
1615

17-
18-
@dataclass
16+
@dataclass(frozen=True)
1917
class BenchmarkKey:
2018
module_path: str
2119
function_name: str
@@ -192,6 +190,7 @@ def get_benchmark_timings(trace_path: Path) -> dict[BenchmarkKey, int]:
192190

193191
# Create the benchmark key (file::function::line)
194192
benchmark_key = BenchmarkKey(module_path=benchmark_file, function_name=benchmark_func)
193+
print(f"XXX Processing benchmark: {benchmark_key}")
195194
# Subtract overhead from total time
196195
overhead = overhead_by_benchmark.get(benchmark_key, 0)
197196
result[benchmark_key] = time_ns - overhead
@@ -232,26 +231,16 @@ def pytest_configure(config: pytest.Config) -> None:
232231

233232
@staticmethod
234233
def pytest_collection_modifyitems(config: pytest.Config, items: list[pytest.Item]) -> None:
235-
# Skip tests that don't have the benchmark fixture
234+
# Skip tests that don't have the benchmark marker
236235
if not config.getoption("--codeflash-trace"):
237236
return
238237

239-
skip_no_benchmark = pytest.mark.skip(reason="Test requires benchmark fixture")
240238
for item in items:
241-
# Check for direct benchmark fixture usage
242-
has_fixture = hasattr(item, "fixturenames") and "benchmark" in item.fixturenames
243-
244239
# Check for @pytest.mark.benchmark marker
245-
has_marker = False
246240
if hasattr(item, "get_closest_marker"):
247241
marker = item.get_closest_marker("benchmark")
248-
if marker is not None:
249-
has_marker = True
250-
print("XXX FOUND THE BENCHMARK MARKER")
251-
252-
# Skip if neither fixture nor marker is present
253-
if not (has_fixture or has_marker):
254-
item.add_marker(skip_no_benchmark)
242+
if marker is None:
243+
item.add_marker(pytest.mark.skip(reason="Test requires benchmark marker"))
255244

256245
# Benchmark fixture
257246
class Benchmark: # noqa: D106
@@ -307,11 +296,10 @@ def _run_benchmark(self, func, *args, **kwargs): # noqa: ANN001, ANN002, ANN003
307296
@staticmethod
308297
@pytest.fixture
309298
def benchmark(request: pytest.FixtureRequest) -> object:
299+
"""Fixture to provide the benchmark functionality."""
310300
if not request.config.getoption("--codeflash-trace"):
311-
print("XXX NOOOO PLS")
312301
return None
313302
print("XXX BENCHMARK PLUGIN INITIATED")
314-
315303
return CodeFlashBenchmarkPlugin.Benchmark(request)
316304

317305

codeflash/benchmarking/pytest_new_process_trace_benchmarks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@
2828
"-p",
2929
"no:profiling",
3030
"-s",
31+
"-vv",
3132
"-o",
3233
"addopts=",
3334
],
3435
plugins=[codeflash_benchmark_plugin],
3536
) # Errors will be printed to stdout, not stderr
36-
3737
except Exception as e:
3838
print(f"Failed to collect tests: {e!s}", file=sys.stderr)
3939
exitcode = -1

codeflash/optimization/optimizer.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -89,13 +89,13 @@ def run_benchmarks(
8989
logger.info(
9090
f"No valid benchmarks found in {self.args.benchmarks_root} for functions to optimize, continuing optimization"
9191
)
92-
else:
93-
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
94-
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
95-
function_to_results = validate_and_format_benchmark_table(
96-
function_benchmark_timings, total_benchmark_timings
97-
)
98-
print_benchmark_table(function_to_results)
92+
raise SystemExit # noqa: TRY301
93+
function_benchmark_timings = CodeFlashBenchmarkPlugin.get_function_benchmark_timings(trace_file)
94+
total_benchmark_timings = CodeFlashBenchmarkPlugin.get_benchmark_timings(trace_file)
95+
function_to_results = validate_and_format_benchmark_table(
96+
function_benchmark_timings, total_benchmark_timings
97+
)
98+
print_benchmark_table(function_to_results)
9999
except Exception as e:
100100
logger.info(f"Error while tracing existing benchmarks: {e}")
101101
logger.info("Information on existing benchmarks will not be available for this run.")

0 commit comments

Comments
 (0)