Skip to content

Commit cb73420

Browse files
committed
Merge remote-tracking branch 'origin/multi-language' into multi-language
# Conflicts: # codeflash/verification/verification_utils.py
2 parents 57b37d8 + 12cded7 commit cb73420

9 files changed

Lines changed: 289 additions & 102 deletions

File tree

codeflash/context/code_context_extractor.py

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -207,6 +207,154 @@ def get_code_optimization_context(
207207
preexisting_objects=preexisting_objects,
208208
)
209209

210+
def get_code_optimization_context_for_language(
211+
function_to_optimize: FunctionToOptimize,
212+
project_root_path: Path,
213+
optim_token_limit: int = OPTIMIZATION_CONTEXT_TOKEN_LIMIT,
214+
testgen_token_limit: int = TESTGEN_CONTEXT_TOKEN_LIMIT,
215+
) -> CodeOptimizationContext:
216+
"""Extract code optimization context for non-Python languages.
217+
218+
Uses the language support abstraction to extract code context and converts
219+
it to the CodeOptimizationContext format expected by the pipeline.
220+
221+
This function supports multi-file context extraction, grouping helpers by file
222+
and creating proper CodeStringsMarkdown with file paths for multi-file replacement.
223+
224+
Args:
225+
function_to_optimize: The function to extract context for.
226+
project_root_path: Root of the project.
227+
optim_token_limit: Token limit for optimization context.
228+
testgen_token_limit: Token limit for testgen context.
229+
230+
Returns:
231+
CodeOptimizationContext with target code and dependencies.
232+
233+
"""
234+
from codeflash.languages import get_language_support
235+
from codeflash.languages.base import FunctionInfo, ParentInfo
236+
237+
# Get language support for this function
238+
language = Language(function_to_optimize.language)
239+
lang_support = get_language_support(language)
240+
241+
# Convert FunctionToOptimize to FunctionInfo for language support
242+
parents = tuple(ParentInfo(name=p.name, type=p.type) for p in function_to_optimize.parents)
243+
func_info = FunctionInfo(
244+
name=function_to_optimize.function_name,
245+
file_path=function_to_optimize.file_path,
246+
start_line=function_to_optimize.starting_line or 1,
247+
end_line=function_to_optimize.ending_line or 1,
248+
parents=parents,
249+
is_async=function_to_optimize.is_async,
250+
is_method=len(function_to_optimize.parents) > 0,
251+
language=language,
252+
)
253+
254+
# Extract code context using language support
255+
code_context = lang_support.extract_code_context(func_info, project_root_path, project_root_path)
256+
257+
# Build imports string if available
258+
imports_code = "\n".join(code_context.imports) if code_context.imports else ""
259+
260+
# Get relative path for target file
261+
try:
262+
target_relative_path = function_to_optimize.file_path.resolve().relative_to(project_root_path.resolve())
263+
except ValueError:
264+
target_relative_path = function_to_optimize.file_path
265+
266+
# Group helpers by file path
267+
helpers_by_file: dict[Path, list] = defaultdict(list)
268+
helper_function_sources = []
269+
270+
for helper in code_context.helper_functions:
271+
helpers_by_file[helper.file_path].append(helper)
272+
273+
# Convert to FunctionSource for pipeline compatibility
274+
helper_function_sources.append(
275+
FunctionSource(
276+
file_path=helper.file_path,
277+
qualified_name=helper.qualified_name,
278+
fully_qualified_name=helper.qualified_name,
279+
only_function_name=helper.name,
280+
source_code=helper.source_code,
281+
jedi_definition=None,
282+
)
283+
)
284+
285+
# Build read-writable code (target file + same-file helpers + global variables)
286+
read_writable_code_strings = []
287+
288+
# Combine target code with same-file helpers
289+
target_file_code = code_context.target_code
290+
same_file_helpers = helpers_by_file.get(function_to_optimize.file_path, [])
291+
if same_file_helpers:
292+
helper_code = "\n\n".join(h.source_code for h in same_file_helpers)
293+
target_file_code = target_file_code + "\n\n" + helper_code
294+
295+
# Add global variables (module-level declarations) referenced by the function and helpers
296+
# These should be included in read-writable context so AI can modify them if needed
297+
if code_context.read_only_context:
298+
target_file_code = code_context.read_only_context + "\n\n" + target_file_code
299+
300+
# Add imports to target file code
301+
if imports_code:
302+
target_file_code = imports_code + "\n\n" + target_file_code
303+
304+
read_writable_code_strings.append(
305+
CodeString(code=target_file_code, file_path=target_relative_path, language=function_to_optimize.language)
306+
)
307+
308+
# Add helper files (cross-file helpers)
309+
for file_path, file_helpers in helpers_by_file.items():
310+
if file_path == function_to_optimize.file_path:
311+
continue # Already included in target file
312+
313+
try:
314+
helper_relative_path = file_path.resolve().relative_to(project_root_path.resolve())
315+
except ValueError:
316+
helper_relative_path = file_path
317+
318+
# Combine all helpers from this file
319+
combined_helper_code = "\n\n".join(h.source_code for h in file_helpers)
320+
321+
read_writable_code_strings.append(
322+
CodeString(
323+
code=combined_helper_code, file_path=helper_relative_path, language=function_to_optimize.language
324+
)
325+
)
326+
327+
read_writable_code = CodeStringsMarkdown(
328+
code_strings=read_writable_code_strings, language=function_to_optimize.language
329+
)
330+
331+
# Build testgen context (same as read_writable for non-Python)
332+
testgen_context = CodeStringsMarkdown(
333+
code_strings=read_writable_code_strings.copy(), language=function_to_optimize.language
334+
)
335+
336+
# Check token limits
337+
read_writable_tokens = encoded_tokens_len(read_writable_code.markdown)
338+
if read_writable_tokens > optim_token_limit:
339+
raise ValueError("Read-writable code has exceeded token limit, cannot proceed")
340+
341+
testgen_tokens = encoded_tokens_len(testgen_context.markdown)
342+
if testgen_tokens > testgen_token_limit:
343+
raise ValueError("Testgen code context has exceeded token limit, cannot proceed")
344+
345+
# Generate code hash from all read-writable code
346+
code_hash = hashlib.sha256(read_writable_code.flat.encode("utf-8")).hexdigest()
347+
348+
return CodeOptimizationContext(
349+
testgen_context=testgen_context,
350+
read_writable_code=read_writable_code,
351+
# Global variables are now included in read-writable code, so don't duplicate in read-only
352+
read_only_context_code="",
353+
hashing_code_context=read_writable_code.flat,
354+
hashing_code_context_hash=code_hash,
355+
helper_functions=helper_function_sources,
356+
preexisting_objects=set(), # Not implemented for non-Python yet
357+
)
210358

