|
2 | 2 |
|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | | -import re |
6 | 5 | from typing import TYPE_CHECKING |
7 | 6 |
|
8 | 7 | from codeflash.cli_cmds.console import logger |
@@ -65,7 +64,6 @@ def _add_global_declarations_for_language( |
65 | 64 | # Build a map of existing declaration names to their end lines (1-indexed) |
66 | 65 | existing_decl_end_lines = {decl.name: decl.end_line for decl in original_declarations} |
67 | 66 |
|
68 | | - # Insert each new declaration after its dependencies |
69 | 67 | for decl in new_declarations: |
70 | 68 | result = _insert_declaration_after_dependencies( |
71 | 69 | result, decl, existing_decl_end_lines, analyzer, module_abspath |
@@ -220,63 +218,76 @@ def _find_line_after_imports(lines: list[str], analyzer: TreeSitterAnalyzer, sou |
220 | 218 | return 0 |
221 | 219 |
|
222 | 220 |
|
223 | | -def _merge_imports(original_source: str, optimized_code: str, analyzer: TreeSitterAnalyzer) -> str: |
224 | | - """Merge imports from optimized code into original source. |
| 221 | +def _merge_imports(source: str, new_source: str, analyzer: TreeSitterAnalyzer) -> str: |
| 222 | + """Merge imports from new_source into source. |
225 | 223 |
|
226 | | - For each import in the optimized code that shares a module path with an existing |
227 | | - import in the original source, adds any new named imports to the original import line. |
| 224 | + For imports from the same module, merges named imports so that any new |
| 225 | + named imports from new_source are added to the existing import in source. |
| 226 | + Also merges default imports and namespace imports when the source import |
| 227 | + is missing them. |
228 | 228 | """ |
229 | 229 | try: |
230 | | - original_imports = analyzer.find_imports(original_source) |
231 | | - optimized_imports = analyzer.find_imports(optimized_code) |
| 230 | + source_imports = analyzer.find_imports(source) |
| 231 | + new_imports = analyzer.find_imports(new_source) |
232 | 232 | except Exception: |
233 | | - return original_source |
| 233 | + return source |
234 | 234 |
|
235 | | - if not optimized_imports: |
236 | | - return original_source |
| 235 | + if not new_imports: |
| 236 | + return source |
237 | 237 |
|
238 | | - # Build a map of module_path -> ImportInfo for original imports |
239 | | - original_import_map: dict[str, list] = {} |
240 | | - for imp in original_imports: |
241 | | - original_import_map.setdefault(imp.module_path, []).append(imp) |
| 238 | + source_import_map: dict[str, list] = {} |
| 239 | + for imp in source_imports: |
| 240 | + source_import_map.setdefault(imp.module_path, []).append(imp) |
242 | 241 |
|
243 | | - result = original_source |
244 | | - for opt_imp in optimized_imports: |
245 | | - if opt_imp.module_path not in original_import_map: |
| 242 | + lines = source.splitlines(keepends=True) |
| 243 | + |
| 244 | + replacements: list[tuple[int, int, str]] = [] |
| 245 | + for new_imp in new_imports: |
| 246 | + matching = source_import_map.get(new_imp.module_path) |
| 247 | + if not matching: |
246 | 248 | continue |
247 | 249 |
|
248 | | - # Get new named imports that don't exist in the original |
249 | | - for orig_imp in original_import_map[opt_imp.module_path]: |
250 | | - existing_names = {name for name, _ in orig_imp.named_imports} |
251 | | - new_names = [(name, alias) for name, alias in opt_imp.named_imports if name not in existing_names] |
| 250 | + for src_imp in matching: |
| 251 | + existing_names = {name for name, _ in src_imp.named_imports} |
| 252 | + new_names = [(name, alias) for name, alias in new_imp.named_imports if name not in existing_names] |
252 | 253 |
|
253 | | - if not new_names: |
| 254 | + new_default = new_imp.default_import if not src_imp.default_import and new_imp.default_import else None |
| 255 | + new_namespace = ( |
| 256 | + new_imp.namespace_import if not src_imp.namespace_import and new_imp.namespace_import else None |
| 257 | + ) |
| 258 | + |
| 259 | + if not new_names and not new_default and not new_namespace: |
254 | 260 | continue |
255 | 261 |
|
256 | | - # Find the original import line and add new named imports |
257 | | - lines = result.splitlines(keepends=True) |
258 | | - if orig_imp.start_line <= len(lines): |
259 | | - # Reconstruct the import statement lines |
260 | | - import_text = "".join(lines[orig_imp.start_line - 1 : orig_imp.end_line]) |
261 | | - |
262 | | - # Find the closing brace of named imports and insert new names before it |
263 | | - brace_match = re.search(r"\}", import_text) |
264 | | - if brace_match: |
265 | | - insert_pos = brace_match.start() |
266 | | - new_imports_str = ", ".join( |
267 | | - f"{name} as {alias}" if alias else name for name, alias in new_names |
268 | | - ) |
269 | | - # Check if there's already content before the brace |
270 | | - before_brace = import_text[:insert_pos].rstrip() |
271 | | - if before_brace and not before_brace.endswith(","): |
272 | | - new_imports_str = ", " + new_imports_str |
273 | | - else: |
274 | | - new_imports_str = " " + new_imports_str |
275 | | - |
276 | | - updated_import = import_text[:insert_pos] + new_imports_str + " " + import_text[insert_pos:] |
277 | | - lines[orig_imp.start_line - 1 : orig_imp.end_line] = [updated_import] |
278 | | - result = "".join(lines) |
279 | | - |
280 | | - logger.debug(f"Merged imports for {opt_imp.module_path}: added {[n for n, _ in new_names]}") |
281 | | - |
282 | | - return result |
| 262 | + merged_named = list(src_imp.named_imports) + new_names |
| 263 | + default_part = new_default or src_imp.default_import |
| 264 | + namespace_part = new_namespace or src_imp.namespace_import |
| 265 | + |
| 266 | + parts = [] |
| 267 | + if default_part: |
| 268 | + parts.append(default_part) |
| 269 | + if namespace_part: |
| 270 | + parts.append(f"* as {namespace_part}") |
| 271 | + if merged_named: |
| 272 | + named_str = ", ".join(f"{name} as {alias}" if alias else name for name, alias in merged_named) |
| 273 | + parts.append("{ " + named_str + " }") |
| 274 | + |
| 275 | + orig_line_idx = src_imp.start_line - 1 |
| 276 | + orig_line = lines[orig_line_idx] if orig_line_idx < len(lines) else "" |
| 277 | + quote = "'" if "'" in orig_line else '"' |
| 278 | + semicolon = ";" if orig_line.rstrip().endswith(";") else "" |
| 279 | + type_prefix = "type " if src_imp.is_type_only else "" |
| 280 | + |
| 281 | + merged_line = ( |
| 282 | + f"import {type_prefix}{', '.join(parts)} from {quote}{src_imp.module_path}{quote}{semicolon}\n" |
| 283 | + ) |
| 284 | + replacements.append((src_imp.start_line, src_imp.end_line, merged_line)) |
| 285 | + |
| 286 | + if not replacements: |
| 287 | + return source |
| 288 | + |
| 289 | + replacements.sort(key=lambda r: r[0], reverse=True) |
| 290 | + for start_line, end_line, new_line in replacements: |
| 291 | + lines[start_line - 1 : end_line] = [new_line] |
| 292 | + |
| 293 | + return "".join(lines) |
0 commit comments