22from .defaults import BREAKLINE
33from .logs import logger
44
5- from pydantic import BaseModel , Field , computed_field , field_validator
5+ from pydantic import BaseModel , Field , RootModel , computed_field , field_validator
66from typing import Any , Dict , List , Optional , Literal , Union
77from collections import defaultdict
88import json
@@ -269,7 +269,7 @@ def raw(self)->str:
269269class CodeContextStructure (BaseModel ):
270270 imports :Dict [str , ImportStatement ] = Field (default_factory = dict )
271271 variables :Dict [str , VariableDeclaration ] = Field (default_factory = dict )
272- functions :Dict [str , ClassDefinition ] = Field (default_factory = dict )
272+ functions :Dict [str , FunctionDefinition ] = Field (default_factory = dict )
273273 classes :Dict [str , ClassDefinition ] = Field (default_factory = dict )
274274 classes_headers :Dict [str , ClassDefinition ] = Field (default_factory = dict )
275275 class_attributes :Dict [str , ClassAttribute ] = Field (default_factory = dict )
@@ -349,7 +349,7 @@ def add_class_header(self, cls: ClassDefinition):
349349 def add_preloaded (self , preloaded :Dict [str , str ]):
350350 self .preloaded .update (preloaded )
351351
352- def as_list_str (self )-> List [List [str ]]:
352+ def as_list_str (self , slim : Optional [ bool ] = True )-> List [List [str ]]:
353353
354354 partially_filled_classes :Dict [str , PartialClasses ]= {}
355355
@@ -363,13 +363,18 @@ def as_list_str(self)->List[List[str]]:
363363 raw_elements_by_file [entry .file_path ].append (f"\n { entry .raw } " )
364364
365365 for entry in self .functions .values ():
366- raw_elements_by_file [entry .file_path ].append (f"\n { entry .raw } " )
366+ if slim and entry .docstring :
367+ content = entry .docstring
368+ else :
369+ content = entry .raw
370+
371+ raw_elements_by_file [entry .file_path ].append (f"\n { content } " )
367372
368373 for entry in self .classes .values ():
369374 raw_elements_by_file [entry .file_path ].append (f"\n { entry .raw } " )
370375
371376 for entry in self .classes_headers .values ():
372- raw_elements_by_file [entry .file_path ].append (f"\n { self .trim (entry .raw )} " )
377+ raw_elements_by_file [entry .file_path ].append (f"\n { self .trim (entry .raw ) if slim else entry . raw } " )
373378
374379 unique_class_elements_not_in_classes = set (self ._unique_class_elements_ids ) - set (self .classes .keys ()) - set (self .classes_headers .keys ()) - set (self .requested_elements )
375380
@@ -384,16 +389,22 @@ def as_list_str(self)->List[List[str]]:
384389
385390 for class_attribute in self .class_attributes .values ():
386391 if class_attribute .class_id in unique_class_elements_not_in_classes :
387- partially_filled_classes [classObj .unique_id ].attributes .append (class_attribute .raw )
392+ partially_filled_classes [classObj .unique_id ].attributes .append (f" \n { class_attribute .raw } " )
388393
389394 for class_method in self .class_methods .values ():
390395 if class_method .class_id in unique_class_elements_not_in_classes :
391396 if not partially_filled_classes [classObj .unique_id ].methods :
392- partially_filled_classes [classObj .unique_id ].methods .append ("\n ...\n " )
393- partially_filled_classes [classObj .unique_id ].methods .append (class_method .raw )
397+ partially_filled_classes [classObj .unique_id ].methods .append (" ..." )
398+
399+ if slim and class_method .docstring :
400+ content = class_method .docstring
401+ else :
402+ content = class_method .raw
403+
404+ partially_filled_classes [classObj .unique_id ].methods .append (content )
394405
395406 for partial in partially_filled_classes .values ():
396- raw_elements_by_file [partial .filepath ].append (partial .raw )
407+ raw_elements_by_file [partial .filepath ].append (f" \n { partial .raw } " )
397408
398409 for requested_elemtent in self .requested_elements .values ():
399410 if isinstance (requested_elemtent , (ClassAttribute , MethodDefinition )):
@@ -416,7 +427,12 @@ def as_list_str(self)->List[List[str]]:
416427 return wrapped_list
417428
418429 @classmethod
419- def from_list_of_elements (cls , elements : list , retrieved_elements_reference_type :List [Optional [str ]]= None , requested_element_index :List [int ]= [0 ], preloaded_files :Optional [List [Dict [str , str ]]]= None ) -> 'CodeContextStructure' :
430+ def from_list_of_elements (cls ,
431+ elements : list ,
432+ retrieved_elements_reference_type :List [Optional [str ]]= None ,
433+ requested_element_index :List [int ]= [0 ],
434+ preloaded_files :Optional [List [Dict [str , str ]]]= None ) -> 'CodeContextStructure' :
435+
420436 instance = cls ()
421437 # Normalize negative indices to positive
422438 normalized_indices = [
@@ -430,6 +446,11 @@ def from_list_of_elements(cls, elements: list, retrieved_elements_reference_type
430446 if 0 <= idx < len (elements )
431447 ]
432448
449+ if not retrieved_elements_reference_type :
450+ retrieved_elements_reference_type = [
451+ None for _ in elements
452+ ]
453+
433454 for i , (element , reference_type ) in enumerate (zip (elements , retrieved_elements_reference_type )):
434455 if i in requested_element_index :
435456 instance .add_requested_element (element )
@@ -456,7 +477,7 @@ def from_list_of_elements(cls, elements: list, retrieved_elements_reference_type
456477
457478 return instance
458479
459- class CodeBase (BaseModel ):
480+ class CodeBase (RootModel ):
460481 """Root model representing a complete codebase"""
461482 root : List [CodeFileModel ] = Field (default_factory = list )
462483 _cached_elements :Dict [str , Any ] = dict ()
@@ -715,8 +736,6 @@ def _render_class_contents(self, class_def: 'ClassDefinition', prefix: str,
715736 lines .append (f"{ prefix } { current_prefix } { name } " )
716737
717738 def get (self , unique_id :Union [str , List [str ]], degree :int = 1 , as_string :bool = False , as_list_str :bool = False , slim :Optional [bool ]= False , preloaded_files :Optional [Dict [str , str ]]= None )-> Union [CodeContextStructure , str , List [str ]]:
718- # TODO slim mode is still not perfect as in codecontext it should still heck for individual methods / dependencies which it does not do so far
719- # need to refine it further
720739 if not self ._cached_elements :
721740 logger .debug ("Building cached elements for the first time" )
722741 self ._build_cached_elements ()
@@ -737,6 +756,7 @@ def get(self, unique_id :Union[str, List[str]], degree :int=1, as_string :bool=F
737756 logger .debug (f"Current degree level: { degree } , processing { len (references_ids )} references" )
738757
739758 for i , reference in enumerate (references_ids ):
759+ # print(reference)
740760 element = self ._cached_elements .get (reference )
741761 processed = False
742762 if (
@@ -799,7 +819,7 @@ def get(self, unique_id :Union[str, List[str]], degree :int=1, as_string :bool=F
799819 codeContext ._cached_elements = self ._cached_elements
800820
801821 if as_string :
802- context = codeContext .as_list_str ()
822+ context = codeContext .as_list_str (slim )
803823 if context [0 ]:
804824 context .insert (0 , [CONTEXT_INTRUCTION ])
805825 context .insert (- 1 , [TARGET_INSTRUCTION ])
0 commit comments