Skip to content

Commit f015ac7

Browse files
Merge pull request #1366 from MervinPraison/claude/issue-1365-20260412-0930
fix: address critical concurrency, memory, and resource lifecycle gaps
2 parents a5add02 + 1277a7a commit f015ac7

8 files changed

Lines changed: 169 additions & 87 deletions

File tree

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4501,6 +4501,39 @@ def close(self) -> None:
45014501
except Exception as e:
45024502
logger.warning(f"Memory cleanup failed: {e}")
45034503

4504+
# LLM client cleanup - target actual live clients, not model strings
4505+
try:
4506+
# Primary cleanup targets - actual live clients
4507+
if hasattr(self, 'llm_instance') and self.llm_instance:
4508+
if hasattr(self.llm_instance, 'aclose'):
4509+
# Try async close first
4510+
try:
4511+
import asyncio
4512+
if asyncio.iscoroutinefunction(self.llm_instance.aclose):
4513+
# We're in sync context, so use asyncio.run() for the cleanup
4514+
asyncio.run(self.llm_instance.aclose())
4515+
else:
4516+
self.llm_instance.aclose()
4517+
except Exception:
4518+
# Fall back to sync close if async fails
4519+
if hasattr(self.llm_instance, 'close'):
4520+
self.llm_instance.close()
4521+
elif hasattr(self.llm_instance, 'close'):
4522+
self.llm_instance.close()
4523+
4524+
# Check for OpenAI client (common pattern in agents)
4525+
if hasattr(self, '_Agent__openai_client') and self._Agent__openai_client:
4526+
if hasattr(self._Agent__openai_client, 'close'):
4527+
self._Agent__openai_client.close()
4528+
4529+
# Legacy fallback - check self.llm._client (but less likely to work)
4530+
if hasattr(self, 'llm') and self.llm and not isinstance(self.llm, str):
4531+
llm_client = getattr(self.llm, '_client', None)
4532+
if llm_client and hasattr(llm_client, 'close'):
4533+
llm_client.close()
4534+
except Exception as e:
4535+
logger.warning(f"LLM client cleanup failed: {e}")
4536+
45044537
# MCP cleanup
45054538
try:
45064539
if hasattr(self, '_mcp_clients') and self._mcp_clients:

src/praisonai-agents/praisonaiagents/agent/async_safety.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -35,26 +35,8 @@ class DualLock:
3535
"""
3636

3737
def __init__(self):
38-
self._thread_lock = threading.Lock()
39-
self._async_lock: Optional[asyncio.Lock] = None
40-
self._loop_id: Optional[int] = None
41-
42-
def _get_async_lock(self) -> asyncio.Lock:
43-
"""Get or create asyncio.Lock for current event loop."""
44-
try:
45-
current_loop = asyncio.get_running_loop()
46-
current_loop_id = id(current_loop)
47-
48-
# Create new lock if loop changed or first time
49-
if self._loop_id != current_loop_id:
50-
self._async_lock = asyncio.Lock()
51-
self._loop_id = current_loop_id
52-
53-
return self._async_lock
54-
except RuntimeError:
55-
# No event loop running, fall back to thread lock in a new loop
56-
self._async_lock = asyncio.Lock()
57-
return self._async_lock
38+
"""Initialize with unified thread-safe locking."""
39+
self._thread_lock = threading.Lock() # Single canonical lock for all contexts
5840

5941
@contextmanager
6042
def sync(self):
@@ -64,10 +46,13 @@ def sync(self):
6446

6547
@asynccontextmanager
6648
async def async_lock(self):
67-
"""Acquire lock in asynchronous context using asyncio.Lock."""
68-
async_lock = self._get_async_lock()
69-
async with async_lock:
49+
"""Acquire lock in asynchronous context using threading.Lock via asyncio.to_thread()."""
50+
# Use asyncio.to_thread to acquire the thread lock without blocking the event loop
51+
await asyncio.to_thread(self._thread_lock.acquire)
52+
try:
7053
yield
54+
finally:
55+
self._thread_lock.release()
7156

7257
def is_async_context(self) -> bool:
7358
"""Check if we're currently in an async context."""
@@ -129,14 +114,12 @@ def __exit__(self, exc_type, exc_val, exc_tb):
129114

130115
async def __aenter__(self):
131116
"""Support for asynchronous context manager protocol."""
132-
async_lock = self._lock._get_async_lock()
133-
await async_lock.acquire()
117+
await asyncio.to_thread(self._lock._thread_lock.acquire)
134118
return self.value
135119

136120
async def __aexit__(self, exc_type, exc_val, exc_tb):
137121
"""Support for asynchronous context manager protocol."""
138-
async_lock = self._lock._get_async_lock()
139-
async_lock.release()
122+
self._lock._thread_lock.release()
140123
return None
141124

