Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
107 changes: 107 additions & 0 deletions src/praisonai-agents/praisonaiagents/_server_registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
"""
Centralized server registry for thread-safe server state management.
Unified solution for both Agent and Agents classes to share server resources safely.
"""

import threading
import logging
from typing import Dict, Optional, Any

logger = logging.getLogger(__name__)


class ServerRegistry:
"""
Thread-safe centralized registry for managing shared FastAPI servers.

This singleton class manages server state across Agent and Agents classes
to prevent port conflicts and ensure thread safety.
"""

_instance: Optional['ServerRegistry'] = None
_lock = threading.Lock()

def __new__(cls) -> 'ServerRegistry':
"""Ensure singleton pattern for global server state."""
with cls._lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance

def __init__(self):
"""Initialize the server registry (only once)."""
if self._initialized:
return

self._server_lock = threading.Lock()
self._server_started: Dict[int, bool] = {} # port -> started boolean
self._registered_endpoints: Dict[int, Dict[str, str]] = {} # port -> {path: endpoint_id}
self._shared_apps: Dict[int, Any] = {} # port -> FastAPI app
self._initialized = True

def is_server_started(self, port: int) -> bool:
"""Check if server is started on given port (thread-safe)."""
with self._server_lock:
return self._server_started.get(port, False)

def mark_server_started(self, port: int) -> None:
"""Mark server as started on given port (thread-safe)."""
with self._server_lock:
self._server_started[port] = True

def get_registered_endpoints(self, port: int) -> Dict[str, str]:
"""Get registered endpoints for a port (thread-safe)."""
with self._server_lock:
return self._registered_endpoints.get(port, {}).copy()

def register_endpoint(self, port: int, path: str, endpoint_id: str) -> bool:
"""
Register an endpoint for a port (thread-safe).

Returns:
bool: True if registered successfully, False if path already exists
"""
with self._server_lock:
if port not in self._registered_endpoints:
self._registered_endpoints[port] = {}

if path in self._registered_endpoints[port]:
logger.warning(f"Path '{path}' is already registered on port {port}")
return False

self._registered_endpoints[port][path] = endpoint_id
return True

def get_shared_app(self, port: int) -> Optional[Any]:
"""Get the shared FastAPI app for a port (thread-safe)."""
with self._server_lock:
return self._shared_apps.get(port)

def set_shared_app(self, port: int, app: Any) -> None:
"""Set the shared FastAPI app for a port (thread-safe)."""
with self._server_lock:
self._shared_apps[port] = app

def initialize_port(self, port: int) -> None:
"""Initialize collections for a port if needed (thread-safe)."""
with self._server_lock:
if port not in self._registered_endpoints:
self._registered_endpoints[port] = {}

def get_server_info(self, port: int) -> Dict[str, Any]:
"""Get complete server info for a port (thread-safe)."""
with self._server_lock:
return {
'started': self._server_started.get(port, False),
'endpoints': self._registered_endpoints.get(port, {}).copy(),
'has_app': port in self._shared_apps
}


# Global singleton instance
_registry = ServerRegistry()

def get_server_registry() -> ServerRegistry:
"""Get the global server registry instance."""
return _registry
45 changes: 24 additions & 21 deletions src/praisonai-agents/praisonaiagents/agent/__init__.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
"""Agent module for AI agents - uses lazy loading for performance"""
import threading

# Lazy loading cache
# Thread-safe lazy loading cache
_lazy_cache = {}
_cache_lock = threading.Lock()

def __getattr__(name):
"""Lazy load agent classes to avoid importing rich at startup."""
if name in _lazy_cache:
return _lazy_cache[name]

# Core Agent - always needed
if name == 'Agent':
from .agent import Agent
_lazy_cache[name] = Agent
return Agent
if name == 'BudgetExceededError':
from .agent import BudgetExceededError
_lazy_cache[name] = BudgetExceededError
return BudgetExceededError
if name == 'Heartbeat':
from .heartbeat import Heartbeat
_lazy_cache[name] = Heartbeat
return Heartbeat
if name == 'HeartbeatConfig':
from .heartbeat import HeartbeatConfig
_lazy_cache[name] = HeartbeatConfig
return HeartbeatConfig
with _cache_lock:
if name in _lazy_cache:
return _lazy_cache[name]

# Core Agent - always needed
if name == 'Agent':
from .agent import Agent
_lazy_cache[name] = Agent
return Agent
if name == 'BudgetExceededError':
from .agent import BudgetExceededError
_lazy_cache[name] = BudgetExceededError
return BudgetExceededError
if name == 'Heartbeat':
from .heartbeat import Heartbeat
_lazy_cache[name] = Heartbeat
return Heartbeat
if name == 'HeartbeatConfig':
from .heartbeat import HeartbeatConfig
_lazy_cache[name] = HeartbeatConfig
return HeartbeatConfig

# Specialized agents - lazy loaded (import rich)
if name == 'ImageAgent':
Expand Down
56 changes: 27 additions & 29 deletions src/praisonai-agents/praisonaiagents/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,11 +158,8 @@ def _is_file_path(value: str) -> bool:
# Applied even when context management is disabled to prevent runaway tool outputs
DEFAULT_TOOL_OUTPUT_LIMIT = 16000

