diff --git a/simple_test.py b/simple_test.py new file mode 100644 index 000000000..207ea279e --- /dev/null +++ b/simple_test.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python3 +""" +Simple test to verify thread safety fixes work correctly. +""" + +import sys +import os +import contextvars + +# Add the src directory to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents')) + +from praisonaiagents.main import error_logs, sync_display_callbacks + +def test_basic_context_isolation(): + """Test basic context isolation""" + print("Testing basic context isolation...") + + # Clear any existing state + error_logs.clear() + sync_display_callbacks.clear() + + print(f"Initial error_logs length: {len(error_logs)}") + + # Add an error in main context + error_logs.append("Main context error") + print(f"After adding main error: {len(error_logs)}") + + # Create a new context + ctx = contextvars.copy_context() + + def in_new_context(): + print(f"In new context, error_logs length: {len(error_logs)}") + error_logs.append("New context error") + print(f"After adding new context error: {len(error_logs)}") + return list(error_logs) + + result = ctx.run(in_new_context) + print(f"Result from new context: {result}") + print(f"Back in main context, error_logs length: {len(error_logs)}") + print(f"Main context errors: {list(error_logs)}") + +if __name__ == "__main__": + test_basic_context_isolation() \ No newline at end of file diff --git a/src/praisonai-agents/praisonaiagents/main.py b/src/praisonai-agents/praisonaiagents/main.py index cc8340cb1..754269467 100644 --- a/src/praisonai-agents/praisonaiagents/main.py +++ b/src/praisonai-agents/praisonaiagents/main.py @@ -2,6 +2,8 @@ import time import json import logging +import contextvars +import threading from typing import List, Optional, Dict, Any, Union, Literal, Type from pydantic import BaseModel, ConfigDict import asyncio @@ -23,15 +25,190 @@ def _rich(): # Logging is already configured in _logging.py via __init__.py -# Global list to store error logs -error_logs = [] +# Thread-safe context variables for multi-agent safety +# Each agent context gets isolated state via context variables +# Note: Default values should be immutable to avoid sharing state between contexts +_error_logs: contextvars.ContextVar[List[Any]] = contextvars.ContextVar('error_logs') +_sync_display_callbacks: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar('sync_display_callbacks') +_async_display_callbacks: contextvars.ContextVar[Dict[str, Any]] = contextvars.ContextVar('async_display_callbacks') +_approval_callback: contextvars.ContextVar[Optional[Any]] = contextvars.ContextVar('approval_callback') + +# Backward compatibility wrappers that maintain existing API +def _get_error_logs(): + try: + return _error_logs.get() + except LookupError: + # Initialize with empty list for this context + _error_logs.set([]) + return _error_logs.get() + +def _set_error_logs(logs): + _error_logs.set(logs) + +def _get_sync_display_callbacks(): + try: + return _sync_display_callbacks.get() + except LookupError: + # Initialize with empty dict for this context + _sync_display_callbacks.set({}) + return _sync_display_callbacks.get() + +def _set_sync_display_callbacks(callbacks): + _sync_display_callbacks.set(callbacks) + +def _get_async_display_callbacks(): + try: + return _async_display_callbacks.get() + except LookupError: + # Initialize with empty dict for this context + _async_display_callbacks.set({}) + return _async_display_callbacks.get() + +def _set_async_display_callbacks(callbacks): + _async_display_callbacks.set(callbacks) + +def _get_approval_callback(): + try: + return _approval_callback.get() + except LookupError: + # No default callback + return None + +def _set_approval_callback(callback): + _approval_callback.set(callback) -# Separate registries for sync and async callbacks -sync_display_callbacks = {} -async_display_callbacks = {} +# Thread-safe wrapper classes for backward compatibility +class _ContextList: + """Thread-safe list wrapper using contextvars""" + + def __init__(self, context_var): + self._context_var = context_var + + def _get_current(self): + """Get current list, initializing if needed""" + if self._context_var is _error_logs: + return _get_error_logs() + else: + try: + return self._context_var.get() + except LookupError: + self._context_var.set([]) + return self._context_var.get() + + def append(self, item): + current = self._get_current() + # Create a new list to avoid modifying shared state + new_list = current.copy() + new_list.append(item) + self._context_var.set(new_list) + + def extend(self, items): + current = self._get_current() + # Create a new list to avoid modifying shared state + new_list = current.copy() + new_list.extend(items) + self._context_var.set(new_list) + + def clear(self): + self._context_var.set([]) + + def __iter__(self): + return iter(self._get_current()) + + def __len__(self): + return len(self._get_current()) + + def __getitem__(self, index): + return self._get_current()[index] + + def __setitem__(self, index, value): + current = self._get_current() + # Create a new list to avoid modifying shared state + new_list = current.copy() + new_list[index] = value + self._context_var.set(new_list) + +class _ContextDict: + """Thread-safe dict wrapper using contextvars""" + + def __init__(self, context_var): + self._context_var = context_var + + def _get_current(self): + """Get current dict, initializing if needed""" + if self._context_var is _sync_display_callbacks: + return _get_sync_display_callbacks() + elif self._context_var is _async_display_callbacks: + return _get_async_display_callbacks() + else: + try: + return self._context_var.get() + except LookupError: + self._context_var.set({}) + return self._context_var.get() + + def __getitem__(self, key): + return self._get_current()[key] + + def __setitem__(self, key, value): + current = self._get_current() + # Create a new dict to avoid modifying shared state + new_dict = current.copy() + new_dict[key] = value + self._context_var.set(new_dict) + + def __contains__(self, key): + return key in self._get_current() + + def get(self, key, default=None): + return self._get_current().get(key, default) + + def keys(self): + return self._get_current().keys() + + def values(self): + return self._get_current().values() + + def items(self): + return self._get_current().items() + + def clear(self): + self._context_var.set({}) + +class _ContextVar: + """Thread-safe variable wrapper using contextvars""" + + def __init__(self, context_var): + self._context_var = context_var + + def _get_current(self): + """Get current value""" + if self._context_var is _approval_callback: + return _get_approval_callback() + else: + try: + return self._context_var.get() + except LookupError: + return None + + def __call__(self, *args, **kwargs): + # Allow calling the callback if it exists + callback = self._get_current() + if callback: + return callback(*args, **kwargs) + return None + + def __bool__(self): + return self._get_current() is not None + + def set(self, value): + self._context_var.set(value) -# Global approval callback registry -approval_callback = None +# For backward compatibility, expose wrapped globals that maintain thread safety +error_logs = _ContextList(_error_logs) +sync_display_callbacks = _ContextDict(_sync_display_callbacks) +async_display_callbacks = _ContextDict(_async_display_callbacks) +approval_callback = _ContextVar(_approval_callback) # ───────────────────────────────────────────────────────────────────────────── # PraisonAI Unique Color Palette: "Elegant Intelligence" @@ -139,8 +316,7 @@ def register_approval_callback(callback_fn): Args: callback_fn: Function that takes (function_name, arguments, risk_level) and returns ApprovalDecision """ - global approval_callback - approval_callback = callback_fn + approval_callback.set(callback_fn) # Simplified aliases (consistent naming convention) @@ -420,6 +596,7 @@ def display_error(message: str, console=None): title="⚠ Error", border_style=PRAISON_COLORS["error"] )) + # Use thread-safe context-aware error logging error_logs.append(message) def display_generating(content: str = "", start_time: Optional[float] = None): @@ -617,6 +794,7 @@ async def adisplay_error(message: str, console=None): await execute_callback('error', message=message) console.print(Panel.fit(Text(message, style="bold red"), title="Error", border_style="red")) + # Use thread-safe context-aware error logging error_logs.append(message) async def adisplay_generating(content: str = "", start_time: Optional[float] = None): diff --git a/test_thread_safety_main.py b/test_thread_safety_main.py new file mode 100644 index 000000000..aa337253c --- /dev/null +++ b/test_thread_safety_main.py @@ -0,0 +1,213 @@ +#!/usr/bin/env python3 +""" +Test script to verify thread safety fixes for main.py globals. +Tests concurrent access to error_logs, callbacks, and approval_callback. +""" + +import sys +import os +import threading +import time +import contextvars +from concurrent.futures import ThreadPoolExecutor, as_completed + +# Add the src directory to the path +sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents')) + +from praisonaiagents.main import ( + error_logs, + sync_display_callbacks, + async_display_callbacks, + approval_callback, + register_display_callback, + register_approval_callback +) + +def test_error_logs_thread_safety(): + """Test that error_logs are properly isolated per context""" + print("Testing error_logs thread safety...") + + results = [] + + def worker(worker_id): + """Each worker should see its own error logs in its context""" + # Create a new context for this worker + ctx = contextvars.copy_context() + + def in_context(): + # Add some errors specific to this worker + for i in range(5): + error_logs.append(f"Worker {worker_id} error {i}") + time.sleep(0.001) # Small delay to encourage race conditions + + # Check that only this worker's errors are visible + worker_errors = [log for log in error_logs if f"Worker {worker_id}" in log] + total_errors = len(error_logs) + + return { + 'worker_id': worker_id, + 'worker_errors': len(worker_errors), + 'total_errors': total_errors, + 'errors': list(error_logs) + } + + return ctx.run(in_context) + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=10) as executor: + futures = [executor.submit(worker, i) for i in range(10)] + results = [future.result() for future in as_completed(futures)] + + # Analyze results + print(f"Completed {len(results)} workers") + for result in results: + print(f"Worker {result['worker_id']}: {result['worker_errors']} own errors, {result['total_errors']} total") + # Each worker should only see its own 5 errors + assert result['worker_errors'] == 5, f"Worker {result['worker_id']} saw {result['worker_errors']} own errors, expected 5" + assert result['total_errors'] == 5, f"Worker {result['worker_id']} saw {result['total_errors']} total errors, expected 5" + + print("✅ error_logs thread safety test passed!") + +def test_callbacks_thread_safety(): + """Test that callbacks are properly isolated per context""" + print("Testing callbacks thread safety...") + + def test_callback(message=None, **kwargs): + return f"Processed: {message}" + + results = [] + + def worker(worker_id): + """Each worker should see its own callbacks in its context""" + ctx = contextvars.copy_context() + + def in_context(): + # Register a callback specific to this worker + callback_name = f"worker_{worker_id}_callback" + register_display_callback(callback_name, test_callback) + + # Check that only this worker's callback is visible + worker_callbacks = [k for k in sync_display_callbacks.keys() if f"worker_{worker_id}" in k] + all_callbacks = list(sync_display_callbacks.keys()) + + return { + 'worker_id': worker_id, + 'worker_callbacks': len(worker_callbacks), + 'total_callbacks': len(all_callbacks), + 'callback_names': all_callbacks + } + + return ctx.run(in_context) + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(worker, i) for i in range(5)] + results = [future.result() for future in as_completed(futures)] + + # Analyze results + for result in results: + print(f"Worker {result['worker_id']}: {result['worker_callbacks']} own callbacks, {result['total_callbacks']} total") + # Each worker should only see its own 1 callback + assert result['worker_callbacks'] == 1, f"Worker {result['worker_id']} saw {result['worker_callbacks']} own callbacks, expected 1" + assert result['total_callbacks'] == 1, f"Worker {result['worker_id']} saw {result['total_callbacks']} total callbacks, expected 1" + + print("✅ callbacks thread safety test passed!") + +def test_approval_callback_thread_safety(): + """Test that approval_callback is properly isolated per context""" + print("Testing approval_callback thread safety...") + + results = [] + + def worker(worker_id): + """Each worker should have its own approval callback in its context""" + ctx = contextvars.copy_context() + + def in_context(): + # Set an approval callback specific to this worker + def worker_approval_callback(func_name, args, risk_level): + return f"Worker {worker_id} approved {func_name}" + + register_approval_callback(worker_approval_callback) + + # Test the callback + if approval_callback: + result_msg = approval_callback("test_func", {}, "low") + has_own_callback = f"Worker {worker_id}" in result_msg + else: + has_own_callback = False + result_msg = "No callback" + + return { + 'worker_id': worker_id, + 'has_callback': bool(approval_callback), + 'has_own_callback': has_own_callback, + 'result': result_msg + } + + return ctx.run(in_context) + + # Run multiple workers concurrently + with ThreadPoolExecutor(max_workers=3) as executor: + futures = [executor.submit(worker, i) for i in range(3)] + results = [future.result() for future in as_completed(futures)] + + # Analyze results + for result in results: + print(f"Worker {result['worker_id']}: has_callback={result['has_callback']}, own_callback={result['has_own_callback']}") + print(f" Result: {result['result']}") + assert result['has_callback'], f"Worker {result['worker_id']} has no callback" + assert result['has_own_callback'], f"Worker {result['worker_id']} doesn't have its own callback" + + print("✅ approval_callback thread safety test passed!") + +def test_basic_functionality(): + """Test that the basic functionality still works after thread safety changes""" + print("Testing basic functionality...") + + # Test error logging + error_logs.clear() + error_logs.append("Test error 1") + error_logs.append("Test error 2") + assert len(error_logs) == 2 + assert error_logs[0] == "Test error 1" + assert error_logs[1] == "Test error 2" + + # Test sync callbacks + sync_display_callbacks.clear() + + def test_sync_callback(message=None): + return f"Sync: {message}" + + register_display_callback("test", test_sync_callback) + assert "test" in sync_display_callbacks + assert sync_display_callbacks["test"] == test_sync_callback + + # Test approval callback + def test_approval(func_name, args, risk_level): + return f"Approved {func_name}" + + register_approval_callback(test_approval) + assert approval_callback + result = approval_callback("test_func", {}, "low") + assert result == "Approved test_func" + + print("✅ Basic functionality test passed!") + +if __name__ == "__main__": + print("Starting thread safety tests for main.py globals...") + + try: + test_basic_functionality() + test_error_logs_thread_safety() + test_callbacks_thread_safety() + test_approval_callback_thread_safety() + + print("\n🎉 All thread safety tests passed!") + print("The main.py globals are now properly isolated per context and thread-safe.") + + except Exception as e: + print(f"\n❌ Test failed with error: {e}") + import traceback + traceback.print_exc() + sys.exit(1) \ No newline at end of file