Skip to content

Commit 532c745

Browse files
authored
Merge pull request #38 from BrunoV21/improved-chunk-logger
Improved chunk logger
2 parents ef020a2 + 0bcb1b1 commit 532c745

File tree

13 files changed

+1435
-216
lines changed

13 files changed

+1435
-216
lines changed

codetide/agents/tide/agent.py

Lines changed: 48 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from codetide import CodeTide
22
from ...mcp.tools.patch_code import file_exists, open_file, process_patch, remove_file, write_file, parse_patch_blocks
3-
from ...core.defaults import DEFAULT_ENCODING, DEFAULT_STORAGE_PATH
3+
from ...core.defaults import DEFAULT_STORAGE_PATH
44
from ...parsers import SUPPORTED_LANGUAGES
55
from ...autocomplete import AutoComplete
66
from .models import Steps
@@ -13,7 +13,8 @@
1313

1414
try:
1515
from aicore.llm import Llm
16-
from aicore.logger import _logger, SPECIAL_TOKENS
16+
from aicore.logger import _logger
17+
from .streaming.service import custom_logger_fn
1718
except ImportError as e:
1819
raise ImportError(
1920
"The 'codetide.agents' module requires the 'aicore' package. "
@@ -29,18 +30,10 @@
2930
from datetime import date
3031
from pathlib import Path
3132
from ulid import ulid
32-
import aiofiles
3333
import asyncio
3434
import pygit2
3535
import os
3636

37-
async def custom_logger_fn(message :str, session_id :str, filepath :str):
38-
if message not in SPECIAL_TOKENS:
39-
async with aiofiles.open(filepath, 'a', encoding=DEFAULT_ENCODING) as f:
40-
await f.write(message)
41-
42-
await _logger.log_chunk_to_queue(message, session_id)
43-
4437
class AgentTide(BaseModel):
4538
llm :Llm
4639
tide :CodeTide
@@ -60,6 +53,11 @@ class AgentTide(BaseModel):
6053
_has_patch :bool=False
6154
_direct_mode :bool=False
6255

56+
# Number of previous interactions to remember for context identifiers
57+
CONTEXT_WINDOW_SIZE: int = 3
58+
# Rolling window of identifier sets from previous N interactions
59+
_context_identifier_window: Optional[list] = None
60+
6361
model_config = ConfigDict(arbitrary_types_allowed=True)
6462

6563
@model_validator(mode="after")
@@ -134,23 +132,43 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
134132
await self.tide.check_for_updates(serialize=True, include_cached_ids=True)
135133
self._clean_history()
136134

135+
# Initialize the context identifier window if not present
136+
if self._context_identifier_window is None:
137+
self._context_identifier_window = []
138+
137139
codeContext = None
138140
if self._skip_context_retrieval:
139141
...
140142
else:
141143
autocomplete = AutoComplete(self.tide.cached_ids)
142144
if self._direct_mode:
143145
self.contextIdentifiers = None
144-
exact_matches = autocomplete.extract_words_from_text(self.history[-1], max_matches_per_word=1)["all_found_words"]
146+
# Only extract matches from the last message
147+
last_message = self.history[-1] if self.history else ""
148+
exact_matches = autocomplete.extract_words_from_text(last_message, max_matches_per_word=1)["all_found_words"]
145149
self.modifyIdentifiers = self.tide._as_file_paths(exact_matches)
146150
codeIdentifiers = self.modifyIdentifiers
147151
self._direct_mode = False
148-
152+
# Update the context identifier window
153+
self._context_identifier_window.append(set(exact_matches))
154+
if len(self._context_identifier_window) > self.CONTEXT_WINDOW_SIZE:
155+
self._context_identifier_window.pop(0)
149156
else:
150-
matches = autocomplete.extract_words_from_text("\n\n".join(self.history), max_matches_per_word=1)
151-
152-
# --- Begin Unified Identifier Retrieval ---
153-
identifiers_accum = set(matches["all_found_words"]) if codeIdentifiers is None else set(codeIdentifiers + matches["all_found_words"])
157+
# Only extract matches from the last message
158+
last_message = self.history[-1] if self.history else ""
159+
matches = autocomplete.extract_words_from_text(last_message, max_matches_per_word=1)["all_found_words"]
160+
print(f"{matches=}")
161+
# Update the context identifier window
162+
self._context_identifier_window.append(set(matches))
163+
if len(self._context_identifier_window) > self.CONTEXT_WINDOW_SIZE:
164+
self._context_identifier_window.pop(0)
165+
# Combine identifiers from the last N interactions
166+
window_identifiers = set()
167+
for s in self._context_identifier_window:
168+
window_identifiers.update(s)
169+
# If codeIdentifiers is passed, include them as well
170+
identifiers_accum = set(codeIdentifiers) if codeIdentifiers else set()
171+
identifiers_accum.update(window_identifiers)
154172
modify_accum = set()
155173
reasoning_accum = []
156174
repo_tree = None
@@ -166,57 +184,55 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
166184
repo_history = self.history
167185
if previous_reason:
168186
repo_history += [previous_reason]
169-
187+
170188
repo_tree = await self.get_repo_tree_from_user_prompt(self.history, include_modules=bool(smart_search_attempts), expand_paths=expand_paths)
171-
189+
172190
# 2. Single LLM call with unified prompt
173191
# Pass accumulated identifiers for context if this isn't the first iteration
174192
accumulated_context = "\n".join(
175193
sorted((identifiers_accum or set()) | (modify_accum or set()))
176194
) if (identifiers_accum or modify_accum) else ""
177-
195+
178196
unified_response = await self.llm.acomplete(
179197
self.history,
180198
system_prompt=[GET_CODE_IDENTIFIERS_UNIFIED_PROMPT.format(
181-
DATE=TODAY,
199+
DATE=TODAY,
182200
SUPPORTED_LANGUAGES=SUPPORTED_LANGUAGES,
183201
IDENTIFIERS=accumulated_context
184202
)],
185203
prefix_prompt=repo_tree,
186204
stream=False
187205
)
188-
print(f"{unified_response=}")
189206

190207
# Parse the unified response
191208
contextIdentifiers = parse_blocks(unified_response, block_word="Context Identifiers", multiple=False)
192209
modifyIdentifiers = parse_blocks(unified_response, block_word="Modify Identifiers", multiple=False)
193-
expandPaths = parse_blocks(unified_response, block_word="Expand Paths", multiple=False)
194-
210+
expandPaths = parse_blocks(unified_response, block_word="Expand Paths", multiple=False)
211+
195212
# Extract reasoning (everything before the first "*** Begin")
196213
reasoning_parts = unified_response.split("*** Begin")
197214
if reasoning_parts:
198215
reasoning_accum.append(reasoning_parts[0].strip())
199216
previous_reason = reasoning_accum[-1]
200-
217+
201218
# Accumulate identifiers
202219
if contextIdentifiers:
203220
if smart_search_attempts == 0:
204-
### clean wrongly mismtatched idenitifers
205221
identifiers_accum = set()
206222
for ident in contextIdentifiers.splitlines():
207-
if ident := self.get_valid_identifier(autocomplete, ident.strip()):
223+
if ident := self.get_valid_identifier(autocomplete, ident.strip()):
208224
identifiers_accum.add(ident)
209-
225+
210226
if modifyIdentifiers:
211227
for ident in modifyIdentifiers.splitlines():
212228
if ident := self.get_valid_identifier(autocomplete, ident.strip()):
213229
modify_accum.add(ident.strip())
214-
230+
215231
if expandPaths:
216232
expand_paths = [
217233
path for ident in expandPaths if (path := self.get_valid_identifier(autocomplete, ident.strip()))
218234
]
219-
235+
220236
# Check if we have enough identifiers (unified prompt includes this decision)
221237
if "ENOUGH_IDENTIFIERS: TRUE" in unified_response.upper():
222238
done = True
@@ -235,7 +251,7 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
235251
self.modifyIdentifiers = self.tide._as_file_paths(self.modifyIdentifiers)
236252
codeIdentifiers.extend(self.modifyIdentifiers)
237253
# TODO preserve passed identifiers by the user
238-
codeIdentifiers += matches["all_found_words"]
254+
codeIdentifiers += matches
239255

240256
# --- End Unified Identifier Retrieval ---
241257
if codeIdentifiers:
@@ -244,7 +260,8 @@ async def agent_loop(self, codeIdentifiers :Optional[List[str]]=None):
244260

245261
if not codeContext:
246262
codeContext = REPO_TREE_CONTEXT_PROMPT.format(REPO_TREE=self.tide.codebase.get_tree_view())
247-
readmeFile = self.tide.get(["README.md"] + matches["all_found_words"] , as_string_list=True)
263+
# Use matches from the last message for README context
264+
readmeFile = self.tide.get(["README.md"] + (matches if 'matches' in locals() else []), as_string_list=True)
248265
if readmeFile:
249266
codeContext = "\n".join([codeContext, README_CONTEXT_PROMPT.format(README=readmeFile)])
250267

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .background_flusher import BackgroundFlusher
2+
from .chunk_logger import ChunkLogger
3+
4+
__all__ = [
5+
"BackgroundFlusher",
6+
"ChunkLogger"
7+
]
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
from .chunk_logger import ChunkLogger
2+
from typing import Optional
3+
import asyncio
4+
5+
class BackgroundFlusher:
6+
"""
7+
# For very high throughput, you can use the background flusher:
8+
background_flusher = BackgroundFlusher(_optimized_logger, flush_interval=0.05)
9+
await background_flusher.start()
10+
11+
# ... your application code ...
12+
13+
# Clean shutdown
14+
await background_flusher.stop()
15+
await _optimized_logger.shutdown()
16+
"""
17+
def __init__(self, logger: ChunkLogger, flush_interval: float = 0.1):
18+
self.logger = logger
19+
self.flush_interval = flush_interval
20+
self._task: Optional[asyncio.Task] = None
21+
self._running = False
22+
23+
async def start(self):
24+
"""Start background flushing task"""
25+
if self._task and not self._task.done():
26+
return
27+
28+
self._running = True
29+
self._task = asyncio.create_task(self._flush_loop())
30+
self.logger._background_tasks.add(self._task)
31+
32+
async def stop(self):
33+
"""Stop background flushing"""
34+
self._running = False
35+
if self._task:
36+
self._task.cancel()
37+
try:
38+
await self._task
39+
except asyncio.CancelledError:
40+
pass
41+
42+
async def _flush_loop(self):
43+
"""Background flush loop"""
44+
try:
45+
while self._running:
46+
await asyncio.sleep(self.flush_interval)
47+
if not self._running:
48+
break
49+
50+
# Flush all file buffers
51+
flush_tasks = []
52+
for filepath in list(self.logger._file_buffers.keys()):
53+
if self.logger._file_buffers[filepath]:
54+
flush_tasks.append(self.logger._flush_file_buffer(filepath))
55+
56+
if flush_tasks:
57+
await asyncio.gather(*flush_tasks, return_exceptions=True)
58+
except asyncio.CancelledError:
59+
raise
60+
except Exception:
61+
pass # Ignore errors in background task
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
from ....core.defaults import DEFAULT_ENCODING
2+
from aicore.logger import SPECIAL_TOKENS
3+
4+
from typing import List, Dict, AsyncGenerator
5+
from collections import defaultdict, deque
6+
from pathlib import Path
7+
import portalocker
8+
import asyncio
9+
import time
10+
11+
class ChunkLogger:
12+
def __init__(self, buffer_size: int = 1024, flush_interval: float = 0.1):
13+
self.buffer_size = buffer_size
14+
self.flush_interval = flush_interval
15+
self._session_buffers: Dict[str, deque] = defaultdict(deque)
16+
self._session_subscribers: Dict[str, List] = defaultdict(list)
17+
self._file_buffers: Dict[str, List[str]] = defaultdict(list)
18+
self._last_flush_time: Dict[str, float] = defaultdict(float)
19+
self._background_tasks: set = set()
20+
self._shutdown = False
21+
22+
async def log_chunk(self, message: str, session_id: str, filepath: str):
23+
"""Optimized chunk logging with batched file writes and direct streaming"""
24+
if message not in SPECIAL_TOKENS:
25+
# Add to file buffer for batched writing
26+
self._file_buffers[filepath].append(message)
27+
current_time = time.time()
28+
29+
# Check if we should flush based on buffer size or time
30+
should_flush = (
31+
len(self._file_buffers[filepath]) >= self.buffer_size or
32+
current_time - self._last_flush_time[filepath] >= self.flush_interval
33+
)
34+
35+
if should_flush:
36+
await self._flush_file_buffer(filepath)
37+
self._last_flush_time[filepath] = current_time
38+
39+
# Directly notify subscribers without queue overhead
40+
await self._notify_subscribers(session_id, message)
41+
42+
async def _flush_file_buffer(self, filepath: str):
43+
"""Flush buffer to file with file locking"""
44+
if not self._file_buffers[filepath]:
45+
return
46+
47+
messages_to_write = self._file_buffers[filepath].copy()
48+
self._file_buffers[filepath].clear()
49+
50+
# Create directory if it doesn't exist
51+
Path(filepath).parent.mkdir(parents=True, exist_ok=True)
52+
53+
try:
54+
# Use portalocker for safe concurrent file access
55+
with open(filepath, 'a', encoding=DEFAULT_ENCODING) as f:
56+
portalocker.lock(f, portalocker.LOCK_EX)
57+
try:
58+
f.writelines(messages_to_write)
59+
f.flush() # Ensure data is written to disk
60+
finally:
61+
portalocker.unlock(f)
62+
except Exception as e:
63+
# Re-add messages to buffer if write failed
64+
self._file_buffers[filepath].extendleft(reversed(messages_to_write))
65+
raise e
66+
67+
async def _notify_subscribers(self, session_id: str, message: str):
68+
"""Directly notify subscribers without queue overhead"""
69+
if session_id in self._session_subscribers:
70+
# Use a list copy to avoid modification during iteration
71+
subscribers = list(self._session_subscribers[session_id])
72+
for queue in subscribers:
73+
try:
74+
queue.put_nowait(message)
75+
except asyncio.QueueFull:
76+
# Remove full queues (slow consumers)
77+
self._session_subscribers[session_id].remove(queue)
78+
except Exception:
79+
# Remove invalid queues
80+
if queue in self._session_subscribers[session_id]:
81+
self._session_subscribers[session_id].remove(queue)
82+
83+
async def get_session_logs(self, session_id: str) -> AsyncGenerator[str, None]:
84+
"""Get streaming logs for a session without separate distributor task"""
85+
# Create a queue for this subscriber
86+
queue = asyncio.Queue(maxsize=1000) # Prevent memory issues
87+
88+
# Add to subscribers
89+
self._session_subscribers[session_id].append(queue)
90+
91+
try:
92+
while not self._shutdown:
93+
try:
94+
# Use a timeout to allow for cleanup checks
95+
chunk = await asyncio.wait_for(queue.get(), timeout=1.0)
96+
yield chunk
97+
except asyncio.TimeoutError:
98+
# Check if we should continue or if there are no more publishers
99+
continue
100+
except asyncio.CancelledError:
101+
break
102+
finally:
103+
# Cleanup subscriber
104+
if queue in self._session_subscribers[session_id]:
105+
self._session_subscribers[session_id].remove(queue)
106+
107+
# Clean up empty session entries
108+
if not self._session_subscribers[session_id]:
109+
del self._session_subscribers[session_id]
110+
111+
async def ensure_all_flushed(self):
112+
"""Ensure all buffers are flushed - call before shutdown"""
113+
flush_tasks = []
114+
for filepath in list(self._file_buffers.keys()):
115+
if self._file_buffers[filepath]:
116+
flush_tasks.append(self._flush_file_buffer(filepath))
117+
118+
if flush_tasks:
119+
await asyncio.gather(*flush_tasks, return_exceptions=True)
120+
121+
async def shutdown(self):
122+
"""Graceful shutdown"""
123+
self._shutdown = True
124+
await self.ensure_all_flushed()
125+
126+
# Cancel any background tasks
127+
for task in self._background_tasks:
128+
if not task.done():
129+
task.cancel()
130+
131+
if self._background_tasks:
132+
await asyncio.gather(*self._background_tasks, return_exceptions=True)

0 commit comments

Comments
 (0)