From 9cd8dc801d0908d99ba6d6476abec58c47c1b838 Mon Sep 17 00:00:00 2001 From: "praisonai-triage-agent[bot]" <272766704+praisonai-triage-agent[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 09:50:52 +0000 Subject: [PATCH 1/6] feat: implement Round-2 Gap Closure (G2, G5, G6, G8) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - G2: Add InterruptController for cooperative cancellation of agent runs - G5: Add multi-category error classifier for intelligent retry logic - G6: Add message sanitization for surrogate/non-ASCII characters - G8: Add session title auto-generation from first exchange All features follow protocol-driven design with zero impact when disabled. Maintains full backward compatibility with existing APIs. ๐Ÿค– Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Mervin Praison --- .../praisonaiagents/agent/__init__.py | 5 + .../praisonaiagents/agent/agent.py | 17 ++ .../praisonaiagents/agent/interrupt.py | 79 ++++++ .../praisonaiagents/llm/__init__.py | 29 +- .../praisonaiagents/llm/error_classifier.py | 239 ++++++++++++++++ .../praisonaiagents/llm/llm.py | 34 +++ .../praisonaiagents/llm/sanitize.py | 128 +++++++++ .../praisonaiagents/session/hierarchy.py | 53 ++++ .../praisonaiagents/session/title.py | 153 +++++++++++ .../tests/unit/agent/test_interrupt.py | 127 +++++++++ .../tests/unit/llm/test_error_classifier.py | 260 ++++++++++++++++++ .../tests/unit/llm/test_sanitize.py | 157 +++++++++++ .../tests/unit/session/test_title_gen.py | 124 +++++++++ 13 files changed, 1404 insertions(+), 1 deletion(-) create mode 100644 src/praisonai-agents/praisonaiagents/agent/interrupt.py create mode 100644 src/praisonai-agents/praisonaiagents/llm/error_classifier.py create mode 100644 src/praisonai-agents/praisonaiagents/llm/sanitize.py create mode 100644 src/praisonai-agents/praisonaiagents/session/title.py create mode 100644 src/praisonai-agents/tests/unit/agent/test_interrupt.py create mode 100644 src/praisonai-agents/tests/unit/llm/test_error_classifier.py create mode 100644 src/praisonai-agents/tests/unit/llm/test_sanitize.py create mode 100644 src/praisonai-agents/tests/unit/session/test_title_gen.py diff --git a/src/praisonai-agents/praisonaiagents/agent/__init__.py b/src/praisonai-agents/praisonaiagents/agent/__init__.py index c91b6dc1f..4f7ae5502 100644 --- a/src/praisonai-agents/praisonaiagents/agent/__init__.py +++ b/src/praisonai-agents/praisonaiagents/agent/__init__.py @@ -33,6 +33,10 @@ def __getattr__(name): from .heartbeat import HeartbeatConfig _lazy_cache[name] = HeartbeatConfig return HeartbeatConfig + if name == 'InterruptController': + from .interrupt import InterruptController + _lazy_cache[name] = InterruptController + return InterruptController # Specialized agents - lazy loaded (import rich) if name == 'ImageAgent': @@ -194,6 +198,7 @@ def __getattr__(name): 'BudgetExceededError', 'Heartbeat', 'HeartbeatConfig', + 'InterruptController', 'ImageAgent', 'VideoAgent', 'VideoConfig', diff --git a/src/praisonai-agents/praisonaiagents/agent/agent.py b/src/praisonai-agents/praisonaiagents/agent/agent.py index 9501e2dab..d18cbd527 100644 --- a/src/praisonai-agents/praisonaiagents/agent/agent.py +++ b/src/praisonai-agents/praisonaiagents/agent/agent.py @@ -551,6 +551,7 @@ def __init__( parallel_tool_calls: bool = False, # Gap 2: Enable parallel execution of batched LLM tool calls learn: Optional[Union[bool, str, Dict[str, Any], 'LearnConfig']] = None, # Continuous learning (peer to memory) backend: Optional[Any] = None, # External managed agent backend (e.g., ManagedAgentIntegration) + interrupt_controller: Optional['InterruptController'] = None, # G2: Cooperative cancellation ): """Initialize an Agent instance. @@ -1458,6 +1459,8 @@ def __init__( self.instructions = instructions # Gap 2: Store parallel tool calls setting for ToolCallExecutor selection self.parallel_tool_calls = parallel_tool_calls + # G2: Store interrupt controller for cooperative cancellation + self.interrupt_controller = interrupt_controller # Check for model name in environment variable if not provided self._using_custom_llm = False # Flag to track if final result has been displayed to prevent duplicates @@ -3232,6 +3235,20 @@ async def main(): started_at=started_at, ) + # G2: Check for interrupt request (cooperative cancellation) + if self.interrupt_controller and self.interrupt_controller.is_set(): + reason = self.interrupt_controller.reason or "unknown" + return AutonomyResult( + success=False, + output=f"Task interrupted: {reason}", + completion_reason="interrupted", + iterations=iterations, + stage=stage, + actions=actions_taken, + duration_seconds=time_module.time() - start_time, + started_at=started_at, + ) + # Execute one turn using the agent's async chat method # Always use the original prompt (prompt re-injection) diff --git a/src/praisonai-agents/praisonaiagents/agent/interrupt.py b/src/praisonai-agents/praisonaiagents/agent/interrupt.py new file mode 100644 index 000000000..e3af5bf88 --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/agent/interrupt.py @@ -0,0 +1,79 @@ +""" +Interrupt Controller - Cooperative cancellation for agent runs. + +Provides thread-safe, cooperative cancellation mechanism for long-running agent +operations. Follows protocol-driven design with zero overhead when not used. +""" + +import threading +from typing import Optional +from dataclasses import dataclass, field + +__all__ = ["InterruptController"] + + +@dataclass +class InterruptController: + """Thread-safe cooperative cancellation for agent runs. + + Provides a lightweight mechanism for requesting cancellation of agent + operations. Uses threading.Event for thread safety and cooperative + checking patterns. + + Examples: + Basic usage: + >>> controller = InterruptController() + >>> # In another thread: + >>> controller.request("user_cancel") + >>> # In agent loop: + >>> if controller.is_set(): + >>> return f"Cancelled: {controller.reason}" + """ + + _flag: threading.Event = field(default_factory=threading.Event, init=False, repr=False) + _reason: Optional[str] = field(default=None, init=False) + _lock: threading.Lock = field(default_factory=threading.Lock, init=False, repr=False) + + def request(self, reason: str = "user") -> None: + """Request cancellation with a reason. + + Args: + reason: Human-readable reason for cancellation + """ + with self._lock: + self._reason = reason + self._flag.set() + + def clear(self) -> None: + """Clear the cancellation request.""" + with self._lock: + self._reason = None + self._flag.clear() + + def is_set(self) -> bool: + """Check if cancellation has been requested. + + Returns: + True if cancellation was requested + """ + return self._flag.is_set() + + @property + def reason(self) -> Optional[str]: + """Get the reason for cancellation. + + Returns: + Reason string if cancelled, None otherwise + """ + with self._lock: + return self._reason + + def check(self) -> None: + """Check for cancellation and raise if requested. + + Raises: + InterruptedError: If cancellation was requested + """ + if self.is_set(): + reason = self.reason or "unknown" + raise InterruptedError(f"Operation cancelled: {reason}") \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/llm/__init__.py b/src/praisonai-agents/praisonaiagents/llm/__init__.py index 596f51593..34e054f93 100644 --- a/src/praisonai-agents/praisonaiagents/llm/__init__.py +++ b/src/praisonai-agents/praisonaiagents/llm/__init__.py @@ -159,6 +159,26 @@ def __getattr__(name): from .unified_adapters import create_llm_dispatcher _lazy_cache[name] = create_llm_dispatcher return create_llm_dispatcher + elif name == "sanitize_messages": + from .sanitize import sanitize_messages + _lazy_cache[name] = sanitize_messages + return sanitize_messages + elif name == "strip_surrogates": + from .sanitize import strip_surrogates + _lazy_cache[name] = strip_surrogates + return strip_surrogates + elif name == "sanitize_text": + from .sanitize import sanitize_text + _lazy_cache[name] = sanitize_text + return sanitize_text + elif name == "ErrorCategory": + from .error_classifier import ErrorCategory + _lazy_cache[name] = ErrorCategory + return ErrorCategory + elif name == "classify_error": + from .error_classifier import classify_error + _lazy_cache[name] = classify_error + return classify_error raise AttributeError(f"module {__name__!r} has no attribute {name!r}") @@ -199,5 +219,12 @@ def __getattr__(name): "LLMProviderError", "RateLimitError", "ModelNotAvailableError", - "ContextLengthExceededError" + "ContextLengthExceededError", + # Sanitization (G6) + "sanitize_messages", + "strip_surrogates", + "sanitize_text", + # Error Classification (G5) + "ErrorCategory", + "classify_error" ] diff --git a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py new file mode 100644 index 000000000..986ea1f40 --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py @@ -0,0 +1,239 @@ +""" +Error Classification - Multi-category error classifier for intelligent retry logic. + +Extends the existing single-category rate limit detection with comprehensive +error classification for better handling of different failure modes. +""" + +import re +from enum import Enum +from typing import Dict, Tuple, List, Optional + +__all__ = ["ErrorCategory", "classify_error", "should_retry", "get_retry_delay"] + + +class ErrorCategory(str, Enum): + """Categories of LLM provider errors for intelligent handling.""" + RATE_LIMIT = "rate_limit" # Too many requests, temporary + CONTEXT_LIMIT = "context_limit" # Input too long, need compression + AUTH = "auth" # Authentication/authorization failure + INVALID_REQUEST = "invalid_request" # Malformed request, permanent + TRANSIENT = "transient" # Network/server issues, temporary + PERMANENT = "permanent" # Unrecoverable error + + +# Error patterns for classification (case-insensitive) +_ERROR_PATTERNS: Dict[ErrorCategory, List[str]] = { + ErrorCategory.RATE_LIMIT: [ + r"rate.?limit", + r"429", + r"too.?many.?request", + r"resource.?exhausted", + r"quota.?exceeded", + r"tokens.?per.?minute", + r"requests.?per.?minute", + r"concurrent.?requests", + ], + + ErrorCategory.CONTEXT_LIMIT: [ + r"context.?length", + r"maximum.?context", + r"token.?limit", + r"input.?too.?long", + r"sequence.?too.?long", + r"context.?window", + r"413", # Request Entity Too Large + r"payload.?too.?large", + ], + + ErrorCategory.AUTH: [ + r"authenticat", + r"authoriz", + r"401", + r"403", + r"invalid.?api.?key", + r"api.?key.*invalid", + r"permission.?denied", + r"access.?denied", + r"forbidden", + r"unauthorized", + r"invalid.?token", + r"expired.*token", + ], + + ErrorCategory.INVALID_REQUEST: [ + r"invalid.*request", + r"bad.?request", + r"400", + r"malformed", + r"invalid.*parameter", + r"unsupported.*model", + r"model.*not.*found", + r"validation.*error", + r"schema.*error", + ], + + ErrorCategory.TRANSIENT: [ + r"timeout", + r"timed.?out", + r"500", r"502", r"503", r"504", + r"internal.?server.?error", + r"bad.?gateway", + r"service.?unavailable", + r"gateway.?timeout", + r"connection.*error", + r"network.*error", + r"temporary.*unavailable", + r"server.?overload", + r"retry.?after", + ], +} + + +def classify_error(error: Exception) -> ErrorCategory: + """Classify an error into a category for intelligent handling. + + Args: + error: Exception to classify + + Returns: + ErrorCategory indicating how the error should be handled + + Examples: + >>> classify_error(Exception("Rate limit exceeded")) + ErrorCategory.RATE_LIMIT + >>> classify_error(Exception("Context length 8192 exceeded")) + ErrorCategory.CONTEXT_LIMIT + >>> classify_error(Exception("Invalid API key")) + ErrorCategory.AUTH + """ + error_text = f"{type(error).__name__} {error}".lower() + + # Check each category's patterns + for category, patterns in _ERROR_PATTERNS.items(): + for pattern in patterns: + if re.search(pattern, error_text, re.IGNORECASE): + return category + + # Default to permanent for unknown errors + return ErrorCategory.PERMANENT + + +def should_retry(category: ErrorCategory) -> bool: + """Determine if an error category should be retried. + + Args: + category: Error category from classify_error() + + Returns: + True if the error type should be retried + """ + return category in { + ErrorCategory.RATE_LIMIT, + ErrorCategory.CONTEXT_LIMIT, # Could retry with compression + ErrorCategory.TRANSIENT, + } + + +def get_retry_delay(category: ErrorCategory, attempt: int = 1, base_delay: float = 1.0) -> float: + """Get the appropriate delay before retrying based on error category. + + Args: + category: Error category + attempt: Current attempt number (1-based) + base_delay: Base delay in seconds + + Returns: + Delay in seconds, or 0 if should not retry + + Examples: + >>> get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=1) + 2.0 + >>> get_retry_delay(ErrorCategory.TRANSIENT, attempt=3) + 16.0 + >>> get_retry_delay(ErrorCategory.AUTH, attempt=1) + 0 + """ + if not should_retry(category): + return 0 + + if category == ErrorCategory.RATE_LIMIT: + # Longer delay for rate limits to avoid hitting limits again + return min(base_delay * (3 ** (attempt - 1)), 60.0) + + elif category == ErrorCategory.CONTEXT_LIMIT: + # Short delay for context limits (compression should be tried) + return base_delay * 0.5 + + elif category == ErrorCategory.TRANSIENT: + # Exponential backoff for transient errors + return min(base_delay * (2 ** (attempt - 1)), 30.0) + + return 0 + + +def extract_retry_after(error: Exception) -> Optional[float]: + """Extract Retry-After header value from rate limit errors. + + Args: + error: Exception potentially containing Retry-After info + + Returns: + Delay in seconds if found, None otherwise + """ + error_str = str(error) + + # Look for common Retry-After patterns + patterns = [ + r"retry.?after[:\s]+(\d+)", + r"retry[:\s]+(\d+)", + r"wait[:\s]+(\d+)", + r"(\d+).*second", + ] + + for pattern in patterns: + match = re.search(pattern, error_str, re.IGNORECASE) + if match: + try: + delay = float(match.group(1)) + return min(delay, 300.0) # Cap at 5 minutes + except (ValueError, IndexError): + continue + + return None + + +def get_error_context(error: Exception) -> Dict[str, str]: + """Extract structured context from an error for logging/debugging. + + Args: + error: Exception to analyze + + Returns: + Dictionary with error context information + """ + category = classify_error(error) + + context = { + "error_type": type(error).__name__, + "category": category.value, + "should_retry": str(should_retry(category)), + "message": str(error)[:500], # Truncate long messages + } + + # Add category-specific context + if category == ErrorCategory.RATE_LIMIT: + retry_after = extract_retry_after(error) + if retry_after: + context["retry_after"] = str(retry_after) + + elif category == ErrorCategory.CONTEXT_LIMIT: + context["suggestion"] = "Try reducing input size or enabling compression" + + elif category == ErrorCategory.AUTH: + context["suggestion"] = "Check API key configuration" + + elif category == ErrorCategory.INVALID_REQUEST: + context["suggestion"] = "Review request parameters" + + return context \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/llm/llm.py b/src/praisonai-agents/praisonaiagents/llm/llm.py index 024f10a11..5d7772baa 100644 --- a/src/praisonai-agents/praisonaiagents/llm/llm.py +++ b/src/praisonai-agents/praisonaiagents/llm/llm.py @@ -650,6 +650,40 @@ def _is_rate_limit_error(self, error: Exception) -> bool: return any(indicator in error_str or indicator in error_type for indicator in indicators) + def _classify_error_and_should_retry(self, error: Exception) -> tuple[str, bool, float]: + """Classify error and determine retry strategy using G5 error classifier. + + Args: + error: Exception to classify + + Returns: + Tuple of (category, should_retry, retry_delay) + """ + try: + from .error_classifier import classify_error, should_retry, get_retry_delay, extract_retry_after + + category = classify_error(error) + can_retry = should_retry(category) + + if not can_retry: + return category.value, False, 0.0 + + # For rate limits, try to extract specific retry-after first + if category.value == "rate_limit": + retry_after = extract_retry_after(error) + if retry_after: + return category.value, True, retry_after + + # Use category-specific delay calculation + delay = get_retry_delay(category, attempt=1, base_delay=self._retry_delay) + return category.value, True, delay + + except ImportError: + # Fallback to legacy rate limit detection + is_rate_limit = self._is_rate_limit_error(error) + delay = self._parse_retry_delay(str(error)) if is_rate_limit else 0.0 + return "rate_limit" if is_rate_limit else "unknown", is_rate_limit, delay + def _call_with_retry(self, func, *args, **kwargs): """Call a function with automatic retry on rate limit errors. diff --git a/src/praisonai-agents/praisonaiagents/llm/sanitize.py b/src/praisonai-agents/praisonaiagents/llm/sanitize.py new file mode 100644 index 000000000..dd96fedd1 --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/llm/sanitize.py @@ -0,0 +1,128 @@ +""" +Message Sanitization - Clean surrogate/non-ASCII characters before LLM calls. + +Prevents Unicode encoding issues that can cause silent failures with some +providers when processing emoji, non-Latin text, or corrupted Unicode. +""" + +import re +from typing import List, Dict, Any, Union + +__all__ = ["sanitize_messages", "strip_surrogates"] + + +def strip_surrogates(text: str) -> str: + """Remove surrogate pairs and malformed Unicode from text. + + Surrogate pairs (U+D800-U+DFFF) are used in UTF-16 encoding but are + invalid in UTF-8/Unicode strings. They can appear from: + - Incorrect Unicode conversion + - Corrupted text data + - Invalid emoji/character sequences + + Args: + text: Input text potentially containing surrogates + + Returns: + Text with surrogates removed or replaced + + Examples: + >>> strip_surrogates("Hello \\uD83D World") # Missing low surrogate + 'Hello World' + >>> strip_surrogates("Valid text ๐ŸŒ") + 'Valid text ๐ŸŒ' + """ + if not text: + return text + + try: + # Method 1: Encode with surrogatepass, decode with replace + # This converts surrogates to UTF-16, then back to UTF-8 safely + return text.encode('utf-16', 'surrogatepass').decode('utf-16', 'replace') + except (UnicodeError, LookupError): + # Fallback: Remove surrogate code points directly + return re.sub(r'[\uD800-\uDFFF]', '', text) + + +def sanitize_messages(messages: List[Dict[str, Any]]) -> bool: + """Sanitize message content in-place, removing problematic Unicode. + + Processes all string content in message dictionaries, including: + - message.content (string or list) + - message.name + - Any nested string values + + Args: + messages: List of message dicts to sanitize in-place + + Returns: + True if any changes were made, False otherwise + + Examples: + >>> messages = [{"content": "Hello \\uD83D World", "role": "user"}] + >>> changed = sanitize_messages(messages) + >>> assert changed == True + >>> assert messages[0]["content"] == "Hello World" + """ + if not messages: + return False + + changed = False + + for message in messages: + if not isinstance(message, dict): + continue + + # Sanitize content field (most common) + if "content" in message: + content = message["content"] + + if isinstance(content, str): + sanitized = strip_surrogates(content) + if sanitized != content: + message["content"] = sanitized + changed = True + + elif isinstance(content, list): + # Handle list content (e.g., multimodal messages) + for i, item in enumerate(content): + if isinstance(item, dict) and "text" in item: + text = item["text"] + if isinstance(text, str): + sanitized = strip_surrogates(text) + if sanitized != text: + content[i]["text"] = sanitized + changed = True + elif isinstance(item, str): + sanitized = strip_surrogates(item) + if sanitized != item: + content[i] = sanitized + changed = True + + # Sanitize other string fields + for key, value in message.items(): + if isinstance(value, str) and key != "content": # Already handled above + sanitized = strip_surrogates(value) + if sanitized != value: + message[key] = sanitized + changed = True + + return changed + + +def sanitize_text(text: Union[str, None]) -> Union[str, None]: + """Sanitize a single text string. + + Convenience function for sanitizing individual strings. + + Args: + text: Text to sanitize, or None + + Returns: + Sanitized text, or None if input was None + """ + if text is None: + return None + if not isinstance(text, str): + return text + return strip_surrogates(text) \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/session/hierarchy.py b/src/praisonai-agents/praisonaiagents/session/hierarchy.py index 0066bf5a5..402dc10c2 100644 --- a/src/praisonai-agents/praisonaiagents/session/hierarchy.py +++ b/src/praisonai-agents/praisonaiagents/session/hierarchy.py @@ -481,6 +481,59 @@ def set_title(self, session_id: str, title: str) -> bool: session.title = title return self._save_extended_session(session) + async def auto_title(self, session_id: str) -> bool: + """Generate and set title automatically from first exchange. + + Args: + session_id: Session to generate title for + + Returns: + True if title was generated and set, False otherwise + """ + session = self._load_extended_session(session_id) + + # Skip if already has a title + if session.title and session.title.strip(): + return False + + # Need at least one user and one assistant message + messages = session.messages + if not messages or len(messages) < 2: + return False + + # Find first user message and first assistant response + user_msg = None + assistant_msg = None + + for msg in messages: + if msg.get("role") == "user" and not user_msg: + content = msg.get("content", "") + if isinstance(content, str) and content.strip(): + user_msg = content + elif msg.get("role") == "assistant" and not assistant_msg and user_msg: + content = msg.get("content", "") + if isinstance(content, str) and content.strip(): + assistant_msg = content + break + + if not user_msg or not assistant_msg: + return False + + try: + # Generate title using title module + from .title import generate_title_async + title = await generate_title_async(user_msg, assistant_msg) + + if title and title.strip(): + session.title = title.strip() + return self._save_extended_session(session) + + except Exception: + # Title generation failed - not critical + pass + + return False + def get_extended_session(self, session_id: str) -> ExtendedSessionData: """Get extended session data.""" return self._load_extended_session(session_id) diff --git a/src/praisonai-agents/praisonaiagents/session/title.py b/src/praisonai-agents/praisonaiagents/session/title.py new file mode 100644 index 000000000..cb6648455 --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/session/title.py @@ -0,0 +1,153 @@ +""" +Session Title Auto-Generation - Generate descriptive titles for chat sessions. + +Automatically creates concise, meaningful titles based on the first user-assistant +exchange in a conversation. Uses a lightweight model for fast generation. +""" + +import asyncio +from typing import Optional + +__all__ = ["generate_title", "generate_title_async"] + + +def generate_title( + user_msg: str, + assistant_msg: str, + llm_model: str = "gpt-4o-mini", + timeout: float = 10.0, + max_length: int = 60 +) -> str: + """Generate a session title from the first exchange. + + Args: + user_msg: First user message + assistant_msg: First assistant response + llm_model: Model to use for title generation (default: fast, cheap model) + timeout: Timeout in seconds for generation + max_length: Maximum title length in characters + + Returns: + Generated title string, or fallback based on user message + + Examples: + >>> title = generate_title( + ... "Help me debug this Python code", + ... "I'd be happy to help debug your Python code..." + ... ) + >>> # Returns something like: "Python Code Debugging Help" + """ + # Fallback title from user message if generation fails + fallback_title = _create_fallback_title(user_msg, max_length) + + try: + # Run async version in sync context + return asyncio.run(generate_title_async( + user_msg, assistant_msg, llm_model, timeout, max_length + )) + except Exception: + return fallback_title + + +async def generate_title_async( + user_msg: str, + assistant_msg: str, + llm_model: str = "gpt-4o-mini", + timeout: float = 10.0, + max_length: int = 60 +) -> str: + """Async version of generate_title. + + Args: + user_msg: First user message + assistant_msg: First assistant response + llm_model: Model to use for title generation + timeout: Timeout in seconds for generation + max_length: Maximum title length in characters + + Returns: + Generated title string, or fallback based on user message + """ + fallback_title = _create_fallback_title(user_msg, max_length) + + try: + # Lazy import to avoid circular dependencies + from ..llm import LLM + + # Create prompt for title generation + prompt = f"""Generate a concise, descriptive title (3-8 words) for this conversation. + +USER: {user_msg[:400]} +ASSISTANT: {assistant_msg[:400]} + +Return ONLY the title text, no quotes, no explanation.""" + + # Use lightweight model for fast, cheap generation + llm = LLM(model=llm_model) + + # Generate with timeout + try: + response = await asyncio.wait_for( + llm.aget_response(prompt=prompt), + timeout=timeout + ) + + if response and isinstance(response, str): + # Clean up the response + title = response.strip().strip('"').strip("'") + title = title.replace('\n', ' ').replace('\r', '') + + # Truncate if too long + if len(title) > max_length: + title = title[:max_length-3] + "..." + + # Return if non-empty + if title and len(title.strip()) > 0: + return title + + except asyncio.TimeoutError: + # Title generation timed out + pass + except Exception: + # Any other LLM error + pass + + except ImportError: + # LLM module not available + pass + except Exception: + # Unexpected error + pass + + return fallback_title + + +def _create_fallback_title(user_msg: str, max_length: int) -> str: + """Create a fallback title from the user message. + + Args: + user_msg: User message to base title on + max_length: Maximum length for the title + + Returns: + Simple title based on user message + """ + if not user_msg or not user_msg.strip(): + return "Chat Session" + + # Clean up the message + clean_msg = user_msg.strip() + + # Remove common question words and make it more title-like + clean_msg = clean_msg.replace('?', '') + clean_msg = clean_msg.replace('!', '') + + # Take first sentence or reasonable chunk + if '.' in clean_msg: + clean_msg = clean_msg.split('.')[0] + + # Truncate if too long + if len(clean_msg) > max_length: + clean_msg = clean_msg[:max_length-3] + "..." + + return clean_msg if clean_msg else "Chat Session" \ No newline at end of file diff --git a/src/praisonai-agents/tests/unit/agent/test_interrupt.py b/src/praisonai-agents/tests/unit/agent/test_interrupt.py new file mode 100644 index 000000000..0d5fac305 --- /dev/null +++ b/src/praisonai-agents/tests/unit/agent/test_interrupt.py @@ -0,0 +1,127 @@ +""" +Tests for InterruptController. + +Ensures thread safety, cooperative cancellation, and zero overhead when not used. +""" + +import threading +import time +import pytest +from praisonaiagents.agent.interrupt import InterruptController + + +class TestInterruptController: + + def test_basic_functionality(self): + """Test basic interrupt request and check.""" + controller = InterruptController() + + # Initially not set + assert not controller.is_set() + assert controller.reason is None + + # Request interruption + controller.request("test_reason") + assert controller.is_set() + assert controller.reason == "test_reason" + + # Clear interruption + controller.clear() + assert not controller.is_set() + assert controller.reason is None + + def test_default_reason(self): + """Test default reason when none provided.""" + controller = InterruptController() + controller.request() + assert controller.reason == "user" + + def test_check_raises_when_set(self): + """Test that check() raises InterruptedError when set.""" + controller = InterruptController() + + # Should not raise when not set + controller.check() + + # Should raise when set + controller.request("test") + with pytest.raises(InterruptedError, match="Operation cancelled: test"): + controller.check() + + def test_thread_safety(self): + """Test thread-safe operations.""" + controller = InterruptController() + results = [] + + def worker(): + # Wait a bit then request interrupt + time.sleep(0.1) + controller.request("thread_cancel") + results.append("requested") + + def checker(): + # Keep checking until interrupted + while not controller.is_set(): + time.sleep(0.05) + results.append(f"cancelled: {controller.reason}") + + # Start threads + t1 = threading.Thread(target=worker) + t2 = threading.Thread(target=checker) + + t1.start() + t2.start() + + t1.join(timeout=1) + t2.join(timeout=1) + + # Verify both completed + assert "requested" in results + assert "cancelled: thread_cancel" in results + + def test_multiple_requests(self): + """Test that multiple requests preserve the first reason.""" + controller = InterruptController() + + controller.request("first") + controller.request("second") + + assert controller.reason == "first" + assert controller.is_set() + + def test_clear_resets_state(self): + """Test that clear completely resets state.""" + controller = InterruptController() + + controller.request("test") + assert controller.is_set() + assert controller.reason == "test" + + controller.clear() + assert not controller.is_set() + assert controller.reason is None + + # Can be reused + controller.request("new_reason") + assert controller.reason == "new_reason" + + def test_zero_overhead_when_not_used(self): + """Test that creation and is_set() have minimal overhead.""" + import time + + # Test creation overhead + start = time.perf_counter() + for _ in range(1000): + controller = InterruptController() + creation_time = time.perf_counter() - start + + # Test check overhead + controller = InterruptController() + start = time.perf_counter() + for _ in range(10000): + controller.is_set() + check_time = time.perf_counter() - start + + # Should be very fast (< 1ms each) + assert creation_time < 0.001 + assert check_time < 0.001 \ No newline at end of file diff --git a/src/praisonai-agents/tests/unit/llm/test_error_classifier.py b/src/praisonai-agents/tests/unit/llm/test_error_classifier.py new file mode 100644 index 000000000..29be14ce6 --- /dev/null +++ b/src/praisonai-agents/tests/unit/llm/test_error_classifier.py @@ -0,0 +1,260 @@ +""" +Tests for error classification functionality. + +Tests error categorization, retry logic, and delay calculation. +""" + +import pytest +from praisonaiagents.llm.error_classifier import ( + ErrorCategory, classify_error, should_retry, get_retry_delay, + extract_retry_after, get_error_context +) + + +class TestErrorClassification: + + def test_rate_limit_errors(self): + """Test classification of rate limit errors.""" + test_cases = [ + Exception("Rate limit exceeded"), + Exception("HTTP 429: Too many requests"), + Exception("resource_exhausted: Quota exceeded"), + Exception("tokens per minute limit reached"), + Exception("RateLimitError: concurrent requests"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.RATE_LIMIT + assert should_retry(category) is True + + def test_context_limit_errors(self): + """Test classification of context length errors.""" + test_cases = [ + Exception("Context length 8192 exceeded"), + Exception("maximum context window reached"), + Exception("token limit exceeded"), + Exception("input too long"), + Exception("HTTP 413: Payload too large"), + Exception("sequence too long for model"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.CONTEXT_LIMIT + assert should_retry(category) is True + + def test_auth_errors(self): + """Test classification of authentication errors.""" + test_cases = [ + Exception("Invalid API key"), + Exception("HTTP 401: Unauthorized"), + Exception("HTTP 403: Forbidden"), + Exception("authentication failed"), + Exception("permission denied"), + Exception("access denied"), + Exception("invalid token"), + Exception("expired token"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.AUTH + assert should_retry(category) is False + + def test_invalid_request_errors(self): + """Test classification of invalid request errors.""" + test_cases = [ + Exception("Invalid request format"), + Exception("HTTP 400: Bad request"), + Exception("malformed JSON"), + Exception("invalid parameter: model"), + Exception("unsupported model: fake-model"), + Exception("validation error in request"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.INVALID_REQUEST + assert should_retry(category) is False + + def test_transient_errors(self): + """Test classification of transient errors.""" + test_cases = [ + Exception("Connection timeout"), + Exception("HTTP 500: Internal server error"), + Exception("HTTP 502: Bad gateway"), + Exception("HTTP 503: Service unavailable"), + Exception("HTTP 504: Gateway timeout"), + Exception("network error"), + Exception("temporary unavailable"), + Exception("server overload"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.TRANSIENT + assert should_retry(category) is True + + def test_unknown_errors(self): + """Test that unknown errors are classified as permanent.""" + test_cases = [ + Exception("Some unknown error"), + Exception("Custom application error"), + ValueError("Invalid value"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.PERMANENT + assert should_retry(category) is False + + +class TestRetryLogic: + + def test_retry_delays(self): + """Test retry delay calculation for different categories.""" + # Rate limit delays (exponential with factor of 3) + assert get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=1) == 3.0 + assert get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=2) == 9.0 + assert get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=3) == 27.0 + + # Context limit delays (short, for immediate retry with compression) + assert get_retry_delay(ErrorCategory.CONTEXT_LIMIT, attempt=1) == 0.5 + assert get_retry_delay(ErrorCategory.CONTEXT_LIMIT, attempt=2) == 0.5 + + # Transient delays (exponential with factor of 2) + assert get_retry_delay(ErrorCategory.TRANSIENT, attempt=1) == 2.0 + assert get_retry_delay(ErrorCategory.TRANSIENT, attempt=2) == 4.0 + assert get_retry_delay(ErrorCategory.TRANSIENT, attempt=3) == 8.0 + + # No retry for permanent errors + assert get_retry_delay(ErrorCategory.AUTH, attempt=1) == 0 + assert get_retry_delay(ErrorCategory.INVALID_REQUEST, attempt=1) == 0 + assert get_retry_delay(ErrorCategory.PERMANENT, attempt=1) == 0 + + def test_retry_delay_caps(self): + """Test that retry delays have appropriate caps.""" + # Rate limit cap at 60 seconds + assert get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=10) == 60.0 + + # Transient cap at 30 seconds + assert get_retry_delay(ErrorCategory.TRANSIENT, attempt=10) == 30.0 + + def test_base_delay_scaling(self): + """Test custom base delay scaling.""" + assert get_retry_delay(ErrorCategory.TRANSIENT, attempt=1, base_delay=2.0) == 4.0 + assert get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=1, base_delay=2.0) == 6.0 + + +class TestRetryAfterExtraction: + + def test_retry_after_patterns(self): + """Test extraction of Retry-After values from errors.""" + test_cases = [ + (Exception("retry after 30 seconds"), 30.0), + (Exception("Retry-After: 60"), 60.0), + (Exception("retry: 45"), 45.0), + (Exception("wait 120 seconds"), 120.0), + (Exception("Rate limited. 90 second cooldown"), 90.0), + ] + + for error, expected in test_cases: + delay = extract_retry_after(error) + assert delay == expected + + def test_retry_after_cap(self): + """Test that retry-after values are capped.""" + error = Exception("retry after 600 seconds") # 10 minutes + delay = extract_retry_after(error) + assert delay == 300.0 # Capped at 5 minutes + + def test_no_retry_after(self): + """Test handling when no retry-after is found.""" + error = Exception("Generic error message") + delay = extract_retry_after(error) + assert delay is None + + +class TestErrorContext: + + def test_basic_context(self): + """Test basic error context extraction.""" + error = ValueError("Invalid input") + context = get_error_context(error) + + assert context["error_type"] == "ValueError" + assert context["category"] == ErrorCategory.PERMANENT.value + assert context["should_retry"] == "False" + assert "Invalid input" in context["message"] + + def test_rate_limit_context(self): + """Test rate limit specific context.""" + error = Exception("Rate limit. Retry after 30 seconds") + context = get_error_context(error) + + assert context["category"] == ErrorCategory.RATE_LIMIT.value + assert context["should_retry"] == "True" + assert "retry_after" in context + assert context["retry_after"] == "30.0" + + def test_context_limit_context(self): + """Test context limit specific context.""" + error = Exception("Context length exceeded") + context = get_error_context(error) + + assert context["category"] == ErrorCategory.CONTEXT_LIMIT.value + assert "suggestion" in context + assert "compression" in context["suggestion"].lower() + + def test_auth_context(self): + """Test auth error specific context.""" + error = Exception("Invalid API key") + context = get_error_context(error) + + assert context["category"] == ErrorCategory.AUTH.value + assert "suggestion" in context + assert "api key" in context["suggestion"].lower() + + def test_long_message_truncation(self): + """Test that long error messages are truncated.""" + long_message = "x" * 1000 + error = Exception(long_message) + context = get_error_context(error) + + assert len(context["message"]) <= 500 + + +class TestBackwardCompatibility: + + def test_rate_limit_backward_compat(self): + """Test that existing rate limit patterns still work.""" + # These should match the patterns that were in _is_rate_limit_error + test_cases = [ + Exception("429"), + Exception("rate limit"), + Exception("ratelimit"), + Exception("too many request"), + Exception("resource_exhausted"), + Exception("quota exceeded"), + Exception("tokens per minute"), + ] + + for error in test_cases: + category = classify_error(error) + assert category == ErrorCategory.RATE_LIMIT + + # Test with existing _is_rate_limit_error logic (if available) + try: + from praisonaiagents.llm.llm import LLM + llm = LLM(model="fake") + + for error in test_cases: + # Both should agree on rate limit errors + is_rate_limit = llm._is_rate_limit_error(error) + is_classified_rate_limit = classify_error(error) == ErrorCategory.RATE_LIMIT + assert is_rate_limit == is_classified_rate_limit + + except (ImportError, AttributeError): + # Skip backward compatibility test if LLM not available + pass \ No newline at end of file diff --git a/src/praisonai-agents/tests/unit/llm/test_sanitize.py b/src/praisonai-agents/tests/unit/llm/test_sanitize.py new file mode 100644 index 000000000..5d80dfcc1 --- /dev/null +++ b/src/praisonai-agents/tests/unit/llm/test_sanitize.py @@ -0,0 +1,157 @@ +""" +Tests for message sanitization functionality. + +Tests surrogate removal, Unicode handling, and performance characteristics. +""" + +import pytest +from praisonaiagents.llm.sanitize import sanitize_messages, strip_surrogates, sanitize_text + + +class TestStripSurrogates: + + def test_valid_unicode_unchanged(self): + """Test that valid Unicode text is unchanged.""" + text = "Hello ๐ŸŒ World! ไฝ ๅฅฝ" + assert strip_surrogates(text) == text + + def test_empty_string(self): + """Test empty string handling.""" + assert strip_surrogates("") == "" + assert strip_surrogates(None) is None + + def test_surrogate_removal(self): + """Test removal of surrogate characters.""" + # High surrogate without low surrogate (invalid) + text_with_surrogate = "Hello \uD83D World" + cleaned = strip_surrogates(text_with_surrogate) + assert "\uD83D" not in cleaned + assert "Hello" in cleaned + assert "World" in cleaned + + def test_multiple_surrogates(self): + """Test handling of multiple surrogate characters.""" + text = "\uD800\uD801 Valid text \uDFFF" + cleaned = strip_surrogates(text) + assert "Valid text" in cleaned + assert "\uD800" not in cleaned + assert "\uD801" not in cleaned + assert "\uDFFF" not in cleaned + + def test_ascii_only(self): + """Test ASCII-only text is unchanged.""" + text = "Hello World 123 !@#" + assert strip_surrogates(text) == text + + +class TestSanitizeMessages: + + def test_empty_messages(self): + """Test empty message list.""" + assert sanitize_messages([]) is False + assert sanitize_messages(None) is False + + def test_clean_messages_unchanged(self): + """Test that clean messages are not modified.""" + messages = [ + {"role": "user", "content": "Hello world"}, + {"role": "assistant", "content": "Hi there! ๐ŸŒ"}, + ] + original = messages.copy() + changed = sanitize_messages(messages) + assert changed is False + assert messages == original + + def test_sanitize_string_content(self): + """Test sanitization of string content.""" + messages = [ + {"role": "user", "content": "Hello \uD83D World"}, + {"role": "assistant", "content": "Clean content"}, + ] + changed = sanitize_messages(messages) + assert changed is True + assert "\uD83D" not in messages[0]["content"] + assert "Hello" in messages[0]["content"] + assert "World" in messages[0]["content"] + assert messages[1]["content"] == "Clean content" # Unchanged + + def test_sanitize_list_content(self): + """Test sanitization of list content (multimodal).""" + messages = [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello \uD800 World"}, + {"type": "image", "url": "http://example.com/img.jpg"}, + "Direct string \uDFFF here" + ] + } + ] + changed = sanitize_messages(messages) + assert changed is True + + content = messages[0]["content"] + assert "\uD800" not in content[0]["text"] + assert "Hello" in content[0]["text"] and "World" in content[0]["text"] + assert content[1] == {"type": "image", "url": "http://example.com/img.jpg"} # Unchanged + assert "\uDFFF" not in content[2] + assert "Direct string" in content[2] and "here" in content[2] + + def test_sanitize_other_fields(self): + """Test sanitization of non-content fields.""" + messages = [ + { + "role": "user", + "content": "Clean content", + "name": "User\uD801Name", + "custom_field": "Value with \uD900 surrogate" + } + ] + changed = sanitize_messages(messages) + assert changed is True + assert "\uD801" not in messages[0]["name"] + assert "UserName" in messages[0]["name"] + assert "\uD900" not in messages[0]["custom_field"] + assert "Value with" in messages[0]["custom_field"] + + def test_non_dict_messages_skipped(self): + """Test that non-dict messages are skipped.""" + messages = [ + "not a dict", + {"role": "user", "content": "Valid message"}, + None, + {"role": "assistant", "content": "Another valid \uD800 message"} + ] + changed = sanitize_messages(messages) + assert changed is True + assert messages[0] == "not a dict" # Unchanged + assert messages[1]["content"] == "Valid message" # Unchanged + assert messages[2] is None # Unchanged + assert "\uD800" not in messages[3]["content"] # Sanitized + + def test_performance_no_surrogates(self): + """Test that clean messages have minimal overhead.""" + import time + + messages = [ + {"role": "user", "content": "Clean message " * 100}, + {"role": "assistant", "content": "Another clean message " * 100} + ] * 100 # 200 messages total + + start = time.perf_counter() + changed = sanitize_messages(messages) + duration = time.perf_counter() - start + + assert changed is False + assert duration < 0.1 # Should be very fast for clean messages + + +class TestSanitizeText: + + def test_single_text_sanitization(self): + """Test sanitizing individual text strings.""" + assert sanitize_text("Clean text") == "Clean text" + assert sanitize_text("Text with \uD800 surrogate") != "Text with \uD800 surrogate" + assert "Text with" in sanitize_text("Text with \uD800 surrogate") + assert sanitize_text(None) is None + assert sanitize_text(123) == 123 # Non-string unchanged \ No newline at end of file diff --git a/src/praisonai-agents/tests/unit/session/test_title_gen.py b/src/praisonai-agents/tests/unit/session/test_title_gen.py new file mode 100644 index 000000000..e6bee83b2 --- /dev/null +++ b/src/praisonai-agents/tests/unit/session/test_title_gen.py @@ -0,0 +1,124 @@ +""" +Tests for session title auto-generation. + +Tests title generation, fallback behavior, and timeout handling. +""" + +import pytest +import asyncio +from praisonaiagents.session.title import generate_title, generate_title_async, _create_fallback_title + + +class TestTitleGeneration: + + def test_fallback_title(self): + """Test fallback title generation.""" + # Basic message + title = _create_fallback_title("Help me with Python", 60) + assert "Help me with Python" == title + + # Long message gets truncated + long_msg = "This is a very long message that should be truncated because it exceeds the maximum length limit" + title = _create_fallback_title(long_msg, 30) + assert len(title) <= 30 + assert title.endswith("...") + + # Empty message + assert _create_fallback_title("", 60) == "Chat Session" + assert _create_fallback_title(" ", 60) == "Chat Session" + + # Message with punctuation + title = _create_fallback_title("Can you help me? Please!", 60) + assert "?" not in title + assert "!" not in title + + # Message with sentences + title = _create_fallback_title("Help me code. I need assistance with debugging.", 60) + assert "Help me code" in title + assert "debugging" not in title # Should stop at first sentence + + def test_generate_title_fallback_on_import_error(self): + """Test that title generation falls back gracefully when LLM unavailable.""" + # This should fall back to the user message since LLM import might fail + title = generate_title( + "Debug my Python script", + "I'll help you debug that", + llm_model="fake-model", + timeout=0.1 # Very short timeout + ) + + # Should get a reasonable fallback title + assert isinstance(title, str) + assert len(title) > 0 + assert "Debug" in title or "Python" in title or "Chat Session" in title + + @pytest.mark.asyncio + async def test_generate_title_async_timeout(self): + """Test async title generation with timeout.""" + title = await generate_title_async( + "Help with machine learning", + "I can help with ML topics", + llm_model="fake-model", + timeout=0.001 # Very short timeout to trigger fallback + ) + + # Should fall back to user message + assert isinstance(title, str) + assert len(title) > 0 + + def test_title_length_limits(self): + """Test that generated titles respect length limits.""" + long_user_msg = "I need help with a very complex machine learning problem " * 5 + long_assistant_msg = "I'll help you with that complex problem " * 5 + + title = generate_title( + long_user_msg, + long_assistant_msg, + llm_model="fake-model", + max_length=20, + timeout=0.1 + ) + + assert len(title) <= 20 + + def test_empty_messages(self): + """Test behavior with empty or None messages.""" + title = generate_title("", "", timeout=0.1) + assert title == "Chat Session" + + title = generate_title("Hello", "", timeout=0.1) + assert "Hello" in title or title == "Chat Session" + + def test_title_cleaning(self): + """Test that titles are properly cleaned.""" + # Test with quotes and newlines in fallback + title = _create_fallback_title('"Help me"\nwith this', 60) + assert '"' in title # Fallback doesn't clean quotes, only generation does + + # Test basic cleaning in fallback + title = _create_fallback_title("Help me please!", 60) + assert "!" not in title + + def test_unicode_handling(self): + """Test title generation with Unicode content.""" + title = generate_title( + "Hjรคlp mig med ็ผ–็จ‹", # Swedish + Chinese + "I can help with programming", + timeout=0.1 + ) + + assert isinstance(title, str) + assert len(title) > 0 + + def test_sync_wrapper(self): + """Test that sync wrapper properly calls async version.""" + title = generate_title( + "Test sync wrapper", + "Testing response", + timeout=0.1 + ) + + assert isinstance(title, str) + assert len(title) > 0 + # Should contain either generated content or fallback + assert "Test" in title or "sync" in title or "Chat Session" in title \ No newline at end of file From 7927ae62aec31755473122b03d9549b2297c6735 Mon Sep 17 00:00:00 2001 From: "praisonai-triage-agent[bot]" <272766704+praisonai-triage-agent[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 09:53:26 +0000 Subject: [PATCH 2/6] feat: implement G4 parallel tool path overlap detection MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Add path_overlap module with conflict detection for write operations - Integrate with ParallelToolCallExecutor to fallback to sequential when conflicts detected - Prevent file corruption from concurrent writes to same/overlapping paths - Support path conflict grouping for sophisticated scheduling - Zero impact when no path conflicts exist ๐Ÿค– Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Mervin Praison --- .../praisonaiagents/tools/call_executor.py | 11 + .../praisonaiagents/tools/path_overlap.py | 188 +++++++++++++ .../tests/unit/tools/test_path_overlap.py | 262 ++++++++++++++++++ 3 files changed, 461 insertions(+) create mode 100644 src/praisonai-agents/praisonaiagents/tools/path_overlap.py create mode 100644 src/praisonai-agents/tests/unit/tools/test_path_overlap.py diff --git a/src/praisonai-agents/praisonaiagents/tools/call_executor.py b/src/praisonai-agents/praisonaiagents/tools/call_executor.py index 0d17f342b..dd04ba38b 100644 --- a/src/praisonai-agents/praisonaiagents/tools/call_executor.py +++ b/src/praisonai-agents/praisonaiagents/tools/call_executor.py @@ -138,6 +138,17 @@ def execute_batch( sequential_executor = SequentialToolCallExecutor() return sequential_executor.execute_batch(tool_calls, execute_tool_fn) + # G4: Check for path conflicts - fallback to sequential if conflicts detected + try: + from .path_overlap import has_write_conflicts + if has_write_conflicts(tool_calls): + logger.info(f"Path conflicts detected in {len(tool_calls)} tool calls, using sequential execution") + sequential_executor = SequentialToolCallExecutor() + return sequential_executor.execute_batch(tool_calls, execute_tool_fn) + except ImportError: + # path_overlap module not available, continue with parallel execution + pass + def _execute_single_tool(tool_call: ToolCall) -> ToolResult: """Execute a single tool call with error handling.""" try: diff --git a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py new file mode 100644 index 000000000..7eed82d1d --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py @@ -0,0 +1,188 @@ +""" +Path Overlap Detection - Prevent write conflicts in parallel tool execution. + +Detects when multiple tool calls would operate on overlapping file paths +and forces sequential execution to prevent corruption. +""" + +import pathlib +from typing import List, Dict, Any, Set +from .call_executor import ToolCall + +__all__ = ["detect_path_conflicts", "extract_paths", "has_write_conflicts"] + + +# Tool names that perform write operations +_WRITE_TOOLS = frozenset({ + "write_file", "skill_manage", "edit_file", "patch_file", + "create_file", "save_file", "update_file", "delete_file", + "file_write", "file_edit", "file_create", "file_delete", + "mkdir", "rmdir", "move_file", "copy_file" +}) + +# Argument names that typically contain file paths +_PATH_ARG_NAMES = frozenset({ + "path", "file_path", "filepath", "dest", "destination", "target", + "source", "src", "output", "output_path", "filename", "file", + "directory", "dir", "folder" +}) + + +def extract_paths(tool_call: ToolCall) -> List[pathlib.Path]: + """Extract file paths from a tool call. + + Args: + tool_call: Tool call to analyze + + Returns: + List of resolved absolute paths found in the tool call + """ + paths = [] + + # Only check write tools + if tool_call.function_name not in _WRITE_TOOLS: + return paths + + args = tool_call.arguments or {} + + # Look for paths in common argument names + for arg_name, arg_value in args.items(): + if arg_name in _PATH_ARG_NAMES and isinstance(arg_value, str): + try: + path = pathlib.Path(arg_value).resolve() + paths.append(path) + except (OSError, ValueError): + # Invalid path, skip + continue + + return paths + + +def paths_conflict(path1: pathlib.Path, path2: pathlib.Path) -> bool: + """Check if two paths conflict (one is ancestor of the other). + + Args: + path1: First path + path2: Second path + + Returns: + True if paths conflict (overlap) + + Examples: + >>> paths_conflict(Path("/a/b"), Path("/a/b/c")) + True + >>> paths_conflict(Path("/a/b"), Path("/a/c")) + False + >>> paths_conflict(Path("/a/b"), Path("/a/b")) + True + """ + try: + # Same path + if path1 == path2: + return True + + # Check if one is a parent of the other + try: + path1.relative_to(path2) + return True # path1 is under path2 + except ValueError: + pass + + try: + path2.relative_to(path1) + return True # path2 is under path1 + except ValueError: + pass + + return False + + except (OSError, ValueError): + # Error comparing paths - assume no conflict + return False + + +def detect_path_conflicts(tool_calls: List[ToolCall]) -> bool: + """Detect if tool calls have conflicting file path operations. + + Args: + tool_calls: List of tool calls to check + + Returns: + True if any paths conflict, False otherwise + """ + if len(tool_calls) < 2: + return False + + # Extract all paths from write tools + all_paths = [] + for tool_call in tool_calls: + paths = extract_paths(tool_call) + all_paths.extend(paths) + + if len(all_paths) < 2: + return False + + # Check all pairs for conflicts + for i, path1 in enumerate(all_paths): + for path2 in all_paths[i+1:]: + if paths_conflict(path1, path2): + return True + + return False + + +def has_write_conflicts(tool_calls: List[ToolCall]) -> bool: + """Check if tool calls have write conflicts requiring sequential execution. + + This is the main function used by the parallel executor to decide + whether to run tools in parallel or fall back to sequential. + + Args: + tool_calls: List of tool calls to analyze + + Returns: + True if conflicts detected and sequential execution is needed + """ + return detect_path_conflicts(tool_calls) + + +def group_by_conflicts(tool_calls: List[ToolCall]) -> List[List[ToolCall]]: + """Group tool calls into conflict-free batches. + + This can be used for more sophisticated scheduling where some tools + can run in parallel while others must be sequential. + + Args: + tool_calls: Tool calls to group + + Returns: + List of batches where each batch has no internal conflicts + """ + if not tool_calls: + return [] + + if len(tool_calls) == 1: + return [tool_calls] + + # Simple greedy algorithm: create batches sequentially + batches = [] + remaining = tool_calls.copy() + + while remaining: + current_batch = [remaining.pop(0)] + + # Try to add more tools to current batch + i = 0 + while i < len(remaining): + candidate = remaining[i] + + # Check if candidate conflicts with any tool in current batch + test_batch = current_batch + [candidate] + if not detect_path_conflicts(test_batch): + current_batch.append(remaining.pop(i)) + else: + i += 1 + + batches.append(current_batch) + + return batches \ No newline at end of file diff --git a/src/praisonai-agents/tests/unit/tools/test_path_overlap.py b/src/praisonai-agents/tests/unit/tools/test_path_overlap.py new file mode 100644 index 000000000..462b80823 --- /dev/null +++ b/src/praisonai-agents/tests/unit/tools/test_path_overlap.py @@ -0,0 +1,262 @@ +""" +Tests for path overlap detection functionality. + +Tests file path conflict detection and sequential execution decisions. +""" + +import pytest +import pathlib +from praisonaiagents.tools.path_overlap import ( + extract_paths, paths_conflict, detect_path_conflicts, + has_write_conflicts, group_by_conflicts +) +from praisonaiagents.tools.call_executor import ToolCall + + +class TestPathExtraction: + + def test_write_tool_path_extraction(self): + """Test extracting paths from write tools.""" + tool_call = ToolCall( + function_name="write_file", + arguments={"path": "/tmp/test.txt", "content": "hello"}, + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 1 + assert paths[0].name == "test.txt" + + def test_read_tool_no_paths(self): + """Test that read-only tools don't extract paths.""" + tool_call = ToolCall( + function_name="read_file", # Not a write tool + arguments={"path": "/tmp/test.txt"}, + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 0 + + def test_multiple_path_args(self): + """Test extracting multiple paths from one tool call.""" + tool_call = ToolCall( + function_name="copy_file", + arguments={"source": "/tmp/src.txt", "dest": "/tmp/dst.txt"}, + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 2 + assert any(p.name == "src.txt" for p in paths) + assert any(p.name == "dst.txt" for p in paths) + + def test_invalid_path_ignored(self): + """Test that invalid paths are ignored.""" + tool_call = ToolCall( + function_name="write_file", + arguments={"path": "", "content": "hello"}, # Empty path + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 0 + + def test_non_string_args_ignored(self): + """Test that non-string path arguments are ignored.""" + tool_call = ToolCall( + function_name="write_file", + arguments={"path": 123, "content": "hello"}, # Non-string path + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 0 + + +class TestPathConflicts: + + def test_same_path_conflicts(self): + """Test that identical paths conflict.""" + path1 = pathlib.Path("/tmp/test.txt") + path2 = pathlib.Path("/tmp/test.txt") + + assert paths_conflict(path1, path2) is True + + def test_parent_child_conflicts(self): + """Test that parent/child paths conflict.""" + parent = pathlib.Path("/tmp") + child = pathlib.Path("/tmp/subdir/file.txt") + + assert paths_conflict(parent, child) is True + assert paths_conflict(child, parent) is True + + def test_sibling_paths_no_conflict(self): + """Test that sibling paths don't conflict.""" + path1 = pathlib.Path("/tmp/file1.txt") + path2 = pathlib.Path("/tmp/file2.txt") + + assert paths_conflict(path1, path2) is False + + def test_different_trees_no_conflict(self): + """Test that paths in different directory trees don't conflict.""" + path1 = pathlib.Path("/tmp/dir1/file.txt") + path2 = pathlib.Path("/var/dir2/file.txt") + + assert paths_conflict(path1, path2) is False + + def test_subdirectory_conflicts(self): + """Test that subdirectory operations conflict.""" + dir_path = pathlib.Path("/tmp/mydir") + file_in_dir = pathlib.Path("/tmp/mydir/file.txt") + + assert paths_conflict(dir_path, file_in_dir) is True + + +class TestConflictDetection: + + def test_no_conflicts_multiple_reads(self): + """Test that multiple read operations don't conflict.""" + tool_calls = [ + ToolCall("read_file", {"path": "/tmp/file1.txt"}, "1"), + ToolCall("read_file", {"path": "/tmp/file2.txt"}, "2"), + ] + + assert detect_path_conflicts(tool_calls) is False + + def test_no_conflicts_different_paths(self): + """Test that writes to different paths don't conflict.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/file1.txt", "content": "a"}, "1"), + ToolCall("write_file", {"path": "/tmp/file2.txt", "content": "b"}, "2"), + ] + + assert detect_path_conflicts(tool_calls) is False + + def test_conflicts_same_file(self): + """Test that writes to the same file conflict.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/test.txt", "content": "a"}, "1"), + ToolCall("edit_file", {"path": "/tmp/test.txt", "changes": "b"}, "2"), + ] + + assert detect_path_conflicts(tool_calls) is True + + def test_conflicts_parent_child(self): + """Test that parent/child directory operations conflict.""" + tool_calls = [ + ToolCall("mkdir", {"path": "/tmp/mydir"}, "1"), + ToolCall("write_file", {"path": "/tmp/mydir/file.txt", "content": "x"}, "2"), + ] + + assert detect_path_conflicts(tool_calls) is True + + def test_single_tool_no_conflict(self): + """Test that single tool call never conflicts.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/test.txt", "content": "a"}, "1"), + ] + + assert detect_path_conflicts(tool_calls) is False + + def test_empty_list_no_conflict(self): + """Test that empty list has no conflicts.""" + assert detect_path_conflicts([]) is False + + +class TestWriteConflicts: + + def test_has_write_conflicts_delegates_to_detect(self): + """Test that has_write_conflicts properly delegates.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/test.txt", "content": "a"}, "1"), + ToolCall("edit_file", {"path": "/tmp/test.txt", "changes": "b"}, "2"), + ] + + # Both functions should give same result + assert has_write_conflicts(tool_calls) == detect_path_conflicts(tool_calls) + assert has_write_conflicts(tool_calls) is True + + +class TestConflictGrouping: + + def test_group_no_conflicts(self): + """Test grouping when no conflicts exist.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/file1.txt"}, "1"), + ToolCall("write_file", {"path": "/tmp/file2.txt"}, "2"), + ToolCall("write_file", {"path": "/tmp/file3.txt"}, "3"), + ] + + groups = group_by_conflicts(tool_calls) + assert len(groups) == 1 # All can run together + assert len(groups[0]) == 3 + + def test_group_with_conflicts(self): + """Test grouping when conflicts exist.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/test.txt", "content": "a"}, "1"), + ToolCall("write_file", {"path": "/tmp/other.txt", "content": "b"}, "2"), + ToolCall("edit_file", {"path": "/tmp/test.txt", "changes": "c"}, "3"), + ] + + groups = group_by_conflicts(tool_calls) + + # Should be split into groups: tool1+tool2 can run together, tool3 separate + assert len(groups) == 2 + + # Find the group sizes + group_sizes = [len(g) for g in groups] + group_sizes.sort() + assert group_sizes == [1, 2] + + def test_group_single_tool(self): + """Test grouping with single tool.""" + tool_calls = [ + ToolCall("write_file", {"path": "/tmp/test.txt"}, "1"), + ] + + groups = group_by_conflicts(tool_calls) + assert len(groups) == 1 + assert len(groups[0]) == 1 + + def test_group_empty_list(self): + """Test grouping with empty list.""" + groups = group_by_conflicts([]) + assert groups == [] + + +class TestPathArgNames: + + def test_various_path_arg_names(self): + """Test that various path argument names are recognized.""" + path_args = ["path", "file_path", "filepath", "dest", "target", + "source", "output", "filename", "directory"] + + for arg_name in path_args: + tool_call = ToolCall( + function_name="write_file", + arguments={arg_name: "/tmp/test.txt"}, + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 1, f"Failed for arg name: {arg_name}" + + +class TestWriteToolNames: + + def test_various_write_tool_names(self): + """Test that various write tool names are recognized.""" + write_tools = ["write_file", "edit_file", "create_file", "delete_file", + "mkdir", "save_file", "file_write", "file_edit"] + + for tool_name in write_tools: + tool_call = ToolCall( + function_name=tool_name, + arguments={"path": "/tmp/test.txt"}, + tool_call_id="1" + ) + + paths = extract_paths(tool_call) + assert len(paths) == 1, f"Failed for tool: {tool_name}" \ No newline at end of file From a9cd02a99a1aa5b978ed03ffcb5b1335959bc915 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Apr 2026 10:06:25 +0000 Subject: [PATCH 3/6] fix: align gap-closure utilities with unit-test expectations Agent-Logs-Url: https://github.com/MervinPraison/PraisonAI/sessions/46d33039-4077-44a4-abb3-f7a4571f6162 Co-authored-by: MervinPraison <454862+MervinPraison@users.noreply.github.com> --- src/praisonai-agents/praisonaiagents/agent/interrupt.py | 7 ++++--- .../praisonaiagents/llm/error_classifier.py | 6 +++--- src/praisonai-agents/praisonaiagents/llm/sanitize.py | 6 +++--- src/praisonai-agents/praisonaiagents/tools/path_overlap.py | 4 +++- 4 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/praisonai-agents/praisonaiagents/agent/interrupt.py b/src/praisonai-agents/praisonaiagents/agent/interrupt.py index e3af5bf88..38f971c7e 100644 --- a/src/praisonai-agents/praisonaiagents/agent/interrupt.py +++ b/src/praisonai-agents/praisonaiagents/agent/interrupt.py @@ -41,8 +41,9 @@ def request(self, reason: str = "user") -> None: reason: Human-readable reason for cancellation """ with self._lock: - self._reason = reason - self._flag.set() + if not self._flag.is_set(): + self._reason = reason + self._flag.set() def clear(self) -> None: """Clear the cancellation request.""" @@ -76,4 +77,4 @@ def check(self) -> None: """ if self.is_set(): reason = self.reason or "unknown" - raise InterruptedError(f"Operation cancelled: {reason}") \ No newline at end of file + raise InterruptedError(f"Operation cancelled: {reason}") diff --git a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py index 986ea1f40..ae4a34799 100644 --- a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py +++ b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py @@ -159,7 +159,7 @@ def get_retry_delay(category: ErrorCategory, attempt: int = 1, base_delay: float if category == ErrorCategory.RATE_LIMIT: # Longer delay for rate limits to avoid hitting limits again - return min(base_delay * (3 ** (attempt - 1)), 60.0) + return min(base_delay * (3 ** attempt), 60.0) elif category == ErrorCategory.CONTEXT_LIMIT: # Short delay for context limits (compression should be tried) @@ -167,7 +167,7 @@ def get_retry_delay(category: ErrorCategory, attempt: int = 1, base_delay: float elif category == ErrorCategory.TRANSIENT: # Exponential backoff for transient errors - return min(base_delay * (2 ** (attempt - 1)), 30.0) + return min(base_delay * (2 ** attempt), 30.0) return 0 @@ -236,4 +236,4 @@ def get_error_context(error: Exception) -> Dict[str, str]: elif category == ErrorCategory.INVALID_REQUEST: context["suggestion"] = "Review request parameters" - return context \ No newline at end of file + return context diff --git a/src/praisonai-agents/praisonaiagents/llm/sanitize.py b/src/praisonai-agents/praisonaiagents/llm/sanitize.py index dd96fedd1..29c16001a 100644 --- a/src/praisonai-agents/praisonaiagents/llm/sanitize.py +++ b/src/praisonai-agents/praisonaiagents/llm/sanitize.py @@ -36,9 +36,9 @@ def strip_surrogates(text: str) -> str: return text try: - # Method 1: Encode with surrogatepass, decode with replace + # Method 1: Encode with surrogatepass, decode with ignore # This converts surrogates to UTF-16, then back to UTF-8 safely - return text.encode('utf-16', 'surrogatepass').decode('utf-16', 'replace') + return text.encode('utf-16', 'surrogatepass').decode('utf-16', 'ignore') except (UnicodeError, LookupError): # Fallback: Remove surrogate code points directly return re.sub(r'[\uD800-\uDFFF]', '', text) @@ -125,4 +125,4 @@ def sanitize_text(text: Union[str, None]) -> Union[str, None]: return None if not isinstance(text, str): return text - return strip_surrogates(text) \ No newline at end of file + return strip_surrogates(text) diff --git a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py index 7eed82d1d..a66d04e88 100644 --- a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py +++ b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py @@ -48,6 +48,8 @@ def extract_paths(tool_call: ToolCall) -> List[pathlib.Path]: # Look for paths in common argument names for arg_name, arg_value in args.items(): if arg_name in _PATH_ARG_NAMES and isinstance(arg_value, str): + if not arg_value.strip(): + continue try: path = pathlib.Path(arg_value).resolve() paths.append(path) @@ -185,4 +187,4 @@ def group_by_conflicts(tool_calls: List[ToolCall]) -> List[List[ToolCall]]: batches.append(current_batch) - return batches \ No newline at end of file + return batches From af66f595d35bdce14716e2651cadcaa5f82a9cce Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Tue, 21 Apr 2026 10:08:26 +0000 Subject: [PATCH 4/6] docs+fix: guard retry attempt floor and align backoff examples Agent-Logs-Url: https://github.com/MervinPraison/PraisonAI/sessions/46d33039-4077-44a4-abb3-f7a4571f6162 Co-authored-by: MervinPraison <454862+MervinPraison@users.noreply.github.com> --- .../praisonaiagents/llm/error_classifier.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py index ae4a34799..1ba0d1e54 100644 --- a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py +++ b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py @@ -148,12 +148,14 @@ def get_retry_delay(category: ErrorCategory, attempt: int = 1, base_delay: float Examples: >>> get_retry_delay(ErrorCategory.RATE_LIMIT, attempt=1) - 2.0 + 3.0 >>> get_retry_delay(ErrorCategory.TRANSIENT, attempt=3) - 16.0 + 8.0 >>> get_retry_delay(ErrorCategory.AUTH, attempt=1) 0 """ + attempt = max(1, attempt) + if not should_retry(category): return 0 From 76264e280c7f023d750b8d2489e80a02c6ec1e1b Mon Sep 17 00:00:00 2001 From: "praisonai-triage-agent[bot]" <272766704+praisonai-triage-agent[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 10:20:53 +0000 Subject: [PATCH 5/6] fix: resolve all 8 hard blockers and CodeRabbit issues MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - B1: Wire error classifier into retry paths with proper attempt handling - B2: Add interrupt controller checks to sync run_autonomous() - B3: Fix silent failure - fail closed when path overlap unavailable - B4: Enhanced path overlap detection with custom write tool support - B5: Skip empty path strings (already implemented) - B6: Fix SessionMessage attribute access + async I/O blocking - B7: Handle asyncio.run() in running loop context - B8: Implement recursive sanitization for nested tool call arguments Additional fixes: - Add missing type imports (InterruptController) - Replace undefined MemoryManager with Any - Fix test determinism with threading.Event coordination - Add missing __all__ exports for error_classifier - Add InterruptControllerProtocol for extension point ๐Ÿค– Generated with [Claude Code](https://claude.ai/code) Co-authored-by: Claude --- .../praisonaiagents/agent/agent.py | 18 +++- .../praisonaiagents/agent/interrupt.py | 29 ++++++- .../praisonaiagents/llm/error_classifier.py | 9 +- .../praisonaiagents/llm/llm.py | 21 +++-- .../praisonaiagents/llm/sanitize.py | 85 +++++++++++-------- .../praisonaiagents/session/hierarchy.py | 33 +++++-- .../praisonaiagents/session/title.py | 26 ++++-- .../praisonaiagents/tools/call_executor.py | 16 ++-- .../praisonaiagents/tools/path_overlap.py | 37 +++++++- .../tests/unit/agent/test_interrupt.py | 38 ++++----- 10 files changed, 217 insertions(+), 95 deletions(-) diff --git a/src/praisonai-agents/praisonaiagents/agent/agent.py b/src/praisonai-agents/praisonaiagents/agent/agent.py index d18cbd527..07a0fc83f 100644 --- a/src/praisonai-agents/praisonaiagents/agent/agent.py +++ b/src/praisonai-agents/praisonaiagents/agent/agent.py @@ -245,6 +245,7 @@ def _get_default_server_registry() -> ServerRegistry: from ..context.models import ContextConfig from ..context.manager import ContextManager from ..knowledge.knowledge import Knowledge + from .interrupt import InterruptController from ..agent.autonomy import AutonomyConfig from ..task.task import Task from .handoff import Handoff, HandoffConfig, HandoffResult @@ -531,7 +532,7 @@ def __init__( # CONSOLIDATED FEATURE PARAMS (agent-centric API) # Each follows: False=disabled, True=defaults, Config=custom # ============================================================ - memory: Optional[Union[bool, str, 'MemoryConfig', 'MemoryManager']] = None, + memory: Optional[Union[bool, str, 'MemoryConfig', Any]] = None, knowledge: Optional[Union[bool, str, List[str], 'KnowledgeConfig', 'Knowledge']] = None, planning: Optional[Union[bool, str, 'PlanningConfig']] = False, reflection: Optional[Union[bool, str, 'ReflectionConfig']] = None, @@ -576,7 +577,7 @@ def __init__( memory: Memory system configuration. Accepts: - bool: True enables defaults, False disables - MemoryConfig: Custom configuration - - MemoryManager: Pre-configured instance + - Any: Pre-configured memory instance knowledge: Knowledge sources. Accepts: - bool: True enables defaults - List[str]: File paths, URLs, or text content @@ -2838,6 +2839,19 @@ def run_autonomous( started_at=started_at, ) + # G2: Check for interrupt request (cooperative cancellation) - sync version + if self.interrupt_controller and self.interrupt_controller.is_set(): + reason = self.interrupt_controller.reason or "unknown" + return AutonomyResult( + success=False, + output=f"Task interrupted: {reason}", + completion_reason="interrupted", + iterations=iterations, + stage=stage, + actions=actions_taken, + duration_seconds=time_module.time() - start_time, + started_at=started_at, + ) # Execute one turn using the agent's chat method # Always use the original prompt (prompt re-injection) diff --git a/src/praisonai-agents/praisonaiagents/agent/interrupt.py b/src/praisonai-agents/praisonaiagents/agent/interrupt.py index 38f971c7e..098241c73 100644 --- a/src/praisonai-agents/praisonaiagents/agent/interrupt.py +++ b/src/praisonai-agents/praisonaiagents/agent/interrupt.py @@ -6,10 +6,35 @@ """ import threading -from typing import Optional +from typing import Optional, Protocol from dataclasses import dataclass, field -__all__ = ["InterruptController"] +__all__ = ["InterruptControllerProtocol", "InterruptController"] + + +class InterruptControllerProtocol(Protocol): + """Protocol for interrupt controller extension point.""" + + def request(self, reason: str = "user") -> None: + """Request cancellation with optional reason.""" + ... + + def clear(self) -> None: + """Clear interrupt state.""" + ... + + def is_set(self) -> bool: + """Check if interrupt was requested.""" + ... + + @property + def reason(self) -> Optional[str]: + """Get interrupt reason if set.""" + ... + + def check(self) -> None: + """Check for interrupt and raise if set.""" + ... @dataclass diff --git a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py index 1ba0d1e54..d7ecdf96e 100644 --- a/src/praisonai-agents/praisonaiagents/llm/error_classifier.py +++ b/src/praisonai-agents/praisonaiagents/llm/error_classifier.py @@ -9,7 +9,14 @@ from enum import Enum from typing import Dict, Tuple, List, Optional -__all__ = ["ErrorCategory", "classify_error", "should_retry", "get_retry_delay"] +__all__ = [ + "ErrorCategory", + "classify_error", + "should_retry", + "get_retry_delay", + "extract_retry_after", + "get_error_context", +] class ErrorCategory(str, Enum): diff --git a/src/praisonai-agents/praisonaiagents/llm/llm.py b/src/praisonai-agents/praisonaiagents/llm/llm.py index 5d7772baa..ff1fb31c3 100644 --- a/src/praisonai-agents/praisonaiagents/llm/llm.py +++ b/src/praisonai-agents/praisonaiagents/llm/llm.py @@ -650,11 +650,12 @@ def _is_rate_limit_error(self, error: Exception) -> bool: return any(indicator in error_str or indicator in error_type for indicator in indicators) - def _classify_error_and_should_retry(self, error: Exception) -> tuple[str, bool, float]: + def _classify_error_and_should_retry(self, error: Exception, attempt: int = 1) -> tuple[str, bool, float]: """Classify error and determine retry strategy using G5 error classifier. Args: error: Exception to classify + attempt: Current attempt number (1-based) Returns: Tuple of (category, should_retry, retry_delay) @@ -674,8 +675,8 @@ def _classify_error_and_should_retry(self, error: Exception) -> tuple[str, bool, if retry_after: return category.value, True, retry_after - # Use category-specific delay calculation - delay = get_retry_delay(category, attempt=1, base_delay=self._retry_delay) + # Use category-specific delay calculation with proper attempt + delay = get_retry_delay(category, attempt=attempt, base_delay=self._retry_delay) return category.value, True, delay except ImportError: @@ -709,17 +710,16 @@ def _call_with_retry(self, func, *args, **kwargs): return func(*args, **kwargs) except Exception as e: - if not self._is_rate_limit_error(e): + category, can_retry, retry_delay = self._classify_error_and_should_retry(e, attempt + 1) + if not can_retry: raise last_error = e error_str = str(e) if attempt < self._max_retries: - retry_delay = self._parse_retry_delay(error_str) - logging.warning( - f"Rate limit hit (attempt {attempt + 1}/{self._max_retries + 1}), " + f"{category} error hit (attempt {attempt + 1}/{self._max_retries + 1}), " f"waiting {retry_delay:.1f}s before retry..." ) @@ -770,17 +770,16 @@ async def _call_with_retry_async(self, func, *args, **kwargs): return await func(*args, **kwargs) except Exception as e: - if not self._is_rate_limit_error(e): + category, can_retry, retry_delay = self._classify_error_and_should_retry(e, attempt + 1) + if not can_retry: raise last_error = e error_str = str(e) if attempt < self._max_retries: - retry_delay = self._parse_retry_delay(error_str) - logging.warning( - f"Rate limit hit (attempt {attempt + 1}/{self._max_retries + 1}), " + f"{category} error hit (attempt {attempt + 1}/{self._max_retries + 1}), " f"waiting {retry_delay:.1f}s before retry..." ) diff --git a/src/praisonai-agents/praisonaiagents/llm/sanitize.py b/src/praisonai-agents/praisonaiagents/llm/sanitize.py index 29c16001a..941825d18 100644 --- a/src/praisonai-agents/praisonaiagents/llm/sanitize.py +++ b/src/praisonai-agents/praisonaiagents/llm/sanitize.py @@ -44,13 +44,57 @@ def strip_surrogates(text: str) -> str: return re.sub(r'[\uD800-\uDFFF]', '', text) +def _sanitize_value_recursive(value: Any) -> tuple[Any, bool]: + """Recursively sanitize any nested string values in a data structure. + + Args: + value: Value to sanitize (can be dict, list, str, or other) + + Returns: + Tuple of (sanitized_value, changed_flag) + """ + if isinstance(value, str): + sanitized = strip_surrogates(value) + return sanitized, sanitized != value + + elif isinstance(value, list): + changed = False + sanitized_items = [] + for item in value: + sanitized_item, item_changed = _sanitize_value_recursive(item) + sanitized_items.append(sanitized_item) + changed = changed or item_changed + return sanitized_items, changed + + elif isinstance(value, dict): + changed = False + sanitized_dict = {} + for key, nested_value in value.items(): + # Sanitize the key if it's a string + if isinstance(key, str): + sanitized_key = strip_surrogates(key) + if sanitized_key != key: + changed = True + key = sanitized_key + + # Recursively sanitize the value + sanitized_value, value_changed = _sanitize_value_recursive(nested_value) + sanitized_dict[key] = sanitized_value + changed = changed or value_changed + return sanitized_dict, changed + + else: + # Return other types unchanged + return value, False + + def sanitize_messages(messages: List[Dict[str, Any]]) -> bool: """Sanitize message content in-place, removing problematic Unicode. Processes all string content in message dictionaries, including: - message.content (string or list) - message.name - - Any nested string values + - Any nested string values (including tool_calls[].function.arguments) Args: messages: List of message dicts to sanitize in-place @@ -73,39 +117,12 @@ def sanitize_messages(messages: List[Dict[str, Any]]) -> bool: if not isinstance(message, dict): continue - # Sanitize content field (most common) - if "content" in message: - content = message["content"] - - if isinstance(content, str): - sanitized = strip_surrogates(content) - if sanitized != content: - message["content"] = sanitized - changed = True - - elif isinstance(content, list): - # Handle list content (e.g., multimodal messages) - for i, item in enumerate(content): - if isinstance(item, dict) and "text" in item: - text = item["text"] - if isinstance(text, str): - sanitized = strip_surrogates(text) - if sanitized != text: - content[i]["text"] = sanitized - changed = True - elif isinstance(item, str): - sanitized = strip_surrogates(item) - if sanitized != item: - content[i] = sanitized - changed = True - - # Sanitize other string fields - for key, value in message.items(): - if isinstance(value, str) and key != "content": # Already handled above - sanitized = strip_surrogates(value) - if sanitized != value: - message[key] = sanitized - changed = True + # Recursively sanitize the entire message structure + sanitized_message, message_changed = _sanitize_value_recursive(message) + if message_changed: + message.clear() + message.update(sanitized_message) + changed = True return changed diff --git a/src/praisonai-agents/praisonaiagents/session/hierarchy.py b/src/praisonai-agents/praisonaiagents/session/hierarchy.py index 402dc10c2..f6d3182cd 100644 --- a/src/praisonai-agents/praisonaiagents/session/hierarchy.py +++ b/src/praisonai-agents/praisonaiagents/session/hierarchy.py @@ -490,7 +490,10 @@ async def auto_title(self, session_id: str) -> bool: Returns: True if title was generated and set, False otherwise """ - session = self._load_extended_session(session_id) + import asyncio + + # Load session in thread to avoid blocking event loop + session = await asyncio.to_thread(self._load_extended_session, session_id) # Skip if already has a title if session.title and session.title.strip(): @@ -506,12 +509,18 @@ async def auto_title(self, session_id: str) -> bool: assistant_msg = None for msg in messages: - if msg.get("role") == "user" and not user_msg: + # Handle both SessionMessage dataclass and dict formats + if hasattr(msg, 'role'): + role = msg.role + content = msg.content + else: + role = msg.get("role") content = msg.get("content", "") + + if role == "user" and not user_msg: if isinstance(content, str) and content.strip(): user_msg = content - elif msg.get("role") == "assistant" and not assistant_msg and user_msg: - content = msg.get("content", "") + elif role == "assistant" and not assistant_msg and user_msg: if isinstance(content, str) and content.strip(): assistant_msg = content break @@ -525,12 +534,18 @@ async def auto_title(self, session_id: str) -> bool: title = await generate_title_async(user_msg, assistant_msg) if title and title.strip(): - session.title = title.strip() - return self._save_extended_session(session) + # Reload session to avoid overwriting concurrent updates + fresh_session = await asyncio.to_thread(self._load_extended_session, session_id) + # Only set title if it's still empty + if not fresh_session.title or not fresh_session.title.strip(): + fresh_session.title = title.strip() + return await asyncio.to_thread(self._save_extended_session, fresh_session) - except Exception: - # Title generation failed - not critical - pass + except Exception as e: + # Title generation failed - log with context instead of silent failure + import logging + logger = logging.getLogger(__name__) + logger.debug("Auto title generation failed for session %s: %s", session_id, str(e)) return False diff --git a/src/praisonai-agents/praisonaiagents/session/title.py b/src/praisonai-agents/praisonaiagents/session/title.py index cb6648455..668aa098b 100644 --- a/src/praisonai-agents/praisonaiagents/session/title.py +++ b/src/praisonai-agents/praisonaiagents/session/title.py @@ -41,11 +41,27 @@ def generate_title( fallback_title = _create_fallback_title(user_msg, max_length) try: - # Run async version in sync context - return asyncio.run(generate_title_async( - user_msg, assistant_msg, llm_model, timeout, max_length - )) - except Exception: + # Check if we're already in an event loop + asyncio.get_running_loop() + except RuntimeError: + # No running loop - safe to use asyncio.run + try: + return asyncio.run(generate_title_async( + user_msg, assistant_msg, llm_model, timeout, max_length + )) + except Exception: + import logging + logger = logging.getLogger(__name__) + logger.debug("Title generation failed in sync context", exc_info=True) + return fallback_title + else: + # Already in event loop - cannot use asyncio.run + import logging + logger = logging.getLogger(__name__) + logger.warning( + "generate_title() called from running event loop; " + "use generate_title_async() instead. Returning fallback title." + ) return fallback_title diff --git a/src/praisonai-agents/praisonaiagents/tools/call_executor.py b/src/praisonai-agents/praisonaiagents/tools/call_executor.py index dd04ba38b..61b33ed85 100644 --- a/src/praisonai-agents/praisonaiagents/tools/call_executor.py +++ b/src/praisonai-agents/praisonaiagents/tools/call_executor.py @@ -141,13 +141,17 @@ def execute_batch( # G4: Check for path conflicts - fallback to sequential if conflicts detected try: from .path_overlap import has_write_conflicts - if has_write_conflicts(tool_calls): - logger.info(f"Path conflicts detected in {len(tool_calls)} tool calls, using sequential execution") - sequential_executor = SequentialToolCallExecutor() - return sequential_executor.execute_batch(tool_calls, execute_tool_fn) except ImportError: - # path_overlap module not available, continue with parallel execution - pass + logger.warning( + "Path conflict detection unavailable; using sequential execution for safety" + ) + sequential_executor = SequentialToolCallExecutor() + return sequential_executor.execute_batch(tool_calls, execute_tool_fn) + + if has_write_conflicts(tool_calls): + logger.info(f"Path conflicts detected in {len(tool_calls)} tool calls, using sequential execution") + sequential_executor = SequentialToolCallExecutor() + return sequential_executor.execute_batch(tool_calls, execute_tool_fn) def _execute_single_tool(tool_call: ToolCall) -> ToolResult: """Execute a single tool call with error handling.""" diff --git a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py index a66d04e88..b2c674082 100644 --- a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py +++ b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py @@ -20,6 +20,13 @@ "mkdir", "rmdir", "move_file", "copy_file" }) +# Write operation hints in tool names for detecting custom write tools +_WRITE_NAME_HINTS = frozenset({ + "write", "edit", "patch", "create", "save", "update", + "delete", "remove", "mkdir", "rmdir", "move", "copy", "persist", + "modify", "append" +}) + # Argument names that typically contain file paths _PATH_ARG_NAMES = frozenset({ "path", "file_path", "filepath", "dest", "destination", "target", @@ -28,6 +35,32 @@ }) +def _is_potential_write_tool(function_name: str, arguments: dict) -> bool: + """Check if a tool call is potentially a write operation. + + Args: + function_name: Name of the tool function + arguments: Tool call arguments + + Returns: + True if the tool might perform write operations + """ + # Check explicit write tools first + if function_name in _WRITE_TOOLS: + return True + + # Check for write hints in function name + normalized_name = function_name.lower() + if any(hint in normalized_name for hint in _WRITE_NAME_HINTS): + return True + + # Conservative fallback: if tool has path-like arguments, treat as potential writer + if any(arg_name in _PATH_ARG_NAMES for arg_name in arguments.keys()): + return True + + return False + + def extract_paths(tool_call: ToolCall) -> List[pathlib.Path]: """Extract file paths from a tool call. @@ -39,8 +72,8 @@ def extract_paths(tool_call: ToolCall) -> List[pathlib.Path]: """ paths = [] - # Only check write tools - if tool_call.function_name not in _WRITE_TOOLS: + # Check if this tool might perform write operations + if not _is_potential_write_tool(tool_call.function_name, tool_call.arguments): return paths args = tool_call.arguments or {} diff --git a/src/praisonai-agents/tests/unit/agent/test_interrupt.py b/src/praisonai-agents/tests/unit/agent/test_interrupt.py index 0d5fac305..be0891535 100644 --- a/src/praisonai-agents/tests/unit/agent/test_interrupt.py +++ b/src/praisonai-agents/tests/unit/agent/test_interrupt.py @@ -52,17 +52,21 @@ def test_thread_safety(self): """Test thread-safe operations.""" controller = InterruptController() results = [] + ready_event = threading.Event() + done_event = threading.Event() def worker(): - # Wait a bit then request interrupt - time.sleep(0.1) + # Wait for checker to be ready + ready_event.wait(timeout=1) controller.request("thread_cancel") results.append("requested") + done_event.set() def checker(): - # Keep checking until interrupted - while not controller.is_set(): - time.sleep(0.05) + # Signal ready and wait for interrupt + ready_event.set() + done_event.wait(timeout=1) + assert controller.is_set() results.append(f"cancelled: {controller.reason}") # Start threads @@ -105,23 +109,11 @@ def test_clear_resets_state(self): controller.request("new_reason") assert controller.reason == "new_reason" - def test_zero_overhead_when_not_used(self): - """Test that creation and is_set() have minimal overhead.""" - import time + def test_is_set_is_stable_when_not_used(self): + """Test that repeated checks do not mutate state.""" + controller = InterruptController() - # Test creation overhead - start = time.perf_counter() + # Should consistently return False and None for _ in range(1000): - controller = InterruptController() - creation_time = time.perf_counter() - start - - # Test check overhead - controller = InterruptController() - start = time.perf_counter() - for _ in range(10000): - controller.is_set() - check_time = time.perf_counter() - start - - # Should be very fast (< 1ms each) - assert creation_time < 0.001 - assert check_time < 0.001 \ No newline at end of file + assert not controller.is_set() + assert controller.reason is None \ No newline at end of file From 8f8f7f7edb047ddc83f3e335a224ea5fb1b7c478 Mon Sep 17 00:00:00 2001 From: "praisonai-triage-agent[bot]" <272766704+praisonai-triage-agent[bot]@users.noreply.github.com> Date: Tue, 21 Apr 2026 16:07:29 +0000 Subject: [PATCH 6/6] fix: resolve over-conservative write tool detection in path overlap - Add _READ_NAME_HINTS to distinguish read-only tools from writers - Update _is_potential_write_tool() to check read hints before fallback - Require both path args AND payload args for unknown tools - Fixes test_read_tool_no_paths while preserving custom write detection - Resolves performance regression from B4 fix in commit 76264e28 All 23 path overlap tests now pass. Read tools like read_file, get_file can run in parallel while custom write tools are still safely detected. Co-authored-by: Mervin Praison --- .../praisonaiagents/tools/path_overlap.py | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py index b2c674082..b0ff58276 100644 --- a/src/praisonai-agents/praisonaiagents/tools/path_overlap.py +++ b/src/praisonai-agents/praisonaiagents/tools/path_overlap.py @@ -27,6 +27,13 @@ "modify", "append" }) +# Read operation hints - tools that are explicitly read-only +_READ_NAME_HINTS = frozenset({ + "read", "get", "fetch", "load", "list", "ls", "search", "find", + "query", "select", "scan", "inspect", "view", "show", "describe", + "head", "tail", "cat" +}) + # Argument names that typically contain file paths _PATH_ARG_NAMES = frozenset({ "path", "file_path", "filepath", "dest", "destination", "target", @@ -49,13 +56,19 @@ def _is_potential_write_tool(function_name: str, arguments: dict) -> bool: if function_name in _WRITE_TOOLS: return True - # Check for write hints in function name normalized_name = function_name.lower() + + # Explicit read-like name: never a writer + if any(normalized_name.startswith(h + "_") or normalized_name == h for h in _READ_NAME_HINTS): + return False + + # Check for write hints in function name if any(hint in normalized_name for hint in _WRITE_NAME_HINTS): return True - # Conservative fallback: if tool has path-like arguments, treat as potential writer - if any(arg_name in _PATH_ARG_NAMES for arg_name in arguments.keys()): + # Conservative fallback ONLY if there is also a payload-like arg + payload_args = {"content", "data", "text", "body", "patch", "diff", "value"} + if any(a in _PATH_ARG_NAMES for a in arguments) and any(a in payload_args for a in arguments): return True return False