diff --git a/src/praisonai-agents/praisonaiagents/main.py b/src/praisonai-agents/praisonaiagents/main.py index 42003ca7c..e57645602 100644 --- a/src/praisonai-agents/praisonaiagents/main.py +++ b/src/praisonai-agents/praisonaiagents/main.py @@ -2,6 +2,7 @@ import time import json import logging +import threading from praisonaiagents._logging import get_logger from typing import List, Optional, Dict, Any, Union, Literal, Type from pydantic import BaseModel, ConfigDict @@ -23,15 +24,64 @@ def _rich(): # Logging is already configured in _logging.py via __init__.py -# Global list to store error logs -error_logs = [] - -# Separate registries for sync and async callbacks -sync_display_callbacks = {} -async_display_callbacks = {} +# Thread-safe global state protection +_main_globals_lock = threading.Lock() + +# Global list to store error logs (protected by _main_globals_lock) +_error_logs = [] + +# Separate registries for sync and async callbacks (protected by _main_globals_lock) +_sync_display_callbacks = {} +_async_display_callbacks = {} + +# Global approval callback registry (protected by _main_globals_lock) +_approval_callback = None + +# Thread-safe accessor functions +def _get_error_logs(): + """Thread-safe access to error logs.""" + with _main_globals_lock: + return _error_logs.copy() + +def _add_error_log(error): + """Thread-safe addition to error logs.""" + with _main_globals_lock: + _error_logs.append(error) + +def _get_sync_display_callbacks(): + """Thread-safe access to sync display callbacks.""" + with _main_globals_lock: + return _sync_display_callbacks.copy() + +def _get_async_display_callbacks(): + """Thread-safe access to async display callbacks.""" + with _main_globals_lock: + return _async_display_callbacks.copy() + +def _get_approval_callback(): + """Thread-safe access to approval callback.""" + with _main_globals_lock: + return _approval_callback + +def _set_approval_callback(callback): + """Thread-safe setting of approval callback.""" + with _main_globals_lock: + global _approval_callback + _approval_callback = callback + +def _register_display_callback(display_type: str, callback_fn, is_async: bool = False): + """Thread-safe registration of display callback.""" + with _main_globals_lock: + if is_async: + _async_display_callbacks[display_type] = callback_fn + else: + _sync_display_callbacks[display_type] = callback_fn -# Global approval callback registry -approval_callback = None +# Backward compatibility - expose as module-level references +error_logs = _error_logs # Direct reference for read access, use helper functions for writes +sync_display_callbacks = _sync_display_callbacks # Direct reference for read access, use helper functions for writes +async_display_callbacks = _async_display_callbacks # Direct reference for read access, use helper functions for writes +approval_callback = _approval_callback # Direct reference for read access, use helper functions for writes # ───────────────────────────────────────────────────────────────────────────── # PraisonAI Unique Color Palette: "Elegant Intelligence" @@ -128,10 +178,7 @@ def register_display_callback(display_type: str, callback_fn, is_async: bool = F callback_fn: The callback function to register is_async (bool): Whether the callback is asynchronous """ - if is_async: - async_display_callbacks[display_type] = callback_fn - else: - sync_display_callbacks[display_type] = callback_fn + _register_display_callback(display_type, callback_fn, is_async) def register_approval_callback(callback_fn): """Register a global approval callback function for dangerous tool operations. @@ -139,8 +186,7 @@ def register_approval_callback(callback_fn): Args: callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision """ - global approval_callback - approval_callback = callback_fn + _set_approval_callback(callback_fn) # Simplified aliases (consistent naming convention) add_display_callback = register_display_callback @@ -155,8 +201,9 @@ def execute_sync_callback(display_type: str, **kwargs): display_type (str): Type of display event **kwargs: Arguments to pass to the callback function """ - if display_type in sync_display_callbacks: - callback = sync_display_callbacks[display_type] + sync_callbacks = _get_sync_display_callbacks() + if display_type in sync_callbacks: + callback = sync_callbacks[display_type] import inspect sig = inspect.signature(callback) @@ -180,8 +227,9 @@ async def execute_callback(display_type: str, **kwargs): import inspect # Execute synchronous callback if registered - if display_type in sync_display_callbacks: - callback = sync_display_callbacks[display_type] + sync_callbacks = _get_sync_display_callbacks() + if display_type in sync_callbacks: + callback = sync_callbacks[display_type] sig = inspect.signature(callback) # Filter kwargs to what the callback accepts to maintain backward compatibility @@ -196,8 +244,9 @@ async def execute_callback(display_type: str, **kwargs): await loop.run_in_executor(None, lambda: callback(**supported_kwargs)) # Execute asynchronous callback if registered - if display_type in async_display_callbacks: - callback = async_display_callbacks[display_type] + async_callbacks = _get_async_display_callbacks() + if display_type in async_callbacks: + callback = async_callbacks[display_type] sig = inspect.signature(callback) # Filter kwargs to what the callback accepts to maintain backward compatibility @@ -418,7 +467,7 @@ def display_error(message: str, console=None): title="⚠ Error", border_style=PRAISON_COLORS["error"] )) - error_logs.append(message) + _add_error_log(message) def display_generating(content: str = "", start_time: Optional[float] = None): if not content or not str(content).strip(): @@ -615,7 +664,7 @@ async def adisplay_error(message: str, console=None): await execute_callback('error', message=message) console.print(Panel.fit(Text(message, style="bold red"), title="Error", border_style="red")) - error_logs.append(message) + _add_error_log(message) async def adisplay_generating(content: str = "", start_time: Optional[float] = None): """Async version of display_generating.""" diff --git a/src/praisonai-agents/test_thread_safety.py b/src/praisonai-agents/test_thread_safety.py new file mode 100644 index 000000000..364e96f7c --- /dev/null +++ b/src/praisonai-agents/test_thread_safety.py @@ -0,0 +1,69 @@ +#!/usr/bin/env python +""" +Test thread safety of main.py global state fixes. +""" +import threading +import time +import concurrent.futures + +def test_thread_safe_callback_registration(): + """Test that display callback registration is thread-safe.""" + from praisonaiagents.main import register_display_callback, _get_sync_display_callbacks, _get_async_display_callbacks + + def register_callbacks_concurrently(worker_id): + """Register both sync and async callbacks from multiple threads.""" + register_display_callback(f'sync_test_{worker_id}', lambda: f'sync_{worker_id}', is_async=False) + time.sleep(0.001) + register_display_callback(f'async_test_{worker_id}', lambda: f'async_{worker_id}', is_async=True) + return worker_id + + # Run 20 concurrent workers + with concurrent.futures.ThreadPoolExecutor(max_workers=20) as executor: + futures = [executor.submit(register_callbacks_concurrently, i) for i in range(20)] + concurrent.futures.wait(futures) + + # Verify all callbacks were registered successfully + sync_callbacks = _get_sync_display_callbacks() + async_callbacks = _get_async_display_callbacks() + + sync_count = len([k for k in sync_callbacks.keys() if k.startswith('sync_test_')]) + async_count = len([k for k in async_callbacks.keys() if k.startswith('async_test_')]) + + assert sync_count == 20, f"Expected 20 sync callbacks, got {sync_count}" + assert async_count == 20, f"Expected 20 async callbacks, got {async_count}" + print(f"✅ Thread-safe callback registration: {sync_count} sync + {async_count} async callbacks") + + +def test_thread_safe_error_logging(): + """Test that error logging is thread-safe.""" + from praisonaiagents.main import _add_error_log, _get_error_logs + + def log_errors_concurrently(worker_id): + """Log multiple errors from different threads.""" + for i in range(10): + _add_error_log(f"Thread-{worker_id}: Error message #{i}") + time.sleep(0.0001) + return worker_id + + # Run 15 concurrent workers + with concurrent.futures.ThreadPoolExecutor(max_workers=15) as executor: + futures = [executor.submit(log_errors_concurrently, i) for i in range(15)] + concurrent.futures.wait(futures) + + # Verify all errors were logged + all_errors = _get_error_logs() + thread_errors = [e for e in all_errors if e.startswith('Thread-')] + + assert len(thread_errors) == 150, f"Expected 150 error logs, got {len(thread_errors)}" + print(f"✅ Thread-safe error logging: {len(thread_errors)} errors from 15 threads") + + +if __name__ == '__main__': + print("Testing thread safety fixes for main.py global state...") + print("=" * 60) + + test_thread_safe_callback_registration() + test_thread_safe_error_logging() + + print("=" * 60) + print("🎉 All thread safety tests passed!") \ No newline at end of file