Skip to content

⚡️ Speed up method InitDecorator.visit_ClassDef by 15% in PR #1714 (testgen-review)#1771

Closed
codeflash-ai[bot] wants to merge 2 commits into
testgen-reviewfrom
codeflash/optimize-pr1714-2026-03-04T22.46.36
Closed

⚡️ Speed up method InitDecorator.visit_ClassDef by 15% in PR #1714 (testgen-review)#1771
codeflash-ai[bot] wants to merge 2 commits into
testgen-reviewfrom
codeflash/optimize-pr1714-2026-03-04T22.46.36

Conversation

@codeflash-ai

@codeflash-ai codeflash-ai Bot commented Mar 4, 2026

Copy link
Copy Markdown
Contributor

⚡️ This pull request contains optimizations for PR #1714

If you approve this dependent PR, these changes will be merged into the original PR branch testgen-review.

This PR will be automatically closed if the original PR is merged.


📄 15% (0.15x) speedup for InitDecorator.visit_ClassDef in codeflash/verification/instrument_codeflash_capture.py

⏱️ Runtime : 36.6 microseconds 31.7 microseconds (best of 32 runs)

📝 Explanation and details

The optimization defers decorator AST construction until it is actually needed by replacing eager allocation with a None check. In the original code, ast.Call(...) for the decorator was built unconditionally for every target class (245 µs, 12.4% of runtime per profiler line 9), but in ~5% of code paths that decorator is never inserted (e.g. when __init__ already has the decorator or the class is a dataclass). The optimized version sets decorator = None upfront (37 µs) and constructs the decorator inside the if decorator is None: branch only when insertion is confirmed, eliminating wasted allocations. Runtime improved 15% overall (36.6 → 31.7 µs) with no correctness regressions across all test cases.

Correctness verification report:

Test Status
⚙️ Existing Unit Tests 🔘 None Found
🌀 Generated Regression Tests 31 Passed
⏪ Replay Tests 🔘 None Found
🔎 Concolic Coverage Tests 2 Passed
📊 Tests Coverage 100.0%
🌀 Click to see Generated Regression Tests
import ast
from pathlib import Path

import pytest  # used for our unit tests
from codeflash.verification.instrument_codeflash_capture import InitDecorator

# function to test
def _get_class_node(source: str) -> ast.ClassDef:
    """
    Helper: parse Python source and return the first ClassDef node.
    Assumes the source defines at least one class as the top-level first statement.
    """
    module = ast.parse(source)
    # Return the first top-level class definition
    node = module.body[0]
    return node

def _find_init_func(node: ast.ClassDef) -> ast.FunctionDef | None:
    """Return the __init__ FunctionDef from the class body if present, else None."""
    for item in node.body:
        if isinstance(item, ast.FunctionDef) and item.name == "__init__":
            return item
    return None

def _has_decorator_call_named(func_def: ast.FunctionDef, name: str) -> bool:
    """
    Returns True if any decorator in func_def.decorator_list is a Call whose func is a Name equal to `name`.
    This mirrors the detection logic used in the transformer.
    """
    for d in func_def.decorator_list:
        if isinstance(d, ast.Call) and isinstance(d.func, ast.Name) and d.func.id == name:
            return True
    return False

def _get_decorator_keyword(call_node: ast.Call, key: str):
    """Return value node for keyword `key` in ast.Call or None if not present."""
    for kw in call_node.keywords:
        if kw.arg == key:
            return kw.value
    return None

def test_add_decorator_to_existing_init():
    # Basic case: class has an __init__(self) with no decorators -> decorator should be inserted
    src = """
class Target:
    def __init__(self):
        pass
"""
    node = _get_class_node(src)
    # Create a transformer targeting "Target"
    transformer = InitDecorator({"Target"}, "fto", "/tmp/mytmp", Path("tests/root"))
    # Apply the transformation to the ClassDef node
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 4.12μs -> 3.99μs (3.29% faster)

    # Find the __init__ function after transformation
    init_func = _find_init_func(new_node)
    first_dec = init_func.decorator_list[0]

    # Validate function_name keyword value in decorator is "Target.__init__"
    fn_kw = _get_decorator_keyword(first_dec, "function_name")
    # Validate tmp_dir_path and tests_root keywords match the transformer's provided values
    td_kw = _get_decorator_keyword(first_dec, "tmp_dir_path")
    tr_kw = _get_decorator_keyword(first_dec, "tests_root")
    is_fto_kw = _get_decorator_keyword(first_dec, "is_fto")

