Skip to content

Commit 51c08ba

Browse files
committed
refactor: replace is_python()/is_javascript() guards with protocol dispatch
- Optimizer: push dispatch into LanguageSupport protocol methods - Verification: move pytest execution into PythonSupport, callers invoke protocol directly instead of test_runner.py wrappers - Simple sites: concolic_testing, unused_definition_remover, test_framework, deduplicate_code
1 parent 3e72ebc commit 51c08ba

13 files changed

Lines changed: 346 additions & 385 deletions

File tree

codeflash/code_utils/deduplicate_code.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import re
1111

1212
from codeflash.code_utils.normalizers import get_normalizer
13-
from codeflash.languages import current_language, is_python
13+
from codeflash.languages import current_language
14+
from codeflash.languages.base import Language
1415

1516

1617
def normalize_code(
@@ -37,7 +38,7 @@ def normalize_code(
3738
normalizer = get_normalizer(language)
3839

3940
# Python has additional options
40-
if is_python():
41+
if language == Language.PYTHON:
4142
if return_ast_dump:
4243
return normalizer.normalize_for_hash(code)
4344
return normalizer.normalize(code, remove_docstrings=remove_docstrings)
@@ -49,7 +50,7 @@ def normalize_code(
4950
return _basic_normalize(code)
5051
except Exception:
5152
# Parsing error - try other languages or fall back
52-
if is_python():
53+
if language == Language.PYTHON:
5354
# Try JavaScript as fallback
5455
try:
5556
js_normalizer = get_normalizer("javascript")

codeflash/languages/base.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
1212

1313
if TYPE_CHECKING:
14+
import ast
1415
from collections.abc import Callable, Iterable, Sequence
1516
from pathlib import Path
1617

1718
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
18-
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
19+
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
20+
from codeflash.verification.verification_utils import TestConfig
1921

2022
from codeflash.languages.language_enum import Language
2123
from codeflash.models.function_types import FunctionParent
@@ -637,6 +639,22 @@ def compare_test_results(
637639
"""
638640
...
639641

642+
@property
643+
def function_optimizer_class(self) -> type:
644+
"""Return the FunctionOptimizer subclass for this language."""
645+
from codeflash.optimization.function_optimizer import FunctionOptimizer
646+
647+
return FunctionOptimizer
648+
649+
def prepare_module(
650+
self, module_code: str, module_path: Path, project_root: Path
651+
) -> tuple[dict[Path, ValidCode], ast.Module | None] | None:
652+
"""Parse/validate a module before optimization."""
653+
...
654+
655+
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
656+
"""One-time project setup after language detection. Default: no-op."""
657+
640658
# === Configuration ===
641659

642660
def get_test_file_suffix(self) -> str:
@@ -788,6 +806,31 @@ def run_benchmarking_tests(
788806
"""
789807
...
790808

809+
def run_line_profile_tests(
810+
self,
811+
test_paths: Any,
812+
test_env: dict[str, str],
813+
cwd: Path,
814+
timeout: int | None = None,
815+
project_root: Path | None = None,
816+
line_profile_output_file: Path | None = None,
817+
) -> tuple[Path, Any]:
818+
"""Run tests for line profiling.
819+
820+
Args:
821+
test_paths: TestFiles object containing test file information.
822+
test_env: Environment variables for the test run.
823+
cwd: Working directory for running tests.
824+
timeout: Optional timeout in seconds.
825+
project_root: Project root directory.
826+
line_profile_output_file: Path where line profile results will be written.
827+
828+
Returns:
829+
Tuple of (result_file_path, subprocess_result).
830+
831+
"""
832+
...
833+
791834

792835
def convert_parents_to_tuple(parents: list | tuple) -> tuple[FunctionParent, ...]:
793836
"""Convert a list of parent objects to a tuple of FunctionParent.

codeflash/languages/javascript/support.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
from codeflash.languages.base import ReferenceInfo
2525
from codeflash.languages.javascript.treesitter import TypeDefinition
26-
from codeflash.models.models import GeneratedTestsList, InvocationId
26+
from codeflash.models.models import GeneratedTestsList, InvocationId, ValidCode
27+
from codeflash.verification.verification_utils import TestConfig
2728

2829
logger = logging.getLogger(__name__)
2930

@@ -1909,6 +1910,25 @@ def compare_test_results(
19091910

19101911
return compare_test_results(original_results_path, candidate_results_path, project_root=project_root)
19111912

1913+
@property
1914+
def function_optimizer_class(self) -> type:
1915+
from codeflash.languages.javascript.function_optimizer import JavaScriptFunctionOptimizer
1916+
1917+
return JavaScriptFunctionOptimizer
1918+
1919+
def prepare_module(
1920+
self, module_code: str, module_path: Path, project_root: Path
1921+
) -> tuple[dict[Path, ValidCode], None]:
1922+
from codeflash.languages.javascript.optimizer import prepare_javascript_module
1923+
1924+
return prepare_javascript_module(module_code, module_path)
1925+
1926+
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
1927+
from codeflash.languages.javascript.optimizer import find_js_project_root, verify_js_requirements
1928+
1929+
test_cfg.js_project_root = find_js_project_root(file_path)
1930+
verify_js_requirements(test_cfg)
1931+
19121932
# === Configuration ===
19131933

19141934
def get_test_file_suffix(self) -> str:

codeflash/languages/python/context/unused_definition_remover.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,8 @@
1010
import libcst as cst
1111

1212
from codeflash.cli_cmds.console import logger
13-
from codeflash.languages import is_python
13+
from codeflash.languages import current_language
14+
from codeflash.languages.base import Language
1415
from codeflash.languages.python.static_analysis.code_replacer import replace_function_definitions_in_module
1516
from codeflash.models.models import CodeString, CodeStringsMarkdown
1617

@@ -747,7 +748,7 @@ def detect_unused_helper_functions(
747748
748749
"""
749750
# Skip this analysis for non-Python languages since we use Python's ast module
750-
if not is_python():
751+
if current_language() != Language.PYTHON:
751752
logger.debug("Skipping unused helper function detection for non-Python languages")
752753
return []
753754

codeflash/languages/python/function_optimizer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def _resolve_function_ast(
5757
original_module_ast = ast.parse(source_code)
5858
return resolve_python_function_ast(function_name, parents, original_module_ast)
5959

60+
def requires_function_ast(self) -> bool:
61+
return True
62+
6063
def analyze_code_characteristics(self, code_context: CodeOptimizationContext) -> None:
6164
self.is_numerical_code = is_numerical_code(code_string=code_context.read_writable_code.flat)
6265

codeflash/languages/python/support.py

Lines changed: 216 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919
from codeflash.languages.registry import register_language
2020

2121
if TYPE_CHECKING:
22+
import ast
2223
from collections.abc import Sequence
2324

2425
from codeflash.languages.base import DependencyResolver
25-
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId
26+
from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode
27+
from codeflash.verification.verification_utils import TestConfig
2628

2729
logger = logging.getLogger(__name__)
2830

@@ -861,8 +863,217 @@ def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict:
861863
# Python uses line_profiler which has its own output format
862864
return {"timings": {}, "unit": 0, "str_out": ""}
863865

866+
@property
867+
def function_optimizer_class(self) -> type:
868+
from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer
869+
870+
return PythonFunctionOptimizer
871+
872+
def prepare_module(
873+
self, module_code: str, module_path: Path, project_root: Path
874+
) -> tuple[dict[Path, ValidCode], ast.Module] | None:
875+
from codeflash.languages.python.optimizer import prepare_python_module
876+
877+
return prepare_python_module(module_code, module_path, project_root)
878+
879+
def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None:
880+
pass
881+
864882
# === Test Execution (Full Protocol) ===
865-
# Note: For Python, test execution is handled by the main test_runner.py
866-
# which has special Python-specific logic. These methods are not called
867-
# for Python as the test_runner checks is_python() and uses the existing path.
868-
# They are defined here only for protocol compliance.
883+
884+
def run_behavioral_tests(
885+
self,
886+
test_paths: Any,
887+
test_env: dict[str, str],
888+
cwd: Path,
889+
timeout: int | None = None,
890+
project_root: Path | None = None,
891+
enable_coverage: bool = False,
892+
candidate_index: int = 0,
893+
) -> tuple[Path, Any, Path | None, Path | None]:
894+
import contextlib
895+
import shlex
896+
import sys
897+
898+
from codeflash.code_utils.code_utils import get_run_tmp_file
899+
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
900+
from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files
901+
from codeflash.models.models import TestType
902+
from codeflash.verification.test_runner import execute_test_subprocess
903+
904+
blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"]
905+
906+
test_files: list[str] = []
907+
for file in test_paths.test_files:
908+
if file.test_type == TestType.REPLAY_TEST:
909+
if file.tests_in_file:
910+
test_files.extend(
911+
[
912+
str(file.instrumented_behavior_file_path) + "::" + test.test_function
913+
for test in file.tests_in_file
914+
]
915+
)
916+
else:
917+
test_files.append(str(file.instrumented_behavior_file_path))
918+
919+
pytest_cmd_list = shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
920+
test_files = list(set(test_files))
921+
922+
common_pytest_args = [
923+
"--capture=tee-sys",
924+
"-q",
925+
"--codeflash_loops_scope=session",
926+
"--codeflash_min_loops=1",
927+
"--codeflash_max_loops=1",
928+
"--codeflash_seconds=10.0",
929+
]
930+
if timeout is not None:
931+
common_pytest_args.append(f"--timeout={timeout}")
932+
933+
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
934+
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
935+
936+
pytest_test_env = test_env.copy()
937+
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
938+
939+
coverage_database_file: Path | None = None
940+
coverage_config_file: Path | None = None
941+
942+
if enable_coverage:
943+
coverage_database_file, coverage_config_file = prepare_coverage_files()
944+
pytest_test_env["NUMBA_DISABLE_JIT"] = str(1)
945+
pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1)
946+
pytest_test_env["PYTORCH_JIT"] = str(0)
947+
pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0"
948+
pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0)
949+
pytest_test_env["JAX_DISABLE_JIT"] = str(0)
950+
951+
is_windows = sys.platform == "win32"
952+
if is_windows:
953+
if coverage_database_file.exists():
954+
with contextlib.suppress(PermissionError, OSError):
955+
coverage_database_file.unlink()
956+
else:
957+
cov_erase = execute_test_subprocess(
958+
shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30
959+
)
960+
logger.debug(cov_erase)
961+
coverage_cmd = [
962+
SAFE_SYS_EXECUTABLE,
963+
"-m",
964+
"coverage",
965+
"run",
966+
f"--rcfile={coverage_config_file.as_posix()}",
967+
"-m",
968+
"pytest",
969+
]
970+
971+
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"]
972+
results = execute_test_subprocess(
973+
coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files,
974+
cwd=cwd,
975+
env=pytest_test_env,
976+
timeout=600,
977+
)
978+
logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "")
979+
else:
980+
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
981+
982+
results = execute_test_subprocess(
983+
pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files,
984+
cwd=cwd,
985+
env=pytest_test_env,
986+
timeout=600,
987+
)
988+
logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "")
989+
990+
return result_file_path, results, coverage_database_file, coverage_config_file
991+
992+
def run_benchmarking_tests(
993+
self,
994+
test_paths: Any,
995+
test_env: dict[str, str],
996+
cwd: Path,
997+
timeout: int | None = None,
998+
project_root: Path | None = None,
999+
min_loops: int = 5,
1000+
max_loops: int = 100_000,
1001+
target_duration_seconds: float = 10.0,
1002+
) -> tuple[Path, Any]:
1003+
import shlex
1004+
1005+
from codeflash.code_utils.code_utils import get_run_tmp_file
1006+
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1007+
from codeflash.verification.test_runner import execute_test_subprocess
1008+
1009+
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
1010+
1011+
pytest_cmd_list = shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
1012+
test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files})
1013+
pytest_args = [
1014+
"--capture=tee-sys",
1015+
"-q",
1016+
"--codeflash_loops_scope=session",
1017+
f"--codeflash_min_loops={min_loops}",
1018+
f"--codeflash_max_loops={max_loops}",
1019+
f"--codeflash_seconds={target_duration_seconds}",
1020+
"--codeflash_stability_check=true",
1021+
]
1022+
if timeout is not None:
1023+
pytest_args.append(f"--timeout={timeout}")
1024+
1025+
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
1026+
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
1027+
pytest_test_env = test_env.copy()
1028+
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
1029+
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
1030+
results = execute_test_subprocess(
1031+
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
1032+
cwd=cwd,
1033+
env=pytest_test_env,
1034+
timeout=600,
1035+
)
1036+
return result_file_path, results
1037+
1038+
def run_line_profile_tests(
1039+
self,
1040+
test_paths: Any,
1041+
test_env: dict[str, str],
1042+
cwd: Path,
1043+
timeout: int | None = None,
1044+
project_root: Path | None = None,
1045+
line_profile_output_file: Path | None = None,
1046+
) -> tuple[Path, Any]:
1047+
import shlex
1048+
1049+
from codeflash.code_utils.code_utils import get_run_tmp_file
1050+
from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE
1051+
from codeflash.verification.test_runner import execute_test_subprocess
1052+
1053+
blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"]
1054+
1055+
pytest_cmd_list = shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX)
1056+
test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files})
1057+
pytest_args = [
1058+
"--capture=tee-sys",
1059+
"-q",
1060+
"--codeflash_loops_scope=session",
1061+
"--codeflash_min_loops=1",
1062+
"--codeflash_max_loops=1",
1063+
"--codeflash_seconds=10.0",
1064+
]
1065+
if timeout is not None:
1066+
pytest_args.append(f"--timeout={timeout}")
1067+
result_file_path = get_run_tmp_file(Path("pytest_results.xml"))
1068+
result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"]
1069+
pytest_test_env = test_env.copy()
1070+
pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin"
1071+
blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins]
1072+
pytest_test_env["LINE_PROFILE"] = "1"
1073+
results = execute_test_subprocess(
1074+
pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files,
1075+
cwd=cwd,
1076+
env=pytest_test_env,
1077+
timeout=600,
1078+
)
1079+
return result_file_path, results

0 commit comments

Comments
 (0)