Skip to content

Commit 677e46e

Browse files
Merge branch 'main' into fix/js-skip-module-conversion-for-ts-jest
2 parents 60ab938 + 3b3adf8 commit 677e46e

3 files changed

Lines changed: 210 additions & 49 deletions

File tree

codeflash/code_utils/code_replacer.py

Lines changed: 131 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -628,19 +628,22 @@ def _add_global_declarations_for_language(
628628
Finds module-level declarations (const, let, var, class, type, interface, enum)
629629
in the optimized code that don't exist in the original source and adds them.
630630
631+
New declarations are inserted after any existing declarations they depend on.
632+
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
633+
is already declared in the original source, `_has` will be inserted after `FOO`.
634+
631635
Args:
632636
optimized_code: The optimized code that may contain new declarations.
633637
original_source: The original source code.
634638
module_abspath: Path to the module file (for parser selection).
635639
language: The language of the code.
636640
637641
Returns:
638-
Original source with new declarations added after imports.
642+
Original source with new declarations added in dependency order.
639643
640644
"""
641645
from codeflash.languages.base import Language
642646

643-
# Only process JavaScript/TypeScript
644647
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
645648
return original_source
646649

@@ -649,84 +652,164 @@ def _add_global_declarations_for_language(
649652

650653
analyzer = get_analyzer_for_file(module_abspath)
651654

652-
# Find declarations in both original and optimized code
653655
original_declarations = analyzer.find_module_level_declarations(original_source)
654656
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
655657

656658
if not optimized_declarations:
657659
return original_source
658660

659-
# Get names of existing declarations
660-
existing_names = {decl.name for decl in original_declarations}
661-
662-
# Also exclude names that are already imported (to avoid duplicating imported types)
663-
original_imports = analyzer.find_imports(original_source)
664-
for imp in original_imports:
665-
# Add default import name
666-
if imp.default_import:
667-
existing_names.add(imp.default_import)
668-
# Add named imports (use alias if present, otherwise use original name)
669-
for name, alias in imp.named_imports:
670-
existing_names.add(alias if alias else name)
671-
# Add namespace import
672-
if imp.namespace_import:
673-
existing_names.add(imp.namespace_import)
674-
675-
# Find new declarations (names that don't exist in original)
676-
new_declarations = []
677-
seen_sources = set() # Track to avoid duplicates from destructuring
678-
for decl in optimized_declarations:
679-
if decl.name not in existing_names and decl.source_code not in seen_sources:
680-
new_declarations.append(decl)
681-
seen_sources.add(decl.source_code)
661+
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
662+
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
682663

683664
if not new_declarations:
684665
return original_source
685666

686-
# Sort by line number to maintain order
687-
new_declarations.sort(key=lambda d: d.start_line)
688-
689-
# Find insertion point (after imports)
690-
lines = original_source.splitlines(keepends=True)
691-
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)
692-
693-
# Build new declarations string
694-
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
695-
new_decl_code = new_decl_code + "\n\n"
667+
# Build a map of existing declaration names to their end lines (1-indexed)
668+
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
696669

697-
# Insert declarations
698-
before = lines[:insertion_line]
699-
after = lines[insertion_line:]
700-
result_lines = [*before, new_decl_code, *after]
670+
# Insert each new declaration after its dependencies
671+
result = original_source
672+
for decl in new_declarations:
673+
result = _insert_declaration_after_dependencies(
674+
result, decl, existing_decl_end_lines, analyzer, module_abspath
675+
)
676+
# Update the map with the newly inserted declaration for subsequent insertions
677+
# Re-parse to get accurate line numbers after insertion
678+
updated_declarations = analyzer.find_module_level_declarations(result)
679+
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
701680

702-
return "".join(result_lines)
681+
return result
703682

704683
except Exception as e:
705684
logger.debug(f"Error adding global declarations: {e}")
706685
return original_source
707686

708687

709-
def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
710-
"""Find the line index where new declarations should be inserted (after imports).
688+
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
689+
"""Get all names that already exist in the original source (declarations + imports)."""
690+
existing_names = {decl.name for decl in original_declarations}
691+
692+
original_imports = analyzer.find_imports(original_source)
693+
for imp in original_imports:
694+
if imp.default_import:
695+
existing_names.add(imp.default_import)
696+
for name, alias in imp.named_imports:
697+
existing_names.add(alias if alias else name)
698+
if imp.namespace_import:
699+
existing_names.add(imp.namespace_import)
700+
701+
return existing_names
702+
703+
704+
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
705+
"""Filter declarations to only those that don't exist in the original source."""
706+
new_declarations = []
707+
seen_sources: set[str] = set()
708+
709+
# Sort by line number to maintain order from optimized code
710+
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
711+
712+
for decl in sorted_declarations:
713+
if decl.name not in existing_names and decl.source_code not in seen_sources:
714+
new_declarations.append(decl)
715+
seen_sources.add(decl.source_code)
716+
717+
return new_declarations
718+
719+
720+
def _insert_declaration_after_dependencies(
721+
source: str,
722+
declaration,
723+
existing_decl_end_lines: dict[str, int],
724+
analyzer: TreeSitterAnalyzer,
725+
module_abspath: Path,
726+
) -> str:
727+
"""Insert a declaration after the last existing declaration it depends on.
728+
729+
Args:
730+
source: Current source code.
731+
declaration: The declaration to insert.
732+
existing_decl_end_lines: Map of existing declaration names to their end lines.
733+
analyzer: TreeSitter analyzer.
734+
module_abspath: Path to the module file.
735+
736+
Returns:
737+
Source code with the declaration inserted at the correct position.
738+
739+
"""
740+
# Find identifiers referenced in this declaration
741+
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
742+
743+
# Find the latest end line among all referenced declarations
744+
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
745+
746+
lines = source.splitlines(keepends=True)
747+
748+
# Ensure proper spacing
749+
decl_code = declaration.source_code
750+
if not decl_code.endswith("\n"):
751+
decl_code += "\n"
752+
753+
# Add blank line before if inserting after content
754+
if insertion_line > 0 and lines[insertion_line - 1].strip():
755+
decl_code = "\n" + decl_code
756+
757+
before = lines[:insertion_line]
758+
after = lines[insertion_line:]
759+
760+
return "".join([*before, decl_code, *after])
761+
762+
763+
def _find_insertion_line_for_declaration(
764+
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
765+
) -> int:
766+
"""Find the line where a declaration should be inserted based on its dependencies.
767+
768+
Args:
769+
source: Source code.
770+
referenced_names: Names referenced by the declaration.
771+
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
772+
analyzer: TreeSitter analyzer.
773+
774+
Returns:
775+
Line index (0-based) where the declaration should be inserted.
776+
777+
"""
778+
# Find the maximum end line among referenced declarations
779+
max_dependency_line = 0
780+
for name in referenced_names:
781+
if name in existing_decl_end_lines:
782+
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
783+
784+
if max_dependency_line > 0:
785+
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
786+
return max_dependency_line
787+
788+
# No dependencies found - insert after imports
789+
lines = source.splitlines(keepends=True)
790+
return _find_line_after_imports(lines, analyzer, source)
791+
792+
793+
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
794+
"""Find the line index after all imports.
711795
712796
Args:
713797
lines: Source lines.
714-
analyzer: TreeSitter analyzer for the file.
798+
analyzer: TreeSitter analyzer.
715799
source: Full source code.
716800
717801
Returns:
718-
Line index (0-based) for insertion.
802+
Line index (0-based) for insertion after imports.
719803
720804
"""
721805
try:
722806
imports = analyzer.find_imports(source)
723807
if imports:
724-
# Find the last import's end line
725808
return max(imp.end_line for imp in imports)
726809
except Exception as exc:
727-
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
810+
logger.debug(f"Exception in _find_line_after_imports: {exc}")
728811

729-
# Default: insert at beginning (after any shebang/directive comments)
812+
# Default: insert at beginning (after shebang/directive comments)
730813
for i, line in enumerate(lines):
731814
stripped = line.strip()
732815
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):

