Skip to content

Commit 5b43dc9

Browse files
authored
Merge branch 'main' into fix/js-jest30-loop-runner
2 parents 4dedb9f + 97c1249 commit 5b43dc9

8 files changed

Lines changed: 169 additions & 65 deletions

File tree

codeflash/benchmarking/trace_benchmarks.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,18 @@
11
from __future__ import annotations
22

3-
import os
43
import re
54
import subprocess
65
from pathlib import Path
76

87
from codeflash.cli_cmds.console import logger
98
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
10-
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args
9+
from codeflash.code_utils.shell_utils import get_cross_platform_subprocess_run_args, make_env_with_project_root
1110

1211

1312
def trace_benchmarks_pytest(
1413
benchmarks_root: Path, tests_root: Path, project_root: Path, trace_file: Path, timeout: int = 300
1514
) -> None:
16-
benchmark_env = os.environ.copy()
17-
if "PYTHONPATH" not in benchmark_env:
18-
benchmark_env["PYTHONPATH"] = str(project_root)
19-
else:
20-
benchmark_env["PYTHONPATH"] += os.pathsep + str(project_root)
15+
benchmark_env = make_env_with_project_root(project_root)
2116
run_args = get_cross_platform_subprocess_run_args(
2217
cwd=project_root, env=benchmark_env, timeout=timeout, check=False, text=True, capture_output=True
2318
)

codeflash/code_utils/concolic_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import sentry_sdk
1010

1111
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE, codeflash_temp_dir
12+
from codeflash.code_utils.shell_utils import make_env_with_project_root
1213

1314
# Known CrossHair limitations that produce invalid Python syntax in generated tests:
1415
# - "<locals>" - higher-order functions returning nested functions
@@ -37,6 +38,7 @@ def is_valid_concolic_test(test_code: str, project_root: Optional[str] = None) -
3738
text=True,
3839
cwd=project_root,
3940
timeout=10,
41+
env=make_env_with_project_root(project_root) if project_root else None,
4042
)
4143
except (subprocess.TimeoutExpired, Exception):
4244
return False

codeflash/code_utils/shell_utils.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -238,6 +238,18 @@ def save_api_key_to_rc(api_key: str) -> Result[str, str]:
238238
)
239239

240240

