diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 87422cbd5..63c83149f 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -199,6 +199,8 @@ jobs: run: | uv run ruff check --fix . || true uv run ruff format . + # uv-dynamic-versioning rewrites version.py on every `uv run` — discard those changes + git checkout HEAD -- codeflash/version.py codeflash-benchmark/codeflash_benchmark/version.py 2>/dev/null || true - name: Commit and push fixes run: | diff --git a/benchmarks/__init__.py b/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/benchmarks/bench_cli_startup.py b/benchmarks/bench_cli_startup.py new file mode 100644 index 000000000..e6b8e0ad0 --- /dev/null +++ b/benchmarks/bench_cli_startup.py @@ -0,0 +1,72 @@ +"""Benchmark CLI startup latency for codeflash compare --script mode. + +Run from a worktree root. Installs deps via uv sync, then times several +CLI entry points and writes a JSON file mapping command names to median +wall-clock seconds. + +Usage: + codeflash compare main codeflash/optimize \ + --script "python benchmarks/bench_cli_startup.py" \ + --script-output benchmarks/results.json +""" + +from __future__ import annotations + +import json +import os +import subprocess +import time +from pathlib import Path + +WARMUP = 3 +RUNS = 30 +OUTPUT = os.environ.get("BENCH_OUTPUT", "benchmarks/results.json") + +COMMANDS: dict[str, list[str]] = { + "version": ["uv", "run", "codeflash", "--version"], + "help": ["uv", "run", "codeflash", "--help"], + "auth_status": ["uv", "run", "codeflash", "auth", "status"], + "compare_help": ["uv", "run", "codeflash", "compare", "--help"], +} + + +def measure(cmd: list[str], warmup: int = WARMUP, runs: int = RUNS) -> float: + """Return median wall-clock seconds for *cmd* over *runs* iterations.""" + env = {**os.environ, "CODEFLASH_API_KEY": "bench_dummy_key"} + for _ in range(warmup): + subprocess.run(cmd, capture_output=True, check=False, env=env) + + times: list[float] = [] + for _ in range(runs): + t0 = time.perf_counter() + subprocess.run(cmd, capture_output=True, check=False, env=env) + times.append(time.perf_counter() - t0) + + times.sort() + mid = len(times) // 2 + return times[mid] if len(times) % 2 else (times[mid - 1] + times[mid]) / 2 + + +def main() -> None: + # Ensure deps are installed in the worktree + subprocess.run(["uv", "sync"], check=True, capture_output=True) + + results: dict[str, float] = {} + for name, cmd in COMMANDS.items(): + print(f" {name}: ", end="", flush=True) + median = measure(cmd) + results[name] = round(median, 4) + print(f"{median * 1000:.0f} ms") + + # Total = sum of medians (useful for a single summary number) + results["__total__"] = round(sum(results.values()), 4) + + output_path = Path(OUTPUT) + output_path.parent.mkdir(parents=True, exist_ok=True) + with output_path.open("w") as f: + json.dump(results, f, indent=2) + print(f"\nResults written to {OUTPUT}") + + +if __name__ == "__main__": + main() diff --git a/codeflash/benchmarking/instrument_codeflash_trace.py b/codeflash/benchmarking/instrument_codeflash_trace.py index bc779dd9c..c5797ccad 100644 --- a/codeflash/benchmarking/instrument_codeflash_trace.py +++ b/codeflash/benchmarking/instrument_codeflash_trace.py @@ -4,6 +4,7 @@ import libcst as cst +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.code_utils.formatter import sort_imports if TYPE_CHECKING: diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index bbc26ea86..400403843 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -5,15 +5,6 @@ from functools import lru_cache from pathlib import Path -from codeflash.cli_cmds import logging_config -from codeflash.cli_cmds.console import apologize_and_exit, logger -from codeflash.code_utils import env_utils -from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths -from codeflash.code_utils.config_parser import parse_config_file -from codeflash.languages.test_framework import set_current_test_framework -from codeflash.lsp.helpers import is_LSP_enabled -from codeflash.version import __version__ as version - def parse_args() -> Namespace: parser = _build_parser() @@ -30,12 +21,17 @@ def parse_args() -> Namespace: def process_and_validate_cmd_args(args: Namespace) -> Namespace: + from codeflash.cli_cmds import logging_config + from codeflash.cli_cmds.console import logger + from codeflash.code_utils import env_utils + from codeflash.code_utils.code_utils import exit_with_message from codeflash.code_utils.git_utils import ( check_running_in_git_repo, confirm_proceeding_with_no_git_repo, get_repo_owner_and_name, ) from codeflash.code_utils.github_utils import require_github_app_or_exit + from codeflash.version import __version__ as version if args.server: os.environ["CODEFLASH_AIS_SERVER"] = args.server @@ -85,6 +81,12 @@ def process_and_validate_cmd_args(args: Namespace) -> Namespace: def process_pyproject_config(args: Namespace) -> Namespace: + from codeflash.code_utils import env_utils + from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths + from codeflash.code_utils.config_parser import parse_config_file + from codeflash.languages.test_framework import set_current_test_framework + from codeflash.lsp.helpers import is_LSP_enabled + try: pyproject_config, pyproject_file_path = parse_config_file(args.config_file) except ValueError as e: @@ -222,6 +224,9 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: + from codeflash.cli_cmds.console import apologize_and_exit, logger + from codeflash.code_utils.code_utils import exit_with_message + if hasattr(args, "all") or (hasattr(args, "file") and args.file): no_pr = getattr(args, "no_pr", False) diff --git a/codeflash/cli_cmds/cmd_auth.py b/codeflash/cli_cmds/cmd_auth.py index 96b863fec..148649116 100644 --- a/codeflash/cli_cmds/cmd_auth.py +++ b/codeflash/cli_cmds/cmd_auth.py @@ -2,17 +2,17 @@ import os -import click - -from codeflash.cli_cmds.console import console -from codeflash.cli_cmds.oauth_handler import perform_oauth_signin -from codeflash.code_utils.env_utils import get_codeflash_api_key -from codeflash.code_utils.shell_utils import save_api_key_to_rc -from codeflash.either import is_successful - def auth_login() -> None: """Perform OAuth login and save the API key.""" + import click + + from codeflash.cli_cmds.console import console + from codeflash.cli_cmds.oauth_handler import perform_oauth_signin + from codeflash.code_utils.env_utils import get_codeflash_api_key + from codeflash.code_utils.shell_utils import save_api_key_to_rc + from codeflash.either import is_successful + try: existing_api_key = get_codeflash_api_key() except OSError: @@ -41,6 +41,9 @@ def auth_login() -> None: def auth_status() -> None: """Check and display current authentication status.""" + from codeflash.cli_cmds.console import console + from codeflash.code_utils.env_utils import get_codeflash_api_key + try: api_key = get_codeflash_api_key() except OSError: diff --git a/codeflash/code_utils/_libcst_cache.py b/codeflash/code_utils/_libcst_cache.py new file mode 100644 index 000000000..0db7e258a --- /dev/null +++ b/codeflash/code_utils/_libcst_cache.py @@ -0,0 +1,64 @@ +"""Cache libcst visitor dispatch table construction. + +libcst's ``MatcherDecoratableTransformer`` and +``MatcherDecoratableVisitor`` rebuild visitor dispatch tables on +every instantiation by iterating ``dir(self)`` (~600 attributes) +and calling ``getattr`` + ``inspect.ismethod`` on each. The +results depend only on the class, not the instance, so caching +by ``type(obj)`` is safe. + +Import this module before any libcst visitors are instantiated +to install the cache. +""" + +from __future__ import annotations + +from typing import Any + +import libcst.matchers._visitors as _mv + +_visit_cache: dict[type, Any] = {} +_leave_cache: dict[type, Any] = {} +_matchers_cache: dict[type, Any] = {} + +_original_visit = _mv._gather_constructed_visit_funcs # noqa: SLF001 +_original_leave = _mv._gather_constructed_leave_funcs # noqa: SLF001 +_original_matchers = _mv._gather_matchers # noqa: SLF001 + + +def _cached_visit(obj: object) -> Any: + """Return cached visit-function dispatch table for the object's class.""" + cls = type(obj) + try: + return _visit_cache[cls] + except KeyError: + result = _original_visit(obj) + _visit_cache[cls] = result + return result + + +def _cached_leave(obj: object) -> Any: + """Return cached leave-function dispatch table for the object's class.""" + cls = type(obj) + try: + return _leave_cache[cls] + except KeyError: + result = _original_leave(obj) + _leave_cache[cls] = result + return result + + +def _cached_matchers(obj: object) -> Any: + """Return cached matcher dispatch table for the object's class.""" + cls = type(obj) + try: + return dict(_matchers_cache[cls]) + except KeyError: + result = _original_matchers(obj) + _matchers_cache[cls] = result + return dict(result) + + +_mv._gather_constructed_visit_funcs = _cached_visit # noqa: SLF001 +_mv._gather_constructed_leave_funcs = _cached_leave # noqa: SLF001 +_mv._gather_matchers = _cached_matchers # noqa: SLF001 diff --git a/codeflash/code_utils/env_utils.py b/codeflash/code_utils/env_utils.py index 03c7abef2..eb2cab1d9 100644 --- a/codeflash/code_utils/env_utils.py +++ b/codeflash/code_utils/env_utils.py @@ -9,17 +9,16 @@ from pathlib import Path from typing import Any, Optional -from codeflash.cli_cmds.console import logger -from codeflash.code_utils.code_utils import exit_with_message -from codeflash.code_utils.formatter import format_code from codeflash.code_utils.shell_utils import read_api_key_from_shell_config, save_api_key_to_rc -from codeflash.languages.registry import get_language_support_by_common_formatters -from codeflash.lsp.helpers import is_LSP_enabled def check_formatter_installed( formatter_cmds: list[str], exit_on_failure: bool = True, language: str = "python" ) -> bool: + from codeflash.cli_cmds.console import logger + from codeflash.code_utils.formatter import format_code + from codeflash.languages.registry import get_language_support_by_common_formatters + if not formatter_cmds or formatter_cmds[0] == "disabled": return True first_cmd = formatter_cmds[0] @@ -69,6 +68,8 @@ def check_formatter_installed( @lru_cache(maxsize=1) def get_codeflash_api_key() -> str: + from codeflash.cli_cmds.console import logger + # Check environment variable first env_api_key = os.environ.get("CODEFLASH_API_KEY") shell_api_key = read_api_key_from_shell_config() @@ -96,7 +97,8 @@ def get_codeflash_api_key() -> str: # Prefer the shell configuration over environment variables for lsp, # as the API key may change in the RC file during lsp runtime. Since the LSP client (extension) can restart # within the same process, the environment variable could become outdated. - api_key = shell_api_key or env_api_key if is_LSP_enabled() else env_api_key or shell_api_key + is_lsp = os.getenv("CODEFLASH_LSP", default="false").lower() == "true" + api_key = shell_api_key or env_api_key if is_lsp else env_api_key or shell_api_key api_secret_docs_message = "For more information, refer to the documentation at [https://docs.codeflash.ai/optimizing-with-codeflash/codeflash-github-actions#manual-setup]." # noqa if not api_key: @@ -106,6 +108,8 @@ def get_codeflash_api_key() -> str: f"{api_secret_docs_message}" ) if is_repo_a_fork(): + from codeflash.code_utils.code_utils import exit_with_message + msg = ( "Codeflash API key not detected in your environment. It appears you're running Codeflash from a GitHub fork.\n" "For external contributors, please ensure you've added your own API key to your fork's repository secrets and set it as the CODEFLASH_API_KEY environment variable.\n" @@ -124,6 +128,8 @@ def get_codeflash_api_key() -> str: def ensure_codeflash_api_key() -> bool: + from codeflash.cli_cmds.console import logger + try: get_codeflash_api_key() except OSError: diff --git a/codeflash/code_utils/instrument_existing_tests.py b/codeflash/code_utils/instrument_existing_tests.py index f15b2d56a..02450bb12 100644 --- a/codeflash/code_utils/instrument_existing_tests.py +++ b/codeflash/code_utils/instrument_existing_tests.py @@ -7,6 +7,7 @@ import libcst as cst +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import get_run_tmp_file, module_name_from_file_path from codeflash.code_utils.formatter import sort_imports diff --git a/codeflash/code_utils/shell_utils.py b/codeflash/code_utils/shell_utils.py index 1569b51a1..fa17045bf 100644 --- a/codeflash/code_utils/shell_utils.py +++ b/codeflash/code_utils/shell_utils.py @@ -8,7 +8,6 @@ from pathlib import Path from typing import TYPE_CHECKING, Optional -from codeflash.cli_cmds.console import logger from codeflash.code_utils.compat import LF from codeflash.either import Failure, Success @@ -41,6 +40,8 @@ def is_powershell() -> bool: 2. COMSPEC pointing to powershell.exe 3. TERM_PROGRAM indicating Windows Terminal (often uses PowerShell) """ + from codeflash.cli_cmds.console import logger + if os.name != "nt": return False @@ -72,6 +73,8 @@ def is_powershell() -> bool: def read_api_key_from_shell_config() -> Optional[str]: """Read API key from shell configuration file.""" + from codeflash.cli_cmds.console import logger + shell_rc_path = get_shell_rc_path() # Ensure shell_rc_path is a Path object for consistent handling if not isinstance(shell_rc_path, Path): @@ -127,6 +130,8 @@ def get_api_key_export_line(api_key: str) -> str: def save_api_key_to_rc(api_key: str) -> Result[str, str]: """Save API key to the appropriate shell configuration file.""" + from codeflash.cli_cmds.console import logger + shell_rc_path = get_shell_rc_path() # Ensure shell_rc_path is a Path object for consistent handling if not isinstance(shell_rc_path, Path): diff --git a/codeflash/languages/function_optimizer.py b/codeflash/languages/function_optimizer.py index 7a5322857..d9b4918fd 100644 --- a/codeflash/languages/function_optimizer.py +++ b/codeflash/languages/function_optimizer.py @@ -20,6 +20,7 @@ from rich.text import Text from rich.tree import Tree +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.api.aiservice import AiServiceClient, AIServiceRefinerRequest, LocalAiServiceClient from codeflash.api.cfapi import add_code_context_hash, create_staging, get_cfapi_base_urls, mark_optimization_success from codeflash.benchmarking.utils import process_benchmark_data diff --git a/codeflash/languages/python/context/code_context_extractor.py b/codeflash/languages/python/context/code_context_extractor.py index 4fdbd5291..67abdae6b 100644 --- a/codeflash/languages/python/context/code_context_extractor.py +++ b/codeflash/languages/python/context/code_context_extractor.py @@ -11,6 +11,7 @@ import libcst as cst +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.cli_cmds.console import logger from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages from codeflash.code_utils.config_consts import ( diff --git a/codeflash/languages/python/context/unused_definition_remover.py b/codeflash/languages/python/context/unused_definition_remover.py index 575797b3d..355a23528 100644 --- a/codeflash/languages/python/context/unused_definition_remover.py +++ b/codeflash/languages/python/context/unused_definition_remover.py @@ -9,6 +9,7 @@ import libcst as cst +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.cli_cmds.console import logger from codeflash.languages import current_language from codeflash.languages.base import Language diff --git a/codeflash/languages/python/static_analysis/code_extractor.py b/codeflash/languages/python/static_analysis/code_extractor.py index 899ee438f..132e8f078 100644 --- a/codeflash/languages/python/static_analysis/code_extractor.py +++ b/codeflash/languages/python/static_analysis/code_extractor.py @@ -11,6 +11,7 @@ from libcst.codemod.visitors import AddImportsVisitor, GatherImportsVisitor, RemoveImportsVisitor from libcst.helpers import calculate_module_and_package +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_consts import MAX_CONTEXT_LEN_REVIEW from codeflash.languages.base import Language diff --git a/codeflash/languages/python/static_analysis/code_replacer.py b/codeflash/languages/python/static_analysis/code_replacer.py index 89dc2751e..2383f0930 100644 --- a/codeflash/languages/python/static_analysis/code_replacer.py +++ b/codeflash/languages/python/static_analysis/code_replacer.py @@ -9,6 +9,7 @@ import libcst as cst from libcst.metadata import PositionProvider +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.cli_cmds.console import logger from codeflash.code_utils.config_parser import find_conftest_files from codeflash.code_utils.formatter import sort_imports diff --git a/codeflash/languages/python/static_analysis/edit_generated_tests.py b/codeflash/languages/python/static_analysis/edit_generated_tests.py index c4aed07de..6ee1e06a0 100644 --- a/codeflash/languages/python/static_analysis/edit_generated_tests.py +++ b/codeflash/languages/python/static_analysis/edit_generated_tests.py @@ -10,6 +10,7 @@ from libcst import MetadataWrapper from libcst.metadata import PositionProvider +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.cli_cmds.console import logger from codeflash.code_utils.time_utils import format_perf, format_time from codeflash.models.models import GeneratedTests, GeneratedTestsList diff --git a/codeflash/languages/python/static_analysis/line_profile_utils.py b/codeflash/languages/python/static_analysis/line_profile_utils.py index 93997b2c6..b7857e161 100644 --- a/codeflash/languages/python/static_analysis/line_profile_utils.py +++ b/codeflash/languages/python/static_analysis/line_profile_utils.py @@ -9,6 +9,7 @@ import libcst as cst +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.code_utils.code_utils import get_run_tmp_file from codeflash.code_utils.formatter import sort_imports diff --git a/codeflash/languages/python/support.py b/codeflash/languages/python/support.py index 34f0527b2..356a3d216 100644 --- a/codeflash/languages/python/support.py +++ b/codeflash/languages/python/support.py @@ -9,6 +9,7 @@ import libcst as cst +import codeflash.code_utils._libcst_cache # noqa: F401 from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.base import ( CodeContext, diff --git a/codeflash/main.py b/codeflash/main.py index 0d4ba1ca5..5193fd736 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -29,9 +29,36 @@ def main() -> None: print(f"Codeflash version {__version__}") return + from codeflash.cli_cmds.cli import parse_args + + args = parse_args() + + # Auth commands skip banner, telemetry, and version check entirely + if args.command == "auth": + from codeflash.cli_cmds.cmd_auth import auth_login, auth_status + + if args.auth_command == "login": + auth_login() + elif args.auth_command == "status": + auth_status() + else: + from codeflash.code_utils.code_utils import exit_with_message + + exit_with_message("Usage: codeflash auth {login,status}", error_on_exit=True) + return + + # Compare command only needs its own imports + if args.command == "compare": + print_codeflash_banner() + from codeflash.cli_cmds.cmd_compare import run_compare + + run_compare(args) + return + + # All other commands need the full stack from pathlib import Path - from codeflash.cli_cmds.cli import parse_args, process_pyproject_config + from codeflash.cli_cmds.cli import process_pyproject_config from codeflash.code_utils import env_utils from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions from codeflash.code_utils.config_parser import parse_config_file @@ -39,11 +66,7 @@ def main() -> None: from codeflash.telemetry import posthog_cf from codeflash.telemetry.sentry import init_sentry - args = parse_args() - if args.command != "auth": - print_codeflash_banner() - - # Check for newer version for all commands + print_codeflash_banner() check_for_newer_minor_version() if args.command: @@ -54,18 +77,7 @@ def main() -> None: init_sentry(enabled=not disable_telemetry, exclude_errors=True) posthog_cf.initialize_posthog(enabled=not disable_telemetry) - if args.command == "auth": - from codeflash.cli_cmds.cmd_auth import auth_login, auth_status - - if args.auth_command == "login": - auth_login() - elif args.auth_command == "status": - auth_status() - else: - from codeflash.code_utils.code_utils import exit_with_message - - exit_with_message("Usage: codeflash auth {login,status}", error_on_exit=True) - elif args.command == "init": + if args.command == "init": from codeflash.cli_cmds.cmd_init import init_codeflash init_codeflash() @@ -77,10 +89,6 @@ def main() -> None: from codeflash.cli_cmds.extension import install_vscode_extension install_vscode_extension() - elif args.command == "compare": - from codeflash.cli_cmds.cmd_compare import run_compare - - run_compare(args) elif args.command == "optimize": from codeflash.tracer import main as tracer_main diff --git a/codeflash/models/models.py b/codeflash/models/models.py index 8ac873b70..640e5230a 100644 --- a/codeflash/models/models.py +++ b/codeflash/models/models.py @@ -1,37 +1,26 @@ from __future__ import annotations -from collections import Counter, defaultdict -from functools import lru_cache -from typing import TYPE_CHECKING - -import libcst as cst -from rich.tree import Tree - -from codeflash.cli_cmds.console import DEBUG_MODE, lsp_log -from codeflash.languages.registry import get_language_support -from codeflash.lsp.helpers import is_LSP_enabled, report_to_markdown_table -from codeflash.lsp.lsp_message import LspMarkdownMessage -from codeflash.models.test_type import TestType - -if TYPE_CHECKING: - from collections.abc import Iterator - import enum import re import sys +from collections import Counter, defaultdict from collections.abc import Collection from enum import Enum, IntEnum +from functools import lru_cache from pathlib import Path from re import Pattern -from typing import Any, NamedTuple, Optional, cast +from typing import TYPE_CHECKING, Any, NamedTuple, Optional, cast from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, ValidationError, model_validator from pydantic.dataclasses import dataclass -from codeflash.cli_cmds.console import console, logger -from codeflash.code_utils.code_utils import diff_length, module_name_from_file_path, validate_python_code -from codeflash.code_utils.env_utils import is_end_to_end -from codeflash.verification.comparator import comparator +from codeflash.models.test_type import TestType + +if TYPE_CHECKING: + from collections.abc import Iterator + + import libcst as cst + from rich.tree import Tree @dataclass(frozen=True) @@ -254,6 +243,8 @@ class CodeString(BaseModel): def validate_code_syntax(self) -> CodeString: """Validate code syntax for the specified language.""" if self.language == "python": + from codeflash.code_utils.code_utils import validate_python_code + validate_python_code(self.code) else: from codeflash.languages.registry import get_language_support @@ -267,6 +258,8 @@ def validate_code_syntax(self) -> CodeString: def get_comment_prefix(file_path: Path) -> str: """Get the comment prefix for a given language.""" + from codeflash.languages.registry import get_language_support + support = get_language_support(file_path) return support.comment_prefix @@ -565,6 +558,8 @@ def handle_duplicate_candidate( self.optimizations_post[past_opt_id] = self.ast_code_to_id[normalized_code]["shorter_source_code"].markdown # Update to shorter code if this candidate has a shorter diff + from codeflash.code_utils.code_utils import diff_length + new_diff_len = diff_length(candidate.source_code.flat, original_flat_code) if new_diff_len < self.ast_code_to_id[normalized_code]["diff_len"]: self.ast_code_to_id[normalized_code]["shorter_source_code"] = candidate.source_code @@ -574,6 +569,8 @@ def register_new_candidate( self, normalized_code: str, candidate: OptimizedCandidate, original_flat_code: str ) -> None: """Register a new candidate that hasn't been seen before.""" + from codeflash.code_utils.code_utils import diff_length + self.ast_code_to_id[normalized_code] = { "optimization_id": candidate.optimization_id, "shorter_source_code": candidate.source_code, @@ -670,6 +667,9 @@ def build_message(self) -> str: def log_coverage(self) -> None: from rich.tree import Tree + from codeflash.cli_cmds.console import console, logger + from codeflash.code_utils.env_utils import is_end_to_end + tree = Tree("Test Coverage Results") tree.add(f"Main Function: {self.main_func_coverage.name}: {self.coverage:.2f}%") if self.dependent_func_coverage: @@ -769,12 +769,16 @@ def test_fn_qualified_name(self) -> str: ) def find_func_in_class(self, class_node: cst.ClassDef, func_name: str) -> Optional[cst.FunctionDef]: + import libcst as cst + for stmt in class_node.body.body: if isinstance(stmt, cst.FunctionDef) and stmt.name.value == func_name: return stmt return None def get_src_code(self, test_path: Path) -> Optional[str]: + import libcst as cst + if not test_path.exists(): return None try: @@ -856,6 +860,8 @@ def add(self, function_test_invocation: FunctionTestInvocation) -> None: unique_id = function_test_invocation.unique_invocation_loop_id test_result_idx = self.test_result_idx if unique_id in test_result_idx: + from codeflash.cli_cmds.console import DEBUG_MODE, logger + if DEBUG_MODE: logger.warning(f"Test result with id {unique_id} already exists. SKIPPING") return @@ -876,6 +882,8 @@ def group_by_benchmarks( self, benchmark_keys: list[BenchmarkKey], benchmark_replay_test_dir: Path, project_root: Path ) -> dict[BenchmarkKey, TestResults]: """Group TestResults by benchmark for calculating improvements for each benchmark.""" + from codeflash.code_utils.code_utils import module_name_from_file_path + test_results_by_benchmark = defaultdict(TestResults) benchmark_module_path = {} for benchmark_key in benchmark_keys: @@ -929,9 +937,17 @@ def report_to_string(report: dict[TestType, dict[str, int]]) -> str: @staticmethod def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree: + from rich.tree import Tree + + from codeflash.lsp.helpers import is_LSP_enabled + tree = Tree(title) if is_LSP_enabled(): + from codeflash.cli_cmds.console import lsp_log + from codeflash.lsp.helpers import report_to_markdown_table + from codeflash.lsp.lsp_message import LspMarkdownMessage + # Build markdown table markdown = report_to_markdown_table(report, title) lsp_log(LspMarkdownMessage(markdown=markdown)) @@ -946,6 +962,8 @@ def report_to_tree(report: dict[TestType, dict[str, int]], title: str) -> Tree: return tree def usable_runtime_data_by_test_case(self) -> dict[InvocationId, list[int]]: + from codeflash.cli_cmds.console import logger + # Efficient single traversal, directly accumulating into a dict. # can track mins here and only sums can be return in total_passed_runtime by_id: dict[InvocationId, list[int]] = {} @@ -1025,6 +1043,8 @@ def __bool__(self) -> bool: return bool(self.test_results) def __eq__(self, other: object) -> bool: + from codeflash.verification.comparator import comparator + # Unordered comparison if type(self) is not type(other): return False diff --git a/codeflash/verification/comparator.py b/codeflash/verification/comparator.py index 5b7c18825..36b98e6e2 100644 --- a/codeflash/verification/comparator.py +++ b/codeflash/verification/comparator.py @@ -74,6 +74,27 @@ _DICT_VALUES_TYPE = type({}.values()) _DICT_ITEMS_TYPE = type({}.items()) +_IDENTITY_EQ_TYPES: frozenset[type[Any]] = frozenset( + { + int, + bool, + complex, + type(None), + type(Ellipsis), + decimal.Decimal, + set, + bytes, + bytearray, + memoryview, + frozenset, + type, + range, + slice, + OrderedDict, + types.GenericAlias, + } +) + _EQUALITY_TYPES = ( int, bool, @@ -184,32 +205,61 @@ def comparator(orig: Any, new: Any, superset_obj: bool = False) -> bool: return False - if type(orig) is not type(new): - type_obj = type(orig) - new_type_obj = type(new) + orig_type = type(orig) + if orig_type is not type(new): # distinct type objects are created at runtime, even if the class code is exactly the same, so we can only compare the names - if type_obj.__name__ != new_type_obj.__name__ or type_obj.__qualname__ != new_type_obj.__qualname__: + if orig_type.__name__ != type(new).__name__ or orig_type.__qualname__ != type(new).__qualname__: + return False + + # Fast-path: type identity checks for the most common return-value types. + # `orig_type is T` is a single pointer comparison — cheaper than frozenset hash + # lookup or isinstance MRO traversal — and these 4 types dominate real workloads. + if orig_type is str: + if orig == new: + return True + if _is_temp_path(orig) and _is_temp_path(new): + return _normalize_temp_path(orig) == _normalize_temp_path(new) + return False + if orig_type is list or orig_type is tuple: + if len(orig) != len(new): + return False + return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) + if orig_type is dict: + if superset_obj: + return all(k in new and comparator(v, new[k], superset_obj) for k, v in orig.items()) + if len(orig) != len(new): return False + for key in orig: + if key not in new: + return False + if not comparator(orig[key], new[key], superset_obj): + return False + return True + if orig_type is float: + if math.isnan(orig) and math.isnan(new): + return True + return math.isclose(orig, new) + # O(1) frozenset lookup for remaining common types (int, bool, None, Decimal, etc.) + if orig_type in _IDENTITY_EQ_TYPES: + return orig == new + + # Slower isinstance path for subclasses (deque, ChainMap, etc.) if isinstance(orig, (list, tuple, deque, ChainMap)): if len(orig) != len(new): return False return all(comparator(elem1, elem2, superset_obj) for elem1, elem2 in zip(orig, new)) - # Handle strings separately to normalize temp paths + # Handle string subclasses separately to normalize temp paths if isinstance(orig, str): if orig == new: return True - # If strings differ, check if they're temp paths that differ only in session number if _is_temp_path(orig) and _is_temp_path(new): return _normalize_temp_path(orig) == _normalize_temp_path(new) return False + # enum.Enum subclasses and UnionType fall through from the frozenset fast-path if isinstance(orig, _EQUALITY_TYPES): return orig == new - if isinstance(orig, float): - if math.isnan(orig) and math.isnan(new): - return True - return math.isclose(orig, new) # Handle weak references (e.g., found in torch.nn.LSTM/GRU modules) if isinstance(orig, weakref.ref): diff --git a/tests/benchmarks/__init__.py b/tests/benchmarks/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/benchmarks/test_benchmark_code_extract_code_context.py b/tests/benchmarks/test_benchmark_code_extract_code_context.py index 4fe06b14d..81b4eaa53 100644 --- a/tests/benchmarks/test_benchmark_code_extract_code_context.py +++ b/tests/benchmarks/test_benchmark_code_extract_code_context.py @@ -1,31 +1,18 @@ -from argparse import Namespace from pathlib import Path from codeflash.discovery.functions_to_optimize import FunctionToOptimize from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context from codeflash.models.models import FunctionParent -from codeflash.optimization.optimizer import Optimizer def test_benchmark_extract(benchmark) -> None: - file_path = Path(__file__).parent.parent.parent.resolve() / "codeflash" - opt = Optimizer( - Namespace( - project_root=file_path.resolve(), - disable_telemetry=True, - tests_root=(file_path / "tests").resolve(), - test_framework="pytest", - pytest_cmd="pytest", - experiment_id=None, - test_project_root=Path.cwd(), - ) - ) + project_root = Path(__file__).parent.parent.parent.resolve() / "codeflash" function_to_optimize = FunctionToOptimize( function_name="replace_function_and_helpers_with_optimized_code", - file_path=file_path / "languages" / "function_optimizer.py", + file_path=project_root / "languages" / "function_optimizer.py", parents=[FunctionParent(name="FunctionOptimizer", type="ClassDef")], starting_line=None, ending_line=None, ) - benchmark(get_code_optimization_context, function_to_optimize, opt.args.project_root) + benchmark(get_code_optimization_context, function_to_optimize, project_root) diff --git a/tests/benchmarks/test_benchmark_comparator.py b/tests/benchmarks/test_benchmark_comparator.py new file mode 100644 index 000000000..71576c370 --- /dev/null +++ b/tests/benchmarks/test_benchmark_comparator.py @@ -0,0 +1,133 @@ +"""Benchmark comparator type dispatch performance. + +Exercises the fast-path frozenset lookup vs isinstance MRO traversal +across realistic return value shapes: primitives, nested containers, +and mixed-type structures typical of real optimization verification. +""" + +from __future__ import annotations + +from collections import OrderedDict +from decimal import Decimal + +from codeflash.verification.comparator import comparator + +# --- Test data: realistic return value shapes --- + +# 1. Flat primitives (int, bool, None, str, float, bytes) — the fast-path sweet spot +_PRIMITIVES_A = [ + 42, + True, + None, + 3.14, + "hello", + b"bytes", + 0, + False, + "", + 1.0, + -1, + None, + True, + 99, + "world", + b"\x00\x01", + 2**31, + 0.0, + False, + None, +] +_PRIMITIVES_B = list(_PRIMITIVES_A) + +# 2. Nested dict of lists (common return value shape: API responses, parsed configs) +_NESTED_DICT_A = { + "users": [{"id": i, "name": f"user_{i}", "active": i % 2 == 0, "score": i * 1.5} for i in range(50)], + "metadata": {"total": 50, "page": 1, "has_next": True}, + "tags": [f"tag_{i}" for i in range(20)], + "config": {"timeout": 30, "retries": 3, "debug": False, "threshold": Decimal("0.95")}, +} +_NESTED_DICT_B = { + "users": [{"id": i, "name": f"user_{i}", "active": i % 2 == 0, "score": i * 1.5} for i in range(50)], + "metadata": {"total": 50, "page": 1, "has_next": True}, + "tags": [f"tag_{i}" for i in range(20)], + "config": {"timeout": 30, "retries": 3, "debug": False, "threshold": Decimal("0.95")}, +} + +# 3. List of tuples (common: database rows, CSV data) +_ROWS_A = [(i, f"row_{i}", i * 0.1, i % 3 == 0, None if i % 5 == 0 else i) for i in range(200)] +_ROWS_B = [(i, f"row_{i}", i * 0.1, i % 3 == 0, None if i % 5 == 0 else i) for i in range(200)] + + +# 4. Deeply nested structure (worst case for recursive comparator) +def _make_deep(depth: int) -> dict: + if depth == 0: + return {"leaf": True, "value": 42, "items": [1, 2, 3], "label": "end"} + return {"level": depth, "child": _make_deep(depth - 1), "siblings": list(range(depth))} + + +_DEEP_A = _make_deep(15) +_DEEP_B = _make_deep(15) + +# 5. Mixed identity types (frozenset, range, slice, OrderedDict, bytes, complex) +_IDENTITY_TYPES_A = [ + frozenset({1, 2, 3}), + range(100), + complex(1, 2), + Decimal("3.14"), + OrderedDict(a=1, b=2), + b"binary", + bytearray(b"mutable"), + memoryview(b"view"), + type(None), + True, + 42, + None, +] * 10 +_IDENTITY_TYPES_B = list(_IDENTITY_TYPES_A) + + +def _compare_all_primitives() -> None: + for a, b in zip(_PRIMITIVES_A, _PRIMITIVES_B): + comparator(a, b) + + +def _compare_nested_dict() -> None: + comparator(_NESTED_DICT_A, _NESTED_DICT_B) + + +def _compare_rows() -> None: + comparator(_ROWS_A, _ROWS_B) + + +def _compare_deep() -> None: + comparator(_DEEP_A, _DEEP_B) + + +def _compare_identity_types() -> None: + for a, b in zip(_IDENTITY_TYPES_A, _IDENTITY_TYPES_B): + comparator(a, b) + + +def test_benchmark_comparator_primitives(benchmark) -> None: + """20 flat primitive comparisons (int, bool, None, str, float, bytes).""" + benchmark(_compare_all_primitives) + + +def test_benchmark_comparator_nested_dict(benchmark) -> None: + """Nested dict with 50-element user list, metadata, tags, config.""" + benchmark(_compare_nested_dict) + + +def test_benchmark_comparator_rows(benchmark) -> None: + """200 tuples of (int, str, float, bool, Optional[int]).""" + benchmark(_compare_rows) + + +def test_benchmark_comparator_deep(benchmark) -> None: + """15-level deep nested dict structure.""" + benchmark(_compare_deep) + + +def test_benchmark_comparator_identity_types(benchmark) -> None: + """120 frozenset/range/complex/Decimal/OrderedDict/bytes comparisons.""" + benchmark(_compare_identity_types) diff --git a/tests/benchmarks/test_benchmark_libcst_multi_file.py b/tests/benchmarks/test_benchmark_libcst_multi_file.py new file mode 100644 index 000000000..d9faf1722 --- /dev/null +++ b/tests/benchmarks/test_benchmark_libcst_multi_file.py @@ -0,0 +1,75 @@ +"""Benchmark libcst visitor performance across many files. + +Exercises the visitor-heavy codepaths that benefit from the libcst dispatch +table cache: discover_functions + get_code_optimization_context on multiple +real source files. +""" + +from __future__ import annotations + +from pathlib import Path + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context +from codeflash.languages.python.support import PythonSupport +from codeflash.models.models import FunctionParent + +# Real source files from the codeflash codebase, chosen for size and visitor diversity. +_CODEFLASH_ROOT = Path(__file__).parent.parent.parent.resolve() / "codeflash" + +_SOURCE_FILES: list[Path] = [ + _CODEFLASH_ROOT / "languages" / "function_optimizer.py", + _CODEFLASH_ROOT / "languages" / "python" / "context" / "code_context_extractor.py", + _CODEFLASH_ROOT / "languages" / "python" / "support.py", + _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_extractor.py", + _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_replacer.py", + _CODEFLASH_ROOT / "code_utils" / "instrument_existing_tests.py", + _CODEFLASH_ROOT / "benchmarking" / "compare.py", + _CODEFLASH_ROOT / "models" / "models.py", + _CODEFLASH_ROOT / "discovery" / "discover_unit_tests.py", + _CODEFLASH_ROOT / "languages" / "base.py", +] + +# For each file, pick one top-level function to extract context for. +# (class, function_name) — class=None means module-level. +_TARGETS: list[tuple[Path, str | None, str]] = [ + (_SOURCE_FILES[0], "FunctionOptimizer", "replace_function_and_helpers_with_optimized_code"), + (_SOURCE_FILES[1], None, "get_code_optimization_context"), + (_SOURCE_FILES[2], "PythonSupport", "discover_functions"), + (_SOURCE_FILES[3], None, "add_global_assignments"), + (_SOURCE_FILES[4], None, "replace_functions_in_file"), + (_SOURCE_FILES[5], None, "inject_profiling_into_existing_test"), + (_SOURCE_FILES[6], None, "compare_branches"), + (_SOURCE_FILES[7], None, "get_comment_prefix"), + (_SOURCE_FILES[8], None, "discover_unit_tests"), + (_SOURCE_FILES[9], None, "convert_parents_to_tuple"), +] + + +def _discover_all() -> None: + """Run discover_functions on all source files.""" + ps = PythonSupport() + for file_path in _SOURCE_FILES: + source = file_path.read_text(encoding="utf-8") + ps.discover_functions(source=source, file_path=file_path) + + +def _extract_all_contexts() -> None: + """Run get_code_optimization_context on every target function.""" + project_root = _CODEFLASH_ROOT.parent + for file_path, class_name, func_name in _TARGETS: + parents = [FunctionParent(name=class_name, type="ClassDef")] if class_name else [] + fto = FunctionToOptimize( + function_name=func_name, file_path=file_path, parents=parents, starting_line=None, ending_line=None + ) + get_code_optimization_context(fto, project_root) + + +def test_benchmark_discover_functions_multi_file(benchmark) -> None: + """Discover functions across 10 source files.""" + benchmark(_discover_all) + + +def test_benchmark_extract_context_multi_file(benchmark) -> None: + """Extract code optimization context for 10 functions across 10 files.""" + benchmark(_extract_all_contexts) diff --git a/tests/benchmarks/test_benchmark_libcst_pipeline.py b/tests/benchmarks/test_benchmark_libcst_pipeline.py new file mode 100644 index 000000000..eca0a7e3f --- /dev/null +++ b/tests/benchmarks/test_benchmark_libcst_pipeline.py @@ -0,0 +1,56 @@ +"""Benchmark the full libcst-heavy pipeline on a single file. + +Runs discover → extract context → replace functions → add global assignments +in sequence, exercising ~15 distinct visitor/transformer classes in one pass. +""" + +from __future__ import annotations + +from pathlib import Path + +from codeflash.discovery.functions_to_optimize import FunctionToOptimize +from codeflash.languages.python.context.code_context_extractor import get_code_optimization_context +from codeflash.languages.python.static_analysis.code_extractor import add_global_assignments +from codeflash.languages.python.static_analysis.code_replacer import replace_functions_in_file +from codeflash.languages.python.support import PythonSupport + +_CODEFLASH_ROOT = Path(__file__).parent.parent.parent.resolve() / "codeflash" +_PROJECT_ROOT = _CODEFLASH_ROOT.parent + +# Target: a real, non-trivial file with classes and module-level functions. +_TARGET_FILE = _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_extractor.py" +_TARGET_FUNC = "add_global_assignments" + +# A second file to serve as "optimized" source for replace/merge steps. +_SECOND_FILE = _CODEFLASH_ROOT / "languages" / "python" / "static_analysis" / "code_replacer.py" + + +def _run_pipeline() -> None: + """Simulate a single-file optimization pass through the full visitor pipeline.""" + source = _TARGET_FILE.read_text(encoding="utf-8") + source2 = _SECOND_FILE.read_text(encoding="utf-8") + + # 1. Discover functions (FunctionVisitor + MetadataWrapper) + ps = PythonSupport() + functions = ps.discover_functions(source=source, file_path=_TARGET_FILE) + + # 2. Extract code optimization context (multiple collectors + dependency resolver) + fto = FunctionToOptimize( + function_name=_TARGET_FUNC, file_path=_TARGET_FILE, parents=[], starting_line=None, ending_line=None + ) + get_code_optimization_context(fto, _PROJECT_ROOT) + + # 3. Replace functions (GlobalFunctionCollector + GlobalFunctionTransformer) + # Use a class method from discovered functions if available, else module-level. + func_names = [_TARGET_FUNC] + replace_functions_in_file( + source_code=source, original_function_names=func_names, optimized_code=source2, preexisting_objects=set() + ) + + # 4. Add global assignments (6 visitors/transformers) + add_global_assignments(source2, source) + + +def test_benchmark_full_pipeline(benchmark) -> None: + """Full discover → extract → replace → merge pipeline on one file.""" + benchmark(_run_pipeline) diff --git a/tests/test_cmd_auth.py b/tests/test_cmd_auth.py index d12cecf58..7ad156c0b 100644 --- a/tests/test_cmd_auth.py +++ b/tests/test_cmd_auth.py @@ -9,8 +9,8 @@ class TestAuthLogin: - @patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key") - @patch("codeflash.cli_cmds.cmd_auth.console") + @patch("codeflash.code_utils.env_utils.get_codeflash_api_key") + @patch("codeflash.cli_cmds.console.console") def test_existing_api_key_skips_oauth(self, mock_console: MagicMock, mock_get_key: MagicMock) -> None: mock_get_key.return_value = "cf-test1234abcd" @@ -21,19 +21,19 @@ def test_existing_api_key_skips_oauth(self, mock_console: MagicMock, mock_get_ke "To re-authenticate, unset [bold]CODEFLASH_API_KEY[/bold] and run this command again." ) - @patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key") - @patch("codeflash.cli_cmds.cmd_auth.console") + @patch("codeflash.code_utils.env_utils.get_codeflash_api_key") + @patch("codeflash.cli_cmds.console.console") def test_existing_api_key_oserror_treated_as_missing( self, mock_console: MagicMock, mock_get_key: MagicMock ) -> None: mock_get_key.side_effect = OSError("permission denied") with pytest.raises(SystemExit): - with patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin", return_value=None): + with patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin", return_value=None): auth_login() - @patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin") - @patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="") + @patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin") + @patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="") def test_oauth_failure_exits_with_code_1(self, mock_get_key: MagicMock, mock_oauth: MagicMock) -> None: mock_oauth.return_value = None @@ -41,10 +41,10 @@ def test_oauth_failure_exits_with_code_1(self, mock_get_key: MagicMock, mock_oau auth_login() @patch("codeflash.cli_cmds.cmd_auth.os") - @patch("codeflash.cli_cmds.cmd_auth.save_api_key_to_rc") - @patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin") - @patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="") - @patch("codeflash.cli_cmds.cmd_auth.console") + @patch("codeflash.code_utils.shell_utils.save_api_key_to_rc") + @patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin") + @patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="") + @patch("codeflash.cli_cmds.console.console") def test_successful_oauth_saves_key( self, mock_console: MagicMock, @@ -63,10 +63,10 @@ def test_successful_oauth_saves_key( mock_console.print.assert_called_with("[green]Signed in successfully![/green]") @patch("codeflash.cli_cmds.cmd_auth.os") - @patch("codeflash.cli_cmds.cmd_auth.save_api_key_to_rc") - @patch("codeflash.cli_cmds.cmd_auth.perform_oauth_signin") - @patch("codeflash.cli_cmds.cmd_auth.get_codeflash_api_key", return_value="") - @patch("codeflash.cli_cmds.cmd_auth.console") + @patch("codeflash.code_utils.shell_utils.save_api_key_to_rc") + @patch("codeflash.cli_cmds.oauth_handler.perform_oauth_signin") + @patch("codeflash.code_utils.env_utils.get_codeflash_api_key", return_value="") + @patch("codeflash.cli_cmds.console.console") def test_windows_oauth_saves_key( self, mock_console: MagicMock,