Skip to content

Commit d693eeb

Browse files
Merge branch 'main' into fix/dont-extract-object-properties
2 parents b2b4cde + 3b3adf8 commit d693eeb

3 files changed

Lines changed: 210 additions & 49 deletions

File tree

codeflash/code_utils/code_replacer.py

Lines changed: 131 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -628,19 +628,22 @@ def _add_global_declarations_for_language(
628628
Finds module-level declarations (const, let, var, class, type, interface, enum)
629629
in the optimized code that don't exist in the original source and adds them.
630630
631+
New declarations are inserted after any existing declarations they depend on.
632+
For example, if optimized code has `const _has = FOO.bar.bind(FOO)`, and `FOO`
633+
is already declared in the original source, `_has` will be inserted after `FOO`.
634+
631635
Args:
632636
optimized_code: The optimized code that may contain new declarations.
633637
original_source: The original source code.
634638
module_abspath: Path to the module file (for parser selection).
635639
language: The language of the code.
636640
637641
Returns:
638-
Original source with new declarations added after imports.
642+
Original source with new declarations added in dependency order.
639643
640644
"""
641645
from codeflash.languages.base import Language
642646

643-
# Only process JavaScript/TypeScript
644647
if language not in (Language.JAVASCRIPT, Language.TYPESCRIPT):
645648
return original_source
646649

@@ -649,84 +652,164 @@ def _add_global_declarations_for_language(
649652

650653
analyzer = get_analyzer_for_file(module_abspath)
651654

652-
# Find declarations in both original and optimized code
653655
original_declarations = analyzer.find_module_level_declarations(original_source)
654656
optimized_declarations = analyzer.find_module_level_declarations(optimized_code)
655657

656658
if not optimized_declarations:
657659
return original_source
658660

659-
# Get names of existing declarations
660-
existing_names = {decl.name for decl in original_declarations}
661-
662-
# Also exclude names that are already imported (to avoid duplicating imported types)
663-
original_imports = analyzer.find_imports(original_source)
664-
for imp in original_imports:
665-
# Add default import name
666-
if imp.default_import:
667-
existing_names.add(imp.default_import)
668-
# Add named imports (use alias if present, otherwise use original name)
669-
for name, alias in imp.named_imports:
670-
existing_names.add(alias if alias else name)
671-
# Add namespace import
672-
if imp.namespace_import:
673-
existing_names.add(imp.namespace_import)
674-
675-
# Find new declarations (names that don't exist in original)
676-
new_declarations = []
677-
seen_sources = set() # Track to avoid duplicates from destructuring
678-
for decl in optimized_declarations:
679-
if decl.name not in existing_names and decl.source_code not in seen_sources:
680-
new_declarations.append(decl)
681-
seen_sources.add(decl.source_code)
661+
existing_names = _get_existing_names(original_declarations, analyzer, original_source)
662+
new_declarations = _filter_new_declarations(optimized_declarations, existing_names)
682663

683664
if not new_declarations:
684665
return original_source
685666

686-
# Sort by line number to maintain order
687-
new_declarations.sort(key=lambda d: d.start_line)
688-
689-
# Find insertion point (after imports)
690-
lines = original_source.splitlines(keepends=True)
691-
insertion_line = _find_insertion_line_after_imports_js(lines, analyzer, original_source)
692-
693-
# Build new declarations string
694-
new_decl_code = "\n".join(decl.source_code for decl in new_declarations)
695-
new_decl_code = new_decl_code + "\n\n"
667+
# Build a map of existing declaration names to their end lines (1-indexed)
668+
existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations}
696669

697-
# Insert declarations
698-
before = lines[:insertion_line]
699-
after = lines[insertion_line:]
700-
result_lines = [*before, new_decl_code, *after]
670+
# Insert each new declaration after its dependencies
671+
result = original_source
672+
for decl in new_declarations:
673+
result = _insert_declaration_after_dependencies(
674+
result, decl, existing_decl_end_lines, analyzer, module_abspath
675+
)
676+
# Update the map with the newly inserted declaration for subsequent insertions
677+
# Re-parse to get accurate line numbers after insertion
678+
updated_declarations = analyzer.find_module_level_declarations(result)
679+
existing_decl_end_lines = {d.name: d.end_line for d in updated_declarations}
701680

702-
return "".join(result_lines)
681+
return result
703682

704683
except Exception as e:
705684
logger.debug(f"Error adding global declarations: {e}")
706685
return original_source
707686

708687

709-
def _find_insertion_line_after_imports_js(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
710-
"""Find the line index where new declarations should be inserted (after imports).
688+
def _get_existing_names(original_declarations: list, analyzer: TreeSitterAnalyzer, original_source: str) -> set[str]:
689+
"""Get all names that already exist in the original source (declarations + imports)."""
690+
existing_names = {decl.name for decl in original_declarations}
691+
692+
original_imports = analyzer.find_imports(original_source)
693+
for imp in original_imports:
694+
if imp.default_import:
695+
existing_names.add(imp.default_import)
696+
for name, alias in imp.named_imports:
697+
existing_names.add(alias if alias else name)
698+
if imp.namespace_import:
699+
existing_names.add(imp.namespace_import)
700+
701+
return existing_names
702+
703+
704+
def _filter_new_declarations(optimized_declarations: list, existing_names: set[str]) -> list:
705+
"""Filter declarations to only those that don't exist in the original source."""
706+
new_declarations = []
707+
seen_sources: set[str] = set()
708+
709+
# Sort by line number to maintain order from optimized code
710+
sorted_declarations = sorted(optimized_declarations, key=lambda d: d.start_line)
711+
712+
for decl in sorted_declarations:
713+
if decl.name not in existing_names and decl.source_code not in seen_sources:
714+
new_declarations.append(decl)
715+
seen_sources.add(decl.source_code)
716+
717+
return new_declarations
718+
719+
720+
def _insert_declaration_after_dependencies(
721+
source: str,
722+
declaration,
723+
existing_decl_end_lines: dict[str, int],
724+
analyzer: TreeSitterAnalyzer,
725+
module_abspath: Path,
726+
) -> str:
727+
"""Insert a declaration after the last existing declaration it depends on.
728+
729+
Args:
730+
source: Current source code.
731+
declaration: The declaration to insert.
732+
existing_decl_end_lines: Map of existing declaration names to their end lines.
733+
analyzer: TreeSitter analyzer.
734+
module_abspath: Path to the module file.
735+
736+
Returns:
737+
Source code with the declaration inserted at the correct position.
738+
739+
"""
740+
# Find identifiers referenced in this declaration
741+
referenced_names = analyzer.find_referenced_identifiers(declaration.source_code)
742+
743+
# Find the latest end line among all referenced declarations
744+
insertion_line = _find_insertion_line_for_declaration(source, referenced_names, existing_decl_end_lines, analyzer)
745+
746+
lines = source.splitlines(keepends=True)
747+
748+
# Ensure proper spacing
749+
decl_code = declaration.source_code
750+
if not decl_code.endswith("\n"):
751+
decl_code += "\n"
752+
753+
# Add blank line before if inserting after content
754+
if insertion_line > 0 and lines[insertion_line - 1].strip():
755+
decl_code = "\n" + decl_code
756+
757+
before = lines[:insertion_line]
758+
after = lines[insertion_line:]
759+
760+
return "".join([*before, decl_code, *after])
761+
762+
763+
def _find_insertion_line_for_declaration(
764+
source: str, referenced_names: set[str], existing_decl_end_lines: dict[str, int], analyzer: TreeSitterAnalyzer
765+
) -> int:
766+
"""Find the line where a declaration should be inserted based on its dependencies.
767+
768+
Args:
769+
source: Source code.
770+
referenced_names: Names referenced by the declaration.
771+
existing_decl_end_lines: Map of declaration names to their end lines (1-indexed).
772+
analyzer: TreeSitter analyzer.
773+
774+
Returns:
775+
Line index (0-based) where the declaration should be inserted.
776+
777+
"""
778+
# Find the maximum end line among referenced declarations
779+
max_dependency_line = 0
780+
for name in referenced_names:
781+
if name in existing_decl_end_lines:
782+
max_dependency_line = max(max_dependency_line, existing_decl_end_lines[name])
783+
784+
if max_dependency_line > 0:
785+
# Insert after the last dependency (end_line is 1-indexed, we need 0-indexed)
786+
return max_dependency_line
787+
788+
# No dependencies found - insert after imports
789+
lines = source.splitlines(keepends=True)
790+
return _find_line_after_imports(lines, analyzer, source)
791+
792+
793+
def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, source: str) -> int:
794+
"""Find the line index after all imports.
711795
712796
Args:
713797
lines: Source lines.
714-
analyzer: TreeSitter analyzer for the file.
798+
analyzer: TreeSitter analyzer.
715799
source: Full source code.
716800
717801
Returns:
718-
Line index (0-based) for insertion.
802+
Line index (0-based) for insertion after imports.
719803
720804
"""
721805
try:
722806
imports = analyzer.find_imports(source)
723807
if imports:
724-
# Find the last import's end line
725808
return max(imp.end_line for imp in imports)
726809
except Exception as exc:
727-
logger.debug(f"Exception occurred in _find_insertion_line_after_imports_js: {exc}")
810+
logger.debug(f"Exception in _find_line_after_imports: {exc}")
728811