241+
def make_env_with_project_root(project_root: Path | str) -> dict[str, str]:
242+
"""Return a copy of os.environ with project_root prepended to PYTHONPATH."""
243+
env = os.environ.copy()
244+
project_root_str = str(project_root)
245+
pythonpath = env.get("PYTHONPATH", "")
246+
if pythonpath:
247+
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
248+
else:
249+
env["PYTHONPATH"] = project_root_str
250+
return env
251+
252+
241253
def get_cross_platform_subprocess_run_args(
242254
cwd: Path | str | None = None,
243255
env: Mapping[str, str] | None = None,

codeflash/discovery/functions_to_optimize.py

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -78,10 +78,23 @@ def __init__(self, file_path: str) -> None:
7878
self.file_path: str = file_path
7979
self.functions: list[FunctionToOptimize] = []
8080

81+
@staticmethod
82+
def is_pytest_fixture(node: cst.FunctionDef) -> bool:
83+
for decorator in node.decorators:
84+
dec = decorator.decorator
85+
if isinstance(dec, cst.Call):
86+
dec = dec.func
87+
if isinstance(dec, cst.Attribute) and dec.attr.value == "fixture":
88+
if isinstance(dec.value, cst.Name) and dec.value.value == "pytest":
89+
return True
90+
if isinstance(dec, cst.Name) and dec.value == "fixture":
91+
return True
92+
return False
93+
8194
def visit_FunctionDef(self, node: cst.FunctionDef) -> None:
8295
return_visitor: ReturnStatementVisitor = ReturnStatementVisitor()
8396
node.visit(return_visitor)
84-
if return_visitor.has_return_statement:
97+
if return_visitor.has_return_statement and not self.is_pytest_fixture(node):
8598
pos: CodeRange = self.get_metadata(cst.metadata.PositionProvider, node)
8699
parents: CSTNode | None = self.get_metadata(cst.metadata.ParentNodeProvider, node)
87100
ast_parents: list[FunctionParent] = []
@@ -108,14 +121,12 @@ def __init__(self, file_path: Path) -> None:
108121
self.file_path: Path = file_path
109122

110123
def visit_FunctionDef(self, node: FunctionDef) -> None:
111-
# Check if the function has a return statement and add it to the list
112124
if function_has_return_statement(node) and not function_is_a_property(node):
113125
self.functions.append(
114126
FunctionToOptimize(function_name=node.name, file_path=self.file_path, parents=self.ast_path[:])
115127
)
116128

117129
def visit_AsyncFunctionDef(self, node: AsyncFunctionDef) -> None:
118-
# Check if the async function has a return statement and add it to the list
119130
if function_has_return_statement(node) and not function_is_a_property(node):
120131
self.functions.append(
121132
FunctionToOptimize(
@@ -831,22 +842,17 @@ def filter_functions(
831842
test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep)
832843

833844
def is_test_file(file_path_normalized: str) -> bool:
834-
"""Check if a file is a test file based on patterns."""
835845
if tests_root_overlaps_source:
836-
# Use file pattern matching when tests_root overlaps with source
837846
file_lower = file_path_normalized.lower()
838-
# Check filename patterns (e.g., .test.ts, .spec.ts)
847+
basename = Path(file_lower).name
848+
if basename.startswith("test_") or basename == "conftest.py":
849+
return True
839850
if any(pattern in file_lower for pattern in test_file_name_patterns):
840851
return True
841-
# Check directory patterns, but only within the project root
842-
# to avoid false positives from parent directories (e.g., project at /home/user/tests/myproject)
843852
if project_root_str and file_lower.startswith(project_root_str.lower()):
844853
relative_path = file_lower[len(project_root_str) :]
845854
return any(pattern in relative_path for pattern in test_dir_patterns)
846-
# If we can't compute relative path from project root, don't check directory patterns
847-
# This avoids false positives when project is inside a folder named "tests"
848855
return False
849-
# Use directory-based filtering when tests are in a separate directory
850856
return file_path_normalized.startswith(tests_root_str + os.sep)
851857

852858
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching

codeflash/optimization/function_optimizer.py

Lines changed: 2 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
from typing import TYPE_CHECKING, Callable
1414

1515
import libcst as cst
16-
import sentry_sdk
1716
from rich.console import Group
1817
from rich.panel import Panel
1918
from rich.syntax import Syntax
@@ -70,6 +69,7 @@
7069
from codeflash.code_utils.git_utils import git_root_dir
7170
from codeflash.code_utils.instrument_existing_tests import inject_profiling_into_existing_test
7271
from codeflash.code_utils.line_profile_utils import add_decorator_imports, contains_jit_decorator
72+
from codeflash.code_utils.shell_utils import make_env_with_project_root
7373
from codeflash.code_utils.static_analysis import get_first_top_level_function_or_method_ast
7474
from codeflash.code_utils.time_utils import humanize_runtime
7575
from codeflash.context import code_context_extractor
@@ -686,30 +686,6 @@ def optimize_function(self) -> Result[BestOptimization, str]:
686686
):
687687
console.rule()
688688
new_code_context = code_context
689-
if (
690-
self.is_numerical_code and not self.args.no_jit_opts
691-
): # if the code is numerical in nature (uses numpy/tensorflow/math/pytorch/jax)
692-
jit_compiled_opt_candidate = self.aiservice_client.get_jit_rewritten_code(
693-
code_context.read_writable_code.markdown, self.function_trace_id
694-
)
695-
if jit_compiled_opt_candidate: # jit rewrite was successful
696-
# write files
697-
# Try to replace function with optimized code
698-
self.replace_function_and_helpers_with_optimized_code(
699-
code_context=code_context,
700-
optimized_code=jit_compiled_opt_candidate[0].source_code,
701-
original_helper_code=original_helper_code,
702-
)
703-
# get code context
704-
try:
705-
new_code_context = self.get_code_optimization_context().unwrap()
706-
except Exception as e:
707-
sentry_sdk.capture_exception(e)
708-
logger.debug("!lsp|Getting new code context failed, revert to original one")
709-
# unwrite files
710-
self.write_code_and_helpers(
711-
self.function_to_optimize_source_code, original_helper_code, self.function_to_optimize.file_path
712-
)
713689
# Generate tests and optimizations in parallel
714690
future_tests = self.executor.submit(self.generate_and_instrument_tests, new_code_context)
715691
future_optimizations = self.executor.submit(
@@ -2858,14 +2834,10 @@ def cleanup_generated_files(self) -> None:
28582834
def get_test_env(
28592835
self, codeflash_loop_index: int, codeflash_test_iteration: int, codeflash_tracer_disable: int = 1
28602836
) -> dict:
2861-
test_env = os.environ.copy()
2837+
test_env = make_env_with_project_root(self.args.project_root)
28622838
test_env["CODEFLASH_TEST_ITERATION"] = str(codeflash_test_iteration)
28632839
test_env["CODEFLASH_TRACER_DISABLE"] = str(codeflash_tracer_disable)
28642840
test_env["CODEFLASH_LOOP_INDEX"] = str(codeflash_loop_index)
2865-
if "PYTHONPATH" not in test_env:
2866-
test_env["PYTHONPATH"] = str(self.args.project_root)
2867-
else:
2868-
test_env["PYTHONPATH"] += os.pathsep + str(self.args.project_root)
28692841
return test_env
28702842

28712843
def line_profiler_step(

codeflash/tracer.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from __future__ import annotations
1313

1414
import json
15-
import os
1615
import pickle
1716
import subprocess
1817
import sys
@@ -26,6 +25,7 @@
2625
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
2726
from codeflash.code_utils.config_consts import EffortLevel
2827
from codeflash.code_utils.config_parser import parse_config_file
28+
from codeflash.code_utils.shell_utils import make_env_with_project_root
2929
from codeflash.tracing.pytest_parallelization import pytest_split
3030

3131
if TYPE_CHECKING:
@@ -131,13 +131,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
131131
else:
132132
updated_sys_argv.append(elem)
133133
args_dict["command"] = " ".join(updated_sys_argv)
134-
env = os.environ.copy()
135-
pythonpath = env.get("PYTHONPATH", "")
136-
project_root_str = str(project_root)
137-
if pythonpath:
138-
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
139-
else:
140-
env["PYTHONPATH"] = project_root_str
134+
env = make_env_with_project_root(project_root)
141135
# Disable JIT compilation to ensure tracing captures all function calls
142136
env["NUMBA_DISABLE_JIT"] = str(1)
143137
env["TORCHDYNAMO_DISABLE"] = str(1)
@@ -174,14 +168,7 @@ def main(args: Namespace | None = None) -> ArgumentParser:
174168
args_dict["result_pickle_file_path"] = str(result_pickle_file_path)
175169
args_dict["command"] = " ".join(sys.argv)
176170

177-
env = os.environ.copy()
178-
# Add project root to PYTHONPATH so imports work correctly
179-
pythonpath = env.get("PYTHONPATH", "")
180-
project_root_str = str(project_root)
181-
if pythonpath:
182-
env["PYTHONPATH"] = f"{project_root_str}{os.pathsep}{pythonpath}"
183-
else:
184-
env["PYTHONPATH"] = project_root_str
171+
env = make_env_with_project_root(project_root)
185172
# Disable JIT compilation to ensure tracing captures all function calls
186173
env["NUMBA_DISABLE_JIT"] = str(1)
187174
env["TORCHDYNAMO_DISABLE"] = str(1)

codeflash/verification/concolic_testing.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from codeflash.cli_cmds.console import console, logger
1111
from codeflash.code_utils.compat import SAFE_SYS_EXECUTABLE
1212
from codeflash.code_utils.concolic_utils import clean_concolic_tests, is_valid_concolic_test
13+
from codeflash.code_utils.shell_utils import make_env_with_project_root
1314
from codeflash.code_utils.static_analysis import has_typed_parameters
1415
from codeflash.discovery.discover_unit_tests import discover_unit_tests
1516
from codeflash.languages import is_python
@@ -63,6 +64,7 @@ def generate_concolic_tests(
6364
logger.info("Generating concolic opcode coverage tests for the original code…")
6465
console.rule()
6566
try:
67+
env = make_env_with_project_root(args.project_root)
6668
cover_result = subprocess.run(
6769
[
6870
SAFE_SYS_EXECUTABLE,
@@ -86,6 +88,7 @@ def generate_concolic_tests(
8688
cwd=args.project_root,
8789
check=False,
8890
timeout=600,
91+
env=env,
8992
)
9093
except subprocess.TimeoutExpired:
9194
logger.debug("CrossHair Cover test generation timed out")

tests/test_function_discovery.py

Lines changed: 128 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1149,4 +1149,131 @@ def test_is_object_empty():
11491149
)
11501150

11511151
# Strict check: exactly 2 functions
1152-
assert count == 2, f"Expected exactly 2 functions, got {count}"
1152+
assert count == 2, f"Expected exactly 2 functions, got {count}"
1153+
1154+
1155+
def test_filter_functions_python_test_prefix_convention():
1156+
"""Test that files following Python's test_*.py naming convention are filtered.
1157+
1158+
Python's standard test file naming uses the test_ prefix (e.g., test_utils.py),
1159+
which was previously not caught by the pattern matching in overlapping mode.
1160+
"""
1161+
with tempfile.TemporaryDirectory() as temp_dir_str:
1162+
temp_dir = Path(temp_dir_str)
1163+
1164+
# Source file that should NOT be filtered
1165+
source_file = temp_dir / "utils.py"
1166+
with source_file.open("w") as f:
1167+
f.write("def process(): return 1")
1168+
1169+
# Python test file with test_ prefix - SHOULD be filtered
1170+
test_prefix_file = temp_dir / "test_utils.py"
1171+
with test_prefix_file.open("w") as f:
1172+
f.write("def test_process(): return 1")
1173+
1174+
# conftest.py - SHOULD be filtered
1175+
conftest_file = temp_dir / "conftest.py"
1176+
with conftest_file.open("w") as f:
1177+
f.write("""
1178+
import pytest
1179+
1180+
@pytest.fixture
1181+
def sample_data():
1182+
return [1, 2, 3]
1183+
""")
1184+
1185+
# File in a test_ prefixed directory - should NOT be filtered by file patterns
1186+
# (directory patterns don't cover test_ prefix dirs, which is fine)
1187+
test_subdir = temp_dir / "test_integration"
1188+
test_subdir.mkdir()
1189+
file_in_test_dir = test_subdir / "helpers.py"
1190+
with file_in_test_dir.open("w") as f:
1191+
f.write("def helper(): return 1")
1192+
1193+
# test_ prefix file inside a subdirectory - SHOULD be filtered
1194+
test_in_subdir = test_subdir / "test_helpers.py"
1195+
with test_in_subdir.open("w") as f:
1196+
f.write("def test_helper(): return 1")
1197+
1198+
all_functions = {}
1199+
for file_path in [source_file, test_prefix_file, conftest_file, file_in_test_dir, test_in_subdir]:
1200+
discovered = find_all_functions_in_file(file_path)
1201+
all_functions.update(discovered)
1202+
1203+
with unittest.mock.patch(
1204+
"codeflash.discovery.functions_to_optimize.get_blocklisted_functions", return_value={}
1205+
):
1206+
filtered, count = filter_functions(
1207+
all_functions,
1208+
tests_root=temp_dir, # Overlapping case
1209+
ignore_paths=[],
1210+
project_root=temp_dir,
1211+
module_root=temp_dir,
1212+
)
1213+
1214+
# source_file and file_in_test_dir should remain
1215+
# test_prefix_file, conftest_file, and test_in_subdir should be filtered
1216+
expected_files = {source_file, file_in_test_dir}
1217+
assert set(filtered.keys()) == expected_files, (
1218+
f"Expected {expected_files}, got {set(filtered.keys())}"
1219+
)
1220+
assert count == 2, f"Expected exactly 2 functions, got {count}"
1221+
1222+
1223+
def test_pytest_fixture_not_discovered():
1224+
"""Test that @pytest.fixture decorated functions are not discovered via libcst path."""
1225+
from codeflash.languages.python.support import PythonSupport
1226+
1227+
with tempfile.TemporaryDirectory() as temp_dir_str:
1228+
temp_dir = Path(temp_dir_str)
1229+
1230+
fixture_file = temp_dir / "conftest.py"
1231+
with fixture_file.open("w") as f:
1232+
f.write("""
1233+
import pytest
1234+
from pytest import fixture
1235+
1236+
def regular_function():
1237+
return 42
1238+
1239+
@pytest.fixture
1240+
def sample_data():
1241+
return [1, 2, 3]
1242+
1243+
@pytest.fixture()
1244+
def sample_config():
1245+
return {"key": "value"}
1246+
1247+
@fixture
1248+
def direct_import_fixture():
1249+
return "data"
1250+
1251+
@fixture()
1252+
def direct_import_fixture_with_parens():
1253+
return "data"
1254+
1255+
@pytest.fixture(scope="session")
1256+
def session_fixture():
1257+
return "session"
1258+
1259+
class TestHelpers:
1260+
@pytest.fixture
1261+
def class_fixture(self):
1262+
return "class_data"
1263+
1264+
def helper_method(self):
1265+
return "helper"
1266+
""")
1267+
1268+
support = PythonSupport()
1269+
functions = support.discover_functions(fixture_file)
1270+
function_names = [fn.function_name for fn in functions]
1271+
1272+
assert "regular_function" in function_names
1273+
assert "helper_method" in function_names
1274+
assert "sample_data" not in function_names
1275+
assert "sample_config" not in function_names
1276+
assert "direct_import_fixture" not in function_names
1277+
assert "direct_import_fixture_with_parens" not in function_names
1278+
assert "session_fixture" not in function_names
1279+
assert "class_fixture" not in function_names

0 commit comments

Comments
 (0)