Skip to content

⚡️ Speed up method JavaAssertTransformer._find_balanced_braces by 121% in PR #1199 (omni-java)#1707

Closed
codeflash-ai[bot] wants to merge 85 commits into
omni-javafrom
codeflash/optimize-pr1199-2026-03-02T05.57.20
Closed

⚡️ Speed up method JavaAssertTransformer._find_balanced_braces by 121% in PR #1199 (omni-java)#1707
codeflash-ai[bot] wants to merge 85 commits into
omni-javafrom
codeflash/optimize-pr1199-2026-03-02T05.57.20

Conversation

@codeflash-ai
Copy link
Copy Markdown
Contributor

@codeflash-ai codeflash-ai Bot commented Mar 2, 2026

⚡️ This pull request contains optimizations for PR #1199

If you approve this dependent PR, these changes will be merged into the original PR branch omni-java.

This PR will be automatically closed if the original PR is merged.


📄 121% (1.21x) speedup for JavaAssertTransformer._find_balanced_braces in codeflash/languages/java/remove_asserts.py

⏱️ Runtime : 1.76 milliseconds 798 microseconds (best of 250 runs)

📝 Explanation and details

The hot loop was repeatedly calling re.search to locate the next special character (", ', {, }, (), then calling m.start() and m.group() on every iteration—profiler shows these three lines consumed ~37% of total runtime. The optimized version replaces regex scanning with a direct character-by-character walk (ch = s[pos]; pos += 1), eliminating the match-object overhead and achieving a 120% speedup (1.76 ms → 798 µs). The original's per-iteration cost of ~2.4 µs (regex + extraction) drops to ~1.2 µs with simple indexing, confirmed by the deep-nesting test case improving from 646 µs to 200 µs. Trade-off: two test cases regressed by ~4–21% in wall time due to earlier failure-path exits, but these are error cases that return immediately and represent negligible absolute time (<1 µs difference).

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 87 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 97.3%
🌀 Click to see Generated Regression Tests
import pytest  # used for our unit tests
from codeflash.languages.java.remove_asserts import JavaAssertTransformer

def test_basic_simple_balanced():
    # Create a real JavaAssertTransformer instance with a dummy function name.
    t = JavaAssertTransformer("dummy")
    # A minimal balanced-brace example: single outer braces with "abc" inside.
    code = "{abc}"
    # Call the instance method with the position of the opening brace (0).
    content, pos = t._find_balanced_braces(code, 0) # 3.54μs -> 1.91μs (84.7% faster)

def test_nested_braces():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # A string containing nested braces; find the first outer brace.
    code = "prefix {a {b} c} suffix"
    open_pos = code.index("{")  # location of the outer brace
    # Extract content and position
    content, pos = t._find_balanced_braces(code, open_pos) # 4.70μs -> 2.65μs (77.7% faster)
    # pos should point to the character after the matching outer closing brace.
    expected_pos = code.index("}", open_pos) + 1

def test_brace_inside_double_quoted_string_is_ignored():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # The inner '}' is inside a quoted string, so it must be ignored when matching braces.
    # The Java snippet: { "}" }  (note the double-quoted string contains a brace)
    code = '{ "}" }'
    # Opening brace at position 0
    content, pos = t._find_balanced_braces(code, 0) # 4.58μs -> 2.86μs (60.3% faster)

def test_brace_inside_single_quoted_char_literal_is_ignored():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # Single-quoted character containing a brace should be skipped by the parser logic.
    # Java snippet: { '}' }
    code = "{ '}' }"
    content, pos = t._find_balanced_braces(code, 0) # 4.39μs -> 2.69μs (62.8% faster)

def test_unbalanced_missing_closing_brace_returns_error():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # Missing the closing brace -> unbalanced
    code = "{abc"
    content, pos = t._find_balanced_braces(code, 0) # 1.44μs -> 1.50μs (3.93% slower)

def test_unterminated_double_quote_inside_braces_returns_error():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # The string literal inside the braces is not terminated: { "unterminated }
    # This should cause the routine to fail because it cannot find the closing quote.
    code = '{ "unterminated }'
    content, pos = t._find_balanced_braces(code, 0) # 3.39μs -> 1.88μs (79.8% faster)

def test_open_position_not_brace_or_out_of_range_returns_error():
    # Real instance
    t = JavaAssertTransformer("dummy")
    code = "no braces here"
    # Position 0 is not a '{' -> should fail.
    content, pos = t._find_balanced_braces(code, 0) # 651ns -> 561ns (16.0% faster)

    # Position beyond the string length -> should also fail.
    content2, pos2 = t._find_balanced_braces(code, len(code) + 10) # 290ns -> 290ns (0.000% faster)

def test_parentheses_inside_do_not_affect_brace_matching():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # Parentheses should be treated as "special" but not change brace depth; content should include them.
    code = "{ (call(arg1, arg2)) }"
    content, pos = t._find_balanced_braces(code, 0) # 5.61μs -> 3.55μs (58.2% faster)

def test_escaped_quotes_inside_strings_are_handled_properly():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # A string literal that contains an escaped quote: "a \" b"
    # The algorithm should skip from the opening double-quote to its matching closing double-quote,
    # correctly handling the escaped internal quote.
    code = '{ "a \\" inner \\" b" }'
    # Opening at 0
    content, pos = t._find_balanced_braces(code, 0) # 5.16μs -> 3.47μs (48.8% faster)

def test_large_scale_nested_braces_depth_1000():
    # Real instance
    t = JavaAssertTransformer("dummy")
    # Construct a deeply nested braces string with depth n = 1000 to test scalability.
    n = 1000
    # Outer brace plus (n-1) inner opens then (n-1) closes and final outer close:
    code = "{" + ("{" * (n - 1)) + ("}" * (n - 1)) + "}"
    # Call the method on the outermost opening brace at position 0.
    content, pos = t._find_balanced_braces(code, 0) # 646μs -> 200μs (223% faster)
    # The content should be the (n-1) opens followed by (n-1) closes.
    expected_content = "{" * (n - 1) + "}" * (n - 1)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import pytest
from codeflash.languages.java.parser import get_java_analyzer
from codeflash.languages.java.remove_asserts import JavaAssertTransformer

def test_find_balanced_braces_simple_braces():
    """Test finding content within simple balanced braces."""
    # Arrange: Create a transformer and prepare code with balanced braces
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content }"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.34μs -> 2.60μs (28.6% faster)

def test_find_balanced_braces_nested_braces():
    """Test finding content with nested braces."""
    # Arrange: Create a transformer and prepare code with nested braces
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ outer { inner } }"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 4.53μs -> 3.34μs (35.8% faster)

def test_find_balanced_braces_empty_braces():
    """Test finding content within empty braces."""
    # Arrange: Create a transformer and prepare code with empty braces
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{}"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.02μs -> 1.46μs (106% faster)

def test_find_balanced_braces_with_string_literal():
    """Test handling of string literals containing braces."""
    # Arrange: Create a transformer and prepare code with string containing braces
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ content "{ brace }" }'
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 4.61μs -> 3.67μs (25.7% faster)

def test_find_balanced_braces_with_single_quotes():
    """Test handling of single-quoted strings containing braces."""
    # Arrange: Create a transformer and prepare code with single-quoted string
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content '{ brace }' }"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 4.47μs -> 3.52μs (27.1% faster)

def test_find_balanced_braces_multiple_levels_of_nesting():
    """Test handling of deeply nested braces."""
    # Arrange: Create a transformer and prepare code with multiple nesting levels
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ level1 { level2 { level3 } } }"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 5.43μs -> 4.54μs (19.7% faster)

def test_find_balanced_braces_with_escaped_quotes_in_string():
    """Test handling of escaped quotes within string literals."""
    # Arrange: Create a transformer and prepare code with escaped quotes
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ content "\\"escaped\\"" }'
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 5.15μs -> 4.08μs (26.3% faster)

def test_find_balanced_braces_invalid_start_position_out_of_bounds():
    """Test with start position beyond string length."""
    # Arrange: Create a transformer and prepare code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content }"
    
    # Act: Try to find balanced braces at position beyond string length
    content, end_pos = transformer._find_balanced_braces(code, 100) # 501ns -> 501ns (0.000% faster)

def test_find_balanced_braces_invalid_start_not_brace():
    """Test when start position does not point to an opening brace."""
    # Arrange: Create a transformer and prepare code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "x { content }"
    
    # Act: Try to find balanced braces starting from non-brace character
    content, end_pos = transformer._find_balanced_braces(code, 0) # 651ns -> 621ns (4.83% faster)

def test_find_balanced_braces_unbalanced_missing_closing():
    """Test with unbalanced braces (missing closing brace)."""
    # Arrange: Create a transformer and prepare code with unbalanced braces
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content"
    
    # Act: Try to find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 1.58μs -> 1.99μs (20.6% slower)

def test_find_balanced_braces_unbalanced_extra_closing():
    """Test with extra closing brace after balanced content."""
    # Arrange: Create a transformer and prepare code with extra closing brace
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content }}"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.44μs -> 2.54μs (35.6% faster)

def test_find_balanced_braces_unclosed_string():
    """Test with unclosed string literal."""
    # Arrange: Create a transformer and prepare code with unclosed string
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ content "unclosed'
    
    # Act: Try to find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.19μs -> 2.81μs (13.5% faster)

def test_find_balanced_braces_unclosed_single_quote():
    """Test with unclosed single quote."""
    # Arrange: Create a transformer and prepare code with unclosed single quote
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content 'unclosed"
    
    # Act: Try to find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.17μs -> 2.69μs (17.9% faster)

def test_find_balanced_braces_only_opening_brace_no_closing():
    """Test with single opening brace and no closing brace."""
    # Arrange: Create a transformer and prepare code with only opening brace
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{"
    
    # Act: Try to find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 912ns -> 902ns (1.11% faster)

def test_find_balanced_braces_closing_brace_inside_string():
    """Test with closing brace hidden inside a string literal."""
    # Arrange: Create a transformer and prepare code with closing brace in string
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ content "}" extra'
    
    # Act: Try to find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.78μs -> 3.66μs (3.31% faster)

def test_find_balanced_braces_with_escaped_backslash_before_quote():
    """Test with escaped backslash before quote in string."""
    # Arrange: Create a transformer and prepare code with escaped backslash
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ content "\\\\\\"" }'
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 4.98μs -> 3.84μs (29.8% faster)

def test_find_balanced_braces_start_at_exact_brace_position():
    """Test starting at exact position of opening brace."""
    # Arrange: Create a transformer and prepare code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "prefix { content } suffix"
    
    # Act: Find balanced braces starting from position of opening brace
    content, end_pos = transformer._find_balanced_braces(code, 7) # 3.39μs -> 2.60μs (30.0% faster)

def test_find_balanced_braces_empty_string():
    """Test with empty string."""
    # Arrange: Create a transformer and prepare code with empty string
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = ""
    
    # Act: Try to find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 531ns -> 461ns (15.2% faster)

def test_find_balanced_braces_position_at_closing_brace():
    """Test when start position points to a closing brace instead of opening."""
    # Arrange: Create a transformer and prepare code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ content }"
    
    # Act: Try to find balanced braces starting from closing brace position
    content, end_pos = transformer._find_balanced_braces(code, 10) # 641ns -> 592ns (8.28% faster)

def test_find_balanced_braces_adjacent_braces():
    """Test with adjacent braces at different depths."""
    # Arrange: Create a transformer and prepare code with adjacent braces
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ { } }"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 4.94μs -> 2.40μs (105% faster)

