Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 32 additions & 6 deletions codeflash/context/code_context_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,11 @@
from codeflash.cli_cmds.console import logger
from codeflash.code_utils.code_extractor import add_needed_imports_from_module, find_preexisting_objects
from codeflash.code_utils.code_utils import encoded_tokens_len, get_qualified_name, path_belongs_to_site_packages
from codeflash.context.unused_definition_remover import remove_unused_definitions_by_function_names
from codeflash.context.unused_definition_remover import (
collect_top_level_defs_with_usages,
extract_names_from_targets,
remove_unused_definitions_by_function_names,
)
from codeflash.discovery.functions_to_optimize import FunctionToOptimize # noqa: TC001
from codeflash.models.models import (
CodeContextType,
Expand All @@ -29,6 +33,8 @@
from jedi.api.classes import Name
from libcst import CSTNode

from codeflash.context.unused_definition_remover import UsageInfo


def get_code_optimization_context(
function_to_optimize: FunctionToOptimize,
Expand Down Expand Up @@ -498,8 +504,10 @@ def parse_code_and_prune_cst(
) -> str:
"""Create a read-only version of the code by parsing and filtering the code to keep only class contextual information, and other module scoped variables."""
module = cst.parse_module(code)
defs_with_usages = collect_top_level_defs_with_usages(module, target_functions | helpers_of_helper_functions)

if code_context_type == CodeContextType.READ_WRITABLE:
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions)
filtered_node, found_target = prune_cst_for_read_writable_code(module, target_functions, defs_with_usages)
elif code_context_type == CodeContextType.READ_ONLY:
filtered_node, found_target = prune_cst_for_read_only_code(
module, target_functions, helpers_of_helper_functions, remove_docstrings=remove_docstrings
Expand All @@ -524,7 +532,7 @@ def parse_code_and_prune_cst(


def prune_cst_for_read_writable_code( # noqa: PLR0911
node: cst.CSTNode, target_functions: set[str], prefix: str = ""
node: cst.CSTNode, target_functions: set[str], defs_with_usages: dict[str, UsageInfo], prefix: str = ""
) -> tuple[cst.CSTNode | None, bool]:
"""Recursively filter the node and its children to build the read-writable codeblock. This contains nodes that lead to target functions.

Expand Down Expand Up @@ -569,6 +577,21 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911

return node.with_changes(body=cst.IndentedBlock(body=new_body)), found_target

if isinstance(node, cst.Assign):
for target in node.targets:
names = extract_names_from_targets(target.target)
for name in names:
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
return node, True
return None, False

if isinstance(node, (cst.AnnAssign, cst.AugAssign)):
names = extract_names_from_targets(node.target)
for name in names:
if name in defs_with_usages and defs_with_usages[name].used_by_qualified_function:
return node, True
return None, False

# For other nodes, we preserve them only if they contain target functions in their children.
section_names = get_section_names(node)
if not section_names:
Expand All @@ -583,7 +606,9 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
new_children = []
section_found_target = False
for child in original_content:
filtered, found_target = prune_cst_for_read_writable_code(child, target_functions, prefix)
filtered, found_target = prune_cst_for_read_writable_code(
child, target_functions, defs_with_usages, prefix
)
if filtered:
new_children.append(filtered)
section_found_target |= found_target
Expand All @@ -592,15 +617,16 @@ def prune_cst_for_read_writable_code( # noqa: PLR0911
found_any_target = True
updates[section] = new_children
elif original_content is not None:
filtered, found_target = prune_cst_for_read_writable_code(original_content, target_functions, prefix)
filtered, found_target = prune_cst_for_read_writable_code(
original_content, target_functions, defs_with_usages, prefix
)
if found_target:
found_any_target = True
if filtered:
updates[section] = filtered

if not found_any_target:
return None, False

return (node.with_changes(**updates) if updates else node), True


Expand Down
56 changes: 39 additions & 17 deletions codeflash/context/unused_definition_remover.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass, field
from itertools import chain
from pathlib import Path
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

import libcst as cst

Expand Down Expand Up @@ -122,6 +122,8 @@ def get_section_names(node: cst.CSTNode) -> list[str]:
class DependencyCollector(cst.CSTVisitor):
"""Collects dependencies between definitions using the visitor pattern with depth tracking."""

METADATA_DEPENDENCIES = (cst.metadata.ParentNodeProvider,)

def __init__(self, definitions: dict[str, UsageInfo]) -> None:
super().__init__()
self.definitions = definitions
Expand Down Expand Up @@ -259,8 +261,12 @@ def visit_Name(self, node: cst.Name) -> None:
if self.processing_variable and name in self.current_variable_names:
return

# Check if name is a top-level definition we're tracking
if name in self.definitions and name != self.current_top_level_name:
# skip if we are refrencing a class attribute and not a top-level definition
if self.class_depth > 0:
parent = self.get_metadata(cst.metadata.ParentNodeProvider, node)
if parent is not None and isinstance(parent, cst.Attribute):
return
self.definitions[self.current_top_level_name].dependencies.add(name)


Expand Down Expand Up @@ -293,13 +299,19 @@ def _expand_qualified_functions(self) -> set[str]:

def mark_used_definitions(self) -> None:
"""Find all qualified functions and mark them and their dependencies as used."""
# First identify all specified functions (including expanded ones)
functions_to_mark = [name for name in self.expanded_qualified_functions if name in self.definitions]
# Avoid list comprehension for set intersection
expanded_names = self.expanded_qualified_functions
defs = self.definitions
functions_to_mark = (
expanded_names & defs.keys()
if isinstance(expanded_names, set)
else [name for name in expanded_names if name in defs]
)

# For each specified function, mark it and all its dependencies as used
for func_name in functions_to_mark:
self.definitions[func_name].used_by_qualified_function = True
for dep in self.definitions[func_name].dependencies:
defs[func_name].used_by_qualified_function = True
for dep in defs[func_name].dependencies:
self.mark_as_used_recursively(dep)

def mark_as_used_recursively(self, name: str) -> None:
Expand Down Expand Up @@ -457,6 +469,25 @@ def remove_unused_definitions_recursively( # noqa: PLR0911
return node, False


def collect_top_level_defs_with_usages(
code: Union[str, cst.Module], qualified_function_names: set[str]
) -> dict[str, UsageInfo]:
"""Collect all top level definitions (classes, variables or functions) and their usages."""
module = code if isinstance(code, cst.Module) else cst.parse_module(code)
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)

# Collect dependencies between definitions using the visitor pattern
wrapper = cst.MetadataWrapper(module)
dependency_collector = DependencyCollector(definitions)
wrapper.visit(dependency_collector)

# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
return definitions


def remove_unused_definitions_by_function_names(code: str, qualified_function_names: set[str]) -> str:
"""Analyze a file and remove top level definitions not used by specified functions.

Expand All @@ -476,19 +507,10 @@ def remove_unused_definitions_by_function_names(code: str, qualified_function_na
return code

try:
# Collect all definitions (top level classes, variables or function)
definitions = collect_top_level_definitions(module)

# Collect dependencies between definitions using the visitor pattern
dependency_collector = DependencyCollector(definitions)
module.visit(dependency_collector)

# Mark definitions used by specified functions, and their dependencies recursively
usage_marker = QualifiedFunctionUsageMarker(definitions, qualified_function_names)
usage_marker.mark_used_definitions()
defs_with_usages = collect_top_level_defs_with_usages(module, qualified_function_names)

# Apply the recursive removal transformation
modified_module, _ = remove_unused_definitions_recursively(module, definitions)
modified_module, _ = remove_unused_definitions_recursively(module, defs_with_usages)

return modified_module.code if modified_module else "" # noqa: TRY300
except Exception as e:
Expand Down
Loading
Loading