Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions codeflash/languages/javascript/normalizer.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""JavaScript/TypeScript code normalizer using tree-sitter.

Not currently wired into JavaScriptSupport.normalize_code — kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.

The old CodeNormalizer ABC (deleted from base.py) is preserved below for reference.
"""
Expand Down Expand Up @@ -236,8 +235,7 @@ def normalize_js_code(code: str, typescript: bool = False) -> str:
Uses tree-sitter to parse and normalize variable names. Falls back to
basic comment/whitespace stripping if tree-sitter is unavailable or parsing fails.

Not currently wired into JavaScriptSupport.normalize_code — kept as a
ready-to-use upgrade path when AST-based JS deduplication is needed.
Wired into JavaScriptSupport.normalize_code for AST-based JS deduplication.
"""
try:
from codeflash.languages.javascript.treesitter import TreeSitterAnalyzer, TreeSitterLanguage
Expand Down
57 changes: 31 additions & 26 deletions codeflash/languages/javascript/support.py
Original file line number Diff line number Diff line change
Expand Up @@ -1207,20 +1207,29 @@ def find_function_node(node, target_name: str):
return node

# Check function declarations
if node.type in ("function_declaration", "function"):
if node.type in (
"function_declaration",
"function",
"generator_function_declaration",
"generator_function",
):
name_node = node.child_by_field_name("name")
if name_node:
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return node

# Check arrow functions assigned to variables
if node.type == "lexical_declaration":
# Check arrow functions and function expressions assigned to variables
if node.type in ("lexical_declaration", "variable_declaration"):
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name:
return value_node
Expand All @@ -1235,6 +1244,7 @@ def find_function_node(node, target_name: str):

func_node = find_function_node(tree.root_node, function_name)
if not func_node:
logger.debug("Could not find function '%s' in optimized code for body extraction", function_name)
return None

# Find the body node
Expand Down Expand Up @@ -1295,14 +1305,21 @@ def find_function_at_line(node, target_name: str, target_line: int):
if name == target_name and (node.start_point[0] + 1) == target_line:
return node

if node.type == "lexical_declaration":
if node.type in ("lexical_declaration", "variable_declaration"):
for child in node.children:
if child.type == "variable_declarator":
name_node = child.child_by_field_name("name")
value_node = child.child_by_field_name("value")
if name_node and value_node and value_node.type == "arrow_function":
if (
name_node
and value_node
and value_node.type in ("arrow_function", "function_expression", "generator_function")
):
name = source_bytes[name_node.start_byte : name_node.end_byte].decode("utf8")
if name == target_name and (node.start_point[0] + 1) == target_line:
if name == target_name and (
(node.start_point[0] + 1) == target_line
or (value_node.start_point[0] + 1) == target_line
):
return value_node

for child in node.children:
Expand Down Expand Up @@ -1686,26 +1703,14 @@ def validate_syntax(self, source: str) -> bool:
return False

def normalize_code(self, source: str) -> str:
"""Normalize JavaScript code for deduplication.

Removes comments and normalizes whitespace.

Args:
source: Source code to normalize.

Returns:
Normalized source code.
"""Normalize JavaScript code for deduplication using tree-sitter."""
from codeflash.languages.javascript.normalizer import normalize_js_code

"""
# Simple normalization: remove extra whitespace
# A full implementation would use tree-sitter to strip comments
lines = source.splitlines()
normalized_lines = []
for line in lines:
stripped = line.strip()
if stripped and not stripped.startswith("//"):
normalized_lines.append(stripped)
return "\n".join(normalized_lines)
try:
is_ts = self.treesitter_language == TreeSitterLanguage.TYPESCRIPT
return normalize_js_code(source, typescript=is_ts)
except Exception:
return source

def generate_concolic_tests(
self, test_cfg: Any, project_root: Any, function_to_optimize: Any, function_to_optimize_ast: Any
Expand Down
72 changes: 72 additions & 0 deletions tests/test_code_deduplication.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from codeflash.languages.javascript.normalizer import normalize_js_code
from codeflash.languages.python.normalizer import normalize_python_code as normalize_code


Expand Down Expand Up @@ -133,3 +134,74 @@ def safe_divide(a, b):
assert normalize_code(code9) == normalize_code(code10)

assert normalize_code(code9) != normalize_code(code8)


# === JavaScript deduplication tests ===


def test_js_deduplicate_same_logic_different_vars():
code1 = """
function process(items) {
const result = [];
for (const item of items) {
result.push(item * 2);
}
return result;
}
"""
code2 = """
function process(items) {
const output = [];
for (const val of items) {
output.push(val * 2);
}
return output;
}
"""
assert normalize_js_code(code1) == normalize_js_code(code2)


def test_js_different_logic_not_deduplicated():
code1 = """
function compute(x) {
return x + 1;
}
"""
code2 = """
function compute(x) {
return x * 2;
}
"""
assert normalize_js_code(code1) != normalize_js_code(code2)


def test_js_deduplicate_whitespace_and_comments():
code1 = """
function add(a, b) {
// fast path
return a + b;
}
"""
code2 = """
function add(a, b) {
/* optimized */
return a + b;
}
"""
assert normalize_js_code(code1) == normalize_js_code(code2)


def test_ts_normalize():
code1 = """
function greet(name: string): string {
const msg = "hello " + name;
return msg;
}
"""
code2 = """
function greet(name: string): string {
const result = "hello " + name;
return result;
}
"""
assert normalize_js_code(code1, typescript=True) == normalize_js_code(code2, typescript=True)
48 changes: 36 additions & 12 deletions tests/test_languages/test_javascript_support.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,10 +443,10 @@ def test_syntax_error_types(self, js_support):


class TestNormalizeCode:
"""Tests for normalize_code method."""
"""Tests for normalize_code method using tree-sitter normalizer."""

def test_removes_comments(self, js_support):
"""Test that single-line comments are removed."""
"""Test that comments are absent from normalized output."""
code = """
function add(a, b) {
// Add two numbers
Expand All @@ -455,19 +455,43 @@ def test_removes_comments(self, js_support):
"""
normalized = js_support.normalize_code(code)
assert "// Add two numbers" not in normalized
assert "return a + b" in normalized
assert "Add two numbers" not in normalized

def test_preserves_functionality(self, js_support):
"""Test that code functionality is preserved."""
code = """
function add(a, b) {
// Comment
return a + b;
def test_same_logic_different_vars_are_equal(self, js_support):
"""Test that two functions with same logic but different variable names normalize identically."""
code1 = """
function process(items) {
const result = [];
for (const item of items) {
result.push(item * 2);
}
return result;
}
"""
normalized = js_support.normalize_code(code)
assert "function add" in normalized
assert "return" in normalized
code2 = """
function process(items) {
const output = [];
for (const val of items) {
output.push(val * 2);
}
return output;
}
"""
assert js_support.normalize_code(code1) == js_support.normalize_code(code2)

def test_different_logic_not_equal(self, js_support):
"""Test that two functions with different logic produce different normalized forms."""
code1 = """
function compute(x) {
return x + 1;
}
"""
code2 = """
function compute(x) {
return x * 2;
}
"""
assert js_support.normalize_code(code1) != js_support.normalize_code(code2)


class TestExtractCodeContext:
Expand Down
Loading
Loading