211359
def extract_code_markdown_context_from_files(
212360
helpers_of_fto: dict[Path, set[FunctionSource]],

codeflash/discovery/functions_to_optimize.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ class FunctionToOptimize:
155155
starting_line: Optional[int] = None
156156
ending_line: Optional[int] = None
157157
starting_col: Optional[int] = None # Column offset for precise location
158-
ending_col: Optional[int] = None # Column offset for precise location
158+
ending_col: Optional[int] = None # Column offset for precise location
159159
is_async: bool = False
160160
language: str = "python" # Language identifier for multi-language support
161161

@@ -186,11 +186,14 @@ def qualified_name_with_modules_from_root(self, project_root_path: Path) -> str:
186186
# =============================================================================
187187

188188

189-
def get_files_for_language(module_root_path: Path, language: Language | None = None) -> list[Path]:
189+
def get_files_for_language(
190+
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
191+
) -> list[Path]:
190192
"""Get all source files for supported languages.
191193
192194
Args:
193195
module_root_path: Root path to search for source files.
196+
ignore_paths: List of paths to ignore (can be files or directories).
194197
language: Optional specific language to filter for. If None, includes all supported languages.
195198
196199
Returns:
@@ -206,7 +209,10 @@ def get_files_for_language(module_root_path: Path, language: Language | None = N
206209
files = []
207210
for ext in extensions:
208211
pattern = f"*{ext}"
209-
files.extend(module_root_path.rglob(pattern))
212+
for file_path in module_root_path.rglob(pattern):
213+
if any(file_path.is_relative_to(ignore_path) for ignore_path in ignore_paths):
214+
continue
215+
files.append(file_path)
210216
return files
211217

212218

@@ -289,7 +295,7 @@ def get_functions_to_optimize(
289295
if optimize_all:
290296
logger.info("!lsp|Finding all functions in the module '%s'…", optimize_all)
291297
console.rule()
292-
functions = get_all_files_and_functions(Path(optimize_all))
298+
functions = get_all_files_and_functions(Path(optimize_all), ignore_paths)
293299
elif replay_test:
294300
functions, trace_file_path = get_all_replay_test_functions(
295301
replay_test=replay_test, test_cfg=test_cfg, project_root_path=project_root
@@ -452,20 +458,21 @@ def get_functions_within_lines(modified_lines: dict[str, list[int]]) -> dict[str
452458

453459

454460
def get_all_files_and_functions(
455-
module_root_path: Path, language: Language | None = None
461+
module_root_path: Path, ignore_paths: list[Path], language: Language | None = None
456462
) -> dict[str, list[FunctionToOptimize]]:
457463
"""Get all optimizable functions from files in the module root.
458464
459465
Args:
460466
module_root_path: Root path to search for source files.
467+
ignore_paths: List of paths to ignore.
461468
language: Optional specific language to filter for. If None, includes all supported languages.
462469
463470
Returns:
464471
Dictionary mapping file paths to lists of FunctionToOptimize.
465472
466473
"""
467474
functions: dict[str, list[FunctionToOptimize]] = {}
468-
for file_path in get_files_for_language(module_root_path, language):
475+
for file_path in get_files_for_language(module_root_path, ignore_paths, language):
469476
# Find all the functions in the file
470477
functions.update(find_all_functions_in_file(file_path).items())
471478
# Randomize the order of the files to optimize to avoid optimizing the same file in the same order every time.

codeflash/languages/javascript/edit_tests.py

Lines changed: 6 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
from __future__ import annotations
88

99
import re
10-
from pathlib import Path
1110

1211
from codeflash.cli_cmds.console import logger
1312
from codeflash.code_utils.time_utils import format_perf, format_time
@@ -23,6 +22,7 @@ def format_runtime_comment(original_time: int, optimized_time: int) -> str:
2322
2423
Returns:
2524
Formatted comment string with // prefix.
25+
2626
"""
2727
perf_gain = format_perf(
2828
abs(performance_gain(original_runtime_ns=original_time, optimized_runtime_ns=optimized_time) * 100)
@@ -31,11 +31,7 @@ def format_runtime_comment(original_time: int, optimized_time: int) -> str:
3131
return f"// {format_time(original_time)} -> {format_time(optimized_time)} ({perf_gain}% {status})"
3232

3333

34-
def add_runtime_comments(
35-
source: str,
36-
original_runtimes: dict[str, int],
37-
optimized_runtimes: dict[str, int],
38-
) -> str:
34+
def add_runtime_comments(source: str, original_runtimes: dict[str, int], optimized_runtimes: dict[str, int]) -> str:
3935
"""Add runtime comments to JavaScript test source code.
4036
4137
For JavaScript, we match timing data by test function name and add comments
@@ -48,6 +44,7 @@ def add_runtime_comments(
4844
4945
Returns:
5046
Source code with runtime comments added.
47+
5148
"""
5249
logger.debug(f"[js-annotations] original_runtimes has {len(original_runtimes)} entries")
5350
logger.debug(f"[js-annotations] optimized_runtimes has {len(optimized_runtimes)} entries")
@@ -144,6 +141,7 @@ def remove_test_functions(source: str, functions_to_remove: list[str]) -> str:
144141
145142
Returns:
146143
Source code with specified functions removed.
144+
147145
"""
148146
if not functions_to_remove:
149147
return source
@@ -152,8 +150,7 @@ def remove_test_functions(source: str, functions_to_remove: list[str]) -> str:
152150
# Pattern to match test('name', ...) or it('name', ...) blocks
153151
# This handles nested callbacks and multi-line test bodies
154152
test_pattern = re.compile(
155-
rf"(?:test|it)\s*\(\s*['\"]" + re.escape(func_name) + rf"['\"].*?\)\s*;?\s*\n?",
156-
re.DOTALL,
153+
r"(?:test|it)\s*\(\s*['\"]" + re.escape(func_name) + r"['\"].*?\)\s*;?\s*\n?", re.DOTALL
157154
)
158155

159156
# Try to find and remove matching test blocks
@@ -180,6 +177,7 @@ def _find_block_end(source: str, start: int) -> int:
180177
181178
Returns:
182179
Position after the closing brace, or start if not found.
180+
183181
"""
184182
# Find the opening brace
185183
brace_pos = source.find("{", start)
@@ -230,21 +228,3 @@ def _find_block_end(source: str, start: int) -> int:
230228
i += 1
231229

232230
return start
233-
234-
235-
def get_comment_prefix() -> str:
236-
"""Get the comment prefix for JavaScript.
237-
238-
Returns:
239-
The JavaScript single-line comment prefix.
240-
"""
241-
return "//"
242-
243-
244-
def get_test_file_suffix() -> str:
245-
"""Get the test file suffix for JavaScript.
246-
247-
Returns:
248-
The Jest test file suffix.
249-
"""
250-
return ".test.js"

codeflash/languages/javascript/instrument.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ def _instrument_js_test_code(code: str, func_name: str, test_file_path: str, mod
119119
"""
120120
# Add codeflash helper import if not already present
121121
# Support both npm package (codeflash) and legacy local file (codeflash-jest-helper)
122-
has_codeflash_import = "codeflash" in code or "codeflash-jest-helper" in code
122+
has_codeflash_import = "codeflash" in code
123123
if not has_codeflash_import:
124124
# Detect module system: ESM uses "import ... from", CommonJS uses "require()"
125125
is_esm = bool(re.search(r"^\s*import\s+.+\s+from\s+['\"]", code, re.MULTILINE))
@@ -128,11 +128,7 @@ def _instrument_js_test_code(code: str, func_name: str, test_file_path: str, mod
128128
# ESM: Use import statement at the top of the file (after any other imports)
129129
helper_import = "import codeflash from 'codeflash';\n"
130130
# Find the last import statement to add after
131-
import_matches = list(re.finditer(
132-
r"^import\s+.+\s+from\s+['\"][^'\"]+['\"]\s*;?\s*\n",
133-
code,
134-
re.MULTILINE,
135-
))
131+
import_matches = list(re.finditer(r"^import\s+.+\s+from\s+['\"][^'\"]+['\"]\s*;?\s*\n", code, re.MULTILINE))
136132
if import_matches:
137133
# Add after the last import
138134
last_import = import_matches[-1]
@@ -145,11 +141,7 @@ def _instrument_js_test_code(code: str, func_name: str, test_file_path: str, mod
145141
# CommonJS: Use require statement
146142
helper_require = "const codeflash = require('codeflash');\n"
147143
# Find the first require statement to add after
148-
import_match = re.search(
149-
r"^((?:const|let|var)\s+.+?require\([^)]+\).*;?\s*\n)",
150-
code,
151-
re.MULTILINE,
152-
)
144+
import_match = re.search(r"^((?:const|let|var)\s+.+?require\([^)]+\).*;?\s*\n)", code, re.MULTILINE)
153145
if import_match:
154146
insert_pos = import_match.end()
155147
code = code[:insert_pos] + helper_require + code[insert_pos:]

0 commit comments

Comments
 (0)