729-
# Default: insert at beginning (after any shebang/directive comments)
812+
# Default: insert at beginning (after shebang/directive comments)
730813
for i, line in enumerate(lines):
731814
stripped = line.strip()
732815
if stripped and not stripped.startswith("//") and not stripped.startswith("#!"):

codeflash/result/create_pr.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,8 +281,8 @@ def check_create_pr(
281281
function_trace_id: str,
282282
coverage_message: str,
283283
replay_tests: str,
284-
concolic_tests: str,
285284
root_dir: Path,
285+
concolic_tests: str = "",
286286
git_remote: Optional[str] = None,
287287
optimization_review: str = "",
288288
original_line_profiler: str | None = None,

tests/test_languages/test_js_code_replacer.py

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,6 +1895,84 @@ class DataProcessor<T> {
18951895

18961896

18971897

1898+
class TestNewVariableFromOptimizedCode:
1899+
"""Tests for handling new variables introduced in optimized code."""
1900+
1901+
def test_new_bound_method_variable_added_after_referenced_constant(self, ts_support, temp_project):
1902+
"""Test that a new variable binding a method is added after the constant it references.
1903+
1904+
When optimized code introduces a new module-level variable (like `_has`) that
1905+
references an existing constant (like `CODEFLASH_EMPLOYEE_GITHUB_IDS`), the
1906+
replacement should:
1907+
1. Add the new variable after the constant it references
1908+
2. Replace the function with the optimized version
1909+
"""
1910+
from codeflash.models.models import CodeStringsMarkdown, CodeString
1911+
1912+
original_source = '''\
1913+
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
1914+
"1234",
1915+
]);
1916+
1917+
export function isCodeflashEmployee(userId: string): boolean {
1918+
return CODEFLASH_EMPLOYEE_GITHUB_IDS.has(userId);
1919+
}
1920+
'''
1921+
file_path = temp_project / "auth.ts"
1922+
file_path.write_text(original_source, encoding="utf-8")
1923+
1924+
# Optimized code introduces a bound method variable for performance
1925+
optimized_code = '''const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
1926+
CODEFLASH_EMPLOYEE_GITHUB_IDS
1927+
);
1928+
1929+
export function isCodeflashEmployee(userId: string): boolean {
1930+
return _has(userId);
1931+
}
1932+
'''
1933+
1934+
code_markdown = CodeStringsMarkdown(
1935+
code_strings=[
1936+
CodeString(
1937+
code=optimized_code,
1938+
file_path=Path("auth.ts"),
1939+
language="typescript"
1940+
)
1941+
],
1942+
language="typescript"
1943+
)
1944+
1945+
replaced = replace_function_definitions_for_language(
1946+
["isCodeflashEmployee"],
1947+
code_markdown,
1948+
file_path,
1949+
temp_project,
1950+
)
1951+
1952+
assert replaced
1953+
result = file_path.read_text()
1954+
1955+
# Expected result for strict equality check
1956+
expected_result = '''\
1957+
const CODEFLASH_EMPLOYEE_GITHUB_IDS = new Set([
1958+
"1234",
1959+
]);
1960+
1961+
const _has: (id: string) => boolean = CODEFLASH_EMPLOYEE_GITHUB_IDS.has.bind(
1962+
CODEFLASH_EMPLOYEE_GITHUB_IDS
1963+
);
1964+
1965+
export function isCodeflashEmployee(userId: string): boolean {
1966+
return _has(userId);
1967+
}
1968+
'''
1969+
assert result == expected_result, (
1970+
f"Result does not match expected output.\n"
1971+
f"Expected:\n{expected_result}\n\n"
1972+
f"Got:\n{result}"
1973+
)
1974+
1975+
18981976
class TestImportedTypeNotDuplicated:
18991977
"""Tests to ensure imported types are not duplicated during code replacement.
19001978

0 commit comments

Comments
 (0)