Skip to content

Commit 1277a7a

Browse files
fix: resolve critical architecture issues identified in code reviews
- Fix checkpoint pruning logic reversal (use newest-last semantics) - Add CHECKPOINTS_PRUNED event type to replace ERROR for normal operations - Fix ThreadPoolExecutor timeout bypass with explicit executor lifecycle - Unify AsyncSafeState to use single thread lock across sync/async contexts - Fix agent cleanup to target actual live clients (llm_instance, openai_client) - Sync memory fallback logic across async/structured STM entry points - Move contextvars import to module level for better performance Addresses critical concurrency, security, and data integrity issues. Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent c13b9a7 commit 1277a7a

File tree

6 files changed

+95
-66
lines changed

6 files changed

+95
-66
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4501,9 +4501,33 @@ def close(self) -> None:
45014501
except Exception as e:
45024502
logger.warning(f"Memory cleanup failed: {e}")
45034503

4504-
# LLM client cleanup
4504+
# LLM client cleanup - target actual live clients, not model strings
45054505
try:
4506-
if hasattr(self, 'llm') and self.llm:
4506+
# Primary cleanup targets - actual live clients
4507+
if hasattr(self, 'llm_instance') and self.llm_instance:
4508+
if hasattr(self.llm_instance, 'aclose'):
4509+
# Try async close first
4510+
try:
4511+
import asyncio
4512+
if asyncio.iscoroutinefunction(self.llm_instance.aclose):
4513+
# We're in sync context, so use asyncio.run() for the cleanup
4514+
asyncio.run(self.llm_instance.aclose())
4515+
else:
4516+
self.llm_instance.aclose()
4517+
except Exception:
4518+
# Fall back to sync close if async fails
4519+
if hasattr(self.llm_instance, 'close'):
4520+
self.llm_instance.close()
4521+
elif hasattr(self.llm_instance, 'close'):
4522+
self.llm_instance.close()
4523+
4524+
# Check for OpenAI client (common pattern in agents)
4525+
if hasattr(self, '_Agent__openai_client') and self._Agent__openai_client:
4526+
if hasattr(self._Agent__openai_client, 'close'):
4527+
self._Agent__openai_client.close()
4528+
4529+
# Legacy fallback - check self.llm._client (but less likely to work)
4530+
if hasattr(self, 'llm') and self.llm and not isinstance(self.llm, str):
45074531
llm_client = getattr(self.llm, '_client', None)
45084532
if llm_client and hasattr(llm_client, 'close'):
45094533
llm_client.close()

src/praisonai-agents/praisonaiagents/agent/async_safety.py

Lines changed: 10 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -35,30 +35,8 @@ class DualLock:
3535
"""
3636

3737
def __init__(self):
38-
self._thread_lock = threading.Lock()
39-
self._async_lock: Optional[asyncio.Lock] = None
40-
self._loop_id: Optional[int] = None
41-
42-
def _get_async_lock(self) -> asyncio.Lock:
43-
"""Get or create asyncio.Lock for current event loop."""
44-
try:
45-
current_loop = asyncio.get_running_loop()
46-
current_loop_id = id(current_loop)
47-
48-
# Atomic check and create: use thread lock to protect async lock creation
49-
with self._thread_lock:
50-
# Create new lock if loop changed or first time
51-
if self._loop_id != current_loop_id:
52-
self._async_lock = asyncio.Lock()
53-
self._loop_id = current_loop_id
54-
55-
return self._async_lock
56-
except RuntimeError:
57-
# No event loop running, fall back to thread lock in a new loop
58-
with self._thread_lock:
59-
if self._async_lock is None:
60-
self._async_lock = asyncio.Lock()
61-
return self._async_lock
38+
"""Initialize with unified thread-safe locking."""
39+
self._thread_lock = threading.Lock() # Single canonical lock for all contexts
6240

6341
@contextmanager
6442
def sync(self):
@@ -68,10 +46,13 @@ def sync(self):
6846

6947
@asynccontextmanager
7048
async def async_lock(self):
71-
"""Acquire lock in asynchronous context using asyncio.Lock."""
72-
async_lock = self._get_async_lock()
73-
async with async_lock:
49+
"""Acquire lock in asynchronous context using threading.Lock via asyncio.to_thread()."""
50+
# Use asyncio.to_thread to acquire the thread lock without blocking the event loop
51+
await asyncio.to_thread(self._thread_lock.acquire)
52+
try:
7453
yield
54+
finally:
55+
self._thread_lock.release()
7556

7657
def is_async_context(self) -> bool:
7758
"""Check if we're currently in an async context."""
@@ -133,14 +114,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
133114

