Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .claude/rules/architecture.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ Core protocol in `languages/base.py`. Each language (`PythonSupport`, `JavaScrip
|----------|----------------|---------|
| Identity | `language`, `file_extensions`, `default_file_extension` | Language identification |
| Identity | `comment_prefix`, `dir_excludes` | Language conventions |
| AI service | `default_language_version` | Language version for API payloads (`None` for Python, `"ES2022"` for JS) |
| AI service | `language_version` | Detected language version for API payloads (e.g., `"3.11.0"` for Python, `"17"` for Java) |
| AI service | `valid_test_frameworks` | Allowed test frameworks for validation |
| Discovery | `discover_functions`, `discover_tests` | Find optimizable functions and their tests |
| Discovery | `adjust_test_config_for_discovery` | Pre-discovery config adjustment (no-op default) |
Expand Down
63 changes: 26 additions & 37 deletions codeflash/api/aiservice.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ def optimize_code(
experiment_metadata: ExperimentMetadata | None = None,
*,
language: str = "python",
language_version: str
| None = None, # TODO:{claude} add language version to the language support and it should be cached
language_version: str | None = None,
module_system: str | None = None,
is_async: bool = False,
n_candidates: int = 5,
Expand Down Expand Up @@ -177,16 +176,12 @@ def optimize_code(
"is_numerical_code": is_numerical_code,
}

# Add language-specific version fields
# Always include python_version for backward compatibility with older backend
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif is_java():
payload["language_version"] = language_version or "17" # Default Java version
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
# Add language version (canonical for all languages)
payload["language_version"] = language_version
# Backward compat: backend still expects python_version
payload["python_version"] = language_version if is_python() else platform.python_version()

if not is_python():
if module_system:
payload["module_system"] = module_system

Expand Down Expand Up @@ -262,7 +257,8 @@ def get_jit_rewritten_code(self, source_code: str, trace_id: str) -> list[Optimi
"source_code": source_code,
"trace_id": trace_id,
"dependency_code": "", # dummy value to please the api endpoint
"python_version": "3.12.1", # dummy value to please the api endpoint
"language_version": platform.python_version(),
"python_version": platform.python_version(), # backward compat
"current_username": get_last_commit_author_if_pr_exists(None),
"repo_owner": git_repo_owner,
"repo_name": git_repo_name,
Expand Down Expand Up @@ -329,18 +325,15 @@ def optimize_python_code_line_profiler(
logger.info("Generating optimized candidates with line profiler…")
console.rule()

# Set python_version for backward compatibility with Python, or use language_version
python_version = language_version if language_version else platform.python_version()

payload = {
"source_code": source_code,
"dependency_code": dependency_code,
"n_candidates": n_candidates,
"line_profiler_results": line_profiler_results,
"trace_id": trace_id,
"python_version": python_version,
"language": language,
"language_version": language_version,
"python_version": language_version if is_python() else platform.python_version(), # backward compat
"experiment_metadata": experiment_metadata,
"codeflash_version": codeflash_version,
"call_sequence": self.get_next_sequence(),
Expand Down Expand Up @@ -434,14 +427,10 @@ def optimize_code_refinement(self, request: list[AIServiceRefinerRequest]) -> li
"language": opt.language,
}

# Add language version - always include python_version for backward compatibility
item["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif opt.language_version:
item["language_version"] = opt.language_version
else:
item["language_version"] = "ES2022" # Default for JS/TS
# Add language version (canonical for all languages)
item["language_version"] = opt.language_version
# Backward compat: backend still expects python_version
item["python_version"] = opt.language_version if is_python() else platform.python_version()

# Add multi-file context if provided
if opt.additional_context_files:
Expand Down Expand Up @@ -649,7 +638,8 @@ def generate_ranking(
"diffs": diffs,
"speedups": speedups,
"optimization_ids": optimization_ids,
"python_version": platform.python_version(),
"language_version": platform.python_version(),
"python_version": platform.python_version(), # backward compat
"function_references": function_references,
}
logger.info("loading|Generating ranking")
Expand Down Expand Up @@ -785,18 +775,16 @@ def generate_regression_tests(
"is_async": function_to_optimize.is_async,
"call_sequence": self.get_next_sequence(),
"is_numerical_code": is_numerical_code,
"class_name": function_to_optimize.class_name,
"qualified_name": function_to_optimize.qualified_name,
}

# Add language-specific version fields
# Always include python_version for backward compatibility with older backend
payload["python_version"] = platform.python_version()
if is_python():
pass # python_version already set
elif is_java():
payload["language_version"] = language_version or "17" # Default Java version
else:
payload["language_version"] = language_version or "ES2022"
# Add module system for JavaScript/TypeScript (esm or commonjs)
# Add language version (canonical for all languages)
payload["language_version"] = language_version
# Backward compat: backend still expects python_version
payload["python_version"] = language_version if is_python() else platform.python_version()

if not is_python():
if module_system:
payload["module_system"] = module_system

Expand Down Expand Up @@ -884,7 +872,8 @@ def get_optimization_review(
"codeflash_version": codeflash_version,
"calling_fn_details": calling_fn_details,
"language": language,
"python_version": platform.python_version() if is_python() else None,
"language_version": platform.python_version() if is_python() else None,
"python_version": platform.python_version() if is_python() else None, # backward compat
"call_sequence": self.get_next_sequence(),
}
console.rule()
Expand Down
22 changes: 12 additions & 10 deletions codeflash/api/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,15 @@ def to_payload(self) -> dict[str, Any]:
"is_numerical_code": self.is_numerical_code,
}

# Add language-specific fields
if self.language_info.version:
payload["language_version"] = self.language_info.version
# Add language version (canonical for all languages)
payload["language_version"] = self.language_info.version

# Backward compat: always include python_version
# Backward compat: backend still expects python_version
import platform

payload["python_version"] = platform.python_version()
payload["python_version"] = (
self.language_info.version if self.language_info.name == "python" else platform.python_version()
)

# Module system for JS/TS
if self.language_info.module_system != ModuleSystem.UNKNOWN:
Expand Down Expand Up @@ -205,14 +206,15 @@ def to_payload(self) -> dict[str, Any]:
"is_numerical_code": self.is_numerical_code,
}

# Add language version
if self.language_info.version:
payload["language_version"] = self.language_info.version
# Add language version (canonical for all languages)
payload["language_version"] = self.language_info.version

# Backward compat: always include python_version
# Backward compat: backend still expects python_version
import platform

payload["python_version"] = platform.python_version()
payload["python_version"] = (
self.language_info.version if self.language_info.name == "python" else platform.python_version()
)

# Module system for JS/TS
if self.language_info.module_system != ModuleSystem.UNKNOWN:
Expand Down
11 changes: 4 additions & 7 deletions codeflash/languages/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,13 +325,9 @@ def dir_excludes(self) -> frozenset[str]:
...

@property
def default_language_version(self) -> str | None:
"""Default language version string sent to AI service.

Returns None for languages where the runtime version is auto-detected (e.g. Python).
Returns a version string (e.g. "ES2022") for languages that need an explicit default.
"""
return None
def language_version(self) -> str | None:
"""The detected language version (e.g., "17" for Java, "ES2022" for JS)."""
...

@property
def valid_test_frameworks(self) -> tuple[str, ...]:
Expand Down Expand Up @@ -863,6 +859,7 @@ def run_line_profile_tests(

Returns:
Tuple of (result_file_path, subprocess_result).

"""
...

Expand Down
31 changes: 29 additions & 2 deletions codeflash/languages/java/instrumentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ def wrap_target_calls_with_treesitter(
precise_call_timing: bool = False,
class_name: str = "",
test_method_name: str = "",
target_return_type: str = "",
) -> tuple[list[str], int]:
"""Replace target method calls in body_lines with capture + serialize using tree-sitter.

Expand Down Expand Up @@ -327,6 +328,8 @@ def wrap_target_calls_with_treesitter(
call_counter += 1
var_name = f"_cf_result{iter_id}_{call_counter}"
cast_type = _infer_array_cast_type(body_line)
if not cast_type and target_return_type and target_return_type != "void":
cast_type = target_return_type
var_with_cast = f"({cast_type}){var_name}" if cast_type else var_name

# Use per-call unique variables (with call_counter suffix) for behavior mode
Expand Down Expand Up @@ -524,6 +527,26 @@ def _infer_array_cast_type(line: str) -> str | None:
return None


def _extract_return_type(function_to_optimize: Any) -> str:
"""Extract the return type of a Java function from its source file using tree-sitter."""
file_path = getattr(function_to_optimize, "file_path", None)
func_name = _get_function_name(function_to_optimize)
if not file_path or not file_path.exists():
return ""
try:
from codeflash.languages.java.parser import get_java_analyzer

analyzer = get_java_analyzer()
source_text = file_path.read_text(encoding="utf-8")
methods = analyzer.find_methods(source_text)
for method in methods:
if method.name == func_name and method.return_type:
return method.return_type
except Exception:
logger.debug("Could not extract return type for %s", func_name)
return ""


def _get_qualified_name(func: Any) -> str:
"""Get the qualified name from FunctionToOptimize."""
if hasattr(func, "qualified_name"):
Expand Down Expand Up @@ -617,6 +640,7 @@ def instrument_existing_test(
"""
source = test_string
func_name = _get_function_name(function_to_optimize)
target_return_type = _extract_return_type(function_to_optimize)

# Get the original class name from the file name
if test_path:
Expand Down Expand Up @@ -648,14 +672,16 @@ def instrument_existing_test(
)
else:
# Behavior mode: add timing instrumentation that also writes to SQLite
modified_source = _add_behavior_instrumentation(modified_source, original_class_name, func_name)
modified_source = _add_behavior_instrumentation(
modified_source, original_class_name, func_name, target_return_type
)

logger.debug("Java %s testing for %s: renamed class %s -> %s", mode, func_name, original_class_name, new_class_name)
# Why return True here?
return True, modified_source


def _add_behavior_instrumentation(source: str, class_name: str, func_name: str) -> str:
def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, target_return_type: str = "") -> str:
"""Add behavior instrumentation to test methods.

For behavior mode, this adds:
Expand Down Expand Up @@ -796,6 +822,7 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str)
precise_call_timing=True,
class_name=class_name,
test_method_name=test_method_name,
target_return_type=target_return_type,
)

# Add behavior instrumentation setup code (shared variables for all calls in the method)
Expand Down
31 changes: 31 additions & 0 deletions codeflash/languages/java/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def __init__(self) -> None:
self._analyzer = get_java_analyzer()
self.line_profiler_agent_arg: str | None = None
self.line_profiler_warmup_iterations: int = 0
self._language_version: str | None = None

@property
def language(self) -> Language:
Expand Down Expand Up @@ -93,6 +94,10 @@ def default_file_extension(self) -> str:
def dir_excludes(self) -> frozenset[str]:
return frozenset({"target", "build", ".gradle", ".mvn", ".idea", "apidocs", "javadoc"})

@property
def language_version(self) -> str | None:
return self._language_version

def postprocess_generated_tests(
self, generated_tests: GeneratedTestsList, test_framework: str, project_root: Path, source_file_path: Path
) -> GeneratedTestsList:
Expand Down Expand Up @@ -364,10 +369,36 @@ def ensure_runtime_environment(self, project_root: Path) -> bool:
if config is None:
return False

self._language_version = config.java_version
if self._language_version is None:
self._detect_java_version()

# For now, assume the runtime is available
# A full implementation would check/install the JAR
return True

def _detect_java_version(self) -> None:
"""Detect and cache the Java runtime version."""
Comment thread
HeshamHM28 marked this conversation as resolved.
import subprocess

try:
result = subprocess.run(["java", "-version"], check=False, capture_output=True, text=True, timeout=10)
# java -version outputs to stderr, e.g. 'openjdk version "17.0.2"'
output = result.stderr or result.stdout
for line in output.splitlines():
if "version" in line:
# Extract version between quotes: "17.0.2" -> "17"
start = line.find('"')
end = line.find('"', start + 1)
if start != -1 and end != -1:
full_version = line[start + 1 : end]
# Use major version only: "17.0.2" -> "17", "1.8.0_292" -> "8"
major = full_version.split(".")[0]
self._language_version = "8" if major == "1" else major
return
except Exception:
pass

def instrument_existing_test(
self,
test_string: str,
Expand Down
18 changes: 18 additions & 0 deletions codeflash/languages/javascript/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,9 @@ class JavaScriptSupport:
using tree-sitter for code analysis and Jest for test execution.
"""

def __init__(self) -> None:
self._language_version: str | None = None

# === Properties ===

@property
Expand Down Expand Up @@ -68,6 +71,10 @@ def comment_prefix(self) -> str:
def dir_excludes(self) -> frozenset[str]:
return frozenset({"node_modules", "dist", "build", ".next", ".nuxt", "coverage", ".cache", ".turbo", ".vercel"})

@property
def language_version(self) -> str | None:
return self._language_version

# === Discovery ===

def discover_functions(
Expand Down Expand Up @@ -2077,6 +2084,15 @@ def verify_requirements(self, project_root: Path, test_framework: str = "jest")

return len(errors) == 0, errors

def _detect_node_version(self) -> None:
"""Detect and cache the Node.js runtime version."""
try:
result = subprocess.run(["node", "--version"], check=False, capture_output=True, text=True, timeout=10)
if result.returncode == 0 and result.stdout.strip():
self._language_version = result.stdout.strip().lstrip("v")
except Exception:
pass

def ensure_runtime_environment(self, project_root: Path) -> bool:
"""Ensure codeflash npm package is installed.

Expand All @@ -2091,6 +2107,8 @@ def ensure_runtime_environment(self, project_root: Path) -> bool:
"""
from codeflash.cli_cmds.console import logger

self._detect_node_version()

node_modules_pkg = project_root / "node_modules" / "codeflash"
if node_modules_pkg.exists():
logger.debug("codeflash already installed")
Expand Down
Loading
Loading