2323 recurse_sections ,
2424 remove_unused_definitions_by_function_names ,
2525)
26- from codeflash .languages .python .static_analysis .code_extractor import (
27- add_needed_imports_from_module ,
28- find_preexisting_objects ,
29- )
26+ from codeflash .languages .python .static_analysis .code_extractor import add_needed_imports_from_module , find_preexisting_objects
3027from codeflash .models .models import (
3128 CodeContextType ,
3229 CodeOptimizationContext ,
@@ -550,6 +547,35 @@ def get_function_sources_from_jedi(
550547 file_path_to_function_source [definition_path ].add (function_source )
551548 function_source_list .append (function_source )
552549
550+ if definition .type == "statement" :
551+ try :
552+ for inferred in name .infer ():
553+ if (
554+ inferred .type == "class"
555+ and inferred .full_name
556+ and inferred .module_path
557+ and is_project_path (inferred .module_path , project_root_path )
558+ and inferred .full_name .startswith (inferred .module_name )
559+ ):
560+ class_fqn = f"{ inferred .full_name } .__init__"
561+ class_qname = get_qualified_name (inferred .module_name , class_fqn )
562+ if len (class_qname .split ("." )) <= 2 :
563+ class_path = inferred .module_path
564+ rel = safe_relative_to (class_path , project_root_path )
565+ if not rel .is_absolute ():
566+ class_path = project_root_path / rel
567+ class_source = FunctionSource (
568+ file_path = class_path ,
569+ qualified_name = class_qname ,
570+ fully_qualified_name = class_fqn ,
571+ only_function_name = "__init__" ,
572+ source_code = inferred .get_line_code (),
573+ )
574+ file_path_to_function_source [class_path ].add (class_source )
575+ function_source_list .append (class_source )
576+ except Exception :
577+ logger .debug (f"Error inferring type for statement { definition .full_name } " )
578+
553579 return file_path_to_function_source , function_source_list
554580
555581
@@ -750,6 +776,16 @@ def extract_class_and_bases(
750776
751777 extract_class_and_bases (name , module_path , module_source , module_tree )
752778
779+ if (module_path , name ) not in extracted_classes :
780+ for ast_node in module_tree .body :
781+ if isinstance (ast_node , ast .Assign ):
782+ for target in ast_node .targets :
783+ if isinstance (target , ast .Name ) and target .id == name :
784+ if isinstance (ast_node .value , ast .Call ) and isinstance (ast_node .value .func , ast .Name ):
785+ class_name = ast_node .value .func .id
786+ if class_name not in existing_classes :
787+ extract_class_and_bases (class_name , module_path , module_source , module_tree )
788+
753789 except Exception :
754790 logger .debug (f"Error extracting class definition for { name } from { module_name } " )
755791 continue
@@ -759,7 +795,7 @@ def extract_class_and_bases(
759795 for cls , name in resolve_classes_from_modules (external_base_classes ):
760796 if name in emitted_class_names :
761797 continue
762- stub = extract_init_stub (cls , name , require_site_packages = False )
798+ stub = extract_class_stub (cls , name , require_site_packages = False )
763799 if stub is not None :
764800 code_strings .append (stub )
765801 emitted_class_names .add (name )
@@ -778,7 +814,7 @@ def extract_class_and_bases(
778814 continue
779815 processed_classes .add (cls )
780816
781- stub = extract_init_stub (cls , class_name )
817+ stub = extract_class_stub (cls , class_name )
782818 if stub is None :
783819 continue
784820
@@ -888,22 +924,24 @@ def resolve_transitive_type_deps(cls: type) -> list[type]:
888924 return deps
889925
890926
891- def extract_init_stub (cls : type , class_name : str , require_site_packages : bool = True ) -> CodeString | None :
892- """Extract a stub containing the class definition with only its __init__ method.
927+ def extract_class_stub (cls : type , class_name : str , require_site_packages : bool = True ) -> CodeString | None :
928+ """Extract the full class source, falling back to an __init__-only stub.
929+
930+ Attempts ``inspect.getsource(cls)`` first so the LLM sees every method and
931+ attribute. Falls back to extracting just ``__init__`` when the full source
932+ is unavailable (C extensions, dynamically generated classes). Classes whose
933+ ``__init__`` is inherited from ``object`` are still included when the full
934+ source can be retrieved.
893935
894936 Args:
895- cls: The class object to extract __init__ from
937+ cls: The class object to extract from
896938 class_name: Name to use for the class in the stub
897939 require_site_packages: If True, only extract from site-packages. If False, include stdlib too.
898940
899941 """
900942 import inspect
901943 import textwrap
902944
903- init_method = getattr (cls , "__init__" , None )
904- if init_method is None or init_method is object .__init__ :
905- return None
906-
907945 try :
908946 class_file = Path (inspect .getfile (cls ))
909947 except (OSError , TypeError ):
@@ -912,17 +950,30 @@ def extract_init_stub(cls: type, class_name: str, require_site_packages: bool =
912950 if require_site_packages and not path_belongs_to_site_packages (class_file ):
913951 return None
914952
953+ parts = class_file .parts
954+ if "site-packages" in parts :
955+ idx = parts .index ("site-packages" )
956+ class_file = Path (* parts [idx + 1 :])
957+
958+ # Try full class source first
959+ try :
960+ class_source = inspect .getsource (cls )
961+ class_source = textwrap .dedent (class_source )
962+ return CodeString (code = class_source , file_path = class_file )
963+ except (OSError , TypeError ):
964+ pass
965+
966+ # Fallback: __init__-only stub
967+ init_method = getattr (cls , "__init__" , None )
968+ if init_method is None or init_method is object .__init__ :
969+ return None
970+
915971 try :
916972 init_source = inspect .getsource (init_method )
917973 init_source = textwrap .dedent (init_source )
918974 except (OSError , TypeError ):
919975 return None
920976
921- parts = class_file .parts
922- if "site-packages" in parts :
923- idx = parts .index ("site-packages" )
924- class_file = Path (* parts [idx + 1 :])
925-
926977 class_source = f"class { class_name } :\n " + textwrap .indent (init_source , " " )
927978 return CodeString (code = class_source , file_path = class_file )
928979
@@ -1080,6 +1131,7 @@ def parse_code_and_prune_cst(
10801131 filtered_node , found_target = prune_cst (
10811132 module ,
10821133 target_functions ,
1134+ defs_with_usages = defs_with_usages ,
10831135 helpers = helpers_of_helper_functions ,
10841136 remove_docstrings = remove_docstrings ,
10851137 include_dunder_methods = True ,
@@ -1219,7 +1271,7 @@ def prune_cst(
12191271 stmt ,
12201272 target_functions ,
12211273 class_prefix ,
1222- defs_with_usages = defs_with_usages ,
1274+ defs_with_usages = None ,
12231275 helpers = helpers ,
12241276 remove_docstrings = remove_docstrings ,
12251277 include_target_in_output = include_target_in_output ,
0 commit comments