diff --git a/src/praisonai-agents/praisonaiagents/agent/agent.py b/src/praisonai-agents/praisonaiagents/agent/agent.py index ab9932bb9..055740e1d 100644 --- a/src/praisonai-agents/praisonaiagents/agent/agent.py +++ b/src/praisonai-agents/praisonaiagents/agent/agent.py @@ -163,11 +163,40 @@ def _is_file_path(value: str) -> bool: # Applied even when context management is disabled to prevent runaway tool outputs DEFAULT_TOOL_OUTPUT_LIMIT = 16000 -# Global variables for API server (protected by _server_lock for thread safety) -_server_lock = threading.Lock() -_server_started = {} # Dict of port -> started boolean -_registered_agents = {} # Dict of port -> Dict of path -> agent_id -_shared_apps = {} # Dict of port -> FastAPI app +# DEPRECATED: Legacy global variables - now using unified ServerRegistry +# These are kept for backward compatibility but redirect to ServerRegistry +from ..core.server_registry import get_server_registry + +_server_registry = get_server_registry() + +# Backward compatibility wrappers +def _get_server_started(): + """Get server started status dict (backward compatibility).""" + return {port: _server_registry.is_started(port) for port in _server_registry.list_ports()} + +def _get_registered_agents(): + """Get registered agents dict (backward compatibility).""" + result = {} + for port in _server_registry.list_ports(): + agents = _server_registry.get_agents(port) + if agents: + result[port] = agents + return result + +def _get_shared_apps(): + """Get shared apps dict (backward compatibility).""" + result = {} + for port in _server_registry.list_ports(): + app = _server_registry.get_app(port) + if app is not None: + result[port] = app + return result + +# Legacy variables that redirect to ServerRegistry +_server_started = _get_server_started() +_registered_agents = _get_registered_agents() +_shared_apps = _get_shared_apps() +_server_lock = _server_registry._lock # Use ServerRegistry's lock for compatibility # Don't import FastAPI dependencies here - use lazy loading instead diff --git a/src/praisonai-agents/praisonaiagents/agents/agents.py b/src/praisonai-agents/praisonaiagents/agents/agents.py index 9ca17609d..171a317ff 100644 --- a/src/praisonai-agents/praisonaiagents/agents/agents.py +++ b/src/praisonai-agents/praisonaiagents/agents/agents.py @@ -29,12 +29,17 @@ class TaskStatus(Enum): # Set up logger logger = logging.getLogger(__name__) -# Global variables for managing the shared servers with thread-safety -import threading -_agents_server_lock = threading.Lock() # Protect all global server state mutations -_agents_server_started = {} # Dict of port -> started boolean -_agents_registered_endpoints = {} # Dict of port -> Dict of path -> endpoint_id -_agents_shared_apps = {} # Dict of port -> FastAPI app +# DEPRECATED: Legacy global variables - now using unified ServerRegistry +# These redirect to the shared ServerRegistry to prevent race conditions +from ..core.server_registry import get_server_registry + +_server_registry = get_server_registry() + +# Backward compatibility using unified ServerRegistry +_agents_server_lock = _server_registry._lock # Use unified lock +_agents_server_started = {} # Will be populated dynamically +_agents_registered_endpoints = {} # Will be populated dynamically +_agents_shared_apps = {} # Will be populated dynamically def encode_file_to_base64(file_path: str) -> str: """Base64-encode a file.""" diff --git a/src/praisonai-agents/praisonaiagents/core/server_registry.py b/src/praisonai-agents/praisonaiagents/core/server_registry.py new file mode 100644 index 000000000..40fdb364c --- /dev/null +++ b/src/praisonai-agents/praisonaiagents/core/server_registry.py @@ -0,0 +1,208 @@ +""" +Unified thread-safe server registry for managing API server state. + +This module replaces the duplicate server state management found in: +- agent/agent.py (_server_started, _registered_agents, _shared_apps) +- agents/agents.py (_agents_server_started, _agents_registered_endpoints, _agents_shared_apps) + +Design follows AGENTS.md principle of "Multi-agent + async safe by default". +""" + +import threading +from typing import Dict, Any, Optional, Set +from dataclasses import dataclass, field + + +@dataclass +class ServerInfo: + """Information about a registered server instance.""" + port: int + started: bool = False + app: Optional[Any] = None # FastAPI app instance + endpoints: Dict[str, str] = field(default_factory=dict) # path -> agent/endpoint_id + agents: Dict[str, str] = field(default_factory=dict) # path -> agent_id (for backward compatibility) + + +class ServerRegistry: + """ + Thread-safe singleton registry for managing API server state. + + Replaces duplicate lock domains in agent.py and agents.py with a unified, + thread-safe approach that prevents port conflicts and race conditions. + """ + + _instance = None + _creation_lock = threading.Lock() + + def __new__(cls): + """Singleton pattern with thread-safe creation.""" + if cls._instance is None: + with cls._creation_lock: + if cls._instance is None: + cls._instance = super().__new__(cls) + cls._instance._initialized = False + return cls._instance + + def __init__(self): + """Initialize the registry (only once due to singleton pattern).""" + if not getattr(self, '_initialized', False): + self._lock = threading.Lock() + self._servers: Dict[int, ServerInfo] = {} + self._initialized = True + + def register_server(self, port: int, app: Any = None) -> ServerInfo: + """ + Register a server on the given port. + + Args: + port: Port number for the server + app: Optional FastAPI app instance + + Returns: + ServerInfo object for the registered server + + Raises: + ValueError: If port is already registered with a different app + """ + with self._lock: + if port in self._servers: + server_info = self._servers[port] + if app is not None and server_info.app is not None and server_info.app is not app: + raise ValueError(f"Port {port} already registered with different app") + # Update app if provided + if app is not None: + server_info.app = app + return server_info + + # Create new server info + server_info = ServerInfo(port=port, app=app) + self._servers[port] = server_info + return server_info + + def mark_started(self, port: int) -> None: + """Mark a server as started.""" + with self._lock: + if port in self._servers: + self._servers[port].started = True + + def mark_stopped(self, port: int) -> None: + """Mark a server as stopped.""" + with self._lock: + if port in self._servers: + self._servers[port].started = False + + def is_started(self, port: int) -> bool: + """Check if a server is started.""" + with self._lock: + return self._servers.get(port, ServerInfo(port)).started + + def get_app(self, port: int) -> Optional[Any]: + """Get the FastAPI app for a server.""" + with self._lock: + return self._servers.get(port, ServerInfo(port)).app + + def register_endpoint(self, port: int, path: str, endpoint_id: str) -> None: + """Register an endpoint on a server.""" + with self._lock: + if port not in self._servers: + self._servers[port] = ServerInfo(port) + self._servers[port].endpoints[path] = endpoint_id + + def register_agent(self, port: int, path: str, agent_id: str) -> None: + """Register an agent on a server (backward compatibility).""" + with self._lock: + if port not in self._servers: + self._servers[port] = ServerInfo(port) + self._servers[port].agents[path] = agent_id + # Also register as endpoint for unified access + self._servers[port].endpoints[path] = agent_id + + def get_endpoints(self, port: int) -> Dict[str, str]: + """Get all endpoints for a server.""" + with self._lock: + if port in self._servers: + return self._servers[port].endpoints.copy() + return {} + + def get_agents(self, port: int) -> Dict[str, str]: + """Get all agents for a server (backward compatibility).""" + with self._lock: + if port in self._servers: + return self._servers[port].agents.copy() + return {} + + def unregister_server(self, port: int) -> None: + """Unregister a server and all its endpoints.""" + with self._lock: + if port in self._servers: + del self._servers[port] + + def list_ports(self) -> Set[int]: + """Get all registered port numbers.""" + with self._lock: + return set(self._servers.keys()) + + def get_server_info(self, port: int) -> Optional[ServerInfo]: + """Get complete server information.""" + with self._lock: + if port in self._servers: + # Return a copy to prevent external mutation + info = self._servers[port] + return ServerInfo( + port=info.port, + started=info.started, + app=info.app, + endpoints=info.endpoints.copy(), + agents=info.agents.copy() + ) + return None + + def clear(self) -> None: + """Clear all registered servers (useful for testing).""" + with self._lock: + self._servers.clear() + + +# Global registry instance +_registry = ServerRegistry() + + +# Convenience functions for easy access +def get_server_registry() -> ServerRegistry: + """Get the global server registry instance.""" + return _registry + + +def register_server(port: int, app: Any = None) -> ServerInfo: + """Register a server on the given port.""" + return _registry.register_server(port, app) + + +def is_server_started(port: int) -> bool: + """Check if a server is started.""" + return _registry.is_started(port) + + +def mark_server_started(port: int) -> None: + """Mark a server as started.""" + _registry.mark_started(port) + + +def mark_server_stopped(port: int) -> None: + """Mark a server as stopped.""" + _registry.mark_stopped(port) + + +def get_server_app(port: int) -> Optional[Any]: + """Get the FastAPI app for a server.""" + return _registry.get_app(port) + + +def register_agent_endpoint(port: int, path: str, agent_id: str) -> None: + """Register an agent endpoint on a server.""" + _registry.register_agent(port, path, agent_id) + + +def register_endpoint(port: int, path: str, endpoint_id: str) -> None: + """Register a generic endpoint on a server.""" + _registry.register_endpoint(port, path, endpoint_id) \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/main.py b/src/praisonai-agents/praisonaiagents/main.py index cc8340cb1..2cf25f261 100644 --- a/src/praisonai-agents/praisonaiagents/main.py +++ b/src/praisonai-agents/praisonaiagents/main.py @@ -2,6 +2,8 @@ import time import json import logging +import contextvars +import threading from typing import List, Optional, Dict, Any, Union, Literal, Type from pydantic import BaseModel, ConfigDict import asyncio @@ -23,15 +25,69 @@ def _rich(): # Logging is already configured in _logging.py via __init__.py -# Global list to store error logs -error_logs = [] - -# Separate registries for sync and async callbacks -sync_display_callbacks = {} -async_display_callbacks = {} - -# Global approval callback registry -approval_callback = None +# Thread-safe global state using context variables for multi-agent safety +# Each agent/session gets its own isolated state +_error_logs_var: contextvars.ContextVar[List[str]] = contextvars.ContextVar('error_logs', default=[]) +_sync_callbacks_var: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar('sync_display_callbacks', default={}) +_async_callbacks_var: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar('async_display_callbacks', default={}) +_approval_callback_var: contextvars.ContextVar[Any] = contextvars.ContextVar('approval_callback', default=None) + +# Lock for protecting the fallback global registries (used when no context available) +_global_lock = threading.Lock() +_global_error_logs = [] +_global_sync_callbacks = {} +_global_async_callbacks = {} +_global_approval_callback = None + +# Context-aware accessors with fallback to thread-safe globals +def _get_error_logs(): + """Get error logs for current context, fallback to thread-safe global.""" + try: + return _error_logs_var.get() + except LookupError: + # No context variable set, use thread-safe global fallback + with _global_lock: + return _global_error_logs.copy() + +def _add_error_log(message: str): + """Add error log to current context, fallback to thread-safe global.""" + try: + logs = _error_logs_var.get() + logs.append(message) + except LookupError: + # No context variable set, use thread-safe global fallback + with _global_lock: + _global_error_logs.append(message) + +def _get_sync_callbacks(): + """Get sync callbacks for current context, fallback to thread-safe global.""" + try: + return _sync_callbacks_var.get() + except LookupError: + with _global_lock: + return _global_sync_callbacks.copy() + +def _get_async_callbacks(): + """Get async callbacks for current context, fallback to thread-safe global.""" + try: + return _async_callbacks_var.get() + except LookupError: + with _global_lock: + return _global_async_callbacks.copy() + +def _get_approval_callback(): + """Get approval callback for current context, fallback to thread-safe global.""" + try: + return _approval_callback_var.get() + except LookupError: + with _global_lock: + return _global_approval_callback + +# Backward compatibility globals (dynamically resolve to context-aware versions) +error_logs = [] # Will be replaced by context-aware calls in functions +sync_display_callbacks = {} # Will be replaced by context-aware calls in functions +async_display_callbacks = {} # Will be replaced by context-aware calls in functions +approval_callback = None # Will be replaced by context-aware calls in functions # ───────────────────────────────────────────────────────────────────────────── # PraisonAI Unique Color Palette: "Elegant Intelligence" @@ -128,10 +184,20 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F callback_fn: The callback function to register is_async (bool): Whether the callback is asynchronous """ - if is_async: - async_display_callbacks[display_type] = callback_fn - else: - sync_display_callbacks[display_type] = callback_fn + try: + if is_async: + callbacks = _async_callbacks_var.get() + callbacks[display_type] = callback_fn + else: + callbacks = _sync_callbacks_var.get() + callbacks[display_type] = callback_fn + except LookupError: + # No context variable set, use thread-safe global fallback + with _global_lock: + if is_async: + _global_async_callbacks[display_type] = callback_fn + else: + _global_sync_callbacks[display_type] = callback_fn def register_approval_callback(callback_fn): """Register a global approval callback function for dangerous tool operations. @@ -139,8 +205,13 @@ def register_approval_callback(callback_fn): Args: callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision """ - global approval_callback - approval_callback = callback_fn + try: + _approval_callback_var.set(callback_fn) + except LookupError: + # No context variable set, use thread-safe global fallback + with _global_lock: + global _global_approval_callback + _global_approval_callback = callback_fn # Simplified aliases (consistent naming convention) @@ -157,8 +228,10 @@ def execute_sync_callback(display_type: str, **kwargs): display_type (str): Type of display event **kwargs: Arguments to pass to the callback function """ - if display_type in sync_display_callbacks: - callback = sync_display_callbacks[display_type] + # Get callbacks from current context or thread-safe global fallback + callbacks = _get_sync_callbacks() + if display_type in callbacks: + callback = callbacks[display_type] import inspect sig = inspect.signature(callback) @@ -181,9 +254,13 @@ async def execute_callback(display_type: str, **kwargs): """ import inspect + # Get callbacks from current context or thread-safe global fallback + sync_callbacks = _get_sync_callbacks() + async_callbacks = _get_async_callbacks() + # Execute synchronous callback if registered - if display_type in sync_display_callbacks: - callback = sync_display_callbacks[display_type] + if display_type in sync_callbacks: + callback = sync_callbacks[display_type] sig = inspect.signature(callback) # Filter kwargs to what the callback accepts to maintain backward compatibility @@ -198,8 +275,8 @@ async def execute_callback(display_type: str, **kwargs): await loop.run_in_executor(None, lambda: callback(**supported_kwargs)) # Execute asynchronous callback if registered - if display_type in async_display_callbacks: - callback = async_display_callbacks[display_type] + if display_type in async_callbacks: + callback = async_callbacks[display_type] sig = inspect.signature(callback) # Filter kwargs to what the callback accepts to maintain backward compatibility @@ -420,7 +497,7 @@ def display_error(message: str, console=None): title="⚠ Error", border_style=PRAISON_COLORS["error"] )) - error_logs.append(message) + _add_error_log(message) def display_generating(content: str = "", start_time: Optional[float] = None): if not content or not str(content).strip(): @@ -617,7 +694,7 @@ async def adisplay_error(message: str, console=None): await execute_callback('error', message=message) console.print(Panel.fit(Text(message, style="bold red"), title="Error", border_style="red")) - error_logs.append(message) + _add_error_log(message) async def adisplay_generating(content: str = "", start_time: Optional[float] = None): """Async version of display_generating."""