Skip to content

Commit d0ec9b3

Browse files
Merge branch 'main' into fix/get-language-based-on-formatter
2 parents be02b40 + 4803f26 commit d0ec9b3

7 files changed

Lines changed: 429 additions & 30 deletions

File tree

.github/workflows/claude.yml

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,30 @@ jobs:
1919
(github.event_name == 'issues' && (contains(github.event.issue.body, '@claude') || contains(github.event.issue.title, '@claude')))
2020
runs-on: ubuntu-latest
2121
permissions:
22-
contents: read
23-
pull-requests: read
22+
contents: write
23+
pull-requests: write
2424
issues: read
2525
id-token: write
2626
actions: read # Required for Claude to read CI results on PRs
2727
steps:
28+
- name: Get PR head ref
29+
id: pr-ref
30+
env:
31+
GH_TOKEN: ${{ github.token }}
32+
run: |
33+
# For issue_comment events, we need to fetch the PR info
34+
if [ "${{ github.event_name }}" = "issue_comment" ]; then
35+
PR_REF=$(gh api repos/${{ github.repository }}/pulls/${{ github.event.issue.number }} --jq '.head.ref')
36+
echo "ref=$PR_REF" >> $GITHUB_OUTPUT
37+
else
38+
echo "ref=${{ github.event.pull_request.head.ref || github.head_ref }}" >> $GITHUB_OUTPUT
39+
fi
40+
2841
- name: Checkout repository
2942
uses: actions/checkout@v4
3043
with:
31-
fetch-depth: 1
44+
fetch-depth: 0
45+
ref: ${{ steps.pr-ref.outputs.ref }}
3246

3347
- name: Run Claude Code
3448
id: claude

codeflash/cli_cmds/init_javascript.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from codeflash.code_utils.git_utils import get_git_remotes
2727
from codeflash.code_utils.shell_utils import get_shell_rc_path, is_powershell
2828
from codeflash.telemetry.posthog_cf import ph
29+
from rich.prompt import Confirm
2930

3031

3132
class ProjectLanguage(Enum):
@@ -208,9 +209,7 @@ def init_js_project(language: ProjectLanguage) -> None:
208209

209210
def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
210211
"""Check if package.json has valid codeflash config for JS/TS projects."""
211-
from rich.prompt import Confirm
212-
213-
package_json_path = Path.cwd() / "package.json"
212+
package_json_path = Path("package.json")
214213

215214
if not package_json_path.exists():
216215
click.echo("❌ No package.json found. Please run 'npm init' first.")
@@ -230,6 +229,10 @@ def should_modify_package_json_config() -> tuple[bool, dict[str, Any] | None]:
230229
if not Path(module_root).is_dir():
231230
return True, None
232231