def test_find_balanced_braces_multiple_strings_with_quotes():
    """Test with multiple string literals in sequence."""
    # Arrange: Create a transformer and prepare code with multiple strings
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ "first" \'second\' }'
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 5.59μs -> 3.53μs (58.5% faster)

def test_find_balanced_braces_string_with_backslash_at_end():
    """Test string ending with backslash followed by quote."""
    # Arrange: Create a transformer and prepare code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = '{ content "\\\\" }'
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.43μs -> 3.02μs (13.7% faster)

def test_find_balanced_braces_single_quote_with_backslash():
    """Test single-quoted string with escaped character."""
    # Arrange: Create a transformer and prepare code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    code = "{ '\\n' }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 4.61μs -> 2.73μs (68.5% faster)

def test_find_balanced_braces_deeply_nested_large_scale():
    """Test with deeply nested braces (large scale)."""
    # Arrange: Create a transformer and prepare deeply nested code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build code with 100 levels of nesting
    code = "{" + "{ " * 100 + "content" + " }" * 100 + "}"
    
    # Act: Find balanced braces starting from position 0
    content, end_pos = transformer._find_balanced_braces(code, 0) # 70.6μs -> 39.9μs (77.0% faster)

def test_find_balanced_braces_large_string_literal():
    """Test with very large string literal inside braces."""
    # Arrange: Create a transformer and prepare code with large string
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    large_string = "x" * 1000
    code = "{ \"" + large_string + "\" }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 5.11μs -> 3.61μs (41.7% faster)

def test_find_balanced_braces_many_nested_strings():
    """Test with many string literals nested in braces."""
    # Arrange: Create a transformer and prepare code with many strings
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build code with 100 strings inside braces
    code = "{ " + " ".join(['"str' + str(i) + '"' for i in range(100)]) + " }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 54.4μs -> 39.1μs (38.9% faster)
    for i in range(100):
        pass

def test_find_balanced_braces_complex_nested_structure():
    """Test with complex mix of nested braces and strings."""
    # Arrange: Create a transformer and prepare complex code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build complex nested structure with 50 levels of braces and strings
    code = "{"
    for i in range(50):
        code += "{ \"string" + str(i) + "\" "
    code += "content"
    for i in range(50):
        code += " }"
    code += " }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 64.3μs -> 41.1μs (56.6% faster)

def test_find_balanced_braces_long_string_with_many_escapes():
    """Test with long string containing many escaped characters."""
    # Arrange: Create a transformer and prepare code with escaped characters
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build string with 500 escaped sequences
    escaped_content = "\\\\" * 500
    code = "{ \"" + escaped_content + "\" }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 3.77μs -> 2.48μs (51.6% faster)

def test_find_balanced_braces_alternating_quotes_large_scale():
    """Test with alternating double and single quotes (large scale)."""
    # Arrange: Create a transformer and prepare code with alternating quotes
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build code with alternating quote types
    code = "{ " + " ".join([('"s' + str(i) + '"' if i % 2 == 0 else "'s" + str(i) + "'") for i in range(200)]) + " }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 109μs -> 79.0μs (39.2% faster)

def test_find_balanced_braces_very_long_code():
    """Test with very long code string (1000+ characters)."""
    # Arrange: Create a transformer and prepare very long code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build code with nested structures repeated many times
    inner = "{ \"content\" { } }"
    code = "{ " + inner * 50 + " }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 99.7μs -> 55.5μs (79.8% faster)

def test_find_balanced_braces_performance_deeply_nested():
    """Test performance with extremely deep nesting (500 levels)."""
    # Arrange: Create a transformer and prepare extremely deeply nested code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build code with 500 levels of nesting
    code = "{" + "{" * 500 + "x" + "}" * 500 + "}"
    
    # Act: Find balanced braces (should complete in reasonable time)
    content, end_pos = transformer._find_balanced_braces(code, 0) # 345μs -> 100μs (243% faster)

def test_find_balanced_braces_large_alternating_nested_braces():
    """Test with large number of alternating nested braces."""
    # Arrange: Create a transformer and prepare code with many nested levels
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build alternating pattern of nested braces
    code = "{"
    for i in range(200):
        if i % 2 == 0:
            code += " {"
        else:
            code += " }"
    code += " }"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 71.3μs -> 39.8μs (79.1% faster)

def test_find_balanced_braces_mixed_content_stress():
    """Stress test with mixed strings, braces, and quotes."""
    # Arrange: Create a transformer and prepare stress test code
    transformer = JavaAssertTransformer("test", analyzer=get_java_analyzer())
    # Build complex mixed content
    code = "{ "
    for i in range(100):
        code += "{ \"str" + str(i) + "\" '" + chr(65 + (i % 26)) + "' } "
    code += "}"
    
    # Act: Find balanced braces
    content, end_pos = transformer._find_balanced_braces(code, 0) # 182μs -> 117μs (54.9% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To edit these changes git checkout codeflash/optimize-pr1199-2026-03-02T05.57.20 and push.

Codeflash Static Badge

aseembits93 and others added 30 commits February 4, 2026 16:48
Remove safe_relative_to, resolve_classes_from_modules,
extract_classes_from_type_hint, resolve_transitive_type_deps,
extract_init_stub, _is_project_module_cached, is_project_path,
_is_project_module, extract_imports_for_class,
collect_names_from_annotation, is_dunder_method, _qualified_name,
and _validate_classdef. Inline trivial helpers into prune_cst and
clean up enrich_testgen_context and get_function_sources_from_jedi.
Remove corresponding tests.
Add enrichment step that parses FTO parameter type annotations, resolves
types via jedi (following re-exports), and extracts full __init__ source
to give the LLM constructor context for typed parameters.
Fix 10 failing tests: remove wrong assertions expecting import statements
inside extracted class code, use substring matching for UserDict class
signature, and rewrite click-dependent tests as project-local equivalents.
Add tests for resolve_instance_class_name, enhanced extract_init_stub_from_class,
and enrich_testgen_context instance resolution.
Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com>
The optimized code achieves a **70% runtime speedup** (from 7.02ms to 4.13ms) through three key improvements:

## 1. **Faster Class Discovery via Deque-Based BFS (Primary Speedup)**
The original code uses `ast.walk()` which recursively traverses the entire AST tree even after finding the target class. The line profiler shows this taking 20.5ms (71% of time).

The optimized version replaces this with an explicit BFS using `collections.deque`, which stops immediately upon finding the target class. The profiler shows this reduces traversal time to 9.95ms - **cutting the search overhead by >50%**.

This is especially impactful when:
- The target class appears early in the module (eliminates unnecessary traversal)
- The module contains many classes (test shows 7-10% faster on modules with 100-1000 classes)
- The function is called frequently (shown by the 108% speedup on 1000 repeated calls)

## 2. **Explicit Loops Replace Generator Overhead**
The original code uses `any()` with a generator expression and `min()` with a generator to check decorators and find minimum line numbers. These create function call and generator overhead.

The optimized version uses explicit `for` loops with early breaks:
- Decorator checking: Directly iterates and breaks on first match
- Min line number: Uses explicit comparison instead of `min()` generator

The profiler shows decorator processing time reduced from ~1.4ms to ~0.3ms, and min line calculation from 69μs to 28μs.

## 3. **Conditional Flag Pattern for Relevance Checking**
Instead of evaluating both conditions in a compound expression, the optimized version uses an `is_relevant` flag with early exits, reducing redundant checks.

## Impact on Workloads
Based on `function_references`, this function is called from:
- `enrich_testgen_context`: Used in test generation workflows where it may process many classes
- Benchmark tests: Indicates this is in a performance-critical path

The optimization particularly benefits:
- **Large codebases**: 89-90% faster on classes with 100+ methods or 50+ properties
- **Repeated calls**: 108% faster when called 1000 times in sequence
- **Early matches**: Up to 88% faster when target class is found quickly
- **Deep nesting**: 57% faster for nested classes

The annotated tests show consistent 50-108% speedups across most scenarios, with minimal gains (6-10%) only when processing very large files where string slicing dominates runtime.
Add --agent CLI flag for AI agent integrations that skips all
interactive prompts. In agent mode, checkpoint resume is skipped
entirely so each run starts fresh. Also gates the existing checkpoint
prompt behind --yes.
In agent mode, disable all Rich output (panels, spinners, progress bars,
syntax highlighting) and use a plain StreamHandler for logging. Optimization
results with explanation and unified diff are written to stdout. A log
filter strips LSP prefixes and drops noisy test/file-path messages.
Also skip checkpoint creation and suppress Python warnings in agent mode.
…oncolic tests

- --agent now implies --no-pr and --worktree so source files stay clean
- Output uses structured XML (codeflash-optimization) with optimized-code
  for the consuming agent to apply via Edit/Write
- Skip concolic test generation in agent mode
- Skip patch file creation in worktree + agent mode
fix: resolve test file paths in discover_tests_pytest to fix path com…
The comparator did not recognize `types.UnionType` (Python 3.10+ `X | Y`
syntax), causing it to fall through to "Unknown comparator input type".
Conditionally include it in the equality-checked types tuple.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Rename the CLI flag, env var (CODEFLASH_SUBAGENT_MODE), helper
(is_subagent_mode), and related symbols to avoid confusion with
CodeFlash's own agent terminology.
Avoids computing the full summary (callee counts, string formatting)
only to discard it when running in subagent mode.
KRRT7 and others added 14 commits February 27, 2026 15:14
perf: optimize tracer hot path with string-based path ops and caching
Remove the AST-based discovery path for Python, routing all languages
through the unified CST-based `_find_all_functions_via_language_support`.
Delete dead code: `find_functions_with_return_statement`,
`_find_all_functions_in_python_file`, `function_has_return_statement`,
`function_is_a_property`, and associated constants. Fix FunctionVisitor
to skip nested functions and exclude @property/@cached_property, and let
parse errors propagate for correct empty-dict behavior on invalid files.
Nested functions are now skipped by FunctionVisitor, and
discover_functions no longer swallows parse/IO errors — callers
handle them. Update test expectations accordingly.
The comparator had no handler for itertools.count (an infinite iterator),
causing it to fall through all type checks and return False even for
equal objects. Use repr() comparison which reliably reflects internal
state and avoids the __reduce__ deprecation coming in Python 3.14.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
itertools.repeat uses repr() comparison (same approach as count).
itertools.cycle uses __reduce__() to extract internal state (saved items,
remaining items, and first-pass flag) since repr() only shows a memory
address. The __reduce__ approach is deprecated in 3.14 but is the only
way to access cycle state without consuming elements.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- Change FunctionVisitor.file_path from str to Path
- Unify dict keys to Path across discovery functions (get_all_files_and_functions,
  get_functions_within_lines, get_functions_within_git_diff, etc.)