134115
async def __aenter__(self):
135116
"""Support for asynchronous context manager protocol."""
136-
async_lock = self._lock._get_async_lock()
137-
await async_lock.acquire()
117+
await asyncio.to_thread(self._lock._thread_lock.acquire)
138118
return self.value
139119

140120
async def __aexit__(self, exc_type, exc_val, exc_tb):
141121
"""Support for asynchronous context manager protocol."""
142-
async_lock = self._lock._get_async_lock()
143-
async_lock.release()
122+
self._lock._thread_lock.release()
144123
return None
145124

146125
def get(self) -> Any:

src/praisonai-agents/praisonaiagents/agent/tool_execution.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import asyncio
1414
import inspect
15+
import contextvars
1516
import concurrent.futures
1617
from typing import List, Optional, Any, Dict, Union, TYPE_CHECKING
1718

@@ -194,20 +195,31 @@ def _execute_tool_with_context(self, function_name, arguments, state, tool_call_
194195
tool_timeout = getattr(self, '_tool_timeout', None)
195196
if tool_timeout and tool_timeout > 0:
196197
# Use copy_context to preserve injection context in executor thread
197-
import contextvars
198198
ctx = contextvars.copy_context()
199199

200200
def execute_with_context():
201201
with with_injection_context(state):
202202
return self._execute_tool_impl(function_name, arguments)
203203

204-
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
204+
# Use explicit executor lifecycle to actually bound execution time
205+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
206+
try:
205207
future = executor.submit(ctx.run, execute_with_context)
206208
try:
207209
result = future.result(timeout=tool_timeout)
208210
except concurrent.futures.TimeoutError:
211+
# Cancel and shutdown immediately to avoid blocking
212+
future.cancel()
213+
executor.shutdown(wait=False, cancel_futures=True)
209214
logging.warning(f"Tool {function_name} timed out after {tool_timeout}s")
210215
result = {"error": f"Tool timed out after {tool_timeout}s", "timeout": True}
216+
else:
217+
# Normal completion - shutdown gracefully
218+
executor.shutdown(wait=False)
219+
finally:
220+
# Ensure executor is always cleaned up
221+
if not executor._shutdown:
222+
executor.shutdown(wait=False)
211223
else:
212224
with with_injection_context(state):
213225
result = self._execute_tool_impl(function_name, arguments)

src/praisonai-agents/praisonaiagents/checkpoints/service.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -487,15 +487,15 @@ async def _prune_checkpoints(self):
487487

488488
# Calculate how many to remove
489489
num_to_remove = len(self._checkpoints) - self.config.max_checkpoints
490-
checkpoints_to_remove = self._checkpoints[-num_to_remove:] # Remove oldest ones
491490

492-
# Keep only the most recent checkpoints in memory
493-
self._checkpoints = self._checkpoints[:self.config.max_checkpoints]
491+
# Keep only the most recent checkpoints in memory (newest-last semantics)
492+
# Since save() appends (newest last), keep the last N entries
493+
self._checkpoints = self._checkpoints[-self.config.max_checkpoints:]
494494

495495
logger.info(f"Pruned {num_to_remove} old checkpoints to stay under limit of {self.config.max_checkpoints}")
496496

497497
# Emit pruning event for any cleanup hooks
498-
self._emit(CheckpointEvent.ERROR, {"action": "pruned", "removed_count": num_to_remove})
498+
self._emit(CheckpointEvent.CHECKPOINTS_PRUNED, {"action": "pruned", "removed_count": num_to_remove})
499499

500500
async def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
501501
"""Get a specific checkpoint by ID."""

src/praisonai-agents/praisonaiagents/checkpoints/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class CheckpointEvent(str, Enum):
2626
INITIALIZED = "initialized"
2727
CHECKPOINT_CREATED = "checkpoint_created"
2828
CHECKPOINT_RESTORED = "checkpoint_restored"
29+
CHECKPOINTS_PRUNED = "checkpoints_pruned"
2930
ERROR = "error"
3031

3132

src/praisonai-agents/praisonaiagents/memory/core.py

Lines changed: 40 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -120,38 +120,39 @@ def store_short_term_structured(self, content: str, metadata: Optional[Dict] = N
120120
clean_metadata = self._sanitize_metadata(metadata)
121121

122122
# Protocol-driven storage: Try primary adapter first
123+
memory_id = ""
123124
primary_error = None
124-
memory_id = None
125-
126125
try:
127126
if hasattr(self, 'memory_adapter') and self.memory_adapter:
128127
memory_id = self.memory_adapter.store_short_term(content, metadata=clean_metadata, **kwargs)
129128
self._log_verbose(f"Stored in {self.provider} STM via adapter: {content[:100]}...")
130-
131-
# Auto-promote to long-term memory if quality is high
132-
if auto_promote and quality_score >= 7.5:
133-
try:
134-
self.store_long_term(content, clean_metadata, quality_score, user_id, **kwargs)
135-
self._log_verbose(f"Auto-promoted STM content to LTM (score: {quality_score:.2f})")
136-
except Exception as e:
137-
# Auto-promotion failure doesn't affect the primary storage result
138-
logging.warning(f"Failed to auto-promote to LTM: {e}")
139-
140-
# Emit memory event for successful storage
141-
self._emit_memory_event("store", "short_term", content, clean_metadata)
142-
143-
return MemoryResult.success_result(
144-
memory_id=memory_id,
145-
adapter_used=self.provider,
146-
context={
147-
"quality_score": quality_score,
148-
"auto_promoted": auto_promote and quality_score >= 7.5
149-
}
150-
)
151129
except Exception as e:
152130
primary_error = str(e)
153131
self._log_verbose(f"Failed to store in {self.provider} STM: {e}", logging.WARNING)
154132

133+
# Only proceed with success if we got a valid memory_id
134+
if memory_id:
135+
# Auto-promote to long-term memory if quality is high
136+
if auto_promote and quality_score >= 7.5:
137+
try:
138+
self.store_long_term(content, clean_metadata, quality_score, user_id, **kwargs)
139+
self._log_verbose(f"Auto-promoted STM content to LTM (score: {quality_score:.2f})")
140+
except Exception as e:
141+
# Auto-promotion failure doesn't affect the primary storage result
142+
logging.warning(f"Failed to auto-promote to LTM: {e}")
143+
144+
# Emit memory event for successful storage
145+
self._emit_memory_event("store", "short_term", content, clean_metadata)
146+
147+
return MemoryResult.success_result(
148+
memory_id=memory_id,
149+
adapter_used=self.provider,
150+
context={
151+
"quality_score": quality_score,
152+
"auto_promoted": auto_promote and quality_score >= 7.5
153+
}
154+
)
155+
155156
# Fallback to SQLite if available and different from primary adapter
156157
fallback_error = None
157158
if hasattr(self, '_sqlite_adapter') and self._sqlite_adapter != getattr(self, 'memory_adapter', None):
@@ -448,13 +449,25 @@ async def store_short_term_async(self, content: str, metadata: Optional[Dict] =
448449
raw_metadata["user_id"] = user_id
449450
clean_metadata = self._sanitize_metadata(raw_metadata)
450451

451-
# Store in SQLite STM
452+
# Try primary adapter first (async version)
452453
memory_id = ""
453454
try:
454-
memory_id = await asyncio.to_thread(self._store_sqlite_stm, content, clean_metadata, quality_score)
455+
if hasattr(self, 'memory_adapter') and self.memory_adapter:
456+
memory_id = await asyncio.to_thread(
457+
self.memory_adapter.store_short_term, content, metadata=clean_metadata, **kwargs
458+
)
459+
self._log_verbose(f"Stored in {self.provider} async STM via adapter: {content[:100]}...")
455460
except Exception as e:
456-
logging.error(f"Failed to store in SQLite STM: {e}")
457-
return ""
461+
self._log_verbose(f"Failed to store in {self.provider} async STM: {e}", logging.WARNING)
462+
463+
# Only use SQLite fallback if primary storage failed completely
464+
if not memory_id and hasattr(self, '_sqlite_adapter') and self._sqlite_adapter != getattr(self, 'memory_adapter', None):
465+
try:
466+
memory_id = await asyncio.to_thread(self._store_sqlite_stm, content, clean_metadata, quality_score)
467+
self._log_verbose(f"Stored in SQLite async STM as fallback: {content[:100]}...")
468+
except Exception as e:
469+
logging.error(f"Failed to store in SQLite async STM fallback: {e}")
470+
return ""
458471

459472
# Auto-promote to long-term memory if quality is high (async)
460473
if auto_promote and quality_score >= 7.5: # High quality threshold

0 commit comments

Comments
 (0)