Skip to content

Commit b5740cf

Browse files
committed
Refactor CodeContextStructure and CodeBase for improved handling of function definitions and slim mode output
1 parent 6255d12 commit b5740cf

File tree

1 file changed

+34
-14
lines changed

1 file changed

+34
-14
lines changed

codetide/core/models.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from .defaults import BREAKLINE
33
from .logs import logger
44

5-
from pydantic import BaseModel, Field, computed_field, field_validator
5+
from pydantic import BaseModel, Field, RootModel, computed_field, field_validator
66
from typing import Any, Dict, List, Optional, Literal, Union
77
from collections import defaultdict
88
import json
@@ -269,7 +269,7 @@ def raw(self)->str:
269269
class CodeContextStructure(BaseModel):
270270
imports :Dict[str, ImportStatement] = Field(default_factory=dict)
271271
variables :Dict[str, VariableDeclaration] = Field(default_factory=dict)
272-
functions :Dict[str, ClassDefinition] = Field(default_factory=dict)
272+
functions :Dict[str, FunctionDefinition] = Field(default_factory=dict)
273273
classes :Dict[str, ClassDefinition] = Field(default_factory=dict)
274274
classes_headers :Dict[str, ClassDefinition] = Field(default_factory=dict)
275275
class_attributes :Dict[str, ClassAttribute] = Field(default_factory=dict)
@@ -349,7 +349,7 @@ def add_class_header(self, cls: ClassDefinition):
349349
def add_preloaded(self, preloaded :Dict[str, str]):
350350
self.preloaded.update(preloaded)
351351

352-
def as_list_str(self)->List[List[str]]:
352+
def as_list_str(self, slim :Optional[bool]=True)->List[List[str]]:
353353

354354
partially_filled_classes :Dict[str, PartialClasses]= {}
355355

@@ -363,13 +363,18 @@ def as_list_str(self)->List[List[str]]:
363363
raw_elements_by_file[entry.file_path].append(f"\n{entry.raw}")
364364

365365
for entry in self.functions.values():
366-
raw_elements_by_file[entry.file_path].append(f"\n{entry.raw}")
366+
if slim and entry.docstring:
367+
content = entry.docstring
368+
else:
369+
content = entry.raw
370+
371+
raw_elements_by_file[entry.file_path].append(f"\n{content}")
367372

368373
for entry in self.classes.values():
369374
raw_elements_by_file[entry.file_path].append(f"\n{entry.raw}")
370375

371376
for entry in self.classes_headers.values():
372-
raw_elements_by_file[entry.file_path].append(f"\n{self.trim(entry.raw)}")
377+
raw_elements_by_file[entry.file_path].append(f"\n{self.trim(entry.raw) if slim else entry.raw}")
373378

374379
unique_class_elements_not_in_classes = set(self._unique_class_elements_ids) - set(self.classes.keys()) - set(self.classes_headers.keys()) - set(self.requested_elements)
375380

@@ -384,16 +389,22 @@ def as_list_str(self)->List[List[str]]:
384389

385390
for class_attribute in self.class_attributes.values():
386391
if class_attribute.class_id in unique_class_elements_not_in_classes:
387-
partially_filled_classes[classObj.unique_id].attributes.append(class_attribute.raw)
392+
partially_filled_classes[classObj.unique_id].attributes.append(f"\n{class_attribute.raw}")
388393

389394
for class_method in self.class_methods.values():
390395
if class_method.class_id in unique_class_elements_not_in_classes:
391396
if not partially_filled_classes[classObj.unique_id].methods:
392-
partially_filled_classes[classObj.unique_id].methods.append("\n ...\n")
393-
partially_filled_classes[classObj.unique_id].methods.append(class_method.raw)
397+
partially_filled_classes[classObj.unique_id].methods.append(" ...")
398+
399+
if slim and class_method.docstring:
400+
content = class_method.docstring
401+
else:
402+
content = class_method.raw
403+
404+
partially_filled_classes[classObj.unique_id].methods.append(content)
394405

395406
for partial in partially_filled_classes.values():
396-
raw_elements_by_file[partial.filepath].append(partial.raw)
407+
raw_elements_by_file[partial.filepath].append(f"\n{partial.raw}")
397408

398409
for requested_elemtent in self.requested_elements.values():
399410
if isinstance(requested_elemtent, (ClassAttribute, MethodDefinition)):
@@ -416,7 +427,12 @@ def as_list_str(self)->List[List[str]]:
416427
return wrapped_list
417428

418429
@classmethod
419-
def from_list_of_elements(cls, elements: list, retrieved_elements_reference_type :List[Optional[str]]=None, requested_element_index :List[int]=[0], preloaded_files :Optional[List[Dict[str, str]]]=None) -> 'CodeContextStructure':
430+
def from_list_of_elements(cls,
431+
elements: list,
432+
retrieved_elements_reference_type :List[Optional[str]]=None,
433+
requested_element_index :List[int]=[0],
434+
preloaded_files :Optional[List[Dict[str, str]]]=None) -> 'CodeContextStructure':
435+
420436
instance = cls()
421437
# Normalize negative indices to positive
422438
normalized_indices = [
@@ -430,6 +446,11 @@ def from_list_of_elements(cls, elements: list, retrieved_elements_reference_type
430446
if 0 <= idx < len(elements)
431447
]
432448

449+
if not retrieved_elements_reference_type:
450+
retrieved_elements_reference_type = [
451+
None for _ in elements
452+
]
453+
433454
for i, (element, reference_type) in enumerate(zip(elements, retrieved_elements_reference_type)):
434455
if i in requested_element_index:
435456
instance.add_requested_element(element)
@@ -456,7 +477,7 @@ def from_list_of_elements(cls, elements: list, retrieved_elements_reference_type
456477

457478
return instance
458479

459-
class CodeBase(BaseModel):
480+
class CodeBase(RootModel):
460481
"""Root model representing a complete codebase"""
461482
root: List[CodeFileModel] = Field(default_factory=list)
462483
_cached_elements :Dict[str, Any] = dict()
@@ -715,8 +736,6 @@ def _render_class_contents(self, class_def: 'ClassDefinition', prefix: str,
715736
lines.append(f"{prefix}{current_prefix}{name}")
716737

717738
def get(self, unique_id :Union[str, List[str]], degree :int=1, as_string :bool=False, as_list_str :bool=False, slim :Optional[bool]=False, preloaded_files :Optional[Dict[str, str]]=None)->Union[CodeContextStructure, str, List[str]]:
718-
# TODO slim mode is still not perfect as in codecontext it should still heck for individual methods / dependencies which it does not do so far
719-
# need to refine it further
720739
if not self._cached_elements:
721740
logger.debug("Building cached elements for the first time")
722741
self._build_cached_elements()
@@ -737,6 +756,7 @@ def get(self, unique_id :Union[str, List[str]], degree :int=1, as_string :bool=F
737756
logger.debug(f"Current degree level: {degree}, processing {len(references_ids)} references")
738757

739758
for i, reference in enumerate(references_ids):
759+
# print(reference)
740760
element = self._cached_elements.get(reference)
741761
processed = False
742762
if (
@@ -799,7 +819,7 @@ def get(self, unique_id :Union[str, List[str]], degree :int=1, as_string :bool=F
799819
codeContext._cached_elements = self._cached_elements
800820

801821
if as_string:
802-
context = codeContext.as_list_str()
822+
context = codeContext.as_list_str(slim)
803823
if context[0]:
804824
context.insert(0, [CONTEXT_INTRUCTION])
805825
context.insert(-1, [TARGET_INSTRUCTION])

0 commit comments

Comments
 (0)