|
1 | 1 | from codetide import CodeTide |
2 | 2 | from ...mcp.tools.patch_code import file_exists, open_file, process_patch, remove_file, write_file, parse_patch_blocks |
3 | 3 | from ...core.defaults import DEFAULT_ENCODING, DEFAULT_STORAGE_PATH |
4 | | -from ...search.code_search import SmartCodeSearch |
5 | 4 | from ...parsers import SUPPORTED_LANGUAGES |
6 | 5 | from ...autocomplete import AutoComplete |
7 | 6 | from .models import Steps |
8 | 7 | from .prompts import ( |
9 | | - AGENT_TIDE_SYSTEM_PROMPT, CALMNESS_SYSTEM_PROMPT, CMD_BRAINSTORM_PROMPT, CMD_CODE_REVIEW_PROMPT, CMD_TRIGGER_PLANNING_STEPS, CMD_WRITE_TESTS_PROMPT, GET_CODE_IDENTIFIERS_SYSTEM_PROMPT, README_CONTEXT_PROMPT, REJECT_PATCH_FEEDBACK_TEMPLATE, |
| 8 | + AGENT_TIDE_SYSTEM_PROMPT, CALMNESS_SYSTEM_PROMPT, CMD_BRAINSTORM_PROMPT, CMD_CODE_REVIEW_PROMPT, CMD_TRIGGER_PLANNING_STEPS, CMD_WRITE_TESTS_PROMPT, GET_CODE_IDENTIFIERS_UNIFIED_PROMPT, README_CONTEXT_PROMPT, REJECT_PATCH_FEEDBACK_TEMPLATE, |
10 | 9 | REPO_TREE_CONTEXT_PROMPT, STAGED_DIFFS_TEMPLATE, STEPS_SYSTEM_PROMPT, WRITE_PATCH_SYSTEM_PROMPT |
11 | 10 | ) |
12 | 11 | from .utils import delete_file, parse_blocks, parse_steps_markdown, trim_to_patch_section |
@@ -67,31 +66,19 @@ def pass_custom_logger_fn(self)->Self: |
67 | 66 | self.llm.logger_fn = partial(custom_logger_fn, session_id=self.session_id, filepath=self.patch_path) |
68 | 67 | return self |
69 | 68 |
|
70 | | - async def get_repo_tree_from_user_prompt(self, history :list)->str: |
| 69 | + async def get_repo_tree_from_user_prompt(self, history :list, include_modules :bool=False, expand_paths :Optional[List[str]]=None)->str: |
71 | 70 |
|
72 | | - history_str = "\n\n".join([str(entry) for entry in history]) |
| 71 | + history_str = "\n\n".join(history) |
73 | 72 | for CMD_PROMPT in [CMD_TRIGGER_PLANNING_STEPS, CMD_WRITE_TESTS_PROMPT, CMD_BRAINSTORM_PROMPT, CMD_CODE_REVIEW_PROMPT]: |
74 | 73 | history_str.replace(CMD_PROMPT, "") |
75 | | - ### TODO evalutate sending last N messages and giving more importance to |
76 | | - ### search results from latter messages |
77 | | - |
78 | | - nodes_dict = self.tide.codebase.compile_tree_nodes_dict() |
79 | | - nodes_dict = { |
80 | | - filepath: contents for filepath, elements in nodes_dict.items() |
81 | | - if (contents := "\n".join([filepath] + elements).strip()) |
82 | | - } |
83 | | - |
84 | | - codeSearch = SmartCodeSearch(documents=nodes_dict) |
85 | | - await codeSearch.initialize_async() |
86 | | - |
87 | | - results = await codeSearch.search_smart(history_str, top_k=15) |
88 | 74 |
|
89 | | - self.tide.codebase._build_tree_dict([doc_key for doc_key,_ in results] or None) |
| 75 | + self.tide.codebase._build_tree_dict(expand_paths) |
90 | 76 |
|
91 | | - return self.tide.codebase.get_tree_view( |
92 | | - include_modules=True, |
| 77 | + tree = self.tide.codebase.get_tree_view( |
| 78 | + include_modules=include_modules, |
93 | 79 | include_types=True |
94 | 80 | ) |
| 81 | + return tree |
95 | 82 |
|
96 | 83 | def approve(self): |
97 | 84 | self._has_patch = False |
@@ -126,56 +113,122 @@ def trim_messages(messages, tokenizer_fn, max_tokens :Optional[int]=None): |
126 | 113 | while messages and sum(len(tokenizer_fn(str(msg))) for msg in messages) > max_tokens: |
127 | 114 | messages.pop(0) # Remove from the beginning |
128 | 115 |
|
| 116 | + @staticmethod |
| 117 | + def get_valid_identifier(autocomplete :AutoComplete, identifier:str)->Optional[str]: |
| 118 | + result = autocomplete.validate_code_identifier(identifier) |
| 119 | + if result.get("is_valid"): |
| 120 | + return identifier |
| 121 | + elif result.get("matching_identifiers"): |
| 122 | + return result.get("matching_identifiers")[0] |
| 123 | + return None |
| 124 | + |
| 125 | + def _clean_history(self): |
| 126 | + for i in range(len(self.history)): |
| 127 | + message = self.history[i] |
| 128 | + if isinstance(message, dict): |
| 129 | + self.history[i] = message.get("content" ,"") |
| 130 | + |
129 | 131 | async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None): |
130 | 132 | TODAY = date.today() |
131 | | - |
132 | | - # update codetide with the latest changes made by the human and agent |
133 | 133 | await self.tide.check_for_updates(serialize=True, include_cached_ids=True) |
| 134 | + self._clean_history() |
134 | 135 |
|
135 | 136 | codeContext = None |
136 | 137 | if self._skip_context_retrieval: |
137 | 138 | ... |
138 | 139 | else: |
139 | | - if codeIdentifiers is None: |
140 | | - repo_tree = await self.get_repo_tree_from_user_prompt(self.history) |
141 | | - context_response = await self.llm.acomplete( |
| 140 | + autocomplete = AutoComplete(self.tide.cached_ids) |
| 141 | + matches = autocomplete.extract_words_from_text("\n\n".join(self.history)) |
| 142 | + |
| 143 | + # --- Begin Unified Identifier Retrieval --- |
| 144 | + identifiers_accum = set(matches["all_found_words"]) if codeIdentifiers is None else set(codeIdentifiers + matches["all_found_words"]) |
| 145 | + modify_accum = set() |
| 146 | + reasoning_accum = [] |
| 147 | + repo_tree = None |
| 148 | + smart_search_attempts = 0 |
| 149 | + max_smart_search_attempts = 3 |
| 150 | + done = False |
| 151 | + previous_reason = None |
| 152 | + |
| 153 | + while not done: |
| 154 | + expand_paths = ["./"] |
| 155 | + # 1. SmartCodeSearch to filter repo tree |
| 156 | + if repo_tree is None or smart_search_attempts > 0: |
| 157 | + repo_history = self.history |
| 158 | + if previous_reason: |
| 159 | + repo_history += [previous_reason] |
| 160 | + |
| 161 | + repo_tree = await self.get_repo_tree_from_user_prompt(self.history, include_modules=bool(smart_search_attempts), expand_paths=expand_paths) |
| 162 | + |
| 163 | + # 2. Single LLM call with unified prompt |
| 164 | + # Pass accumulated identifiers for context if this isn't the first iteration |
| 165 | + accumulated_context = "\n".join( |
| 166 | + sorted((identifiers_accum or set()) | (modify_accum or set())) |
| 167 | + ) if (identifiers_accum or modify_accum) else "" |
| 168 | + |
| 169 | + unified_response = await self.llm.acomplete( |
142 | 170 | self.history, |
143 | | - system_prompt=[GET_CODE_IDENTIFIERS_SYSTEM_PROMPT.format(DATE=TODAY, SUPPORTED_LANGUAGES=SUPPORTED_LANGUAGES)], # TODO improve this prompt to handle generic scenarios liek what does my porject do and so on |
| 171 | + system_prompt=[GET_CODE_IDENTIFIERS_UNIFIED_PROMPT.format( |
| 172 | + DATE=TODAY, |
| 173 | + SUPPORTED_LANGUAGES=SUPPORTED_LANGUAGES, |
| 174 | + IDENTIFIERS=accumulated_context |
| 175 | + )], |
144 | 176 | prefix_prompt=repo_tree, |
145 | 177 | stream=False |
146 | | - # json_output=True |
147 | 178 | ) |
148 | 179 |
|
149 | | - contextIdentifiers = parse_blocks(context_response, block_word="Context Identifiers", multiple=False) |
150 | | - modifyIdentifiers = parse_blocks(context_response, block_word="Modify Identifiers", multiple=False) |
151 | | - |
152 | | - reasoning = context_response.split("*** Begin") |
153 | | - if not reasoning: |
154 | | - reasoning = [context_response] |
155 | | - self.reasoning = reasoning[0].strip() |
156 | | - |
157 | | - self.contextIdentifiers = contextIdentifiers.splitlines() if isinstance(contextIdentifiers, str) else None |
158 | | - self.modifyIdentifiers = modifyIdentifiers.splitlines() if isinstance(modifyIdentifiers, str) else None |
159 | | - codeIdentifiers = self.contextIdentifiers or [] |
| 180 | + # Parse the unified response |
| 181 | + contextIdentifiers = parse_blocks(unified_response, block_word="Context Identifiers", multiple=False) |
| 182 | + modifyIdentifiers = parse_blocks(unified_response, block_word="Modify Identifiers", multiple=False) |
| 183 | + expandPaths = parse_blocks(unified_response, block_word="Expand Paths", multiple=False) |
160 | 184 |
|
161 | | - if self.modifyIdentifiers: |
162 | | - codeIdentifiers.extend(self.tide._as_file_paths(self.modifyIdentifiers)) |
163 | | - |
| 185 | + # Extract reasoning (everything before the first "*** Begin") |
| 186 | + reasoning_parts = unified_response.split("*** Begin") |
| 187 | + if reasoning_parts: |
| 188 | + reasoning_accum.append(reasoning_parts[0].strip()) |
| 189 | + previous_reason = reasoning_accum[-1] |
| 190 | + |
| 191 | + # Accumulate identifiers |
| 192 | + if contextIdentifiers: |
| 193 | + if smart_search_attempts == 0: |
| 194 | + ### clean wrongly mismtatched idenitifers |
| 195 | + identifiers_accum = set() |
| 196 | + for ident in contextIdentifiers.splitlines(): |
| 197 | + if ident := self.get_valid_identifier(autocomplete, ident.strip()): |
| 198 | + identifiers_accum.add(ident) |
| 199 | + |
| 200 | + if modifyIdentifiers: |
| 201 | + for ident in modifyIdentifiers.splitlines(): |
| 202 | + if ident := self.get_valid_identifier(autocomplete, ident.strip()): |
| 203 | + modify_accum.add(ident.strip()) |
| 204 | + |
| 205 | + if expandPaths: |
| 206 | + expand_paths = [ |
| 207 | + path for ident in expandPaths if (path := self.get_valid_identifier(autocomplete, ident.strip())) |
| 208 | + ] |
| 209 | + |
| 210 | + # Check if we have enough identifiers (unified prompt includes this decision) |
| 211 | + if "ENOUGH_IDENTIFIERS: TRUE" in unified_response.upper(): |
| 212 | + done = True |
| 213 | + else: |
| 214 | + smart_search_attempts += 1 |
| 215 | + if smart_search_attempts >= max_smart_search_attempts: |
| 216 | + done = True |
| 217 | + |
| 218 | + # Finalize identifiers |
| 219 | + self.reasoning = "\n\n".join(reasoning_accum) |
| 220 | + self.contextIdentifiers = list(identifiers_accum) if identifiers_accum else None |
| 221 | + self.modifyIdentifiers = list(modify_accum) if modify_accum else None |
| 222 | + |
| 223 | + codeIdentifiers = self.contextIdentifiers or [] |
| 224 | + if self.modifyIdentifiers: |
| 225 | + codeIdentifiers.extend(self.tide._as_file_paths(self.modifyIdentifiers)) |
| 226 | + |
| 227 | + # --- End Unified Identifier Retrieval --- |
164 | 228 | if codeIdentifiers: |
165 | | - autocomplete = AutoComplete(self.tide.cached_ids) |
166 | | - # Validate each code identifier |
167 | | - validatedCodeIdentifiers = [] |
168 | | - for codeId in codeIdentifiers: |
169 | | - result = autocomplete.validate_code_identifier(codeId) |
170 | | - if result.get("is_valid"): |
171 | | - validatedCodeIdentifiers.append(codeId) |
172 | | - |
173 | | - elif result.get("matching_identifiers"): |
174 | | - validatedCodeIdentifiers.append(result.get("matching_identifiers")[0]) |
| 229 | + self._last_code_identifers = set(codeIdentifiers) |
| 230 | + codeContext = self.tide.get(codeIdentifiers, as_string=True) |
175 | 231 |
|
176 | | - self._last_code_identifers = set(validatedCodeIdentifiers) |
177 | | - codeContext = self.tide.get(validatedCodeIdentifiers, as_string=True) |
178 | | - |
179 | 232 | if not codeContext: |
180 | 233 | codeContext = REPO_TREE_CONTEXT_PROMPT.format(REPO_TREE=self.tide.codebase.get_tree_view()) |
181 | 234 | readmeFile = self.tide.get("README.md", as_string_list=True) |
@@ -241,12 +294,19 @@ async def get_git_diff_staged_simple(directory: str) -> str: |
241 | 294 |
|
242 | 295 | return stdout.decode() |
243 | 296 |
|
| 297 | + def _has_staged(self)->bool: |
| 298 | + status = self.tide.repo.status() |
| 299 | + result = any([file_status == pygit2.GIT_STATUS_INDEX_MODIFIED for file_status in status.values()]) |
| 300 | + _logger.logger.debug(f"_has_staged {result=}") |
| 301 | + return result |
| 302 | + |
244 | 303 | async def _stage(self)->str: |
245 | 304 | index = self.tide.repo.index |
246 | | - for path in self.changed_paths: |
247 | | - index.add(path) |
| 305 | + if not self._has_staged(): |
| 306 | + for path in self.changed_paths: |
| 307 | + index.add(path) |
248 | 308 |
|
249 | | - index.write() |
| 309 | + index.write() |
250 | 310 |
|
251 | 311 | staged_diff = await self.get_git_diff_staged_simple(self.tide.rootpath) |
252 | 312 | staged_diff = staged_diff.strip() |
|
0 commit comments