# Global variables for API server (protected by _server_lock for thread safety)
_server_lock = threading.Lock()
_server_started = {} # Dict of port -> started boolean
_registered_agents = {} # Dict of port -> Dict of path -> agent_id
_shared_apps = {} # Dict of port -> FastAPI app
# Use centralized server registry for thread-safe server state management
from .._server_registry import get_server_registry

# Don't import FastAPI dependencies here - use lazy loading instead

Expand Down Expand Up @@ -8602,7 +8599,8 @@ def launch(self, path: str = '/', port: int = 8000, host: str = '0.0.0.0', debug
None
"""
if protocol == "http":
global _server_started, _registered_agents, _shared_apps, _server_lock
# Get centralized server registry
registry = get_server_registry()

# Try to import FastAPI dependencies - lazy loading
try:
Expand All @@ -8629,53 +8627,53 @@ class AgentQuery(BaseModel):
print("pip install 'praisonaiagents[api]'")
return None

with _server_lock:
# Registry methods are already thread-safe
# Initialize port-specific collections if needed
if port not in _registered_agents:
_registered_agents[port] = {}
registry.initialize_port(port)

# Initialize shared FastAPI app if not already created for this port
if _shared_apps.get(port) is None:
_shared_apps[port] = FastAPI(
if registry.get_shared_app(port) is None:
app = FastAPI(
title=f"PraisonAI Agents API (Port {port})",
description="API for interacting with PraisonAI Agents"
)
registry.set_shared_app(port, app)

# Add a root endpoint with a welcome message
@_shared_apps[port].get("/")
@registry.get_shared_app(port).get("/")
async def root():
return {
"message": f"Welcome to PraisonAI Agents API on port {port}. See /docs for usage.",
"endpoints": list(_registered_agents[port].keys())
"endpoints": list(registry.get_registered_endpoints(port).keys())
}

# Add healthcheck endpoint
@_shared_apps[port].get("/health")
@registry.get_shared_app(port).get("/health")
async def healthcheck():
return {
"status": "ok",
"endpoints": list(_registered_agents[port].keys())
"endpoints": list(registry.get_registered_endpoints(port).keys())
}

# Normalize path to ensure it starts with /
if not path.startswith('/'):
path = f'/{path}'

# Check if path is already registered for this port
if path in _registered_agents[port]:
logging.warning(f"Path '{path}' is already registered on port {port}. Please use a different path.")
# Check if path is already registered for this port and register
if not registry.register_endpoint(port, path, self.agent_id):
print(f"⚠️ Warning: Path '{path}' is already registered on port {port}.")
# Use a modified path to avoid conflicts
original_path = path
path = f"{path}_{self.agent_id[:6]}"
logging.warning(f"Using '{path}' instead of '{original_path}'")
print(f"🔄 Using '{path}' instead")

# Register the agent to this path
_registered_agents[port][path] = self.agent_id
if not registry.register_endpoint(port, path, self.agent_id):
# If even that fails, just continue (shouldn't happen normally)
logging.error(f"Failed to register any path for agent {self.agent_id} on port {port}")
else:
logging.warning(f"Using '{path}' instead of '{original_path}'")
print(f"🔄 Using '{path}' instead")

# Define the endpoint handler
@_shared_apps[port].post(path)
@registry.get_shared_app(port).post(path)
async def handle_agent_query(request: Request, query_data: Optional[AgentQuery] = None):
# Handle both direct JSON with query field and form data
if query_data is None:
Expand Down Expand Up @@ -8714,9 +8712,9 @@ async def handle_agent_query(request: Request, query_data: Optional[AgentQuery]
print(f"🚀 Agent '{self.name}' available at http://{host}:{port}")

# Check and mark server as started atomically to prevent race conditions
should_start = not _server_started.get(port, False)
should_start = not registry.is_server_started(port)
if should_start:
_server_started[port] = True
registry.mark_server_started(port)

# Server start/wait outside the lock to avoid holding it during sleep
if should_start:
Expand All @@ -8725,8 +8723,8 @@ def run_server():
try:
print(f"✅ FastAPI server started at http://{host}:{port}")
print(f"📚 API documentation available at http://{host}:{port}/docs")
print(f"🔌 Available endpoints: {', '.join(list(_registered_agents[port].keys()))}")
uvicorn.run(_shared_apps[port], host=host, port=port, log_level="debug" if debug else "info")
print(f"🔌 Available endpoints: {', '.join(list(registry.get_registered_endpoints(port).keys()))}")
uvicorn.run(registry.get_shared_app(port), host=host, port=port, log_level="debug" if debug else "info")
except Exception as e:
logging.error(f"Error starting server: {str(e)}", exc_info=True)
print(f"❌ Error starting server: {str(e)}")
Expand All @@ -8740,7 +8738,7 @@ def run_server():
else:
# If server is already running, wait a moment to make sure the endpoint is registered
time.sleep(0.1)
print(f"🔌 Available endpoints on port {port}: {', '.join(list(_registered_agents[port].keys()))}")
print(f"🔌 Available endpoints on port {port}: {', '.join(list(registry.get_registered_endpoints(port).keys()))}")

# Get the stack frame to check if this is the last launch() call in the script
import inspect
Expand Down
Loading