Skip to content

Commit 3c45e6f

Browse files
fix: Agent thread safety - eager lock initialization and LRU cache bounds (#1175)
* fix: implement thread safety for Agent class - Replace lazy lock initialization with eager initialization for chat_history and cache locks - Add thread-safe methods for chat_history operations: _add_to_chat_history(), _get_chat_history_length(), _truncate_chat_history() - Implement proper LRU cache eviction using OrderedDict with _cache_put() and _cache_get() methods - Add snapshot_lock to protect undo/redo stacks from TOCTOU race conditions - Update critical chat_history and cache operations to use thread-safe methods - Enforce _max_cache_size limits with LRU eviction instead of simple size checks Fixes race conditions in multi-agent concurrent workflows where: - chat_history entries could be silently dropped during tool execution - cache dictionaries could have interleaved reads/writes - snapshot/redo stacks had TOCTOU vulnerabilities in pop() operations - unbounded cache growth in long-running agents Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com> * fix: resolve thread safety issues in Agent class - Move OrderedDict import to module level for better performance - Add module-level logger to fix undefined logger errors in undo/redo - Fix critical TOCTOU race in duplicate message check with atomic method - Initialize snapshot lock unconditionally to prevent AttributeError - Protect diff() method snapshot stack reads with lock Addresses reviewer feedback from Qodo, CodeRabbit, and Copilot. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Mervin Praison <MervinPraison@users.noreply.github.com>
1 parent 9b85a22 commit 3c45e6f

File tree

1 file changed

+164
-61
lines changed
  • src/praisonai-agents/praisonaiagents/agent

1 file changed

+164
-61
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 164 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,12 @@
66
import contextlib
77
import threading
88
from typing import List, Optional, Any, Dict, Union, Literal, TYPE_CHECKING, Callable, Generator
9+
from collections import OrderedDict
910
import inspect
1011

12+
# Module-level logger for thread safety errors and debugging
13+
logger = logging.getLogger(__name__)
14+
1115
# ============================================================================
1216
# Performance: Lazy imports for heavy dependencies
1317
# Rich, LLM, and display utilities are only imported when needed (output=verbose)
@@ -1511,9 +1515,12 @@ def __init__(
15111515
self.embedder_config = embedder_config
15121516
self.knowledge = knowledge
15131517
self.use_system_prompt = use_system_prompt
1514-
# Thread-safe chat_history with lazy lock for concurrent access
1518+
# Thread-safe chat_history with eager lock initialization
15151519
self.chat_history = []
1516-
self.__history_lock = None # Lazy initialized
1520+
self.__history_lock = threading.Lock() # Eager initialization to prevent race conditions
1521+
1522+
# Thread-safe snapshot/redo stack lock - always available even when autonomy is disabled
1523+
self.__snapshot_lock = threading.Lock()
15171524
self.markdown = markdown
15181525
self.stream = stream
15191526
self.metrics = metrics
@@ -1632,10 +1639,11 @@ def __init__(
16321639
# P8/G11: Tool timeout - prevent slow tools from blocking
16331640
self._tool_timeout = tool_timeout
16341641

1635-
# Cache for system prompts and formatted tools with lazy thread-safe lock
1636-
self._system_prompt_cache = {}
1637-
self._formatted_tools_cache = {}
1638-
self.__cache_lock = None # Lazy initialized RLock
1642+
# Cache for system prompts and formatted tools with eager thread-safe lock
1643+
# Use OrderedDict for LRU behavior
1644+
self._system_prompt_cache = OrderedDict()
1645+
self._formatted_tools_cache = OrderedDict()
1646+
self.__cache_lock = threading.RLock() # Eager initialization to prevent race conditions
16391647
# Limit cache size to prevent unbounded growth
16401648
self._max_cache_size = 100
16411649

@@ -1747,20 +1755,106 @@ def _telemetry(self):
17471755

17481756
@property
17491757
def _history_lock(self):
1750-
"""Lazy-loaded history lock for thread-safe chat history access."""
1751-
if self.__history_lock is None:
1752-
import threading
1753-
self.__history_lock = threading.Lock()
1758+
"""Thread-safe chat history lock."""
17541759
return self.__history_lock
17551760

17561761
@property
17571762
def _cache_lock(self):
1758-
"""Lazy-loaded cache lock for thread-safe cache access."""
1759-
if self.__cache_lock is None:
1760-
import threading
1761-
self.__cache_lock = threading.RLock()
1763+
"""Thread-safe cache lock."""
17621764
return self.__cache_lock
17631765

1766+
@property
1767+
def _snapshot_lock(self):
1768+
"""Thread-safe snapshot/redo stack lock."""
1769+
return self.__snapshot_lock
1770+
1771+
def _cache_put(self, cache_dict, key, value):
1772+
"""Thread-safe LRU cache put operation.
1773+
1774+
Args:
1775+
cache_dict: The cache dictionary (OrderedDict)
1776+
key: Cache key
1777+
value: Value to cache
1778+
"""
1779+
with self._cache_lock:
1780+
# Move to end if already exists (LRU update)
1781+
if key in cache_dict:
1782+
del cache_dict[key]
1783+
1784+
# Add new entry
1785+
cache_dict[key] = value
1786+
1787+
# Evict oldest if over limit
1788+
while len(cache_dict) > self._max_cache_size:
1789+
cache_dict.popitem(last=False) # Remove oldest (FIFO)
1790+
1791+
def _add_to_chat_history(self, role, content):
1792+
"""Thread-safe method to add messages to chat history.
1793+
1794+
Args:
1795+
role: Message role ("user", "assistant", "system")
1796+
content: Message content
1797+
"""
1798+
with self._history_lock:
1799+
self.chat_history.append({"role": role, "content": content})
1800+
1801+
def _add_to_chat_history_if_not_duplicate(self, role, content):
1802+
"""Thread-safe method to add messages to chat history only if not duplicate.
1803+
1804+
Atomically checks for duplicate and adds message under the same lock to prevent TOCTOU races.
1805+
1806+
Args:
1807+
role: Message role ("user", "assistant", "system")
1808+
content: Message content
1809+
1810+
Returns:
1811+
bool: True if message was added, False if duplicate was detected
1812+
"""
1813+
with self._history_lock:
1814+
# Check for duplicate within the same critical section
1815+
if (self.chat_history and
1816+
self.chat_history[-1].get("role") == role and
1817+
self.chat_history[-1].get("content") == content):
1818+
return False
1819+
1820+
# Not a duplicate, add the message
1821+
self.chat_history.append({"role": role, "content": content})
1822+
return True
1823+
1824+
def _get_chat_history_length(self):
1825+
"""Thread-safe method to get chat history length."""
1826+
with self._history_lock:
1827+
return len(self.chat_history)
1828+
1829+
def _truncate_chat_history(self, length):
1830+
"""Thread-safe method to truncate chat history to specified length.
1831+
1832+
Args:
1833+
length: Target length for chat history
1834+
"""
1835+
with self._history_lock:
1836+
self.chat_history = self.chat_history[:length]
1837+
1838+
def _cache_get(self, cache_dict, key):
1839+
"""Thread-safe LRU cache get operation.
1840+
1841+
Args:
1842+
cache_dict: The cache dictionary (OrderedDict)
1843+
key: Cache key
1844+
1845+
Returns:
1846+
Value if found, None otherwise
1847+
"""
1848+
with self._cache_lock:
1849+
if key not in cache_dict:
1850+
return None
1851+
1852+
# Move to end (mark as recently used)
1853+
value = cache_dict[key]
1854+
del cache_dict[key]
1855+
cache_dict[key] = value
1856+
return value
1857+
17641858
@property
17651859
def auto_memory(self):
17661860
"""AutoMemory instance for automatic memory extraction."""
@@ -2220,19 +2314,23 @@ def undo(self) -> bool:
22202314
result = agent.start("Refactor utils.py")
22212315
agent.undo() # Restore original files
22222316
"""
2223-
if self._file_snapshot is None or not self._snapshot_stack:
2224-
return False
2225-
try:
2226-
target_hash = self._snapshot_stack.pop()
2227-
# Get current hash before restore (for redo)
2228-
current_hash = self._file_snapshot.get_current_hash()
2229-
if current_hash:
2230-
self._redo_stack.append(current_hash)
2231-
self._file_snapshot.restore(target_hash)
2232-
return True
2233-
except Exception as e:
2234-
logger.debug(f"Undo failed: {e}")
2317+
if self._file_snapshot is None:
22352318
return False
2319+
2320+
with self._snapshot_lock:
2321+
if not self._snapshot_stack:
2322+
return False
2323+
try:
2324+
target_hash = self._snapshot_stack.pop()
2325+
# Get current hash before restore (for redo)
2326+
current_hash = self._file_snapshot.get_current_hash()
2327+
if current_hash:
2328+
self._redo_stack.append(current_hash)
2329+
self._file_snapshot.restore(target_hash)
2330+
return True
2331+
except Exception as e:
2332+
logger.debug(f"Undo failed: {e}")
2333+
return False
22362334

22372335
def redo(self) -> bool:
22382336
"""Redo a previously undone set of file changes.
@@ -2242,18 +2340,22 @@ def redo(self) -> bool:
22422340
Returns:
22432341
True if redo was successful, False if nothing to redo.
22442342
"""
2245-
if self._file_snapshot is None or not self._redo_stack:
2246-
return False
2247-
try:
2248-
target_hash = self._redo_stack.pop()
2249-
current_hash = self._file_snapshot.get_current_hash()
2250-
if current_hash:
2251-
self._snapshot_stack.append(current_hash)
2252-
self._file_snapshot.restore(target_hash)
2253-
return True
2254-
except Exception as e:
2255-
logger.debug(f"Redo failed: {e}")
2343+
if self._file_snapshot is None:
22562344
return False
2345+
2346+
with self._snapshot_lock:
2347+
if not self._redo_stack:
2348+
return False
2349+
try:
2350+
target_hash = self._redo_stack.pop()
2351+
current_hash = self._file_snapshot.get_current_hash()
2352+
if current_hash:
2353+
self._snapshot_stack.append(current_hash)
2354+
self._file_snapshot.restore(target_hash)
2355+
return True
2356+
except Exception as e:
2357+
logger.debug(f"Redo failed: {e}")
2358+
return False
22572359

22582360
def diff(self, from_hash: Optional[str] = None):
22592361
"""Get file diffs from autonomous execution.
@@ -2279,8 +2381,11 @@ def diff(self, from_hash: Optional[str] = None):
22792381
return []
22802382
try:
22812383
base = from_hash
2282-
if base is None and self._snapshot_stack:
2283-
base = self._snapshot_stack[0]
2384+
if base is None:
2385+
# Protect snapshot stack read with lock to prevent TOCTOU with undo/redo
2386+
with self._snapshot_lock:
2387+
if self._snapshot_stack:
2388+
base = self._snapshot_stack[0]
22842389
if base is None:
22852390
return []
22862391
return self._file_snapshot.diff(base)
@@ -2477,8 +2582,9 @@ def run_autonomous(
24772582
if self._file_snapshot is not None and self.autonomy_config.get("snapshot", False):
24782583
try:
24792584
snap_info = self._file_snapshot.track(message="pre-autonomous")
2480-
self._snapshot_stack.append(snap_info.commit_hash)
2481-
self._redo_stack.clear()
2585+
with self._snapshot_lock:
2586+
self._snapshot_stack.append(snap_info.commit_hash)
2587+
self._redo_stack.clear()
24822588
except Exception as e:
24832589
logging.debug(f"Pre-autonomous snapshot failed: {e}")
24842590

@@ -4304,8 +4410,9 @@ def _build_system_prompt(self, tools=None):
43044410
tools_key = self._get_tools_cache_key(tools)
43054411
cache_key = f"{self.role}:{self.goal}:{tools_key}"
43064412

4307-
if cache_key in self._system_prompt_cache:
4308-
return self._system_prompt_cache[cache_key]
4413+
cached_prompt = self._cache_get(self._system_prompt_cache, cache_key)
4414+
if cached_prompt is not None:
4415+
return cached_prompt
43094416
else:
43104417
cache_key = None # Don't cache when memory is enabled
43114418

@@ -4371,9 +4478,9 @@ def _build_system_prompt(self, tools=None):
43714478
system_prompt += "\n\nExplain Before Acting: Before calling a tool, provide a brief one-sentence explanation of what you are about to do and why. Skip explanations only for repetitive low-level operations where narration would be noisy. When performing a batch of similar operations (e.g. searching for multiple items), explain the group once rather than narrating each call individually."
43724479

43734480
# Cache the generated system prompt (only if cache_key is set, i.e., memory not enabled)
4374-
# Simple cache size limit to prevent unbounded growth
4375-
if cache_key and len(self._system_prompt_cache) < self._max_cache_size:
4376-
self._system_prompt_cache[cache_key] = system_prompt
4481+
# Use LRU eviction to prevent unbounded growth
4482+
if cache_key:
4483+
self._cache_put(self._system_prompt_cache, cache_key, system_prompt)
43774484
return system_prompt
43784485

43794486
def _build_response_format(self, schema_model):
@@ -4567,8 +4674,9 @@ def _format_tools_for_completion(self, tools=None):
45674674

45684675
# Check cache first
45694676
tools_key = self._get_tools_cache_key(tools)
4570-
if tools_key in self._formatted_tools_cache:
4571-
return self._formatted_tools_cache[tools_key]
4677+
cached_tools = self._cache_get(self._formatted_tools_cache, tools_key)
4678+
if cached_tools is not None:
4679+
return cached_tools
45724680

45734681
formatted_tools = []
45744682
for tool in tools:
@@ -4619,10 +4727,8 @@ def _format_tools_for_completion(self, tools=None):
46194727
logging.error(f"Tools are not JSON serializable: {e}")
46204728
return []
46214729

4622-
# Cache the formatted tools
4623-
# Simple cache size limit to prevent unbounded growth
4624-
if len(self._formatted_tools_cache) < self._max_cache_size:
4625-
self._formatted_tools_cache[tools_key] = formatted_tools
4730+
# Cache the formatted tools with LRU eviction
4731+
self._cache_put(self._formatted_tools_cache, tools_key, formatted_tools)
46264732
return formatted_tools
46274733

46284734
def generate_task(self) -> 'Task':
@@ -6279,12 +6385,9 @@ def _chat_impl(self, prompt, temperature, tools, output_json, output_pydantic, r
62796385
# Extract text from multimodal prompts
62806386
normalized_content = next((item["text"] for item in prompt if item.get("type") == "text"), "")
62816387

6282-
# Prevent duplicate messages
6283-
if not (self.chat_history and
6284-
self.chat_history[-1].get("role") == "user" and
6285-
self.chat_history[-1].get("content") == normalized_content):
6286-
# Add user message to chat history BEFORE LLM call so handoffs can access it
6287-
self.chat_history.append({"role": "user", "content": normalized_content})
6388+
# Add user message to chat history BEFORE LLM call so handoffs can access it
6389+
# Use atomic check-then-act to prevent TOCTOU race conditions
6390+
if self._add_to_chat_history_if_not_duplicate("user", normalized_content):
62886391
# Persist user message to DB
62896392
self._persist_message("user", normalized_content)
62906393

@@ -6334,7 +6437,7 @@ def _chat_impl(self, prompt, temperature, tools, output_json, output_pydantic, r
63346437

63356438
response_text = self.llm_instance.get_response(**llm_kwargs)
63366439

6337-
self.chat_history.append({"role": "assistant", "content": response_text})
6440+
self._add_to_chat_history("assistant", response_text)
63386441
# Persist assistant message to DB
63396442
self._persist_message("assistant", response_text)
63406443

@@ -8595,12 +8698,12 @@ async def handle_agent_query(request: Request, query_data: Optional[AgentQuery]
85958698

85968699
print(f"🚀 Agent '{self.name}' available at http://{host}:{port}")
85978700

8598-
# Start the server if it's not already running for this port
8701+
# Check and mark server as started atomically to prevent race conditions
85998702
should_start = not _server_started.get(port, False)
86008703
if should_start:
86018704
_server_started[port] = True
86028705

8603-
# Server start/wait outside the lock to avoid holding it during sleep
8706+
# Server start/wait outside the lock to avoid holding it during sleep
86048707
if should_start:
86058708
# Start the server in a separate thread
86068709
def run_server():

0 commit comments

Comments
 (0)