Skip to content

Commit 88d6e8b

Browse files
authored
Merge branch 'main' into comparator-nn-module
2 parents 555a2f9 + 97531dc commit 88d6e8b

26 files changed

Lines changed: 1228 additions & 404 deletions

MULTI_LANGUAGE_ARCHITECTURE.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ class JavaScriptTransformer:
386386

387387
from pathlib import Path
388388
from codeflash.languages.base import LanguageSupport, FunctionInfo, CodeContext
389-
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
389+
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
390390
from codeflash.languages.javascript.transformer import JavaScriptTransformer
391391

392392
class JavaScriptSupport(LanguageSupport):
@@ -523,7 +523,7 @@ class JavaScriptSupport(LanguageSupport):
523523
# codeflash/languages/javascript/test_discovery.py
524524

525525
from pathlib import Path
526-
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
526+
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
527527

528528
class JestTestDiscovery:
529529
"""Static analysis-based test discovery for Jest."""

codeflash/code_utils/code_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1772,7 +1772,7 @@ def _extract_calling_function_js(source_code: str, function_name: str, ref_line:
17721772
17731773
"""
17741774
try:
1775-
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
1775+
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
17761776

17771777
# Try TypeScript first, fall back to JavaScript
17781778
for lang in [TreeSitterLanguage.TYPESCRIPT, TreeSitterLanguage.TSX, TreeSitterLanguage.JAVASCRIPT]:

codeflash/code_utils/code_replacer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
2828
from codeflash.languages.base import Language, LanguageSupport
29-
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer
29+
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer
3030
from codeflash.models.models import CodeOptimizationContext, CodeStringsMarkdown, OptimizedCandidate, ValidCode
3131

3232
ASTNodeT = TypeVar("ASTNodeT", bound=ast.AST)
@@ -640,7 +640,7 @@ def _add_global_declarations_for_language(
640640
return original_source
641641

642642
try:
643-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
643+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
644644

645645
analyzer = get_analyzer_for_file(module_abspath)
646646

codeflash/code_utils/normalizers/javascript.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def normalize(self, code: str) -> str:
233233
234234
"""
235235
try:
236-
from codeflash.languages.treesitter_utils import TreeSitterAnalyzer, TreeSitterLanguage
236+
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
237237

238238
lang_map = {"javascript": TreeSitterLanguage.JAVASCRIPT, "typescript": TreeSitterLanguage.TYPESCRIPT}
239239
lang = lang_map.get(self._get_tree_sitter_language(), TreeSitterLanguage.JAVASCRIPT)

codeflash/discovery/functions_to_optimize.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _is_js_ts_function_exported(file_path: Path, function_name: str) -> tuple[bo
201201
Tuple of (is_exported, export_name). export_name may be 'default' for default exports.
202202
203203
"""
204-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
204+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
205205

206206
try:
207207
source = file_path.read_text(encoding="utf-8")

codeflash/languages/javascript/find_references.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from tree_sitter import Node
2424

2525
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
26-
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
26+
from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer
2727

2828
logger = logging.getLogger(__name__)
2929

@@ -112,7 +112,7 @@ def find_references(
112112
List of Reference objects describing each call site.
113113
114114
"""
115-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
115+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
116116

117117
function_name = function_to_optimize.function_name
118118
source_file = function_to_optimize.file_path
@@ -168,7 +168,7 @@ def find_references(
168168
if import_info:
169169
# Found an import - mark as visited and search for calls
170170
context.visited_files.add(file_path)
171-
import_name, original_import = import_info
171+
import_name, _original_import = import_info
172172
file_refs = self._find_references_in_file(
173173
file_path, file_code, function_name, import_name, file_analyzer, include_self=True
174174
)
@@ -213,7 +213,7 @@ def find_references(
213213
trigger_check = True
214214
if import_info:
215215
context.visited_files.add(file_path)
216-
import_name, original_import = import_info
216+
import_name, _original_import = import_info
217217
file_refs = self._find_references_in_file(
218218
file_path, file_code, reexport_name, import_name, file_analyzer, include_self=True
219219
)
@@ -404,7 +404,7 @@ def _find_identifier_references(
404404
name_node = node.child_by_field_name("name")
405405
if name_node:
406406
new_current_function = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
407-
elif node.type in ("variable_declarator",):
407+
elif node.type == "variable_declarator":
408408
# Arrow function or function expression assigned to variable
409409
name_node = node.child_by_field_name("name")
410410
value_node = node.child_by_field_name("value")
@@ -719,7 +719,7 @@ def _find_reexports_direct(
719719
continue
720720

721721
# Create a fake ImportInfo to resolve the re-export source
722-
from codeflash.languages.treesitter_utils import ImportInfo
722+
from codeflash.languages.javascript.treesitter import ImportInfo
723723

724724
fake_import = ImportInfo(
725725
module_path=exp.reexport_source,

codeflash/languages/javascript/import_resolver.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
if TYPE_CHECKING:
1515
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
1616
from codeflash.languages.base import HelperFunction
17-
from codeflash.languages.treesitter_utils import ImportInfo, TreeSitterAnalyzer
17+
from codeflash.languages.javascript.treesitter import ImportInfo, TreeSitterAnalyzer
1818

1919
logger = logging.getLogger(__name__)
2020

@@ -486,7 +486,7 @@ def _extract_helper_from_file(
486486
487487
"""
488488
from codeflash.languages.base import HelperFunction
489-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
489+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
490490

491491
try:
492492
source = file_path.read_text(encoding="utf-8")
@@ -558,8 +558,8 @@ def _find_helpers_recursive(
558558
559559
"""
560560
from codeflash.discovery.functions_to_optimize import FunctionToOptimize
561+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
561562
from codeflash.languages.registry import get_language_support
562-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
563563

564564
if context.current_depth >= context.max_depth:
565565
return {}

codeflash/languages/javascript/instrument.py

Lines changed: 110 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -792,7 +792,7 @@ def validate_and_fix_import_style(test_code: str, source_file_path: Path, functi
792792
Fixed test code with correct import style.
793793
794794
"""
795-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
795+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
796796

797797
# Read source file to determine export style
798798
try:
@@ -901,6 +901,115 @@ def is_relevant_import(module_path: str) -> bool:
901901
return test_code
902902

903903

904+
def fix_import_path_for_test_location(
905+
test_code: str, source_file_path: Path, test_file_path: Path, module_root: Path
906+
) -> str:
907+
"""Fix import paths in generated test code to be relative to test file location.
908+
909+
The AI may generate tests with import paths that are relative to the module root
910+
(e.g., 'apps/web/app/file') instead of relative to where the test file is located
911+
(e.g., '../../app/file'). This function fixes such imports.
912+
913+
Args:
914+
test_code: The generated test code.
915+
source_file_path: Absolute path to the source file being tested.
916+
test_file_path: Absolute path to where the test file will be written.
917+
module_root: Root directory of the module/project.
918+
919+
Returns:
920+
Test code with corrected import paths.
921+
922+
"""
923+
import os
924+
925+
# Calculate the correct relative import path from test file to source file
926+
test_dir = test_file_path.parent
927+
try:
928+
correct_rel_path = os.path.relpath(source_file_path, test_dir)
929+
correct_rel_path = correct_rel_path.replace("\\", "/")
930+
# Remove file extension for JS/TS imports
931+
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
932+
if correct_rel_path.endswith(ext):
933+
correct_rel_path = correct_rel_path[: -len(ext)]
934+
break
935+
# Ensure it starts with ./ or ../
936+
if not correct_rel_path.startswith("."):
937+
correct_rel_path = "./" + correct_rel_path
938+
except ValueError:
939+
# Can't compute relative path (different drives on Windows)
940+
return test_code
941+
942+
# Try to compute what incorrect path the AI might have generated
943+
# The AI often uses module_root-relative paths like 'apps/web/app/...'
944+
try:
945+
source_rel_to_module = os.path.relpath(source_file_path, module_root)
946+
source_rel_to_module = source_rel_to_module.replace("\\", "/")
947+
# Remove extension
948+
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
949+
if source_rel_to_module.endswith(ext):
950+
source_rel_to_module = source_rel_to_module[: -len(ext)]
951+
break
952+
except ValueError:
953+
return test_code
954+
955+
# Also check for project root-relative paths (including module_root in path)
956+
try:
957+
project_root = module_root.parent if module_root.name in ["src", "lib", "app", "web", "apps"] else module_root
958+
source_rel_to_project = os.path.relpath(source_file_path, project_root)
959+
source_rel_to_project = source_rel_to_project.replace("\\", "/")
960+
for ext in [".tsx", ".ts", ".jsx", ".js", ".mjs", ".cjs"]:
961+
if source_rel_to_project.endswith(ext):
962+
source_rel_to_project = source_rel_to_project[: -len(ext)]
963+
break
964+
except ValueError:
965+
source_rel_to_project = None
966+
967+
# Source file name (for matching module paths that end with the file name)
968+
source_name = source_file_path.stem
969+
970+
# Patterns to find import statements
971+
# ESM: import { func } from 'path' or import func from 'path'
972+
esm_import_pattern = re.compile(r"(import\s+(?:{[^}]+}|\w+)\s+from\s+['\"])([^'\"]+)(['\"])")
973+
# CommonJS: const { func } = require('path') or const func = require('path')
974+
cjs_require_pattern = re.compile(
975+
r"((?:const|let|var)\s+(?:{[^}]+}|\w+)\s*=\s*require\s*\(\s*['\"])([^'\"]+)(['\"])"
976+
)
977+
978+
def should_fix_path(import_path: str) -> bool:
979+
"""Check if this import path looks like it should point to our source file."""
980+
# Skip relative imports that already look correct
981+
if import_path.startswith(("./", "../")):
982+
return False
983+
# Skip package imports (no path separators or start with @)
984+
if "/" not in import_path and "\\" not in import_path:
985+
return False
986+
if import_path.startswith("@") and "/" in import_path:
987+
# Could be an alias like @/utils - skip these
988+
return False
989+
# Check if it looks like it points to our source file
990+
if import_path == source_rel_to_module:
991+
return True
992+
if source_rel_to_project and import_path == source_rel_to_project:
993+
return True
994+
if import_path.endswith((source_name, "/" + source_name)):
995+
return True
996+
return False
997+
998+
def fix_import(match: re.Match[str]) -> str:
999+
"""Replace incorrect import path with correct relative path."""
1000+
prefix = match.group(1)
1001+
import_path = match.group(2)
1002+
suffix = match.group(3)
1003+
1004+
if should_fix_path(import_path):
1005+
logger.debug(f"Fixing import path: {import_path} -> {correct_rel_path}")
1006+
return f"{prefix}{correct_rel_path}{suffix}"
1007+
return match.group(0)
1008+
1009+
test_code = esm_import_pattern.sub(fix_import, test_code)
1010+
return cjs_require_pattern.sub(fix_import, test_code)
1011+
1012+
9041013
def get_instrumented_test_path(original_path: Path, mode: str) -> Path:
9051014
"""Generate path for instrumented test file.
9061015

codeflash/languages/javascript/line_profiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import logging
1212
from typing import TYPE_CHECKING
1313

14-
from codeflash.languages.treesitter_utils import get_analyzer_for_file
14+
from codeflash.languages.javascript.treesitter import get_analyzer_for_file
1515

1616
if TYPE_CHECKING:
1717
from pathlib import Path

0 commit comments

Comments
 (0)