232+
tests_root = config.get("testsRoot", None)
233+
if tests_root and not Path(tests_root).is_dir():
234+
return True, None
235+
233236
# Config is valid - ask if user wants to reconfigure
234237
return Confirm.ask(
235238
"✅ A valid Codeflash config already exists in package.json. Do you want to re-configure it?",

codeflash/code_utils/config_js.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,9 @@ def parse_package_json_config(package_json_path: Path) -> tuple[dict[str, Any],
251251
detected_module_root = detect_module_root(project_root, package_data)
252252
config["module_root"] = str((project_root / Path(detected_module_root)).resolve())
253253

254+
if codeflash_config.get("testsRoot"):
255+
config["tests_root"] = str(project_root / Path(codeflash_config["testsRoot"]).resolve())
256+
254257
# Auto-detect test runner
255258
config["test_runner"] = detect_test_runner(project_root, package_data)
256259
# Keep pytest_cmd for backwards compatibility with existing code

codeflash/discovery/functions_to_optimize.py

Lines changed: 53 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@
4141

4242
from codeflash.models.models import CodeOptimizationContext
4343
from codeflash.verification.verification_utils import TestConfig
44+
import contextlib
45+
4446
from rich.text import Text
4547

4648
_property_id = "property"
@@ -595,9 +597,10 @@ def get_all_replay_test_functions(
595597
except Exception as e:
596598
logger.warning(f"Error parsing replay test file {replay_test_file}: {e}")
597599

598-
if not trace_file_path:
600+
if trace_file_path is None:
599601
logger.error("Could not find trace_file_path in replay test files.")
600602
exit_with_message("Could not find trace_file_path in replay test files.")
603+
raise AssertionError("Unreachable") # exit_with_message never returns
601604

602605
if not trace_file_path.exists():
603606
logger.error(f"Trace file not found: {trace_file_path}")
@@ -652,7 +655,7 @@ def get_all_replay_test_functions(
652655
if filtered_list:
653656
filtered_valid_functions[file_path] = filtered_list
654657

655-
return filtered_valid_functions, trace_file_path
658+
return dict(filtered_valid_functions), trace_file_path
656659

657660

658661
def is_git_repo(file_path: str) -> bool:
@@ -664,11 +667,13 @@ def is_git_repo(file_path: str) -> bool:
664667

665668

666669
@cache
667-
def ignored_submodule_paths(module_root: str) -> list[str]:
670+
def ignored_submodule_paths(module_root: str) -> list[Path]:
668671
if is_git_repo(module_root):
669672
git_repo = git.Repo(module_root, search_parent_directories=True)
670673
try:
671-
return [Path(git_repo.working_tree_dir, submodule.path).resolve() for submodule in git_repo.submodules]
674+
working_dir = git_repo.working_tree_dir
675+
if working_dir is not None:
676+
return [Path(working_dir, submodule.path).resolve() for submodule in git_repo.submodules]
672677
except Exception as e:
673678
logger.warning(f"Error getting submodule paths: {e}")
674679
return []
@@ -682,7 +687,7 @@ def __init__(
682687
self.class_name = class_name
683688
self.function_name = function_or_method_name
684689
self.is_top_level = False
685-
self.function_has_args = None
690+
self.function_has_args: bool | None = None
686691
self.line_no = line_no
687692
self.is_staticmethod = False
688693
self.is_classmethod = False
@@ -796,31 +801,28 @@ def was_function_previously_optimized(
796801

797802
# Check optimization status if repository info is provided
798803
# already_optimized_count = 0
799-
try:
804+
805+
# Check optimization status if repository info is provided
806+
# already_optimized_count = 0
807+
owner = None
808+
repo = None
809+
with contextlib.suppress(git.exc.InvalidGitRepositoryError):
800810
owner, repo = get_repo_owner_and_name()
801-
except git.exc.InvalidGitRepositoryError:
802-
logger.warning("No git repository found")
803-
owner, repo = None, None
811+
804812
pr_number = get_pr_number()
805813

806814
if not owner or not repo or pr_number is None or getattr(args, "no_pr", False):
807815
return False
808816

809-
code_contexts = []
810-
811817
func_hash = code_context.hashing_code_context_hash
812-
# Use a unique path identifier that includes function info
813818

814-
code_contexts.append(
819+
code_contexts = [
815820
{
816-
"file_path": function_to_optimize.file_path,
821+
"file_path": str(function_to_optimize.file_path),
817822
"function_name": function_to_optimize.qualified_name,
818823
"code_hash": func_hash,
819824
}
820-
)
821-
822-
if not code_contexts:
823-
return False
825+
]
824826

825827
try:
826828
result = is_function_being_optimized_again(owner, repo, pr_number, code_contexts)
@@ -839,7 +841,7 @@ def filter_functions(
839841
ignore_paths: list[Path],
840842
project_root: Path,
841843
module_root: Path,
842-
previous_checkpoint_functions: dict[Path, dict[str, Any]] | None = None,
844+
previous_checkpoint_functions: dict[str, dict[str, Any]] | None = None,
843845
*,
844846
disable_logs: bool = False,
845847
) -> tuple[dict[Path, list[FunctionToOptimize]], int]:
@@ -864,21 +866,49 @@ def filter_functions(
864866
# Normalize paths for case-insensitive comparison on Windows
865867
tests_root_str = os.path.normcase(str(tests_root))
866868
module_root_str = os.path.normcase(str(module_root))
869+
project_root_str = os.path.normcase(str(project_root))
870+
871+
# Check if tests_root overlaps with module_root or project_root
872+
# In this case, we need to use file pattern matching instead of directory matching
873+
tests_root_overlaps_source = tests_root_str in (module_root_str, project_root_str) or module_root_str.startswith(
874+
tests_root_str + os.sep
875+
)
876+
877+
# Test file patterns for when tests_root overlaps with source
878+
test_file_name_patterns = (".test.", ".spec.", "_test.", "_spec.")
879+
test_dir_patterns = (os.sep + "test" + os.sep, os.sep + "tests" + os.sep, os.sep + "__tests__" + os.sep)
880+
881+
def is_test_file(file_path_normalized: str) -> bool:
882+
"""Check if a file is a test file based on patterns."""
883+
if tests_root_overlaps_source:
884+
# Use file pattern matching when tests_root overlaps with source
885+
file_lower = file_path_normalized.lower()
886+
# Check filename patterns (e.g., .test.ts, .spec.ts)
887+
if any(pattern in file_lower for pattern in test_file_name_patterns):
888+
return True
889+
# Check directory patterns, but only within the project root
890+
# to avoid false positives from parent directories
891+
relative_path = file_lower
892+
if project_root_str and file_lower.startswith(project_root_str.lower()):
893+
relative_path = file_lower[len(project_root_str) :]
894+
return any(pattern in relative_path for pattern in test_dir_patterns)
895+
# Use directory-based filtering when tests are in a separate directory
896+
return file_path_normalized.startswith(tests_root_str + os.sep)
867897

868898
# We desperately need Python 3.10+ only support to make this code readable with structural pattern matching
869899
for file_path_path, functions in modified_functions.items():
870900
_functions = functions
871901
file_path = str(file_path_path)
872902
file_path_normalized = os.path.normcase(file_path)
873-
if file_path_normalized.startswith(tests_root_str + os.sep):
903+
if is_test_file(file_path_normalized):
874904
test_functions_removed_count += len(_functions)
875905
continue
876-
if file_path in ignore_paths or any(
906+
if file_path_path in ignore_paths or any(
877907
file_path_normalized.startswith(os.path.normcase(str(ignore_path)) + os.sep) for ignore_path in ignore_paths
878908
):
879909
ignore_paths_removed_count += 1
880910
continue
881-
if file_path in submodule_paths or any(
911+
if file_path_path in submodule_paths or any(
882912
file_path_normalized.startswith(os.path.normcase(str(submodule_path)) + os.sep)
883913
for submodule_path in submodule_paths
884914
):
@@ -970,7 +1000,7 @@ def filter_files_optimized(file_path: Path, tests_root: Path, ignore_paths: list
9701000

9711001
def function_has_return_statement(function_node: FunctionDef | AsyncFunctionDef) -> bool:
9721002
# Custom DFS, return True as soon as a Return node is found
973-
stack = [function_node]
1003+
stack: list[ast.AST] = [function_node]
9741004
while stack:
9751005
node = stack.pop()
9761006
if isinstance(node, ast.Return):

codeflash/setup/config_schema.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,9 @@ def to_package_json_dict(self) -> dict[str, Any]:
103103
if self.module_root and self.module_root not in (".", "src"):
104104
config["moduleRoot"] = self.module_root
105105

106+
if self.tests_root:
107+
config["testsRoot"] = self.tests_root
108+
106109
# Formatter (only if explicitly set)
107110
if self.formatter_cmds:
108111
config["formatterCmds"] = self.formatter_cmds

packages/codeflash/scripts/postinstall.js

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ function installCodeflash(uvBin) {
115115
try {
116116
// Use uv tool install to install codeflash in an isolated environment
117117
// This avoids conflicts with any existing Python environments
118-
execSync(`"${uvBin}" tool install codeflash --force`, {
118+
execSync(`"${uvBin}" tool install --force --python python3.12 codeflash`, {
119119
stdio: 'inherit',
120120
shell: true,
121121
});

0 commit comments

Comments
 (0)