142125
def get(self) -> Any:

src/praisonai-agents/praisonaiagents/agent/tool_execution.py

Lines changed: 32 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import logging
1313
import asyncio
1414
import inspect
15+
import contextvars
1516
import concurrent.futures
1617
from typing import List, Optional, Any, Dict, Union, TYPE_CHECKING
1718

@@ -190,18 +191,37 @@ def _execute_tool_with_context(self, function_name, arguments, state, tool_call_
190191
if res.output and res.output.modified_data:
191192
arguments.update(res.output.modified_data)
192193

193-
with with_injection_context(state):
194-
# P8/G11: Apply tool timeout if configured
195-
tool_timeout = getattr(self, '_tool_timeout', None)
196-
if tool_timeout and tool_timeout > 0:
197-
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
198-
future = executor.submit(self._execute_tool_impl, function_name, arguments)
199-
try:
200-
result = future.result(timeout=tool_timeout)
201-
except concurrent.futures.TimeoutError:
202-
logging.warning(f"Tool {function_name} timed out after {tool_timeout}s")
203-
result = {"error": f"Tool timed out after {tool_timeout}s", "timeout": True}
204-
else:
194+
# P8/G11: Apply tool timeout if configured
195+
tool_timeout = getattr(self, '_tool_timeout', None)
196+
if tool_timeout and tool_timeout > 0:
197+
# Use copy_context to preserve injection context in executor thread
198+
ctx = contextvars.copy_context()
199+
200+
def execute_with_context():
201+
with with_injection_context(state):
202+
return self._execute_tool_impl(function_name, arguments)
203+
204+
# Use explicit executor lifecycle to actually bound execution time
205+
executor = concurrent.futures.ThreadPoolExecutor(max_workers=1)
206+
try:
207+
future = executor.submit(ctx.run, execute_with_context)
208+
try:
209+
result = future.result(timeout=tool_timeout)
210+
except concurrent.futures.TimeoutError:
211+
# Cancel and shutdown immediately to avoid blocking
212+
future.cancel()
213+
executor.shutdown(wait=False, cancel_futures=True)
214+
logging.warning(f"Tool {function_name} timed out after {tool_timeout}s")
215+
result = {"error": f"Tool timed out after {tool_timeout}s", "timeout": True}
216+
else:
217+
# Normal completion - shutdown gracefully
218+
executor.shutdown(wait=False)
219+
finally:
220+
# Ensure executor is always cleaned up
221+
if not executor._shutdown:
222+
executor.shutdown(wait=False)
223+
else:
224+
with with_injection_context(state):
205225
result = self._execute_tool_impl(function_name, arguments)
206226

207227
# Apply tool output truncation to prevent context overflow

src/praisonai-agents/praisonaiagents/checkpoints/service.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -485,9 +485,17 @@ async def _prune_checkpoints(self):
485485
if len(self._checkpoints) <= self.config.max_checkpoints:
486486
return
487487

488-
# Keep only the most recent checkpoints
489-
# Note: This doesn't actually delete git history, just our tracking
490-
self._checkpoints = self._checkpoints[:self.config.max_checkpoints]
488+
# Calculate how many to remove
489+
num_to_remove = len(self._checkpoints) - self.config.max_checkpoints
490+
491+
# Keep only the most recent checkpoints in memory (newest-last semantics)
492+
# Since save() appends (newest last), keep the last N entries
493+
self._checkpoints = self._checkpoints[-self.config.max_checkpoints:]
494+
495+
logger.info(f"Pruned {num_to_remove} old checkpoints to stay under limit of {self.config.max_checkpoints}")
496+
497+
# Emit pruning event for any cleanup hooks
498+
self._emit(CheckpointEvent.CHECKPOINTS_PRUNED, {"action": "pruned", "removed_count": num_to_remove})
491499

492500
async def get_checkpoint(self, checkpoint_id: str) -> Optional[Checkpoint]:
493501
"""Get a specific checkpoint by ID."""

src/praisonai-agents/praisonaiagents/checkpoints/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ class CheckpointEvent(str, Enum):
2626
INITIALIZED = "initialized"
2727
CHECKPOINT_CREATED = "checkpoint_created"
2828
CHECKPOINT_RESTORED = "checkpoint_restored"
29+
CHECKPOINTS_PRUNED = "checkpoints_pruned"
2930
ERROR = "error"
3031

3132

src/praisonai-agents/praisonaiagents/memory/core.py

Lines changed: 45 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,14 @@ def store_short_term(self, content: str, metadata: Optional[Dict] = None, qualit
6262
except Exception as e:
6363
self._log_verbose(f"Failed to store in {self.provider} STM: {e}", logging.WARNING)
6464

65-
# Backward compatibility: Also store in SQLite if not using SQLite adapter
66-
if hasattr(self, '_sqlite_adapter') and self._sqlite_adapter != getattr(self, 'memory_adapter', None):
65+
# Only use SQLite fallback if primary storage failed completely
66+
if not memory_id and hasattr(self, '_sqlite_adapter') and self._sqlite_adapter != getattr(self, 'memory_adapter', None):
6767
try:
68-
fallback_id = self._sqlite_adapter.store_short_term(content, metadata=clean_metadata, **kwargs)
69-
if not memory_id:
70-
memory_id = fallback_id
68+
memory_id = self._sqlite_adapter.store_short_term(content, metadata=clean_metadata, **kwargs)
69+
self._log_verbose(f"Stored in SQLite STM as fallback: {content[:100]}...")
7170
except Exception as e:
7271
logging.error(f"Failed to store in SQLite STM fallback: {e}")
73-
if not memory_id:
74-
return ""
72+
return ""
7573

7674
# Auto-promote to long-term memory if quality is high
7775
if auto_promote and quality_score >= 7.5: # High quality threshold
@@ -122,38 +120,39 @@ def store_short_term_structured(self, content: str, metadata: Optional[Dict] = N
122120
clean_metadata = self._sanitize_metadata(metadata)
123121

124122
# Protocol-driven storage: Try primary adapter first
123+
memory_id = ""
125124
primary_error = None
126-
memory_id = None
127-
128125
try:
129126
if hasattr(self, 'memory_adapter') and self.memory_adapter:
130127
memory_id = self.memory_adapter.store_short_term(content, metadata=clean_metadata, **kwargs)
131128
self._log_verbose(f"Stored in {self.provider} STM via adapter: {content[:100]}...")
132-
133-
# Auto-promote to long-term memory if quality is high
134-
if auto_promote and quality_score >= 7.5:
135-
try:
136-
self.store_long_term(content, clean_metadata, quality_score, user_id, **kwargs)
137-
self._log_verbose(f"Auto-promoted STM content to LTM (score: {quality_score:.2f})")
138-
except Exception as e:
139-
# Auto-promotion failure doesn't affect the primary storage result
140-
logging.warning(f"Failed to auto-promote to LTM: {e}")
141-
142-
# Emit memory event for successful storage
143-
self._emit_memory_event("store", "short_term", content, clean_metadata)
144-
145-
return MemoryResult.success_result(
146-
memory_id=memory_id,
147-
adapter_used=self.provider,
148-
context={
149-
"quality_score": quality_score,
150-
"auto_promoted": auto_promote and quality_score >= 7.5
151-
}
152-
)
153129
except Exception as e:
154130
primary_error = str(e)
155131
self._log_verbose(f"Failed to store in {self.provider} STM: {e}", logging.WARNING)
156132

133+
# Only proceed with success if we got a valid memory_id
134+
if memory_id:
135+
# Auto-promote to long-term memory if quality is high
136+
if auto_promote and quality_score >= 7.5:
137+
try:
138+
self.store_long_term(content, clean_metadata, quality_score, user_id, **kwargs)
139+
self._log_verbose(f"Auto-promoted STM content to LTM (score: {quality_score:.2f})")
140+
except Exception as e:
141+
# Auto-promotion failure doesn't affect the primary storage result
142+
logging.warning(f"Failed to auto-promote to LTM: {e}")
143+
144+
# Emit memory event for successful storage
145+
self._emit_memory_event("store", "short_term", content, clean_metadata)
146+
147+
return MemoryResult.success_result(
148+
memory_id=memory_id,
149+
adapter_used=self.provider,
150+
context={
151+
"quality_score": quality_score,
152+
"auto_promoted": auto_promote and quality_score >= 7.5
153+
}
154+
)
155+
157156
# Fallback to SQLite if available and different from primary adapter
158157
fallback_error = None
159158
if hasattr(self, '_sqlite_adapter') and self._sqlite_adapter != getattr(self, 'memory_adapter', None):
@@ -450,13 +449,25 @@ async def store_short_term_async(self, content: str, metadata: Optional[Dict] =
450449
raw_metadata["user_id"] = user_id
451450
clean_metadata = self._sanitize_metadata(raw_metadata)
452451

453-
# Store in SQLite STM
452+
# Try primary adapter first (async version)
454453
memory_id = ""
455454
try:
456-
memory_id = await asyncio.to_thread(self._store_sqlite_stm, content, clean_metadata, quality_score)
455+
if hasattr(self, 'memory_adapter') and self.memory_adapter:
456+
memory_id = await asyncio.to_thread(
457+
self.memory_adapter.store_short_term, content, metadata=clean_metadata, **kwargs
458+
)
459+
self._log_verbose(f"Stored in {self.provider} async STM via adapter: {content[:100]}...")
457460
except Exception as e:
458-
logging.error(f"Failed to store in SQLite STM: {e}")
459-
return ""
461+
self._log_verbose(f"Failed to store in {self.provider} async STM: {e}", logging.WARNING)
462+
463+
# Only use SQLite fallback if primary storage failed completely
464+
if not memory_id and hasattr(self, '_sqlite_adapter') and self._sqlite_adapter != getattr(self, 'memory_adapter', None):
465+
try:
466+
memory_id = await asyncio.to_thread(self._store_sqlite_stm, content, clean_metadata, quality_score)
467+
self._log_verbose(f"Stored in SQLite async STM as fallback: {content[:100]}...")
468+
except Exception as e:
469+
logging.error(f"Failed to store in SQLite async STM fallback: {e}")
470+
return ""
460471

461472
# Auto-promote to long-term memory if quality is high (async)
462473
if auto_promote and quality_score >= 7.5: # High quality threshold

src/praisonai-agents/praisonaiagents/process/process.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ def __init__(
4545
self.workflow_timeout = workflow_timeout
4646
self.task_retry_counter: Dict[str, int] = {} # Initialize retry counter
4747
self.workflow_finished = False # ADDED: Workflow finished flag
48+
self.workflow_cancelled = False # ADDED: Workflow cancellation flag for timeout
49+
self._state_lock_init = threading.Lock() # Thread lock for async lock creation
4850
self._state_lock = None # Lazy-initialized async lock for shared state protection
4951

5052
# Resolve verbose from output= param (takes precedence) or legacy verbose= param
@@ -255,7 +257,10 @@ def _find_next_not_started_task(self) -> Optional[Task]:
255257
continue # Skip if no valid path exists
256258

257259
if self.task_retry_counter.get(task_candidate.id, 0) < self.max_retries:
258-
self.task_retry_counter[task_candidate.id] = self.task_retry_counter.get(task_candidate.id, 0) + 1
260+
# Atomic increment using thread lock to prevent race conditions
261+
with self._state_lock_init:
262+
current_count = self.task_retry_counter.get(task_candidate.id, 0)
263+
self.task_retry_counter[task_candidate.id] = current_count + 1
259264
temp_current_task = task_candidate
260265
logging.debug(f"Fallback attempt {fallback_attempts}: Found 'not started' task: {temp_current_task.name}, retry count: {self.task_retry_counter[temp_current_task.id]}")
261266
return temp_current_task # Return the found task immediately
@@ -429,13 +434,19 @@ async def aworkflow(self) -> AsyncGenerator[str, None]:
429434
if self.workflow_timeout is not None:
430435
elapsed = time.monotonic() - workflow_start
431436
if elapsed > self.workflow_timeout:
432-
logging.warning(f"Workflow timeout ({self.workflow_timeout}s) exceeded after {elapsed:.1f}s, stopping.")
437+
logging.warning(f"Workflow timeout ({self.workflow_timeout}s) exceeded after {elapsed:.1f}s, cancelling workflow.")
438+
self.workflow_cancelled = True
433439
break
434440

435441
# ADDED: Check workflow finished flag at the start of each cycle
436442
if self.workflow_finished:
437443
logging.info("Workflow finished early as all tasks are completed.")
438444
break
445+
446+
# ADDED: Check workflow cancellation flag
447+
if self.workflow_cancelled:
448+
logging.warning("Workflow has been cancelled, stopping task execution.")
449+
break
439450

440451
# Add task summary at start of each cycle
441452
logging.debug(f"""
@@ -597,8 +608,11 @@ async def aworkflow(self) -> AsyncGenerator[str, None]:
597608
break
598609

599610
# Reset completed task to "not started" so it can run again (atomic operation)
611+
# Atomic state lock initialization
600612
if self._state_lock is None:
601-
self._state_lock = asyncio.Lock()
613+
with self._state_lock_init:
614+
if self._state_lock is None: # Double-checked locking pattern
615+
self._state_lock = asyncio.Lock()
602616
async with self._state_lock:
603617
if self.tasks[task_id].status == "completed":
604618
# Never reset loop tasks, decision tasks, or their subtasks if rerun is False
@@ -1031,6 +1045,11 @@ def workflow(self):
10311045
if self.workflow_finished:
10321046
logging.info("Workflow finished early as all tasks are completed.")
10331047
break
1048+
1049+
# ADDED: Check workflow cancellation flag
1050+
if self.workflow_cancelled:
1051+
logging.warning("Workflow has been cancelled, stopping task execution.")
1052+
break
10341053

10351054
# Add task summary at start of each cycle
10361055
logging.debug(f"""

0 commit comments

Comments
 (0)