- Remove redundant isinstance check in discover_functions
- Add assert for found_function narrowing after exit_with_message
- Fix closest_matching_file_function_name return type narrowing
Add a catch-all handler for itertools iterators (chain, islice, product,
permutations, combinations, starmap, accumulate, compress, dropwhile,
takewhile, filterfalse, zip_longest, groupby, pairwise, batched, tee).
Uses module check (type.__module__ == "itertools") so it automatically
covers any itertools type without version-specific enumeration. groupby
gets special handling to also materialize its group iterators.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Handle itertools.cycle on Python 3.14 where __reduce__ was removed by
falling back to element-by-element sampling. Add version guards for
pairwise (3.10+) and batched (3.12+) tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
refactor: consolidate Python function discovery to CST path only
feat: surface subagent optimization diffs in IDE's native diff view
…ount

fix: handle itertools types in comparator with Python 3.9-3.14 support
# Conflicts:
#	codeflash/cli_cmds/console.py
#	codeflash/cli_cmds/logging_config.py
#	codeflash/code_utils/time_utils.py
#	codeflash/optimization/function_optimizer.py
#	codeflash/optimization/optimizer.py
#	codeflash/verification/parse_test_output.py
#	pyproject.toml
The hot loop was repeatedly calling `re.search` to locate the next special character (`"`, `'`, `{`, `}`, `(`), then calling `m.start()` and `m.group()` on every iteration—profiler shows these three lines consumed ~37% of total runtime. The optimized version replaces regex scanning with a direct character-by-character walk (`ch = s[pos]; pos += 1`), eliminating the match-object overhead and achieving a 120% speedup (1.76 ms → 798 µs). The original's per-iteration cost of ~2.4 µs (regex + extraction) drops to ~1.2 µs with simple indexing, confirmed by the deep-nesting test case improving from 646 µs to 200 µs. Trade-off: two test cases regressed by ~4–21% in wall time due to earlier failure-path exits, but these are error cases that return immediately and represent negligible absolute time (<1 µs difference).
@codeflash-ai codeflash-ai Bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Mar 2, 2026
@codeflash-ai codeflash-ai Bot mentioned this pull request Mar 2, 2026
@claude
Copy link
Copy Markdown
Contributor

claude Bot commented Mar 2, 2026

PR Review Summary

Prek Checks

All checks pass. No formatting or linting issues found.

Mypy

No type errors found in codeflash/languages/java/remove_asserts.py.

Code Review

No critical issues found. This is a clean optimization of JavaAssertTransformer._find_balanced_braces that replaces regex scanning with direct character-by-character iteration. The change:

  • Replaces self._special_re.search(code, pos) with ch = s[pos] per-character walking
  • Maintains identical escape-handling logic for single and double quotes
  • Preserves the same brace depth tracking and error handling
  • Correctly handles all edge cases (unbalanced braces, unterminated strings, nested braces)
  • All 74 existing tests in test_remove_asserts.py pass

Test Coverage

File Stmts Miss Coverage
codeflash/languages/java/remove_asserts.py 457 63 86%

Changed method analysis (_find_balanced_braces, lines 848-893):

  • 17 lines uncovered in the changed range, covering edge-case branches (quote literal handling, nested brace tracking, error returns)
  • These are pre-existing coverage gaps - the same code paths were uncovered before this optimization
  • The core loop structure and happy-path return are covered
  • No coverage regression introduced by this PR

Last updated: 2026-03-02T06:15:00Z

@CLAassistant
Copy link
Copy Markdown

CLA assistant check
Thank you for your submission! We really appreciate it. Like many open source projects, we ask that you sign our Contributor License Agreement before we can accept your contribution.
You have signed the CLA already but the status is still pending? Let us recheck it.

@github-actions github-actions Bot added the workflow-modified This PR modifies GitHub Actions workflows label Mar 3, 2026
@aseembits93
Copy link
Copy Markdown
Contributor

closing as it is completely out of sync with omni-java, doing the merge the other way is not working

@aseembits93 aseembits93 closed this Mar 3, 2026
@codeflash-ai codeflash-ai Bot deleted the codeflash/optimize-pr1199-2026-03-02T05.57.20 branch March 3, 2026 02:46
@codeflash-ai
Copy link
Copy Markdown
Contributor Author

codeflash-ai Bot commented Mar 3, 2026

⚡️ Codeflash found optimizations for this PR

📄 223% (2.23x) speedup for set_level in codeflash/cli_cmds/logging_config.py

⏱️ Runtime : 360 milliseconds 112 milliseconds (best of 26 runs)

A new Optimization Review has been created.

🔗 Review here

Static Badge

Comment on lines +748 to +761
if isinstance(node, ast.Name):
return {node.id}
if isinstance(node, ast.Subscript):
names = collect_type_names_from_annotation(node.value)
names |= collect_type_names_from_annotation(node.slice)
return names
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
return collect_type_names_from_annotation(node.left) | collect_type_names_from_annotation(node.right)
if isinstance(node, ast.Tuple):
names = set[str]()
for elt in node.elts:
names |= collect_type_names_from_annotation(elt)
return names
return set()
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 290% (2.90x) speedup for collect_type_names_from_annotation in codeflash/languages/python/context/code_context_extractor.py

⏱️ Runtime : 5.34 milliseconds 1.37 milliseconds (best of 106 runs)

📝 Explanation and details

The recursive approach was replaced with an explicit stack-based loop, eliminating Python function-call overhead and intermediate set allocations. Each recursive call in the original incurred ~1.5 μs overhead (visible in line profiler: 24% of time in BinOp union calls, 22% in Tuple element unions) plus small set merge costs; the iterative version processes nodes in-place by appending to a single names set. Test results confirm 289% overall speedup with the largest gains on deeply nested structures (759% faster for 1000-type union chains, 216% faster for 100-type unions) where recursion overhead compounds. Simple single-node cases regress ~30–50% (now ~1 μs vs ~500 ns) due to added stack setup cost, but these contribute negligibly to real workloads where annotations contain multiple nested types.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 35 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast
from typing import Set

# imports
import pytest
from codeflash.languages.python.context.code_context_extractor import \
    collect_type_names_from_annotation

def test_none_input_returns_empty_set():
    """Test that None input returns an empty set."""
    codeflash_output = collect_type_names_from_annotation(None); result = codeflash_output # 441ns -> 451ns (2.22% slower)