def test_no_duplicate_if_already_decorated():
    # If __init__ already has a codeflash_capture decorator (simple Name.Call), do not add another
    src = """
class Target:
    @codeflash_capture(function_name='Target.__init__', tmp_dir_path='/tmp', tests_root='t', is_fto=False)
    def __init__(self):
        pass
"""
    node = _get_class_node(src)
    transformer = InitDecorator({"Target"}, "fto", "/tmp", Path("t"))
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 4.44μs -> 2.32μs (91.0% faster)

    init_func = _find_init_func(new_node)
    # The existing decorator should be the codeflash_capture call (so detection worked)
    existing = init_func.decorator_list[0]

def test_skip_non_target_class():
    # If the class name is not in target_classes, the node must be returned unmodified (no decorator insertion)
    src = """
class Other:
    def __init__(self):
        pass
"""
    node = _get_class_node(src)
    # Transformer targets "Target" only, not "Other"
    transformer = InitDecorator({"Target"}, "fto", "/tmp", Path("t"))
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 752ns -> 721ns (4.30% faster)

    # No decorator should have been added
    init_func = _find_init_func(new_node)

def test_add_synthetic_init_when_missing():
    # When class has no __init__ and is not a dataclass, a synthetic __init__ with super().__init__(*args, **kwargs)
    # and the decorator should be inserted.
    src = """
class Target:
    x = 1
"""
    node = _get_class_node(src)
    transformer = InitDecorator({"Target"}, "fto", "/tmp/tmpdir", Path("testsroot"))
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 8.04μs -> 7.99μs (0.751% faster)

    # Synthetic __init__ should be inserted at the top of the class body
    init_func = _find_init_func(new_node)
    call = init_func.body[0].value
    # Check function_name keyword value
    first_dec = init_func.decorator_list[0]
    fn_kw = _get_decorator_keyword(first_dec, "function_name")

def test_skip_dataclass_decorator_no_insertion():
    # If the class has a dataclass decorator (Name or Call or Attribute), transformer must not add synthetic __init__
    src = """
@dataclass
class Target:
    x: int
"""
    node = _get_class_node(src)
    transformer = InitDecorator({"Target"}, "fto", "/tmp", Path("tests"))
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 4.07μs -> 1.54μs (164% faster)

    # No __init__ should have been added (dataclass should be skipped)
    init_func = _find_init_func(new_node)

def test_init_first_arg_not_self_counts_as_missing():
    # If an existing __init__ has a first argument not named "self", the code treats it as missing and inserts a synthetic one.
    src = """
class Target:
    def __init__(not_self):
        pass
"""
    node = _get_class_node(src)
    transformer = InitDecorator({"Target"}, "fto", "/tmp", Path("tests"))
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 7.86μs -> 7.71μs (1.81% faster)
    # New top-of-body must be the synthetic __init__
    top = new_node.body[0]
    # The original should still exist and have its non-self arg
    original = new_node.body[1]

def test_attribute_style_decorator_not_recognized_and_gets_duplicate():
    # If the existing decorator is an Attribute (e.g., mod.codeflash_capture(...)), the transformer's
    # detection will not consider it a Name and will insert another decorator (leading to a duplicate)
    src = """
class Target:
    @mod.codeflash_capture(function_name='Target.__init__')
    def __init__(self):
        pass
"""
    node = _get_class_node(src)
    transformer = InitDecorator({"Target"}, "fto", "/tmp", Path("tests"))
    codeflash_output = transformer.visit_ClassDef(node); new_node = codeflash_output # 4.26μs -> 4.50μs (5.34% slower)

    init_func = _find_init_func(new_node)
    first = init_func.decorator_list[0]
    second = init_func.decorator_list[1]
import ast
from pathlib import Path

# imports
import pytest
from codeflash.verification.instrument_codeflash_capture import InitDecorator

def test_visit_classdef_ignores_non_target_classes():
    """Test that non-target classes are not modified."""
    source = """
class TargetClass:
    def __init__(self):
        pass

class NonTargetClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"TargetClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Check TargetClass is modified
    target_class = result.body[0]
    
    # Check NonTargetClass is not modified
    non_target_class = result.body[1]

def test_visit_classdef_skips_dataclass():
    """Test that dataclass decorated classes are skipped."""
    source = """