codeflash/result/create_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def check_create_pr(
281281
function_trace_id: str,
282282
coverage_message: str,
283283
replay_tests: str,
284-
concolic_tests: str,
285284
root_dir: Path,
285+
concolic_tests: str = "",
286286
git_remote: Optional[str] = None,
287287
optimization_review: str = "",
288288
original_line_profiler: str | None = None,

tests/test_languages/test_js_code_replacer.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2016,6 +2016,84 @@ class DataProcessor<T> {
20162016

20172017

20182018

2019+
class TestNewVariableFromOptimizedCode:
2020+
"""Tests for handling new variables introduced in optimized code."""
2021+
2022+
def test_new_bound_method_variable_added_after_referenced_constant(self, ts_support, temp_project):
2023+
"""Test that a new variable binding a method is added after the constant it references.
2024+
2025+
When optimized code introduces a new module-level variable (like `_has`) that
2026+
references an existing constant (like `CODEFLASH_EMPLOYEE_GITHUB_IDS`), the
2027+
replacement should:
2028+
1. Add the new variable after the constant it references
2029+
2. Replace the function with the optimized version
2030+
"""
2031+
from codeflash.models.models import CodeStringsMarkdown, CodeString
2032+
2033+
original_source = '''\
2034+
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
2035+
"1234",
2036+
]);
2037+
2038+
export function isCodeflashEmployee(userId: string): boolean {
2039+
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
2040+
}
2041+
'''
2042+
file_path = temp_project / "auth.ts"
2043+
file_path.write_text(original_source, encoding="utf-8")
2044+
2045+
# Optimized code introduces a bound method variable for performance
2046+
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
2047+
CODEFLASH_EMPLOYEE_GITHUB_IDS
2048+
);
2049+
2050+
export function isCodeflashEmployee(userId: string): boolean {
2051+
return _has(userId);
2052+
}
2053+
'''
2054+
2055+
code_markdown = CodeStringsMarkdown(
2056+
code_strings=[
2057+
CodeString(
2058+
code=optimized_code,
2059+
file_path=Path("auth.ts"),
2060+
language="typescript"
2061+
)
2062+
],
2063+
language="typescript"
2064+
)
2065+
2066+
replaced = replace_function_definitions_for_language(
2067+
["isCodeflashEmployee"],
2068+
code_markdown,
2069+
file_path,
2070+
temp_project,
2071+
)
2072+
2073+
assert replaced
2074+
result = file_path.read_text()
2075+
2076+
# Expected result for strict equality check
2077+
expected_result = '''\
2078+
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
2079+
"1234",
2080+
]);
2081+
2082+
const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
2083+
CODEFLASH_EMPLOYEE_GITHUB_IDS
2084+
);
2085+
2086+
export function isCodeflashEmployee(userId: string): boolean {
2087+
return _has(userId);
2088+
}
2089+
'''
2090+
assert result == expected_result, (
2091+
f"Result does not match expected output.\n"
2092+
f"Expected:\n{expected_result}\n\n"
2093+
f"Got:\n{result}"
2094+
)
2095+
2096+
20192097
class TestImportedTypeNotDuplicated:
20202098
"""Tests to ensure imported types are not duplicated during code replacement.
20212099

0 commit comments

Comments
 (0)