1515from codeflash .models .models import CodeString , CodeStringsMarkdown
1616
1717if TYPE_CHECKING :
18+ from collections .abc import Callable
19+
1820 from codeflash .discovery .functions_to_optimize import FunctionToOptimize
1921 from codeflash .models .models import CodeOptimizationContext , FunctionSource
2022
@@ -49,6 +51,69 @@ def extract_names_from_targets(target: cst.CSTNode) -> list[str]:
4951 return names
5052
5153
54+ def is_assignment_used (node : cst .CSTNode , definitions : dict [str , UsageInfo ], name_prefix : str = "" ) -> bool :
55+ if isinstance (node , cst .Assign ):
56+ for target in node .targets :
57+ names = extract_names_from_targets (target .target )
58+ for name in names :
59+ lookup = f"{ name_prefix } { name } " if name_prefix else name
60+ if lookup in definitions and definitions [lookup ].used_by_qualified_function :
61+ return True
62+ return False
63+ if isinstance (node , (cst .AnnAssign , cst .AugAssign )):
64+ names = extract_names_from_targets (node .target )
65+ for name in names :
66+ lookup = f"{ name_prefix } { name } " if name_prefix else name
67+ if lookup in definitions and definitions [lookup ].used_by_qualified_function :
68+ return True
69+ return False
70+ return False
71+
72+
73+ def recurse_sections (
74+ node : cst .CSTNode ,
75+ section_names : list [str ],
76+ prune_fn : Callable [[cst .CSTNode ], tuple [cst .CSTNode | None , bool ]],
77+ keep_non_target_children : bool = False ,
78+ ) -> tuple [cst .CSTNode | None , bool ]:
79+ updates : dict [str , list [cst .CSTNode ] | cst .CSTNode ] = {}
80+ found_any_target = False
81+ for section in section_names :
82+ original_content = getattr (node , section , None )
83+ if isinstance (original_content , (list , tuple )):
84+ new_children = []
85+ section_found_target = False
86+ for child in original_content :
87+ filtered , found_target = prune_fn (child )
88+ if filtered :
89+ new_children .append (filtered )
90+ section_found_target |= found_target
91+ if keep_non_target_children :
92+ if section_found_target or new_children :
93+ found_any_target |= section_found_target
94+ updates [section ] = new_children
95+ elif section_found_target :
96+ found_any_target = True
97+ updates [section ] = new_children
98+ elif original_content is not None :
99+ filtered , found_target = prune_fn (original_content )
100+ if keep_non_target_children :
101+ found_any_target |= found_target
102+ if filtered :
103+ updates [section ] = filtered
104+ elif found_target :
105+ found_any_target = True
106+ if filtered :
107+ updates [section ] = filtered
108+ if keep_non_target_children :
109+ if updates :
110+ return node .with_changes (** updates ), found_any_target
111+ return None , False
112+ if not found_any_target :
113+ return None , False
114+ return (node .with_changes (** updates ) if updates else node ), True
115+
116+
52117def collect_top_level_definitions (
53118 node : cst .CSTNode , definitions : Optional [dict [str , UsageInfo ]] = None
54119) -> dict [str , UsageInfo ]:
@@ -423,27 +488,9 @@ def remove_unused_definitions_recursively(
423488 elif isinstance (statement , (cst .Assign , cst .AnnAssign , cst .AugAssign )):
424489 var_used = False
425490
426- # Check if any variable in this assignment is used
427- if isinstance (statement , cst .Assign ):
428- for target in statement .targets :
429- names = extract_names_from_targets (target .target )
430- for name in names :
431- class_var_name = f"{ class_name } .{ name } "
432- if (
433- class_var_name in definitions
434- and definitions [class_var_name ].used_by_qualified_function
435- ):
436- var_used = True
437- method_or_var_used = True
438- break
439- elif isinstance (statement , (cst .AnnAssign , cst .AugAssign )):
440- names = extract_names_from_targets (statement .target )
441- for name in names :
442- class_var_name = f"{ class_name } .{ name } "
443- if class_var_name in definitions and definitions [class_var_name ].used_by_qualified_function :
444- var_used = True
445- method_or_var_used = True
446- break
491+ if is_assignment_used (statement , definitions , name_prefix = f"{ class_name } ." ):
492+ var_used = True
493+ method_or_var_used = True
447494
448495 if var_used or class_has_dependencies :
449496 new_statements .append (statement )
@@ -459,56 +506,19 @@ def remove_unused_definitions_recursively(
459506
460507 return node , method_or_var_used or class_has_dependencies
461508
462- # Handle assignments (Assign and AnnAssign)
463- if isinstance (node , cst .Assign ):
464- for target in node .targets :
465- names = extract_names_from_targets (target .target )
466- for name in names :
467- if name in definitions and definitions [name ].used_by_qualified_function :
468- return node , True
469- return None , False
470-
471- if isinstance (node , (cst .AnnAssign , cst .AugAssign )):
472- names = extract_names_from_targets (node .target )
473- for name in names :
474- if name in definitions and definitions [name ].used_by_qualified_function :
475- return node , True
509+ # Handle assignments (Assign, AnnAssign, AugAssign)
510+ if isinstance (node , (cst .Assign , cst .AnnAssign , cst .AugAssign )):
511+ if is_assignment_used (node , definitions ):
512+ return node , True
476513 return None , False
477514
478515 # For other nodes, recursively process children
479516 section_names = get_section_names (node )
480517 if not section_names :
481518 return node , False
482-
483- updates = {}
484- found_used = False
485-
486- for section in section_names :
487- original_content = getattr (node , section , None )
488- if isinstance (original_content , (list , tuple )):
489- new_children = []
490- section_found_used = False
491-
492- for child in original_content :
493- filtered , used = remove_unused_definitions_recursively (child , definitions )
494- if filtered :
495- new_children .append (filtered )
496- section_found_used |= used
497-
498- if new_children or section_found_used :
499- found_used |= section_found_used
500- updates [section ] = new_children
501- elif original_content is not None :
502- filtered , used = remove_unused_definitions_recursively (original_content , definitions )
503- found_used |= used
504- if filtered :
505- updates [section ] = filtered
506- if not found_used :
507- return None , False
508- if updates :
509- return node .with_changes (** updates ), found_used
510-
511- return node , False
519+ return recurse_sections (
520+ node , section_names , lambda child : remove_unused_definitions_recursively (child , definitions )
521+ )
512522
513523
514524def collect_top_level_defs_with_usages (
0 commit comments