Skip to content

Commit f7d3b89

Browse files
fix: Address 3 critical production robustness gaps
Fixes #1370 ## Gap 1: Multi-Agent Concurrency Safety - 1a: Added thread-safe cost tracking with _cost_lock in Agent class - 1b: Fixed parallel workflow shared state by using copy.deepcopy() instead of shallow copy - 1c: Confirmed SessionDeduplicationCache already has proper thread safety ## Gap 2: Session Isolation - 2a: Fixed session key collision by including session_id in agent keys - 2b: Added session TTL support with is_expired(), close(), time_to_expiry() methods - 2c: Routed chat history through SessionStore first, memory as fallback ## Gap 3: Error Propagation - 3a: Wrapped tool exceptions in ToolExecutionError for better observability - 3b: Added exponential backoff and is_retryable checks to workflow retry logic - 3c: Implemented cross-step handoff cycle detection with visited-set tracking 🤖 Generated with [Claude Code](https://claude.ai/code) Co-authored-by: MervinPraison <MervinPraison@users.noreply.github.com>
1 parent 6693a75 commit f7d3b89

File tree

5 files changed

+153
-26
lines changed

5 files changed

+153
-26
lines changed

src/praisonai-agents/praisonaiagents/agent/agent.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1563,6 +1563,8 @@ def __init__(
15631563
# Token budget guard (zero overhead when _max_budget is None)
15641564
self._max_budget = _max_budget
15651565
self._on_budget_exceeded = _on_budget_exceeded
1566+
# Thread-safe cost/token tracking (Gap 1a fix)
1567+
self._cost_lock = threading.Lock()
15661568
self._total_cost = 0.0
15671569
self._total_tokens_in = 0
15681570
self._total_tokens_out = 0
@@ -1906,7 +1908,9 @@ def thinking_budget(self, value: Optional[int]) -> None:
19061908
@property
19071909
def total_cost(self) -> float:
19081910
"""Cumulative USD cost of all LLM calls in this agent run."""
1909-
return self._total_cost
1911+
# Thread-safe cost reading (Gap 1a fix)
1912+
with self._cost_lock:
1913+
return self._total_cost
19101914

19111915
@property
19121916
def cost_summary(self) -> dict:
@@ -1915,12 +1919,14 @@ def cost_summary(self) -> dict:
19151919
Returns:
19161920
dict with keys: tokens_in, tokens_out, cost, llm_calls
19171921
"""
1918-
return {
1919-
"tokens_in": self._total_tokens_in,
1920-
"tokens_out": self._total_tokens_out,
1921-
"cost": self._total_cost,
1922-
"llm_calls": self._llm_call_count,
1923-
}
1922+
# Thread-safe cost reading (Gap 1a fix)
1923+
with self._cost_lock:
1924+
return {
1925+
"tokens_in": self._total_tokens_in,
1926+
"tokens_out": self._total_tokens_out,
1927+
"cost": self._total_cost,
1928+
"llm_calls": self._llm_call_count,
1929+
}
19241930

19251931
@property
19261932
def context_manager(self) -> Optional[Any]:

src/praisonai-agents/praisonaiagents/agent/chat_mixin.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -677,26 +677,31 @@ def _chat_completion(self, messages, temperature=1.0, tools=None, stream=True, r
677677
)
678678

679679
# Budget tracking & enforcement (zero overhead when _max_budget is None)
680-
self._total_cost += _cost_usd
681-
self._total_tokens_in += _prompt_tokens
682-
self._total_tokens_out += _completion_tokens
683-
self._llm_call_count += 1
684-
if self._max_budget and self._total_cost >= self._max_budget:
680+
# Thread-safe cost tracking (Gap 1a fix)
681+
with self._cost_lock:
682+
self._total_cost += _cost_usd
683+
self._total_tokens_in += _prompt_tokens
684+
self._total_tokens_out += _completion_tokens
685+
self._llm_call_count += 1
686+
budget_exceeded = self._max_budget and self._total_cost >= self._max_budget
687+
current_cost = self._total_cost
688+
689+
if budget_exceeded:
685690
if self._on_budget_exceeded == "stop":
686691
raise BudgetExceededError(
687-
f"Agent '{self.name}' exceeded budget: ${self._total_cost:.4f} >= ${self._max_budget:.4f}",
692+
f"Agent '{self.name}' exceeded budget: ${current_cost:.4f} >= ${self._max_budget:.4f}",
688693
budget_type="cost",
689694
limit=self._max_budget,
690-
used=self._total_cost,
695+
used=current_cost,
691696
agent_id=self.name
692697
)
693698
elif self._on_budget_exceeded == "warn":
694699
logging.warning(
695-
f"[budget] {self.name}: ${self._total_cost:.4f} exceeded "
700+
f"[budget] {self.name}: ${current_cost:.4f} exceeded "
696701
f"${self._max_budget:.4f} budget"
697702
)
698703
elif callable(self._on_budget_exceeded):
699-
self._on_budget_exceeded(self._total_cost, self._max_budget)
704+
self._on_budget_exceeded(current_cost, self._max_budget)
700705

701706
# Trigger AFTER_LLM hook
702707
from ..hooks import HookEvent, AfterLLMInput

src/praisonai-agents/praisonaiagents/agent/tool_execution.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -290,8 +290,15 @@ def _execute_tool_with_context(self, function_name, arguments, state, tool_call_
290290
_duration_ms = (_time.time() - _tool_start_time) * 1000
291291
_trace_emitter.tool_call_end(self.name, function_name, None, _duration_ms, str(e))
292292

293-
# Trigger OnError hook if needed (optional future step)
294-
raise
293+
# Gap 3a fix: Wrap exceptions in ToolExecutionError for better observability
294+
from ..errors import ToolExecutionError
295+
is_retryable = not isinstance(e, (ValueError, TypeError, AttributeError))
296+
raise ToolExecutionError(
297+
f"Tool '{function_name}' failed: {e}",
298+
tool_name=function_name,
299+
agent_id=self.name,
300+
is_retryable=is_retryable,
301+
) from e
295302

296303
def _trigger_after_agent_hook(self, prompt, response, start_time, tools_used=None):
297304
"""Trigger AFTER_AGENT hook and return response."""

src/praisonai-agents/praisonaiagents/session.py

Lines changed: 68 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,8 @@ def __init__(
5151
agent_url: Optional[str] = None,
5252
memory_config: Optional[Dict[str, Any]] = None,
5353
knowledge_config: Optional[Dict[str, Any]] = None,
54-
timeout: int = 30
54+
timeout: int = 30,
55+
session_ttl: Optional[int] = None # Gap 2b: TTL in seconds
5556
):
5657
"""
5758
Initialize a new session with optional persistence or remote agent connectivity.
@@ -63,6 +64,7 @@ def __init__(
6364
memory_config: Configuration for memory system (defaults to RAG)
6465
knowledge_config: Configuration for knowledge base system
6566
timeout: HTTP timeout for remote agent calls (default: 30 seconds)
67+
session_ttl: Time-to-live in seconds after which session expires (Gap 2b)
6668
"""
6769
self.session_id = session_id or str(uuid.uuid4())[:8]
6870
self.user_id = user_id or "default_user"
@@ -110,6 +112,10 @@ def __init__(
110112
self._agents_instance = None
111113
self._agents = {} # Track agents and their chat histories
112114

115+
# Gap 2b: Session TTL and cleanup support
116+
self.session_ttl = session_ttl
117+
self._created_at = time.time() # Track creation time for TTL
118+
113119
def _get_session_dir(self):
114120
"""Return session-specific directory using paths.py."""
115121
from pathlib import Path
@@ -186,8 +192,8 @@ def Agent(
186192

187193
agent = Agent(**agent_kwargs)
188194

189-
# Create a unique key for this agent (using name and role)
190-
agent_key = f"{name}:{role}"
195+
# Create a unique key for this agent (Gap 2a fix: include session_id for proper isolation)
196+
agent_key = f"{self.session_id}:{name}:{role}"
191197

192198
# Restore chat history if it exists from previous sessions
193199
if agent_key in self._agents:
@@ -270,7 +276,7 @@ def restore_state(self) -> Dict[str, Any]:
270276

271277
def _restore_agent_chat_history(self, agent_key: str) -> List[Dict[str, Any]]:
272278
"""
273-
Restore agent chat history from memory.
279+
Restore agent chat history from SessionStore first, then memory fallback (Gap 2c fix).
274280
275281
Args:
276282
agent_key: Unique identifier for the agent
@@ -281,7 +287,18 @@ def _restore_agent_chat_history(self, agent_key: str) -> List[Dict[str, Any]]:
281287
if self.is_remote:
282288
return []
283289

284-
# Search for agent chat history in memory
290+
# Gap 2c: Try SessionStore first for clean separation
291+
try:
292+
from .session.store import get_default_session_store
293+
session_store = get_default_session_store()
294+
session_id = f"{self.session_id}_{agent_key}"
295+
chat_history = session_store.get_chat_history(session_id)
296+
if chat_history:
297+
return chat_history
298+
except ImportError:
299+
pass
300+
301+
# Fallback: Search for agent chat history in memory (backward compatibility)
285302
results = self.memory.search_short_term(
286303
query="Agent chat history for",
287304
limit=10
@@ -350,10 +367,10 @@ def _save_agent_chat_histories(self) -> None:
350367
chat_history = agent_data.get("chat_history")
351368

352369
if chat_history is not None:
353-
# G-2 FIX: Try SessionStore first for clean separation
370+
# G-2 FIX: Use SessionStore for clean separation (Gap 2c fix)
354371
session_store = None
355372
try:
356-
from .session import get_default_session_store
373+
from .session.store import get_default_session_store
357374
session_store = get_default_session_store()
358375
except ImportError:
359376
pass
@@ -580,6 +597,50 @@ def send_message(self, message: str, **kwargs) -> str:
580597
"""
581598
return self.chat(message, **kwargs)
582599

600+
def is_expired(self) -> bool:
601+
"""Check if the session has expired based on TTL (Gap 2b fix)."""
602+
if self.session_ttl is None:
603+
return False
604+
return time.time() - self._created_at > self.session_ttl
605+
606+
def close(self) -> None:
607+
"""Close and cleanup the session (Gap 2b fix)."""
608+
if self.is_remote:
609+
return # No cleanup needed for remote sessions
610+
611+
# Clear memory
612+
if self._memory:
613+
try:
614+
# Clear short-term and long-term memory for this session
615+
# Note: This is a basic implementation - specific memory backends
616+
# might need more sophisticated cleanup
617+
self._memory = None
618+
except Exception:
619+
pass # Ignore cleanup errors
620+
621+
# Clear knowledge
622+
if self._knowledge:
623+
try:
624+
self._knowledge = None
625+
except Exception:
626+
pass # Ignore cleanup errors
627+
628+
# Clear agents
629+
self._agents.clear()
630+
631+
# Clean up session directory (optional - commented out for safety)
632+
# import shutil
633+
# session_dir = self._get_session_dir()
634+
# if session_dir.exists():
635+
# shutil.rmtree(session_dir)
636+
637+
def time_to_expiry(self) -> Optional[float]:
638+
"""Get seconds until session expires, or None if no TTL set (Gap 2b)."""
639+
if self.session_ttl is None:
640+
return None
641+
elapsed = time.time() - self._created_at
642+
return max(0, self.session_ttl - elapsed)
643+
583644
def __str__(self) -> str:
584645
if self.is_remote:
585646
return f"Session(id='{self.session_id}', user='{self.user_id}', remote_agent='{self.agent_url}')"

src/praisonai-agents/praisonaiagents/workflows/workflows.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -585,6 +585,8 @@ class AgentFlow:
585585
_reflection_config: Optional[Any] = field(default=None, repr=False)
586586
# Execution history for debugging (only populated when history=True)
587587
_execution_history: List[Dict[str, Any]] = field(default_factory=list, repr=False)
588+
# Gap 3c: Cross-step handoff cycle detection
589+
_handoff_chain: List[str] = field(default_factory=list, repr=False)
588590

589591
def __post_init__(self):
590592
"""Resolve consolidated params to internal values."""
@@ -899,6 +901,30 @@ def from_template(
899901
"Install with: pip install praisonai"
900902
)
901903

904+
def _check_handoff_cycle(self, step: Any) -> None:
905+
"""Check for cross-step handoff cycles (Gap 3c fix)."""
906+
step_id = getattr(step, 'name', str(step))
907+
908+
# If step involves handoff (basic check for agent steps)
909+
if hasattr(step, 'agent') and step.agent:
910+
# Track this step in handoff chain
911+
if step_id in self._handoff_chain:
912+
# Cycle detected!
913+
cycle_path = self._handoff_chain[self._handoff_chain.index(step_id):] + [step_id]
914+
from ..errors import HandoffCycleError
915+
raise HandoffCycleError(
916+
f"Cross-step handoff cycle detected: {' -> '.join(cycle_path)}",
917+
cycle_path=cycle_path
918+
)
919+
920+
# Add to chain (limit chain length to prevent memory issues)
921+
self._handoff_chain.append(step_id)
922+
if len(self._handoff_chain) > 100: # Reasonable limit
923+
self._handoff_chain = self._handoff_chain[-50:] # Keep last 50
924+
else:
925+
# Non-agent step, reset chain
926+
self._handoff_chain.clear()
927+
902928
def run(
903929
self,
904930
input: str = "",
@@ -1096,6 +1122,9 @@ def run(
10961122
except Exception as e:
10971123
logger.error(f"should_run failed for {step.name}: {e}")
10981124

1125+
# Gap 3c: Check for cross-step handoff cycles
1126+
self._check_handoff_cycle(step)
1127+
10991128
# Execute step with retry and guardrail support
11001129
output = None
11011130
stop = False
@@ -1273,6 +1302,25 @@ def run(
12731302
except Exception as e:
12741303
step_error = e
12751304
output = f"Error: {e}"
1305+
1306+
# Gap 3b fix: Check if error is retryable and implement exponential backoff
1307+
is_retryable = getattr(e, 'is_retryable', True) # Default to retryable
1308+
if not is_retryable:
1309+
# Non-retryable error - break out of retry loop immediately
1310+
if verbose:
1311+
print(f"❌ {step.name} failed with non-retryable error: {e}")
1312+
break
1313+
1314+
retry_count += 1
1315+
if retry_count <= max_retries:
1316+
# Exponential backoff: wait 2^(retry_count-1) seconds
1317+
import time
1318+
backoff_seconds = 2 ** (retry_count - 1)
1319+
if verbose:
1320+
print(f"🔄 {step.name} failed (attempt {retry_count}/{max_retries}), retrying in {backoff_seconds}s: {e}")
1321+
time.sleep(backoff_seconds)
1322+
continue # Retry
1323+
12761324
if self.on_step_error:
12771325
try:
12781326
self.on_step_error(step.name, e)
@@ -2306,7 +2354,7 @@ def execute_with_branch(step=step, idx=idx, opt_prev=optimized_previous):
23062354
emitter.set_branch(f"parallel_{idx}")
23072355
try:
23082356
return self._execute_single_step_internal(
2309-
step, opt_prev, input, all_variables.copy(), model, False, idx, stream, depth=depth+1
2357+
step, opt_prev, input, copy.deepcopy(all_variables), model, False, idx, stream, depth=depth+1
23102358
)
23112359
finally:
23122360
emitter.clear_branch()

0 commit comments

Comments
 (0)