@dataclass
class MyClass:
    field: int
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Extract the class
    class_def = result.body[0]

def test_visit_classdef_with_existing_decorator():
    """Test that existing codeflash_capture decorator is not duplicated."""
    source = """
class MyClass:
    @codeflash_capture()
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Extract the class
    class_def = result.body[0]
    init_method = class_def.body[0]

def test_visit_classdef_with_multiple_target_classes():
    """Test that multiple target classes are all decorated."""
    source = """
class ClassA:
    def __init__(self):
        pass

class ClassB:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"ClassA", "ClassB"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Check both classes are decorated
    class_a = result.body[0]
    class_b = result.body[1]

def test_visit_classdef_returns_ast_classdef():
    """Test that the method returns an ast.ClassDef node."""
    source = """
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    class_def = result.body[0]

def test_visit_classdef_with_empty_target_classes():
    """Test that no classes are modified when target_classes is empty."""
    source = """
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes=set(),
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Check that __init__ was not decorated
    class_def = result.body[0]
    init_method = class_def.body[0]

def test_visit_classdef_with_empty_string_class_name():
    """Test behavior with empty string in target_classes."""
    source = """
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={""},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # MyClass should not be decorated
    class_def = result.body[0]
    init_method = class_def.body[0]

def test_visit_classdef_with_dataclass_with_call_decorator():
    """Test skipping dataclass when it's a Call decorator."""
    source = """
@dataclass()
class MyClass:
    field: int
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)

def test_visit_classdef_with_dataclass_as_attribute():
    """Test skipping dataclass when accessed as attribute."""
    source = """
@module.dataclass
class MyClass:
    field: int
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)

def test_visit_classdef_case_sensitive_target_class():
    """Test that class name matching is case-sensitive."""
    source = """
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"myclass"},  # lowercase
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Should not match because of case sensitivity
    class_def = result.body[0]
    init_method = class_def.body[0]

def test_visit_classdef_inserted_decorator_flag_unchanged_for_non_target():
    """Test that inserted_decorator flag remains False for non-target classes."""
    source = """
class NonTarget:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"OtherClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    initial_flag = decorator.inserted_decorator
    decorator.visit(tree)

def test_visit_classdef_multiple_classes_some_targets():
    """Test with multiple classes where only some are targets."""
    source = """
class ClassA:
    def __init__(self):
        pass

class ClassB:
    pass

class ClassC:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"ClassA", "ClassC"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # ClassA should have decorator on existing __init__
    class_a = result.body[0]
    
    # ClassB should have synthetic __init__ with decorator
    class_b = result.body[1]
    
    # ClassC should have decorator on existing __init__
    class_c = result.body[2]

def test_visit_classdef_preserves_other_decorators_on_class():
    """Test that class-level decorators are preserved."""
    source = """
@decorator_one
@decorator_two
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Class decorators should be unchanged
    class_def = result.body[0]