def test_simple_name_node():
    """Test that a simple Name node returns the identifier in a set."""
    # Create a simple Name node representing 'int'
    node = ast.Name(id='int', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 601ns -> 1.09μs (45.0% slower)

def test_simple_name_node_string_type():
    """Test that a simple Name node for 'str' type works correctly."""
    node = ast.Name(id='str', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 551ns -> 1.03μs (46.6% slower)

def test_simple_name_node_list_type():
    """Test that a simple Name node for 'list' type works correctly."""
    node = ast.Name(id='list', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 501ns -> 1.01μs (50.5% slower)

def test_subscript_with_single_base_type():
    """Test that subscript annotation like List[int] extracts both names."""
    # Represents List[int]
    base = ast.Name(id='List', ctx=ast.Load())
    subscript_value = ast.Name(id='int', ctx=ast.Load())
    node = ast.Subscript(value=base, slice=subscript_value, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.59μs -> 1.96μs (18.8% slower)

def test_subscript_with_dict_types():
    """Test that Dict[str, int] extracts both key and value types."""
    # Represents Dict[str, int]
    base = ast.Name(id='Dict', ctx=ast.Load())
    # Create a Tuple for the slice (str, int)
    key_type = ast.Name(id='str', ctx=ast.Load())
    value_type = ast.Name(id='int', ctx=ast.Load())
    slice_tuple = ast.Tuple(elts=[key_type, value_type], ctx=ast.Load())
    node = ast.Subscript(value=base, slice=slice_tuple, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 4.05μs -> 2.85μs (42.2% faster)

def test_tuple_with_single_element():
    """Test that a Tuple with one element works correctly."""
    # Represents (int,)
    elt = ast.Name(id='int', ctx=ast.Load())
    node = ast.Tuple(elts=[elt], ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 2.67μs -> 1.49μs (78.5% faster)

def test_tuple_with_multiple_elements():
    """Test that a Tuple with multiple elements returns all names."""
    # Represents (int, str, float)
    elts = [
        ast.Name(id='int', ctx=ast.Load()),
        ast.Name(id='str', ctx=ast.Load()),
        ast.Name(id='float', ctx=ast.Load()),
    ]
    node = ast.Tuple(elts=elts, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 3.31μs -> 2.27μs (45.4% faster)

def test_union_type_with_binary_or():
    """Test that union type int | str extracts both types."""
    # Represents int | str
    left = ast.Name(id='int', ctx=ast.Load())
    right = ast.Name(id='str', ctx=ast.Load())
    node = ast.BinOp(left=left, op=ast.BitOr(), right=right)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.75μs -> 1.90μs (7.83% slower)

def test_union_type_with_three_types():
    """Test that union type int | str | float extracts all three types."""
    # Represents (int | str) | float
    left_inner = ast.Name(id='int', ctx=ast.Load())
    right_inner = ast.Name(id='str', ctx=ast.Load())
    left_outer = ast.BinOp(left=left_inner, op=ast.BitOr(), right=right_inner)
    right_outer = ast.Name(id='float', ctx=ast.Load())
    node = ast.BinOp(left=left_outer, op=ast.BitOr(), right=right_outer)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 2.22μs -> 2.48μs (10.1% slower)

def test_nested_subscript():
    """Test nested subscript like Optional[List[int]]."""
    # Represents Optional[List[int]]
    # Build List[int]
    list_base = ast.Name(id='List', ctx=ast.Load())
    int_type = ast.Name(id='int', ctx=ast.Load())
    list_subscript = ast.Subscript(
        value=list_base, slice=int_type, ctx=ast.Load()
    )
    # Build Optional[List[int]]
    optional_base = ast.Name(id='Optional', ctx=ast.Load())
    node = ast.Subscript(
        value=optional_base, slice=list_subscript, ctx=ast.Load()
    )
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.91μs -> 2.48μs (22.7% slower)

def test_unsupported_node_type_returns_empty_set():
    """Test that unsupported AST node types return empty set."""
    # Create an unsupported node type (Constant)
    node = ast.Constant(value=42)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 811ns -> 1.16μs (30.2% slower)

def test_attribute_node_returns_empty_set():
    """Test that Attribute nodes (not directly supported) return empty set."""
    # Represents typing.List (attribute access)
    value = ast.Name(id='typing', ctx=ast.Load())
    node = ast.Attribute(value=value, attr='List', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.03μs -> 1.40μs (26.4% slower)

def test_empty_tuple():
    """Test that an empty Tuple returns empty set."""
    node = ast.Tuple(elts=[], ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 2.27μs -> 1.26μs (79.5% faster)

def test_single_character_type_name():
    """Test that single-character type names are handled correctly."""
    node = ast.Name(id='T', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 521ns -> 1.05μs (50.5% slower)

def test_type_name_with_numbers():
    """Test that type names with numbers are handled correctly."""
    node = ast.Name(id='Type123', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 541ns -> 972ns (44.3% slower)

def test_type_name_with_underscore():
    """Test that type names with underscores are handled correctly."""
    node = ast.Name(id='My_Type', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 541ns -> 982ns (44.9% slower)

def test_deeply_nested_subscripts():
    """Test deeply nested subscripts like Dict[str, List[Tuple[int, str]]]."""
    # Build Tuple[int, str]
    int_type = ast.Name(id='int', ctx=ast.Load())
    str_type = ast.Name(id='str', ctx=ast.Load())
    tuple_inner = ast.Tuple(elts=[int_type, str_type], ctx=ast.Load())
    
    # Build List[Tuple[int, str]]
    list_base = ast.Name(id='List', ctx=ast.Load())
    list_subscript = ast.Subscript(
        value=list_base, slice=tuple_inner, ctx=ast.Load()
    )
    
    # Build Dict[str, List[Tuple[int, str]]]
    dict_base = ast.Name(id='Dict', ctx=ast.Load())
    str_key = ast.Name(id='str', ctx=ast.Load())
    dict_slice = ast.Tuple(elts=[str_key, list_subscript], ctx=ast.Load())
    node = ast.Subscript(value=dict_base, slice=dict_slice, ctx=ast.Load())
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 5.90μs -> 3.79μs (55.8% faster)

def test_union_with_subscripts():
    """Test union type combining subscripts: List[int] | Dict[str, float]."""
    # Build List[int]
    list_base = ast.Name(id='List', ctx=ast.Load())
    int_type = ast.Name(id='int', ctx=ast.Load())
    list_sub = ast.Subscript(value=list_base, slice=int_type, ctx=ast.Load())
    
    # Build Dict[str, float]
    dict_base = ast.Name(id='Dict', ctx=ast.Load())
    str_type = ast.Name(id='str', ctx=ast.Load())
    float_type = ast.Name(id='float', ctx=ast.Load())
    dict_slice = ast.Tuple(elts=[str_type, float_type], ctx=ast.Load())
    dict_sub = ast.Subscript(value=dict_base, slice=dict_slice, ctx=ast.Load())
    
    # Build List[int] | Dict[str, float]
    node = ast.BinOp(left=list_sub, op=ast.BitOr(), right=dict_sub)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 5.27μs -> 4.11μs (28.3% faster)

def test_union_preserves_duplicate_types():
    """Test that union with duplicate types still returns set (no duplicates)."""
    # Represents int | int (which is redundant but valid)
    left = ast.Name(id='int', ctx=ast.Load())
    right = ast.Name(id='int', ctx=ast.Load())
    node = ast.BinOp(left=left, op=ast.BitOr(), right=right)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.66μs -> 2.11μs (21.3% slower)

def test_binary_or_with_nested_union():
    """Test nested union: (int | str) | (float | bytes)."""
    # Build int | str
    int_type = ast.Name(id='int', ctx=ast.Load())
    str_type = ast.Name(id='str', ctx=ast.Load())
    union1 = ast.BinOp(left=int_type, op=ast.BitOr(), right=str_type)
    
    # Build float | bytes
    float_type = ast.Name(id='float', ctx=ast.Load())
    bytes_type = ast.Name(id='bytes', ctx=ast.Load())
    union2 = ast.BinOp(left=float_type, op=ast.BitOr(), right=bytes_type)
    
    # Build (int | str) | (float | bytes)
    node = ast.BinOp(left=union1, op=ast.BitOr(), right=union2)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 2.91μs -> 3.25μs (10.5% slower)

def test_binary_or_ignores_other_operators():
    """Test that BinOp with non-BitOr operators returns empty set."""
    # Represents int + str (invalid type annotation, but test coverage)
    left = ast.Name(id='int', ctx=ast.Load())
    right = ast.Name(id='str', ctx=ast.Load())
    node = ast.BinOp(left=left, op=ast.Add(), right=right)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.05μs -> 1.31μs (19.8% slower)

def test_result_is_always_set_type():
    """Test that the result is always a set (not list, tuple, etc.)."""
    node = ast.Name(id='int', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 551ns -> 1.02μs (46.1% slower)

def test_result_contains_only_strings():
    """Test that result set contains only strings."""
    elts = [
        ast.Name(id='int', ctx=ast.Load()),
        ast.Name(id='str', ctx=ast.Load()),
        ast.Name(id='float', ctx=ast.Load()),
    ]
    node = ast.Tuple(elts=elts, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 3.59μs -> 2.23μs (60.6% faster)

def test_callable_type_name():
    """Test handling of 'Callable' type name."""
    node = ast.Name(id='Callable', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 550ns -> 1.03μs (46.7% slower)

def test_any_type_name():
    """Test handling of 'Any' type name."""
    node = ast.Name(id='Any', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 550ns -> 1.02μs (46.2% slower)

def test_generic_type_parameters():
    """Test generic parameters like TypeVar."""
    node = ast.Name(id='TypeVar', ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 511ns -> 1.06μs (51.9% slower)

def test_complex_nested_structure_with_all_node_types():
    """Test complex structure: Union[Optional[List[int]], Dict[str, Tuple[float, bool]]]."""
    # Build Optional[List[int]]
    int_type = ast.Name(id='int', ctx=ast.Load())
    list_base = ast.Name(id='List', ctx=ast.Load())
    list_sub = ast.Subscript(value=list_base, slice=int_type, ctx=ast.Load())
    optional_base = ast.Name(id='Optional', ctx=ast.Load())
    optional_sub = ast.Subscript(
        value=optional_base, slice=list_sub, ctx=ast.Load()
    )
    
    # Build Dict[str, Tuple[float, bool]]
    float_type = ast.Name(id='float', ctx=ast.Load())
    bool_type = ast.Name(id='bool', ctx=ast.Load())
    tuple_inner = ast.Tuple(elts=[float_type, bool_type], ctx=ast.Load())
    dict_base = ast.Name(id='Dict', ctx=ast.Load())
    str_type = ast.Name(id='str', ctx=ast.Load())
    dict_slice = ast.Tuple(elts=[str_type, tuple_inner], ctx=ast.Load())
    dict_sub = ast.Subscript(value=dict_base, slice=dict_slice, ctx=ast.Load())
    
    # Build Union[Optional[List[int]], Dict[str, Tuple[float, bool]]]
    union_base = ast.Name(id='Union', ctx=ast.Load())
    union_slice = ast.Tuple(elts=[optional_sub, dict_sub], ctx=ast.Load())
    node = ast.Subscript(value=union_base, slice=union_slice, ctx=ast.Load())
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 8.41μs -> 5.41μs (55.4% faster)

def test_subscript_with_none_slice():
    """Test subscript with None slice (edge case)."""
    base = ast.Name(id='List', ctx=ast.Load())
    node = ast.Subscript(value=base, slice=None, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 1.33μs -> 1.74μs (23.5% slower)

def test_large_union_with_100_types():
    """Test union with 100 different types."""
    # Build a union of 100 types: Type0 | Type1 | ... | Type99
    types = [
        ast.Name(id=f'Type{i}', ctx=ast.Load())
        for i in range(100)
    ]
    
    # Build nested BinOp structure
    node = types[0]
    for type_node in types[1:]:
        node = ast.BinOp(left=node, op=ast.BitOr(), right=type_node)
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 135μs -> 42.7μs (216% faster)
    expected = {f'Type{i}' for i in range(100)}

def test_large_tuple_with_50_elements():
    """Test tuple with 50 type elements."""
    elts = [
        ast.Name(id=f'T{i}', ctx=ast.Load())
        for i in range(50)
    ]
    node = ast.Tuple(elts=elts, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 15.8μs -> 9.76μs (61.6% faster)
    expected = {f'T{i}' for i in range(50)}

def test_deeply_nested_subscripts_100_levels():
    """Test deeply nested subscripts up to 100 levels deep."""
    # Build List[List[List[...[int]...]]]
    node = ast.Name(id='int', ctx=ast.Load())
    for i in range(100):
        list_base = ast.Name(id='List', ctx=ast.Load())
        node = ast.Subscript(value=list_base, slice=node, ctx=ast.Load())
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 71.8μs -> 31.2μs (130% faster)

def test_large_mixed_structure_with_many_nodes():
    """Test large mixed structure with subscripts, tuples, and unions."""
    # Create a structure like: Union[Tuple[Type0, Type1, ...], List[Type50], Dict[Type60, Type70]]
    # This tests multiple collection points in a large structure
    
    # Build large tuple with 50 elements
    tuple_elts = [
        ast.Name(id=f'Type{i}', ctx=ast.Load())
        for i in range(50)
    ]
    large_tuple = ast.Tuple(elts=tuple_elts, ctx=ast.Load())
    union_part1 = ast.Name(id='Union', ctx=ast.Load())
    union_sub1 = ast.Subscript(
        value=union_part1, slice=large_tuple, ctx=ast.Load()
    )
    
    # Build List[Type50]
    list_base = ast.Name(id='List', ctx=ast.Load())
    list_elem = ast.Name(id='Type50', ctx=ast.Load())
    list_sub = ast.Subscript(value=list_base, slice=list_elem, ctx=ast.Load())
    
    # Build Dict[Type60, Type70]
    dict_base = ast.Name(id='Dict', ctx=ast.Load())
    dict_key = ast.Name(id='Type60', ctx=ast.Load())
    dict_val = ast.Name(id='Type70', ctx=ast.Load())
    dict_slice = ast.Tuple(elts=[dict_key, dict_val], ctx=ast.Load())
    dict_sub = ast.Subscript(value=dict_base, slice=dict_slice, ctx=ast.Load())
    
    # Combine all parts with unions
    combined1 = ast.BinOp(left=union_sub1, op=ast.BitOr(), right=list_sub)
    node = ast.BinOp(left=combined1, op=ast.BitOr(), right=dict_sub)
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 22.7μs -> 13.0μs (74.2% faster)
    # Should contain all types from 0-70
    expected = (
        {'Union', 'List', 'Dict'} |
        {f'Type{i}' for i in range(71)}
    )

def test_200_type_names_with_various_structures():
    """Test collecting 200 unique type names across various structures."""
    # Create a complex structure with many different type names
    name_list = [
        ast.Name(id=f'CustomType{i}', ctx=ast.Load())
        for i in range(200)
    ]
    
    # Build alternating tuple and union structures
    current = name_list[0]
    for i in range(1, len(name_list)):
        if i % 2 == 0:
            # Create tuple
            current = ast.Tuple(elts=[current, name_list[i]], ctx=ast.Load())
        else:
            # Create union
            current = ast.BinOp(
                left=current, op=ast.BitOr(), right=name_list[i]
            )
    
    codeflash_output = collect_type_names_from_annotation(current); result = codeflash_output # 355μs -> 84.7μs (320% faster)
    expected = {f'CustomType{i}' for i in range(200)}

def test_performance_with_balanced_tree_of_unions():
    """Test performance with a balanced binary tree of union operations."""
    # Create a balanced tree structure like ((A|B)|(C|D))|((E|F)|(G|H))
    # This tests efficient recursive handling
    
    def build_balanced_union_tree(start_idx, depth):
        """Recursively build a balanced union tree."""
        if depth == 0:
            return ast.Name(id=f'Type{start_idx}', ctx=ast.Load())
        
        left = build_balanced_union_tree(start_idx, depth - 1)
        right = build_balanced_union_tree(start_idx + 2 ** (depth - 1), depth - 1)
        return ast.BinOp(left=left, op=ast.BitOr(), right=right)
    
    # Build a tree with depth 8 (2^8 = 256 types)
    node = build_balanced_union_tree(0, 8)
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 181μs -> 113μs (60.5% faster)
    expected = {f'Type{i}' for i in range(256)}

def test_repeated_types_in_large_structure():
    """Test large structure where many type names repeat."""
    # Build a structure where the same type names repeat many times
    # e.g., int | str | int | float | int | str ... (500 total with only 4 unique)
    
    base_types = ['int', 'str', 'float', 'bool']
    types = [
        ast.Name(id=base_types[i % len(base_types)], ctx=ast.Load())
        for i in range(500)
    ]
    
    # Build union structure
    node = types[0]
    for type_node in types[1:]:
        node = ast.BinOp(left=node, op=ast.BitOr(), right=type_node)
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 371μs -> 186μs (99.3% faster)
    expected = set(base_types)

def test_wide_tuple_with_1000_elements():
    """Test tuple with 1000 elements."""
    elts = [
        ast.Name(id=f'T{i}', ctx=ast.Load())
        for i in range(1000)
    ]
    node = ast.Tuple(elts=elts, ctx=ast.Load())
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 230μs -> 135μs (70.0% faster)
    expected = {f'T{i}' for i in range(1000)}

def test_union_chain_1000_types():
    """Test union chain with 1000 types."""
    # Build A | B | C | ... (1000 types)
    types = [
        ast.Name(id=f'Type{i}', ctx=ast.Load())
        for i in range(1000)
    ]
    
    node = types[0]
    for type_node in types[1:]:
        node = ast.BinOp(left=node, op=ast.BitOr(), right=type_node)
    
    codeflash_output = collect_type_names_from_annotation(node); result = codeflash_output # 3.37ms -> 391μs (759% faster)
    expected = {f'Type{i}' for i in range(1000)}
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1707-2026-03-03T04.25.05

Click to see suggested changes
Suggested change
if isinstance(node, ast.Name):
return {node.id}
if isinstance(node, ast.Subscript):
names = collect_type_names_from_annotation(node.value)
names |= collect_type_names_from_annotation(node.slice)
return names
if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr):
return collect_type_names_from_annotation(node.left) | collect_type_names_from_annotation(node.right)
if isinstance(node, ast.Tuple):
names = set[str]()
for elt in node.elts:
names |= collect_type_names_from_annotation(elt)
return names
return set()
names = set()
stack = [node]
while stack:
n = stack.pop()
if n is None:
continue
if isinstance(n, ast.Name):
names.add(n.id)
continue
if isinstance(n, ast.Subscript):
# push slice and value to be processed
stack.append(n.slice)
stack.append(n.value)
continue
if isinstance(n, ast.BinOp) and isinstance(n.op, ast.BitOr):
# push both sides of the union
stack.append(n.right)
stack.append(n.left)
continue
if isinstance(n, ast.Tuple):
# process all elements
stack.extend(n.elts)
continue
# all other node types contribute no names
return names

Static Badge

Comment on lines +766 to +785
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
q: deque[ast.AST] = deque([module_tree])
while q:
candidate = q.popleft()
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
class_node = candidate
break
q.extend(ast.iter_child_nodes(candidate))

if class_node is None:
return None

lines = module_source.splitlines()
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
is_relevant = False
if item.name in ("__init__", "__post_init__"):
is_relevant = True
else:
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 146% (1.46x) speedup for extract_init_stub_from_class in codeflash/languages/python/context/code_context_extractor.py

⏱️ Runtime : 1.49 milliseconds 604 microseconds (best of 140 runs)

📝 Explanation and details

The optimization replaced a deque-based BFS traversal (which visits every node in the entire AST) with a two-tier check: first iterating only module_tree.body for top-level classes, then falling back to ast.walk only if the class isn't found there. Line profiler shows the original q.extend(ast.iter_child_nodes(candidate)) consumed 73.8% of runtime, while the optimized two-tier search drops this to 33.6% in worst-case nested scenarios and eliminates it entirely for the common case of top-level classes. The 146% speedup stems from avoiding traversal of irrelevant nested scopes when the target class is at module scope, which test data confirms is the dominant pattern (e.g., the 1000-class benchmark improved 508%).

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 56 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast

import pytest  # used for our unit tests
from codeflash.languages.python.context.code_context_extractor import \
    extract_init_stub_from_class

def test_returns_none_when_class_not_found():
    # Simple module with a different class; searching for a missing class should return None.
    src = "class Present:\n    pass\n"
    tree = ast.parse(src)
    codeflash_output = extract_init_stub_from_class("Missing", src, tree); result = codeflash_output # 6.17μs -> 8.99μs (31.3% slower)

def test_extracts_simple_init_only():
    # Module with one class that has a straightforward __init__.
    # We craft the source so lines and indentation are predictable.
    src = (
        "class Simple:\n"
        "    def __init__(self, x):\n"
        "        self.x = x\n"
    )
    tree = ast.parse(src)

    # The expected output must start with "class Simple:\n" followed by the exact
    # lines from the start of the __init__ to its end (indentation preserved).
    expected = (
        "class Simple:\n"
        "    def __init__(self, x):\n"
        "        self.x = x"
    )

    codeflash_output = extract_init_stub_from_class("Simple", src, tree); result = codeflash_output # 6.73μs -> 3.91μs (72.3% faster)

def test_includes_property_decorator_and_post_init_and_order():
    # Class contains __init__, a non-relevant method, a @property, and __post_init__.
    # The function should return only __init__, the property, and __post_init__
    # in the same order they appear in the class body.
    src = (
        "class WithProps:\n"
        "    def __init__(self, a):\n"
        "        self.a = a\n"
        "    def helper(self):\n"
        "        return 42\n"
        "    @property\n"
        "    def name(self):\n"
        "        return self.a\n"
        "    def __post_init__(self):\n"
        "        self._initialized = True\n"
    )
    tree = ast.parse(src)

    codeflash_output = extract_init_stub_from_class("WithProps", src, tree); result = codeflash_output # 10.1μs -> 7.10μs (42.7% faster)

def test_returns_none_if_class_has_no_relevant_methods():
    # Class exists but has no __init__, no __post_init__, and no @property-decorated methods.
    src = (
        "class NoRelevant:\n"
        "    def method(self):\n"
        "        return 'x'\n"
    )
    tree = ast.parse(src)
    codeflash_output = extract_init_stub_from_class("NoRelevant", src, tree) # 5.05μs -> 2.27μs (122% faster)

def test_handles_empty_module_source_and_tree():
    # Empty source should parse to an ast.Module with empty body and return None.
    src = ""
    tree = ast.parse(src)
    codeflash_output = extract_init_stub_from_class("Anything", src, tree) # 3.05μs -> 6.08μs (49.9% slower)

def test_decorator_attribute_named_property_is_recognized():
    # Use an attribute decorator like @mod.property which should be detected because .attr == "property".
    src = (
        "class AttrProp:\n"
        "    @mymod.property\n"
        "    def val(self):\n"
        "        return 1\n"
    )
    tree = ast.parse(src)
    codeflash_output = extract_init_stub_from_class("AttrProp", src, tree); res = codeflash_output # 7.44μs -> 4.79μs (55.4% faster)

def test_multiple_decorators_uses_earliest_decorator_lineno():
    # When multiple decorators exist, the snippet should start at the minimum decorator lineno.
    src = (
        "class MultiDec:\n"
        "    @first\n"
        "    @property\n"
        "    def p(self):\n"
        "        return 10\n"
    )
    tree = ast.parse(src)
    codeflash_output = extract_init_stub_from_class("MultiDec", src, tree); res = codeflash_output # 7.87μs -> 5.05μs (55.8% faster)

def test_async_post_init_is_recognized():
    # Ensure AsyncFunctionDef nodes are handled (async def __post_init__ should be included).
    src = (
        "class AsyncInit:\n"
        "    async def __post_init__(self):\n"
        "        return None\n"
    )
    tree = ast.parse(src)
    codeflash_output = extract_init_stub_from_class("AsyncInit", src, tree); res = codeflash_output # 6.69μs -> 3.96μs (69.1% faster)

def test_large_number_of_classes_performance_and_correctness():
    # Build a module with 1000 small classes to evaluate scalability and correctness.
    # Each class named C{i} has a simple __init__ that sets a single attribute.
    N = 1000
    target_index = 500  # pick the middle class to extract
    lines = []
    for i in range(N):
        lines.append(f"class C{i}:")
        lines.append(f"    def __init__(self, x):")
        lines.append(f"        self.x = {i}")
        # add a small separator (blank line) between classes to be realistic
        lines.append("")
    src = "\n".join(lines)
    tree = ast.parse(src)

    # Extract stub for the target class
    class_name = f"C{target_index}"
    codeflash_output = extract_init_stub_from_class(class_name, src, tree); res = codeflash_output # 910μs -> 149μs (508% faster)

    # Build expected snippet for C{target_index} exactly as we constructed the source.
    expected_snippet_lines = [
        f"class {class_name}:",
        f"    def __init__(self, x):",
        f"        self.x = {target_index}",
    ]
    expected = "\n".join(expected_snippet_lines)

def test_class_found_when_nested_and_bfs_order_applies():
    # Construct a module where a class with the target name appears nested inside another class
    # and also appears later at top-level. The BFS-style search should find the first matching
    # ClassDef in AST traversal order (which here will be the nested one if placed earlier).
    src = (
        "class Outer:\n"
        "    class Target:\n"
        "        def __init__(self):\n"
        "            self.x = 1\n"
        "class Target:\n"
        "    def __init__(self):\n"
        "        self.x = 2\n"
    )
    tree = ast.parse(src)
    # The function does a deque BFS starting from module_tree; the nested Target appears
    # earlier in the traversal order because it's a child of Outer and will be discovered
    # before the later top-level Target. We therefore expect the nested definition to be used.
    codeflash_output = extract_init_stub_from_class("Target", src, tree); res = codeflash_output # 9.48μs -> 4.36μs (117% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast

# imports
import pytest
from codeflash.languages.python.context.code_context_extractor import \
    extract_init_stub_from_class

def test_extract_simple_init():
    """Test extracting a simple __init__ method from a class."""
    source = """\
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.61μs -> 3.44μs (92.4% faster)

def test_extract_init_with_body():
    """Test extracting __init__ with actual implementation."""
    source = """\
class MyClass:
    def __init__(self, x):
        self.x = x
        self.y = x * 2
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.59μs -> 3.70μs (78.3% faster)

def test_extract_post_init():
    """Test extracting __post_init__ method (used in dataclasses)."""
    source = """\
class MyClass:
    def __post_init__(self):
        self.computed = self.a + self.b
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.54μs -> 3.59μs (82.4% faster)

def test_extract_property():
    """Test extracting a property decorated method."""
    source = """\
class MyClass:
    @property
    def value(self):
        return self._value
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.61μs -> 4.60μs (65.6% faster)

def test_extract_init_and_property():
    """Test extracting both __init__ and @property methods."""
    source = """\
class MyClass:
    def __init__(self, x):
        self._x = x
    
    @property
    def x(self):
        return self._x
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 8.67μs -> 5.72μs (51.5% faster)

def test_nonexistent_class_returns_none():
    """Test that extracting from a non-existent class returns None."""
    source = """\
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("NonExistent", source, tree); result = codeflash_output # 11.2μs -> 14.6μs (23.0% slower)

def test_class_without_relevant_methods_returns_none():
    """Test that a class without __init__, __post_init__, or @property returns None."""
    source = """\
class MyClass:
    def some_method(self):
        return 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 5.11μs -> 2.31μs (122% faster)

def test_class_with_init_and_unrelated_methods():
    """Test that only relevant methods (__init__, __post_init__, @property) are extracted."""
    source = """\
class MyClass:
    def __init__(self):
        self.x = 1
    
    def helper_method(self):
        return self.x * 2
    
    def another_method(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.54μs -> 4.77μs (58.2% faster)

def test_nested_class():
    """Test extracting from a nested class."""
    source = """\
class Outer:
    class Inner:
        def __init__(self):
            pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("Inner", source, tree); result = codeflash_output # 7.97μs -> 13.9μs (42.4% slower)

def test_multiple_classes():
    """Test extracting from correct class when multiple classes exist."""
    source = """\
class FirstClass:
    def __init__(self):
        self.a = 1

class SecondClass:
    def __init__(self):
        self.b = 2
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("SecondClass", source, tree); result = codeflash_output # 8.89μs -> 4.06μs (119% faster)

def test_decorated_property_with_custom_decorator():
    """Test extracting property with attribute-based decorator like 'obj.property'."""
    source = """\
class MyClass:
    @some_obj.property
    def value(self):
        return 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.44μs -> 4.73μs (57.4% faster)

def test_init_with_multiline_implementation():
    """Test extracting __init__ with multiple lines of code."""
    source = """\
class MyClass:
    def __init__(self, a, b, c):
        self.a = a
        self.b = b
        self.c = c
        self.sum = a + b + c
        self.product = a * b * c
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.10μs -> 4.11μs (72.9% faster)

def test_property_with_setter():
    """Test extracting property with multiple decorators."""
    source = """\
class MyClass:
    @property
    def x(self):
        return self._x
    
    @x.setter
    def x(self, value):
        self._x = value
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 8.62μs -> 5.71μs (50.9% faster)

def test_async_init():
    """Test that async functions are NOT extracted (only FunctionDef and AsyncFunctionDef with __init__/__post_init__)."""
    source = """\
class MyClass:
    async def __init__(self):
        self.x = 1
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.56μs -> 3.85μs (70.5% faster)

def test_class_with_only_async_property():
    """Test extracting from class with async method decorated as property."""
    source = """\
class MyClass:
    @property
    async def value(self):
        return 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.46μs -> 4.92μs (51.7% faster)

def test_empty_class_body():
    """Test handling of empty class (edge case - no body)."""
    source = """\
class MyClass:
    pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 4.53μs -> 1.89μs (139% faster)

def test_class_name_case_sensitive():
    """Test that class name matching is case-sensitive."""
    source = """\
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("myclass", source, tree); result = codeflash_output # 10.9μs -> 14.3μs (23.5% slower)

def test_init_with_decorator():
    """Test extracting __init__ that has decorators (though uncommon)."""
    source = """\
class MyClass:
    @some_decorator
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.00μs -> 4.23μs (65.6% faster)

def test_multiple_decorators_on_property():
    """Test property with multiple decorators."""
    source = """\
class MyClass:
    @decorator1
    @property
    def value(self):
        return 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.75μs -> 5.18μs (49.7% faster)

def test_class_with_docstring():
    """Test extracting from class with a docstring."""
    source = '''\
class MyClass:
    """This is a docstring."""
    def __init__(self):
        self.x = 1
'''
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.74μs -> 3.93μs (71.7% faster)

def test_class_with_class_variables():
    """Test class with class-level variables (should not affect extraction)."""
    source = """\
class MyClass:
    class_var = 10
    
    def __init__(self):
        self.x = 1
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.69μs -> 3.76μs (78.1% faster)

def test_init_with_default_arguments():
    """Test __init__ with default argument values."""
    source = """\
class MyClass:
    def __init__(self, x=10, y=20):
        self.x = x
        self.y = y
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.56μs -> 3.67μs (79.0% faster)

def test_init_with_args_kwargs():
    """Test __init__ with *args and **kwargs."""
    source = """\
class MyClass:
    def __init__(self, *args, **kwargs):
        self.args = args
        self.kwargs = kwargs
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.56μs -> 3.54μs (85.6% faster)

def test_post_init_without_init():
    """Test extracting __post_init__ when __init__ doesn't exist."""
    source = """\
class MyClass:
    def __post_init__(self):
        self.computed = 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.29μs -> 3.47μs (81.5% faster)

def test_property_with_multiple_lines():
    """Test property with multi-line implementation."""
    source = """\
class MyClass:
    @property
    def computed_value(self):
        x = self.a + self.b
        y = x * 2
        return y
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.69μs -> 4.90μs (57.1% faster)

def test_deeply_nested_classes():
    """Test extracting from deeply nested classes."""
    source = """\
class Outer:
    class Middle:
        class Inner:
            def __init__(self):
                self.x = 1
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("Inner", source, tree); result = codeflash_output # 9.65μs -> 15.5μs (37.9% slower)

def test_class_with_comments():
    """Test that comments in code are preserved."""
    source = """\
class MyClass:
    def __init__(self):
        # This is a comment
        self.x = 1  # inline comment
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.48μs -> 3.70μs (75.3% faster)

def test_class_with_type_hints():
    """Test extracting __init__ with type hints."""
    source = """\
class MyClass:
    def __init__(self, x: int, y: str) -> None:
        self.x = x
        self.y = y
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.63μs -> 3.69μs (79.9% faster)

def test_property_with_type_hints():
    """Test extracting property with type hints."""
    source = """\
class MyClass:
    @property
    def value(self) -> int:
        return 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.36μs -> 4.44μs (65.9% faster)

def test_init_with_multiline_string_literal():
    """Test __init__ containing multiline string literals."""
    source = '''\
class MyClass:
    def __init__(self):
        self.text = """
        This is a multiline
        string literal.
        """
'''
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.50μs -> 3.77μs (72.6% faster)

def test_class_with_staticmethod():
    """Test that staticmethod without __init__ returns None."""
    source = """\
class MyClass:
    @staticmethod
    def static_method():
        return 42
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 5.37μs -> 2.63μs (104% faster)

def test_class_with_classmethod():
    """Test that classmethod without __init__ returns None."""
    source = """\
class MyClass:
    @classmethod
    def from_string(cls, s):
        return cls()
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 5.26μs -> 2.50μs (110% faster)

def test_result_includes_class_declaration():
    """Test that result always includes the class declaration line."""
    source = """\
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.12μs -> 3.41μs (79.7% faster)

def test_multiple_init_like_methods():
    """Test class with both __init__ and __post_init__."""
    source = """\
class MyClass:
    def __init__(self):
        self.x = 1
    
    def __post_init__(self):
        self.y = self.x + 1
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.59μs -> 4.67μs (62.7% faster)

def test_indentation_preserved():
    """Test that indentation in extracted code is preserved."""
    source = """\
class MyClass:
    def __init__(self):
        if True:
            self.x = 1
        else:
            self.x = 0
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.67μs -> 3.87μs (72.5% faster)
    # Check that indentation structure is present
    lines = result.split('\n')

def test_empty_string_class_name():
    """Test with empty string as class name."""
    source = """\
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("", source, tree); result = codeflash_output # 10.8μs -> 14.1μs (23.3% slower)

def test_special_characters_in_class_body():
    """Test __init__ with special characters in strings."""
    source = r'''\
class MyClass:
    def __init__(self):
        self.path = "C:\\Users\\test"
        self.regex = r"\d+"
'''
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.88μs -> 3.95μs (74.4% faster)

def test_lambda_in_init():
    """Test __init__ with lambda expressions."""
    source = """\
class MyClass:
    def __init__(self):
        self.func = lambda x: x * 2
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.43μs -> 3.60μs (78.9% faster)

def test_many_unrelated_methods():
    """Test class with many unrelated methods (only relevant ones extracted)."""
    # Generate a class with 100 unrelated methods
    lines = ["class MyClass:"]
    for i in range(100):
        lines.append(f"    def method_{i}(self):")
        lines.append(f"        return {i}")
    lines.append("    def __init__(self):")
    lines.append("        self.x = 1")
    
    source = "\n".join(lines)
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 36.2μs -> 31.9μs (13.7% faster)

def test_many_properties():
    """Test class with many property methods."""
    # Generate a class with 50 properties
    lines = ["class MyClass:"]
    for i in range(50):
        lines.append(f"    @property")
        lines.append(f"    def prop_{i}(self):")
        lines.append(f"        return {i}")
    
    source = "\n".join(lines)
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 54.4μs -> 50.6μs (7.53% faster)

def test_large_init_method():
    """Test extracting a very large __init__ method with 100+ statements."""
    lines = ["class MyClass:"]
    lines.append("    def __init__(self):")
    for i in range(100):
        lines.append(f"        self.var_{i} = {i}")
    
    source = "\n".join(lines)
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 13.4μs -> 10.1μs (31.8% faster)

def test_many_nested_classes():
    """Test finding correct class among many nested classes."""
    lines = ["class Outer:"]
    for i in range(50):
        lines.append(f"    class Inner_{i}:")
        lines.append(f"        def method_{i}(self):")
        lines.append(f"            pass")
    
    lines.append("    class Target:")
    lines.append("        def __init__(self):")
    lines.append("            self.target = True")
    
    source = "\n".join(lines)
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("Target", source, tree); result = codeflash_output # 92.5μs -> 89.5μs (3.39% faster)

def test_complex_source_with_multiple_features():
    """Test extraction from complex source with imports, functions, and multiple classes."""
    source = """\
import os
from typing import Optional

def standalone_function():
    return 42

class FirstClass:
    def regular_method(self):
        return 1

class SecondClass:
    def __init__(self, value: int):
        self.value = value
        self.computed = value * 2
    
    @property
    def double(self) -> int:
        return self.value * 2
    
    def helper(self):
        pass

class ThirdClass:
    pass

def another_function():
    pass
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("SecondClass", source, tree); result = codeflash_output # 17.5μs -> 8.63μs (103% faster)

def test_preserves_exact_source_text():
    """Test that extracted code preserves exact source formatting."""
    source = """\
class MyClass:
    def __init__(self,   x,   y):
        self.x=x
        self.y   =   y
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 6.89μs -> 3.91μs (76.4% faster)

def test_class_with_very_long_method_name():
    """Test extraction with very long method names."""
    method_name = "a" * 200
    source = f"""\
class MyClass:
    def {method_name}(self):
        pass
    
    def __init__(self):
        self.x = 1
"""
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 7.33μs -> 4.39μs (67.1% faster)

def test_lineno_end_lineno_consistency():
    """Test that line number extraction handles end_lineno correctly for large methods."""
    # Create a method that spans many lines
    lines = ["class MyClass:"]
    lines.append("    def __init__(self):")
    for i in range(50):
        lines.append(f"        self.x_{i} = {i}")
    
    source = "\n".join(lines)
    tree = ast.parse(source)
    codeflash_output = extract_init_stub_from_class("MyClass", source, tree); result = codeflash_output # 9.68μs -> 6.71μs (44.2% faster)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1707-2026-03-03T04.30.55

Click to see suggested changes
Suggested change
# Use a deque-based BFS to find the first matching ClassDef (preserves ast.walk order)
q: deque[ast.AST] = deque([module_tree])
while q:
candidate = q.popleft()
if isinstance(candidate, ast.ClassDef) and candidate.name == class_name:
class_node = candidate
break
q.extend(ast.iter_child_nodes(candidate))
if class_node is None:
return None
lines = module_source.splitlines()
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
is_relevant = False
if item.name in ("__init__", "__post_init__"):
is_relevant = True
else:
for node in module_tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_node = node
break
else:
# Check nested classes only if not found at top level
for node in ast.walk(module_tree):
if isinstance(node, ast.ClassDef) and node.name == class_name:
class_node = node
break
if class_node is None:
return None
lines = module_source.splitlines()
relevant_nodes: list[ast.FunctionDef | ast.AsyncFunctionDef] = []
for item in class_node.body:
if isinstance(item, (ast.FunctionDef, ast.AsyncFunctionDef)):
is_relevant = item.name in ("__init__", "__post_init__")
if not is_relevant:
# Single pass through decorators with early exit

Static Badge

Comment on lines +980 to +995
for node in module_tree.body:
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == name:
value = node.value
if isinstance(value, ast.Call):
func = value.func
if isinstance(func, ast.Name):
return func.id
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
return func.value.id
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == name:
ann = node.annotation
if isinstance(ann, ast.Name):
return ann.id
if isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name):
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚡️Codeflash found 21% (0.21x) speedup for resolve_instance_class_name in codeflash/languages/python/context/code_context_extractor.py

⏱️ Runtime : 5.01 milliseconds 4.13 milliseconds (best of 43 runs)

📝 Explanation and details

The optimization hoists six frequently-used AST class references (ast.Assign, ast.Name, etc.) out of the hot loop into local variables, eliminating ~79,000 global attribute lookups per invocation (based on 27,523 iterations × ~3 isinstance checks each). Line profiler shows the two most expensive operations (isinstance(node, ast.Assign) at 22.9% and for target in node.targets at 30.4%) both dropped by ~14% and ~5% respectively in absolute time. Individual test cases regressed by 1–39% on tiny modules (sub-2 µs) where setup overhead dominates, but large-module tests improved by 17–31%, confirming the win scales with AST size—the target use-case for this function.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 255 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 🔘 None Found
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
from __future__ import annotations

import ast  # used to build and parse Python source into an AST for testing
import ast as _ast  # alias to avoid shadowing the test module's ast name

# imports
import pytest  # used for our unit tests
from codeflash.languages.python.context.code_context_extractor import \
    resolve_instance_class_name

def test_simple_assign_returns_constructor_name():
    # Build a simple module: foo = Bar()
    src = "foo = Bar()\n"
    module = ast.parse(src)  # parse source into an AST Module instance
    # Expect the function to detect the Call and return the function name 'Bar'
    codeflash_output = resolve_instance_class_name("foo", module) # 1.26μs -> 1.53μs (17.7% slower)

def test_assign_with_attribute_func_returns_value_name():
    # Build a module where the call is pkg.Bar() -- function should return the Name of the value (pkg)
    src = "foo = pkg.Bar()\n"
    module = ast.parse(src)  # AST where Call.func is Attribute(value=Name('pkg'), attr='Bar')
    # The implementation returns func.value.id which should be 'pkg'
    codeflash_output = resolve_instance_class_name("foo", module) # 1.54μs -> 1.68μs (8.32% slower)

def test_multiple_targets_assignments_match_each_target():
    # Build a chained assignment: a = b = MyClass()
    src = "a = b = MyClass()\n"
    module = ast.parse(src)
    # Both 'a' and 'b' should resolve to the constructor name 'MyClass'
    codeflash_output = resolve_instance_class_name("a", module) # 1.22μs -> 1.33μs (8.25% slower)
    codeflash_output = resolve_instance_class_name("b", module) # 892ns -> 932ns (4.29% slower)

def test_annassign_with_simple_annotation_returns_annotation_name():
    # Annotated assignment: foo: Bar = Bar()
    src = "foo: Bar = Bar()\n"
    module = ast.parse(src)
    # For AnnAssign the function should prefer the annotation 'Bar'
    codeflash_output = resolve_instance_class_name("foo", module) # 1.33μs -> 1.45μs (8.33% slower)

def test_annassign_with_subscript_annotation_returns_subscript_value_name():
    # Annotated with a subscript whose value is a Name: foo: List[Bar]
    src = "from typing import List\nfoo: List[Bar]\n"
    module = ast.parse(src)
    # The annotation is a Subscript with value Name('List'), so expect 'List'
    codeflash_output = resolve_instance_class_name("foo", module) # 1.68μs -> 1.73μs (2.94% slower)

def test_name_not_present_returns_none():
    # Empty module or no matching name should yield None
    module = ast.parse("")  # no body
    codeflash_output = resolve_instance_class_name("nonexistent", module) # 521ns -> 852ns (38.8% slower)

def test_assign_with_non_name_target_is_ignored():
    # Assign to an attribute (obj.attr = MyClass()) - target is Attribute, not Name
    src = "obj.attr = MyClass()\n"
    module = ast.parse(src)
    # The code only checks Name targets, so should return None for name 'attr'
    codeflash_output = resolve_instance_class_name("attr", module) # 1.04μs -> 1.14μs (8.76% slower)
    # Also ensure that asking for 'obj' returns None since the target wasn't a Name('obj')
    codeflash_output = resolve_instance_class_name("obj", module) # 460ns -> 611ns (24.7% slower)

def test_deep_attribute_call_value_does_not_match():
    # Call where func is Attribute(value=Attribute(...)) e.g., pkg.subpkg.Class()
    src = "foo = pkg.subpkg.Class()\n"
    module = ast.parse(src)
    # The implementation only returns func.value.id when func.value is a Name.
    # Here func.value is another Attribute, so it should not match and return None.
    codeflash_output = resolve_instance_class_name("foo", module) # 1.52μs -> 1.70μs (10.6% slower)

def test_assign_value_not_a_call_returns_none():
    # Assignment where the value is not a Call, e.g., foo = 123
    src = "foo = 123\n"
    module = ast.parse(src)
    # Since value isn't ast.Call, it should not identify a constructor name
    codeflash_output = resolve_instance_class_name("foo", module) # 1.16μs -> 1.36μs (14.7% slower)

def test_annassign_subscript_with_attribute_value_returns_none():
    # Annotation using an Attribute as the subscript value: foo: typing.List[Bar]
    src = "import typing\nfoo: typing.List[Bar]\n"
    module = ast.parse(src)
    # The Subscript.value here is an Attribute (typing.List), not a Name, so implementation should return None
    codeflash_output = resolve_instance_class_name("foo", module) # 1.67μs -> 1.76μs (5.10% slower)

def test_first_occurrence_is_preferred_when_multiple_declarations():
    # Multiple declarations for the same name: the function should return the first match encountered.
    src = "foo = A()\nfoo = B()\n"
    module = ast.parse(src)
    # Since the walker inspects nodes in order and returns on first match, expect 'A'
    codeflash_output = resolve_instance_class_name("foo", module) # 1.24μs -> 1.36μs (8.88% slower)

def test_empty_string_name_returns_none():
    # Passing an empty string as the name should not match any variable identifiers
    src = " = 1\n"  # invalid code would raise on parse, so use a valid assignment and query empty name
    module = ast.parse("x = 1\n")
    codeflash_output = resolve_instance_class_name("", module) # 1.01μs -> 1.18μs (14.4% slower)

def test_name_none_type_behaves_like_non_matching():
    # Although the function is annotated to accept str, it doesn't type-check at runtime.
    # Passing None should simply not match any Name.id and therefore return None.
    module = ast.parse("foo = Bar()\n")
    codeflash_output = resolve_instance_class_name(None, module) # 1.13μs -> 1.27μs (11.0% slower)

def test_large_module_many_assignments_resolves_target_at_end():
    # Build a large module with 1000 assignments; the last assignment uses SpecialClass()
    num = 1000  # exercise the function with a large but bounded number of nodes
    lines = [f"v{i} = X()\n" for i in range(num - 1)]
    lines.append("target_var = SpecialClass()\n")  # target of interest at the end
    src = "".join(lines)
    module = ast.parse(src)
    # Ensure that the function still resolves the final variable correctly in a large AST
    codeflash_output = resolve_instance_class_name("target_var", module) # 184μs -> 169μs (8.85% faster)

def test_repeated_calls_on_large_module_are_consistent():
    # Parse a module with many declarations once and call the resolver 1000 times to simulate repeated lookups.
    # This verifies deterministic behavior and that the function does not mutate the AST.
    src = "\n".join([f"item{i} = Factory()" for i in range(200)])  # create 200 items to keep runtime reasonable
    module = ast.parse(src)
    # Call resolve_instance_class_name repeatedly and verify consistent results
    for i in range(200):
        name = f"item{i}"
        # Every lookup should return 'Factory'
        codeflash_output = resolve_instance_class_name(name, module) # 3.46ms -> 2.81ms (23.1% faster)

def test_mixed_assign_and_annassign_prefers_assign_call_when_before_annotation():
    # When both an Assign (call) and later an AnnAssign exist for same name, the first encountered should be returned.
    src = "mix = Maker()\nmix: Annot = None\n"
    module = ast.parse(src)
    # The Assign appears before the AnnAssign and should be returned ('Maker').
    codeflash_output = resolve_instance_class_name("mix", module) # 1.24μs -> 1.49μs (16.8% slower)

def test_mixed_annassign_before_assign_prefers_annotation():
    # When an AnnAssign comes before an Assign, the annotation should be used.
    src = "mix: Annot = None\nmix = Maker()\n"
    module = ast.parse(src)
    # The AnnAssign is encountered first and returns its annotation name ('Annot')
    codeflash_output = resolve_instance_class_name("mix", module) # 1.41μs -> 1.54μs (8.43% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.
import ast

# imports
import pytest
from codeflash.languages.python.context.code_context_extractor import \
    resolve_instance_class_name

def test_simple_assignment_with_function_call():
    # Test that a simple assignment like `obj = ClassName()` returns the class name
    code = "obj = ClassName()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.30μs -> 1.44μs (9.64% slower)

def test_assignment_with_module_attribute_call():
    # Test that an assignment like `obj = module.ClassName()` returns the module name
    code = "obj = module.ClassName()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.51μs -> 1.64μs (7.91% slower)

def test_annotated_assignment_with_simple_type():
    # Test that an annotated assignment like `obj: MyClass = ...` returns the class name
    code = "obj: MyClass = MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.28μs -> 1.41μs (9.27% slower)

def test_annotated_assignment_with_generic_type():
    # Test that an annotated assignment with a generic like `obj: List[int] = ...` returns the generic base
    code = "obj: List[int] = []"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.43μs -> 1.45μs (1.38% slower)

def test_variable_not_found():
    # Test that looking for a variable that doesn't exist returns None
    code = "obj = MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("nonexistent", module_tree); result = codeflash_output # 1.00μs -> 1.20μs (16.6% slower)

def test_empty_module():
    # Test that an empty module returns None
    code = ""
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 561ns -> 772ns (27.3% slower)

def test_multiple_assignments_find_first_match():
    # Test that when multiple assignments exist, the first matching one is returned
    code = """
obj = FirstClass()
obj = SecondClass()
"""
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.35μs -> 1.40μs (3.64% slower)

def test_multiple_targets_in_single_assignment():
    # Test assignment with multiple targets: `a = b = MyClass()`
    code = "a = b = MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("a", module_tree); result = codeflash_output # 1.22μs -> 1.42μs (14.1% slower)

def test_annotation_without_value():
    # Test annotated variable without assignment: `obj: MyClass`
    code = "obj: MyClass"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.29μs -> 1.43μs (9.78% slower)

def test_assignment_with_non_call_value():
    # Test assignment where value is not a function call: `obj = some_variable`
    code = "obj = some_variable"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.11μs -> 1.28μs (13.3% slower)

def test_assignment_with_literal_value():
    # Test assignment with literal: `obj = 42`
    code = "obj = 42"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.14μs -> 1.27μs (10.2% slower)

def test_call_with_no_attribute_access():
    # Test call on undefined expression (not Name or Attribute)
    code = "obj = (get_class())()  # Nested call"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.31μs -> 1.44μs (9.01% slower)

def test_annotation_with_complex_subscript():
    # Test annotation like `obj: Dict[str, List[int]]` returns the base type
    code = "obj: Dict[str, List[int]] = {}"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.49μs -> 1.56μs (4.48% slower)

def test_annotation_with_nested_subscript():
    # Test annotation with nested subscript
    code = "obj: Optional[List[MyClass]] = None"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.33μs -> 1.50μs (11.3% slower)

def test_call_with_arguments():
    # Test function call with arguments: `obj = MyClass(arg1, arg2)`
    code = "obj = MyClass(arg1, arg2)"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.24μs -> 1.38μs (10.1% slower)

def test_call_with_keyword_arguments():
    # Test function call with keyword arguments
    code = "obj = MyClass(key='value', count=42)"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.23μs -> 1.36μs (9.54% slower)

def test_module_path_with_single_attribute():
    # Test module.ClassName() returns module
    code = "obj = pkg.MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.55μs -> 1.55μs (0.064% slower)

def test_deeply_nested_module_attribute():
    # Test that nested module attributes like `a.b.c()` only returns the first level
    code = "obj = a.b.c()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.44μs -> 1.51μs (4.69% slower)

def test_other_statements_mixed_in():
    # Test that other types of statements don't interfere
    code = """
x = 10
def func():
    pass
obj = MyClass()
if True:
    pass
"""
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.87μs -> 1.97μs (5.12% slower)

def test_annotation_that_is_not_name_or_subscript():
    # Test annotation with complex expression
    code = "obj: some_func() = None"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.28μs -> 1.48μs (13.6% slower)

def test_variable_name_with_underscores():
    # Test variable names with underscores: `_obj = MyClass()`
    code = "_obj = MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("_obj", module_tree); result = codeflash_output # 1.24μs -> 1.31μs (5.33% slower)

def test_variable_name_with_numbers():
    # Test variable names with numbers: `obj2 = MyClass()`
    code = "obj2 = MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj2", module_tree); result = codeflash_output # 1.18μs -> 1.25μs (5.59% slower)

def test_class_name_with_underscores():
    # Test class names with underscores: `obj = My_Class()`
    code = "obj = My_Class()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.14μs -> 1.25μs (8.79% slower)

def test_case_sensitive_lookup():
    # Test that variable name lookup is case-sensitive
    code = "Obj = MyClass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 932ns -> 1.19μs (21.8% slower)

def test_case_sensitive_class_name():
    # Test that returned class name preserves case
    code = "obj = myclass()"
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.18μs -> 1.29μs (8.51% slower)

def test_many_statements_before_target():
    # Test performance with many statements before the target variable
    # Create a module with 1000 statements before the target
    code_lines = ["x{} = Value{}()".format(i, i) for i in range(1000)]
    code_lines.append("obj = TargetClass()")
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 200μs -> 157μs (26.8% faster)

def test_many_statements_with_no_match():
    # Test that searching through many statements with no match still returns None
    code_lines = ["x{} = Value{}()".format(i, i) for i in range(1000)]
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("nonexistent", module_tree); result = codeflash_output # 196μs -> 159μs (23.8% faster)

def test_many_multiple_targets_in_assignments():
    # Test performance with assignments having multiple targets
    code_lines = ["a{} = b{} = c{} = Class{}()".format(i, i, i, i) for i in range(500)]
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("a100", module_tree); result = codeflash_output # 38.3μs -> 31.5μs (21.8% faster)

def test_many_annotated_assignments():
    # Test performance with many annotated assignments
    code_lines = ["var{}: Type{} = Type{}()".format(i, i, i) for i in range(1000)]
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("var500", module_tree); result = codeflash_output # 105μs -> 87.5μs (20.7% faster)

def test_mixed_statements_large_module():
    # Test with a large module mixing different statement types
    code_lines = []
    for i in range(500):
        code_lines.append("regular{} = Class{}()".format(i, i))
        code_lines.append("annotated{}: Type{} = Type{}()".format(i, i, i))
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("regular250", module_tree); result = codeflash_output # 111μs -> 85.5μs (31.0% faster)

def test_large_module_find_annotated_far_in():
    # Test finding an annotated assignment far into a large module
    code_lines = []
    for i in range(500):
        code_lines.append("regular{} = Class{}()".format(i, i))
        code_lines.append("annotated{}: Type{} = Type{}()".format(i, i, i))
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("annotated750", module_tree); result = codeflash_output # 201μs -> 171μs (17.4% faster)

def test_many_module_qualified_calls():
    # Test performance with many module-qualified function calls
    code_lines = ["obj{} = module.Class{}()".format(i, i) for i in range(1000)]
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj999", module_tree); result = codeflash_output # 198μs -> 167μs (18.5% faster)

def test_large_complex_subscript_annotations():
    # Test performance with many complex generic type annotations
    code_lines = ["var{}: Dict[str, List[Optional[Type{}]]] = None".format(i, i) for i in range(500)]
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("var250", module_tree); result = codeflash_output # 58.7μs -> 50.1μs (17.1% faster)

def test_many_irrelevant_statements():
    # Test that function correctly ignores many irrelevant statements
    code_lines = []
    for i in range(500):
        code_lines.append("x{} = {}".format(i, i))  # Non-call assignments
        code_lines.append("y{} = some_var".format(i))
    code_lines.append("obj = MyClass()")
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 192μs -> 167μs (15.0% faster)

def test_all_variables_named_same_in_different_scopes():
    # Test that top-level assignments are found (even if different scopes exist)
    code = """
def func():
    obj = LocalClass()
obj = GlobalClass()
"""
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 1.73μs -> 1.83μs (5.46% slower)

def test_attributes_not_top_level():
    # Test that assignments inside classes/functions are not found
    code = """
class MyClass:
    obj = NestedClass()
"""
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("obj", module_tree); result = codeflash_output # 962ns -> 1.14μs (15.8% slower)

def test_performance_early_termination():
    # Test that the function returns immediately upon finding a match
    code_lines = ["target = TargetClass()"]
    code_lines.extend(["x{} = Class{}()".format(i, i) for i in range(1000)])
    code = "\n".join(code_lines)
    module_tree = ast.parse(code)
    codeflash_output = resolve_instance_class_name("target", module_tree); result = codeflash_output # 1.78μs -> 1.99μs (10.5% slower)
# codeflash_output is used to check that the output of the original code is the same as that of the optimized code.

To test or edit this optimization locally git merge codeflash/optimize-pr1707-2026-03-03T04.40.52

Click to see suggested changes
Suggested change
for node in module_tree.body:
if isinstance(node, ast.Assign):
for target in node.targets:
if isinstance(target, ast.Name) and target.id == name:
value = node.value
if isinstance(value, ast.Call):
func = value.func
if isinstance(func, ast.Name):
return func.id
if isinstance(func, ast.Attribute) and isinstance(func.value, ast.Name):
return func.value.id
elif isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name) and node.target.id == name:
ann = node.annotation
if isinstance(ann, ast.Name):
return ann.id
if isinstance(ann, ast.Subscript) and isinstance(ann.value, ast.Name):
AstAssign = ast.Assign
AstAnnAssign = ast.AnnAssign
AstName = ast.Name
AstCall = ast.Call
AstAttribute = ast.Attribute
AstSubscript = ast.Subscript
for node in module_tree.body:
if isinstance(node, AstAssign):
# iterate targets only when this is an Assign
for target in node.targets:
if isinstance(target, AstName) and target.id == name:
value = node.value
if isinstance(value, AstCall):
func = value.func
if isinstance(func, AstName):
return func.id
if isinstance(func, AstAttribute) and isinstance(func.value, AstName):
return func.value.id
elif isinstance(node, AstAnnAssign) and isinstance(node.target, AstName) and node.target.id == name:
ann = node.annotation
if isinstance(ann, AstName):
return ann.id
if isinstance(ann, AstSubscript) and isinstance(ann.value, AstName):

Static Badge

@codeflash-ai
Copy link
Copy Markdown
Contributor Author

codeflash-ai Bot commented Mar 3, 2026

⚡️ Codeflash found optimizations for this PR

📄 244,678% (2,446.78x) speedup for is_subagent_mode in codeflash/lsp/helpers.py

⏱️ Runtime : 979 microseconds 400 nanoseconds (best of 158 runs)

A new Optimization Review has been created.

🔗 Review here

Static Badge

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash workflow-modified This PR modifies GitHub Actions workflows

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants