|
19 | 19 | from codeflash.languages.registry import register_language |
20 | 20 |
|
21 | 21 | if TYPE_CHECKING: |
| 22 | + import ast |
22 | 23 | from collections.abc import Sequence |
23 | 24 |
|
24 | 25 | from codeflash.languages.base import DependencyResolver |
25 | | - from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId |
| 26 | + from codeflash.models.models import FunctionSource, GeneratedTestsList, InvocationId, ValidCode |
| 27 | + from codeflash.verification.verification_utils import TestConfig |
26 | 28 |
|
27 | 29 | logger = logging.getLogger(__name__) |
28 | 30 |
|
@@ -861,8 +863,217 @@ def parse_line_profile_results(self, line_profiler_output_file: Path) -> dict: |
861 | 863 | # Python uses line_profiler which has its own output format |
862 | 864 | return {"timings": {}, "unit": 0, "str_out": ""} |
863 | 865 |
|
| 866 | + @property |
| 867 | + def function_optimizer_class(self) -> type: |
| 868 | + from codeflash.languages.python.function_optimizer import PythonFunctionOptimizer |
| 869 | + |
| 870 | + return PythonFunctionOptimizer |
| 871 | + |
| 872 | + def prepare_module( |
| 873 | + self, module_code: str, module_path: Path, project_root: Path |
| 874 | + ) -> tuple[dict[Path, ValidCode], ast.Module] | None: |
| 875 | + from codeflash.languages.python.optimizer import prepare_python_module |
| 876 | + |
| 877 | + return prepare_python_module(module_code, module_path, project_root) |
| 878 | + |
| 879 | + def setup_test_config(self, test_cfg: TestConfig, file_path: Path) -> None: |
| 880 | + pass |
| 881 | + |
864 | 882 | # === Test Execution (Full Protocol) === |
865 | | - # Note: For Python, test execution is handled by the main test_runner.py |
866 | | - # which has special Python-specific logic. These methods are not called |
867 | | - # for Python as the test_runner checks is_python() and uses the existing path. |
868 | | - # They are defined here only for protocol compliance. |
| 883 | + |
| 884 | + def run_behavioral_tests( |
| 885 | + self, |
| 886 | + test_paths: Any, |
| 887 | + test_env: dict[str, str], |
| 888 | + cwd: Path, |
| 889 | + timeout: int | None = None, |
| 890 | + project_root: Path | None = None, |
| 891 | + enable_coverage: bool = False, |
| 892 | + candidate_index: int = 0, |
| 893 | + ) -> tuple[Path, Any, Path | None, Path | None]: |
| 894 | + import contextlib |
| 895 | + import shlex |
| 896 | + import sys |
| 897 | + |
| 898 | + from codeflash.code_utils.code_utils import get_run_tmp_file |
| 899 | + from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE |
| 900 | + from codeflash.languages.python.static_analysis.coverage_utils import prepare_coverage_files |
| 901 | + from codeflash.models.models import TestType |
| 902 | + from codeflash.verification.test_runner import execute_test_subprocess |
| 903 | + |
| 904 | + blocklisted_plugins = ["benchmark", "codspeed", "xdist", "sugar"] |
| 905 | + |
| 906 | + test_files: list[str] = [] |
| 907 | + for file in test_paths.test_files: |
| 908 | + if file.test_type == TestType.REPLAY_TEST: |
| 909 | + if file.tests_in_file: |
| 910 | + test_files.extend( |
| 911 | + [ |
| 912 | + str(file.instrumented_behavior_file_path) + "::" + test.test_function |
| 913 | + for test in file.tests_in_file |
| 914 | + ] |
| 915 | + ) |
| 916 | + else: |
| 917 | + test_files.append(str(file.instrumented_behavior_file_path)) |
| 918 | + |
| 919 | + pytest_cmd_list = shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX) |
| 920 | + test_files = list(set(test_files)) |
| 921 | + |
| 922 | + common_pytest_args = [ |
| 923 | + "--capture=tee-sys", |
| 924 | + "-q", |
| 925 | + "--codeflash_loops_scope=session", |
| 926 | + "--codeflash_min_loops=1", |
| 927 | + "--codeflash_max_loops=1", |
| 928 | + "--codeflash_seconds=10.0", |
| 929 | + ] |
| 930 | + if timeout is not None: |
| 931 | + common_pytest_args.append(f"--timeout={timeout}") |
| 932 | + |
| 933 | + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) |
| 934 | + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] |
| 935 | + |
| 936 | + pytest_test_env = test_env.copy() |
| 937 | + pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" |
| 938 | + |
| 939 | + coverage_database_file: Path | None = None |
| 940 | + coverage_config_file: Path | None = None |
| 941 | + |
| 942 | + if enable_coverage: |
| 943 | + coverage_database_file, coverage_config_file = prepare_coverage_files() |
| 944 | + pytest_test_env["NUMBA_DISABLE_JIT"] = str(1) |
| 945 | + pytest_test_env["TORCHDYNAMO_DISABLE"] = str(1) |
| 946 | + pytest_test_env["PYTORCH_JIT"] = str(0) |
| 947 | + pytest_test_env["TF_XLA_FLAGS"] = "--tf_xla_auto_jit=0" |
| 948 | + pytest_test_env["TF_ENABLE_ONEDNN_OPTS"] = str(0) |
| 949 | + pytest_test_env["JAX_DISABLE_JIT"] = str(0) |
| 950 | + |
| 951 | + is_windows = sys.platform == "win32" |
| 952 | + if is_windows: |
| 953 | + if coverage_database_file.exists(): |
| 954 | + with contextlib.suppress(PermissionError, OSError): |
| 955 | + coverage_database_file.unlink() |
| 956 | + else: |
| 957 | + cov_erase = execute_test_subprocess( |
| 958 | + shlex.split(f"{SAFE_SYS_EXECUTABLE} -m coverage erase"), cwd=cwd, env=pytest_test_env, timeout=30 |
| 959 | + ) |
| 960 | + logger.debug(cov_erase) |
| 961 | + coverage_cmd = [ |
| 962 | + SAFE_SYS_EXECUTABLE, |
| 963 | + "-m", |
| 964 | + "coverage", |
| 965 | + "run", |
| 966 | + f"--rcfile={coverage_config_file.as_posix()}", |
| 967 | + "-m", |
| 968 | + "pytest", |
| 969 | + ] |
| 970 | + |
| 971 | + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins if plugin != "cov"] |
| 972 | + results = execute_test_subprocess( |
| 973 | + coverage_cmd + common_pytest_args + blocklist_args + result_args + test_files, |
| 974 | + cwd=cwd, |
| 975 | + env=pytest_test_env, |
| 976 | + timeout=600, |
| 977 | + ) |
| 978 | + logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "") |
| 979 | + else: |
| 980 | + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] |
| 981 | + |
| 982 | + results = execute_test_subprocess( |
| 983 | + pytest_cmd_list + common_pytest_args + blocklist_args + result_args + test_files, |
| 984 | + cwd=cwd, |
| 985 | + env=pytest_test_env, |
| 986 | + timeout=600, |
| 987 | + ) |
| 988 | + logger.debug("Result return code: %s, %s", results.returncode, results.stderr or "") |
| 989 | + |
| 990 | + return result_file_path, results, coverage_database_file, coverage_config_file |
| 991 | + |
| 992 | + def run_benchmarking_tests( |
| 993 | + self, |
| 994 | + test_paths: Any, |
| 995 | + test_env: dict[str, str], |
| 996 | + cwd: Path, |
| 997 | + timeout: int | None = None, |
| 998 | + project_root: Path | None = None, |
| 999 | + min_loops: int = 5, |
| 1000 | + max_loops: int = 100_000, |
| 1001 | + target_duration_seconds: float = 10.0, |
| 1002 | + ) -> tuple[Path, Any]: |
| 1003 | + import shlex |
| 1004 | + |
| 1005 | + from codeflash.code_utils.code_utils import get_run_tmp_file |
| 1006 | + from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE |
| 1007 | + from codeflash.verification.test_runner import execute_test_subprocess |
| 1008 | + |
| 1009 | + blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"] |
| 1010 | + |
| 1011 | + pytest_cmd_list = shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX) |
| 1012 | + test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files}) |
| 1013 | + pytest_args = [ |
| 1014 | + "--capture=tee-sys", |
| 1015 | + "-q", |
| 1016 | + "--codeflash_loops_scope=session", |
| 1017 | + f"--codeflash_min_loops={min_loops}", |
| 1018 | + f"--codeflash_max_loops={max_loops}", |
| 1019 | + f"--codeflash_seconds={target_duration_seconds}", |
| 1020 | + "--codeflash_stability_check=true", |
| 1021 | + ] |
| 1022 | + if timeout is not None: |
| 1023 | + pytest_args.append(f"--timeout={timeout}") |
| 1024 | + |
| 1025 | + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) |
| 1026 | + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] |
| 1027 | + pytest_test_env = test_env.copy() |
| 1028 | + pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" |
| 1029 | + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] |
| 1030 | + results = execute_test_subprocess( |
| 1031 | + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files, |
| 1032 | + cwd=cwd, |
| 1033 | + env=pytest_test_env, |
| 1034 | + timeout=600, |
| 1035 | + ) |
| 1036 | + return result_file_path, results |
| 1037 | + |
| 1038 | + def run_line_profile_tests( |
| 1039 | + self, |
| 1040 | + test_paths: Any, |
| 1041 | + test_env: dict[str, str], |
| 1042 | + cwd: Path, |
| 1043 | + timeout: int | None = None, |
| 1044 | + project_root: Path | None = None, |
| 1045 | + line_profile_output_file: Path | None = None, |
| 1046 | + ) -> tuple[Path, Any]: |
| 1047 | + import shlex |
| 1048 | + |
| 1049 | + from codeflash.code_utils.code_utils import get_run_tmp_file |
| 1050 | + from codeflash.code_utils.compat import IS_POSIX, SAFE_SYS_EXECUTABLE |
| 1051 | + from codeflash.verification.test_runner import execute_test_subprocess |
| 1052 | + |
| 1053 | + blocklisted_plugins = ["codspeed", "cov", "benchmark", "profiling", "xdist", "sugar"] |
| 1054 | + |
| 1055 | + pytest_cmd_list = shlex.split(f"{SAFE_SYS_EXECUTABLE} -m pytest", posix=IS_POSIX) |
| 1056 | + test_files: list[str] = list({str(file.benchmarking_file_path) for file in test_paths.test_files}) |
| 1057 | + pytest_args = [ |
| 1058 | + "--capture=tee-sys", |
| 1059 | + "-q", |
| 1060 | + "--codeflash_loops_scope=session", |
| 1061 | + "--codeflash_min_loops=1", |
| 1062 | + "--codeflash_max_loops=1", |
| 1063 | + "--codeflash_seconds=10.0", |
| 1064 | + ] |
| 1065 | + if timeout is not None: |
| 1066 | + pytest_args.append(f"--timeout={timeout}") |
| 1067 | + result_file_path = get_run_tmp_file(Path("pytest_results.xml")) |
| 1068 | + result_args = [f"--junitxml={result_file_path.as_posix()}", "-o", "junit_logging=all"] |
| 1069 | + pytest_test_env = test_env.copy() |
| 1070 | + pytest_test_env["PYTEST_PLUGINS"] = "codeflash.verification.pytest_plugin" |
| 1071 | + blocklist_args = [f"-p no:{plugin}" for plugin in blocklisted_plugins] |
| 1072 | + pytest_test_env["LINE_PROFILE"] = "1" |
| 1073 | + results = execute_test_subprocess( |
| 1074 | + pytest_cmd_list + pytest_args + blocklist_args + result_args + test_files, |
| 1075 | + cwd=cwd, |
| 1076 | + env=pytest_test_env, |
| 1077 | + timeout=600, |
| 1078 | + ) |
| 1079 | + return result_file_path, results |
0 commit comments