def test_visit_classdef_with_many_target_classes():
    """Test with large number of target classes."""
    # Create source with 200 classes
    classes = "\n".join([f"class Class{i}:\n    def __init__(self):\n        pass" for i in range(200)])
    source = classes
    
    tree = ast.parse(source)
    target_classes = {f"Class{i}" for i in range(200)}
    
    decorator = InitDecorator(
        target_classes=target_classes,
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # All classes should be decorated
    for i, class_def in enumerate(result.body):
        if isinstance(class_def, ast.ClassDef):
            init_method = class_def.body[0]

def test_visit_classdef_with_unicode_class_names():
    """Test with unicode characters in class names."""
    source = """
class ClassÀ:
    def __init__(self):
        pass

class Class中文:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    decorator = InitDecorator(
        target_classes={"ClassÀ", "Class中文"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result = decorator.visit(tree)
    
    # Both should be decorated
    for class_def in result.body:
        if isinstance(class_def, ast.ClassDef):
            init_method = class_def.body[0]

def test_visit_classdef_repeatedly_on_same_tree():
    """Test calling visit multiple times on same AST tree."""
    source = """
class MyClass:
    def __init__(self):
        pass
"""
    tree = ast.parse(source)
    
    # First decorator instance
    decorator1 = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result1 = decorator1.visit(tree)
    
    # Second decorator instance on new tree
    tree2 = ast.parse(source)
    decorator2 = InitDecorator(
        target_classes={"MyClass"},
        fto_name="test_fto",
        tmp_dir_path="/tmp",
        tests_root=Path("/tests"),
        is_fto=False
    )
    result2 = decorator2.visit(tree2)
from ast import ClassDef
from codeflash.verification.instrument_codeflash_capture import InitDecorator
from pathlib import Path
import pytest

def test_InitDecorator_visit_ClassDef():
    with pytest.raises(AttributeError, match="'ClassDef'\\ object\\ has\\ no\\ attribute\\ 'name'"):
        InitDecorator.visit_ClassDef(InitDecorator({''}, '', '', Path(), is_fto=True), ClassDef())
🔎 Click to see Concolic Coverage Tests

To edit these changes git checkout codeflash/optimize-pr1714-2026-03-04T22.46.36 and push.

Codeflash Static Badge

The optimization defers decorator AST construction until it is actually needed by replacing eager allocation with a `None` check. In the original code, `ast.Call(...)` for the decorator was built unconditionally for every target class (245 µs, 12.4% of runtime per profiler line 9), but in ~5% of code paths that decorator is never inserted (e.g. when `__init__` already has the decorator or the class is a dataclass). The optimized version sets `decorator = None` upfront (37 µs) and constructs the decorator inside the `if decorator is None:` branch only when insertion is confirmed, eliminating wasted allocations. Runtime improved 15% overall (36.6 → 31.7 µs) with no correctness regressions across all test cases.
@codeflash-ai codeflash-ai Bot added ⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash labels Mar 4, 2026
@claude

claude Bot commented Mar 4, 2026

Copy link
Copy Markdown
Contributor

Claude finished @codeflash-ai[bot]'s task in 8m 14s —— View job


PR Review Summary

  • Triage PR scope
  • Lint and typecheck
  • Resolve stale threads (none found)
  • Review diff
  • Duplicate detection
  • Test coverage
  • Post summary
  • Check optimization PRs

Prek Checks

  • ruff format: Fixed a trailing blank line at instrument_codeflash_capture.py:158 — committed and pushed (style: auto-fix ruff formatting).
  • ruff check: Passed.
  • mypy: Pre-existing error at line 202 (list[Call] vs list[expr] for decorator_list) — present in origin/testgen-review before this PR, not introduced here.

Code Review

The optimization is correct. The change defers ast.Call(...) construction until the decorator is actually needed (lazy initialization), reducing wasted allocations in the ~5% case where insertion is skipped (class already decorated or is a dataclass).

One minor observation: the ast.Call(...) construction is now duplicated in two locations:

These two paths are structurally separate (loop vs post-loop), so the duplication is an acceptable trade-off for the measured 15% speedup. No bugs found.

Duplicate Detection

No duplicates detected. The decorator construction logic only exists in InitDecorator.visit_ClassDef.

Test Coverage

98% on codeflash/verification/instrument_codeflash_capture.py — unchanged from base branch.

Missing lines 171–172 (elif isinstance(dec, ast.Attribute): dec_name = dec.attr) — this is the @module.dataclass attribute-style decorator path, untested in the unit suite but covered by the generated regression test test_visit_classdef_with_dataclass_as_attribute. Pre-existing gap.

Optimization PR Check

This is the only open optimization PR. CI is partially complete — 7 runs still in progress. The one completed failure (E2E - JS ESM Async) is unrelated to the changed Python file.


Last updated: 2026-03-04T22:51 UTC
| Branch

@claude

claude Bot commented Mar 5, 2026

Copy link
Copy Markdown
Contributor

Closing optimization PR: CI is failing (js-esm-async-optimization, js-ts-class-optimization).

@claude claude Bot closed this Mar 5, 2026
@claude claude Bot deleted the codeflash/optimize-pr1714-2026-03-04T22.46.36 branch March 5, 2026 08:25
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

⚡️ codeflash Optimization PR opened by Codeflash AI 🎯 Quality: High Optimization Quality according to Codeflash

Projects

None yet

Development

Successfully merging this pull request may close these issues.

0 participants